Soft Dynamic Time Warping¶
Module Interface¶
- class torchmetrics.timeseries.SoftDTW(distance_fn=None, gamma=1.0, reduction='mean', **kwargs)[source]¶
Compute the Soft Dynamic Time Warping (Soft-DTW) distance between two batched sequences.
This is a differentiable relaxation of the classic Dynamic Time Warping (DTW) algorithm, introduced by Marco Cuturi and Mathieu Blondel (2017). It replaces the hard minimum in DTW recursion with a soft-minimum using a log-sum-exp formulation:
\[\text{softmin}_\gamma(a,b,c) = -\gamma \log \left( e^{-a/\gamma} + e^{-b/\gamma} + e^{-c/\gamma} \right)\]The Soft-DTW recurrence is then defined as:
\[R_{i,j} = D_{i,j} + \text{softmin}_\gamma(R_{i-1,j}, R_{i,j-1}, R_{i-1,j-1})\]where \(D_{i,j}\) is the pairwise distance between sequence elements \(x_i\) and \(y_j\). It could be computed using any differentiable distance function, such as squared Euclidean distance or cosine distance.
The final Soft-DTW distance is \(R_{N,M}\).
- Parameters:
gamma¶ (
float
) – Smoothing parameter (\(\gamma > 0\)). Smaller values make the loss closer to standard DTW (hard minimum), while larger values produce a smoother and more differentiable surface.distance_fn¶ (
Optional
[Callable
]) – Optional callable(x, y) -> [B, N, M]
defining the pairwise distance matrix. IfNone
, defaults to squared Euclidean distance.reduction¶ (
Literal
['sum'
,'mean'
,'none'
]) – indicates how to reduce over the batch dimension. Choose between [sum
,mean
,none
].
- Raises:
ValueError – If
reduction
is not one of [sum
,mean
,none
].ValueError – If
gamma
is not a positive float.ValueError – If input tensors to
update
are not 3-dimensional with the same batch size and feature dimension.
Example
>>> from torch import randn >>> from torchmetrics.timeseries import SoftDTW >>> metric = SoftDTW(gamma=0.1) >>> x = randn(10, 50, 2) >>> y = randn(10, 60, 2) >>> metric(x, y) tensor(43.2051)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.timeseries import SoftDTW >>> metric = SoftDTW() >>> metric.update(torch.randn(10, 100, 2), torch.randn(10, 50, 2)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.timeseries import SoftDTW >>> metric = SoftDTW() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.randn(10, 100, 2), torch.randn(10, 50, 2))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.timeseries.soft_dtw(preds, target, gamma=1.0, distance_fn=None, reduction='mean')[source]¶
Compute the Soft Dynamic Time Warping (Soft-DTW) distance between two batched sequences.
This is a differentiable relaxation of the classic Dynamic Time Warping (DTW) algorithm, introduced by Marco Cuturi and Mathieu Blondel (2017). It replaces the hard minimum in DTW recursion with a soft-minimum using a log-sum-exp formulation:
\[\text{softmin}_\gamma(a,b,c) = -\gamma \log \left( e^{-a/\gamma} + e^{-b/\gamma} + e^{-c/\gamma} \right)\]The Soft-DTW recurrence is then defined as:
\[R_{i,j} = D_{i,j} + \text{softmin}_\gamma(R_{i-1,j}, R_{i,j-1}, R_{i-1,j-1})\]where \(D_{i,j}\) is the pairwise distance between sequence elements \(x_i\) and \(y_j\). It could be computed using any differentiable distance function, such as squared Euclidean distance or cosine distance.
The final Soft-DTW distance is \(R_{N,M}\).
- Parameters:
preds¶ (
Tensor
) – Tensor of shape[B, N, D]
— batch of input sequences.target¶ (
Tensor
) – Tensor of shape[B, M, D]
— batch of target sequences.gamma¶ (
float
) – Smoothing parameter (\(\gamma > 0\)). Smaller values make the loss closer to standard DTW (hard minimum), while larger values produce a smoother and more differentiable surface.distance_fn¶ (
Optional
[Callable
]) – Optional callable(x, y) -> [B, N, M]
defining the pairwise distance matrix. IfNone
, defaults to squared Euclidean distance.reduction¶ (
Literal
['sum'
,'mean'
,'none'
]) – indicates how to reduce over the batch dimension. Choose between [sum
,mean
,none
]. Defaults tomean
.
- Return type:
- Returns:
A tensor of shape
[B]
containing the Soft-DTW distance for each sequence pair in the batch.- Raises:
ValueError – If
reduction
is not one of [sum
,mean
,none
].ValueError – If
gamma
is not a positive float.ValueError – If input tensors to
preds
andtarget
are not 3-dimensional with the same batch size and feature dimension.
- Example::
>>> import torch >>> from torchmetrics.functional.timeseries import soft_dtw >>> >>> x = torch.tensor([[[0.0], [1.0], [2.0]]]) # [B, N, D] >>> y = torch.tensor([[[0.0], [2.0], [3.0]]]) # [B, M, D] >>> soft_dtw(x, y, gamma=0.1) tensor([0.4003])
- Example (custom distance function)::
>>> def cosine_dist(a, b): ... a = torch.nn.functional.normalize(a, dim=-1) ... b = torch.nn.functional.normalize(b, dim=-1) ... return 1 - torch.bmm(a, b.transpose(1, 2)) >>> >>> x = torch.randn(2, 5, 3) >>> y = torch.randn(2, 6, 3) >>> soft_dtw(x, y, gamma=0.5, distance_fn=cosine_dist) tensor([2.8301, 3.0128])