Soft Dynamic Time Warping

Module Interface

class torchmetrics.timeseries.SoftDTW(distance_fn=None, gamma=1.0, reduction='mean', **kwargs)[source]

Compute the Soft Dynamic Time Warping (Soft-DTW) distance between two batched sequences.

This is a differentiable relaxation of the classic Dynamic Time Warping (DTW) algorithm, introduced by Marco Cuturi and Mathieu Blondel (2017). It replaces the hard minimum in DTW recursion with a soft-minimum using a log-sum-exp formulation:

\[\text{softmin}_\gamma(a,b,c) = -\gamma \log \left( e^{-a/\gamma} + e^{-b/\gamma} + e^{-c/\gamma} \right)\]

The Soft-DTW recurrence is then defined as:

\[R_{i,j} = D_{i,j} + \text{softmin}_\gamma(R_{i-1,j}, R_{i,j-1}, R_{i-1,j-1})\]

where \(D_{i,j}\) is the pairwise distance between sequence elements \(x_i\) and \(y_j\). It could be computed using any differentiable distance function, such as squared Euclidean distance or cosine distance.

The final Soft-DTW distance is \(R_{N,M}\).

Parameters:
  • gamma (float) – Smoothing parameter (\(\gamma > 0\)). Smaller values make the loss closer to standard DTW (hard minimum), while larger values produce a smoother and more differentiable surface.

  • distance_fn (Optional[Callable]) – Optional callable (x, y) -> [B, N, M] defining the pairwise distance matrix. If None, defaults to squared Euclidean distance.

  • reduction (Literal['sum', 'mean', 'none']) – indicates how to reduce over the batch dimension. Choose between [sum, mean, none].

Raises:
  • ValueError – If reduction is not one of [sum, mean, none].

  • ValueError – If gamma is not a positive float.

  • ValueError – If input tensors to update are not 3-dimensional with the same batch size and feature dimension.

Example

>>> from torch import randn
>>> from torchmetrics.timeseries import SoftDTW
>>> metric = SoftDTW(gamma=0.1)
>>> x = randn(10, 50, 2)
>>> y = randn(10, 60, 2)
>>> metric(x, y)
tensor(43.2051)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.timeseries import SoftDTW
>>> metric = SoftDTW()
>>> metric.update(torch.randn(10, 100, 2), torch.randn(10, 50, 2))
>>> fig_, ax_ = metric.plot()
../_images/softdtw-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.timeseries import SoftDTW
>>> metric = SoftDTW()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.randn(10, 100, 2), torch.randn(10, 50, 2)))
>>> fig_, ax_ = metric.plot(values)
../_images/softdtw-2.png

Functional Interface

torchmetrics.functional.timeseries.soft_dtw(preds, target, gamma=1.0, distance_fn=None, reduction='mean')[source]

Compute the Soft Dynamic Time Warping (Soft-DTW) distance between two batched sequences.

This is a differentiable relaxation of the classic Dynamic Time Warping (DTW) algorithm, introduced by Marco Cuturi and Mathieu Blondel (2017). It replaces the hard minimum in DTW recursion with a soft-minimum using a log-sum-exp formulation:

\[\text{softmin}_\gamma(a,b,c) = -\gamma \log \left( e^{-a/\gamma} + e^{-b/\gamma} + e^{-c/\gamma} \right)\]

The Soft-DTW recurrence is then defined as:

\[R_{i,j} = D_{i,j} + \text{softmin}_\gamma(R_{i-1,j}, R_{i,j-1}, R_{i-1,j-1})\]

where \(D_{i,j}\) is the pairwise distance between sequence elements \(x_i\) and \(y_j\). It could be computed using any differentiable distance function, such as squared Euclidean distance or cosine distance.

The final Soft-DTW distance is \(R_{N,M}\).

Parameters:
  • preds (Tensor) – Tensor of shape [B, N, D] — batch of input sequences.

  • target (Tensor) – Tensor of shape [B, M, D] — batch of target sequences.

  • gamma (float) – Smoothing parameter (\(\gamma > 0\)). Smaller values make the loss closer to standard DTW (hard minimum), while larger values produce a smoother and more differentiable surface.

  • distance_fn (Optional[Callable]) – Optional callable (x, y) -> [B, N, M] defining the pairwise distance matrix. If None, defaults to squared Euclidean distance.

  • reduction (Literal['sum', 'mean', 'none']) – indicates how to reduce over the batch dimension. Choose between [sum, mean, none]. Defaults to mean.

Return type:

Tensor

Returns:

A tensor of shape [B] containing the Soft-DTW distance for each sequence pair in the batch.

Raises:
  • ValueError – If reduction is not one of [sum, mean, none].

  • ValueError – If gamma is not a positive float.

  • ValueError – If input tensors to preds and target are not 3-dimensional with the same batch size and feature dimension.

Example::
>>> import torch
>>> from torchmetrics.functional.timeseries import soft_dtw
>>>
>>> x = torch.tensor([[[0.0], [1.0], [2.0]]])  # [B, N, D]
>>> y = torch.tensor([[[0.0], [2.0], [3.0]]])  # [B, M, D]
>>> soft_dtw(x, y, gamma=0.1)
tensor([0.4003])
Example (custom distance function)::
>>> def cosine_dist(a, b):
...     a = torch.nn.functional.normalize(a, dim=-1)
...     b = torch.nn.functional.normalize(b, dim=-1)
...     return 1 - torch.bmm(a, b.transpose(1, 2))
>>>
>>> x = torch.randn(2, 5, 3)
>>> y = torch.randn(2, 6, 3)
>>> soft_dtw(x, y, gamma=0.5, distance_fn=cosine_dist)
tensor([2.8301, 3.0128])