Soft Dynamic Time Warping¶
Module Interface¶
- class torchmetrics.timeseries.SoftDTW(distance_fn=None, gamma=1.0, reduction='mean', **kwargs)[source]¶
Compute the Soft Dynamic Time Warping (Soft-DTW) distance between two batched sequences.
This is a differentiable relaxation of the classic Dynamic Time Warping (DTW) algorithm, introduced by Marco Cuturi and Mathieu Blondel (2017). It replaces the hard minimum in DTW recursion with a soft-minimum using a log-sum-exp formulation:
\[\text{softmin}_\gamma(a,b,c) = -\gamma \log \left( e^{-a/\gamma} + e^{-b/\gamma} + e^{-c/\gamma} \right)\]The Soft-DTW recurrence is then defined as:
\[R_{i,j} = D_{i,j} + \text{softmin}_\gamma(R_{i-1,j}, R_{i,j-1}, R_{i-1,j-1})\]where \(D_{i,j}\) is the pairwise distance between sequence elements \(x_i\) and \(y_j\). It could be computed using any differentiable distance function, such as squared Euclidean distance or cosine distance.
The final Soft-DTW distance is \(R_{N,M}\).
- Parameters:
gamma¶ (
float) – Smoothing parameter (\(\gamma > 0\)). Smaller values make the loss closer to standard DTW (hard minimum), while larger values produce a smoother and more differentiable surface.distance_fn¶ (
Optional[Callable]) – Optional callable(x, y) -> [B, N, M]defining the pairwise distance matrix. IfNone, defaults to squared Euclidean distance.reduction¶ (
Literal['sum','mean','none']) – indicates how to reduce over the batch dimension. Choose between [sum,mean,none].
- Raises:
ValueError – If
reductionis not one of [sum,mean,none].ValueError – If
gammais not a positive float.ValueError – If input tensors to
updateare not 3-dimensional with the same batch size and feature dimension.
Example
>>> from torch import randn >>> from torchmetrics.timeseries import SoftDTW >>> metric = SoftDTW(gamma=0.1) >>> x = randn(10, 50, 2) >>> y = randn(10, 60, 2) >>> metric(x, y) tensor(110.8406)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union[Tensor,Sequence[Tensor],None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.timeseries import SoftDTW >>> metric = SoftDTW() >>> metric.update(torch.randn(10, 100, 2), torch.randn(10, 50, 2)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.timeseries import SoftDTW >>> metric = SoftDTW() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.randn(10, 100, 2), torch.randn(10, 50, 2))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.timeseries.soft_dtw(preds, target, gamma=1.0, distance_fn=None, reduction='mean')[source]¶
Compute the Soft Dynamic Time Warping (Soft-DTW) distance between two batched sequences.
This is a differentiable relaxation of the classic Dynamic Time Warping (DTW) algorithm, introduced by Marco Cuturi and Mathieu Blondel (2017). It replaces the hard minimum in DTW recursion with a soft-minimum using a log-sum-exp formulation:
\[\text{softmin}_\gamma(a,b,c) = -\gamma \log \left( e^{-a/\gamma} + e^{-b/\gamma} + e^{-c/\gamma} \right)\]The Soft-DTW recurrence is then defined as:
\[R_{i,j} = D_{i,j} + \text{softmin}_\gamma(R_{i-1,j}, R_{i,j-1}, R_{i-1,j-1})\]where \(D_{i,j}\) is the pairwise distance between sequence elements \(x_i\) and \(y_j\). It could be computed using any differentiable distance function, such as squared Euclidean distance or cosine distance.
The final Soft-DTW distance is \(R_{N,M}\).
- Parameters:
preds¶ (
Tensor) – Tensor of shape[B, N, D]— batch of input sequences.target¶ (
Tensor) – Tensor of shape[B, M, D]— batch of target sequences.gamma¶ (
float) – Smoothing parameter (\(\gamma > 0\)). Smaller values make the loss closer to standard DTW (hard minimum), while larger values produce a smoother and more differentiable surface.distance_fn¶ (
Optional[Callable]) – Optional callable(x, y) -> [B, N, M]defining the pairwise distance matrix. IfNone, defaults to squared Euclidean distance.reduction¶ (
Literal['sum','mean','none']) – indicates how to reduce over the batch dimension. Choose between [sum,mean,none]. Defaults tomean.
- Return type:
- Returns:
A tensor of shape
[B]containing the Soft-DTW distance for each sequence pair in the batch.- Raises:
ValueError – If
reductionis not one of [sum,mean,none].ValueError – If
gammais not a positive float.ValueError – If input tensors to
predsandtargetare not 3-dimensional with the same batch size and feature dimension.
- Example::
>>> import torch >>> from torchmetrics.functional.timeseries import soft_dtw >>> >>> x = torch.tensor([[[0.0], [1.0], [2.0]]]) # [B, N, D] >>> y = torch.tensor([[[0.0], [2.0], [3.0]]]) # [B, M, D] >>> soft_dtw(x, y, gamma=0.1) tensor(1.8901)
- Example (custom distance function)::
>>> def cosine_dist(a, b): ... a = torch.nn.functional.normalize(a, dim=-1) ... b = torch.nn.functional.normalize(b, dim=-1) ... return 1 - torch.bmm(a, b.transpose(1, 2)) >>> >>> x = torch.randn(2, 5, 3) >>> y = torch.randn(2, 6, 3) >>> soft_dtw(x, y, gamma=0.5, distance_fn=cosine_dist) tensor(3.3724)