Does Zuko allow exporting to ONNX? #45
-
Hello again, Is it possible to export Zuko flows to onnx? If possible, do you have any example? If not promptly possible, you have any ideia of much effort would it take? I would be interesting in trying that out. Best, |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 13 replies
-
Hello @CaioDaumann, I have never tried, but I think it should be possible. Looking at https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html, you will probably need to wrap the flow as a "pure function", that is something that takes tensors as input and returns tensors as output. This is not the case of the Also, I don't know how ONNX handles randomness, but sampling from the flow requires first sampling from the base distribution, and then transforming it. It might be easier to only export the transform to ONNX. In any case, if you succeed, consider contributing a tutorial to the repo! |
Beta Was this translation helpful? Give feedback.
-
Hi @francois-rozet , @dpbigler , Getting back to this, I spend some time trying to export a zuko model to onnx and I came up with something like this: import torch
import torch.utils.data as data
import zuko
import numpy as np
import onnxruntime as ort
class WrappedNSF(torch.nn.Module):
def __init__(self):
super(WrappedNSF, self).__init__()
self.flow = zuko.flows.NSF(features=2, transforms=3, hidden_features=(64, 64))
def forward(self, x):
result = self.flow().transform(x)
return result
def two_moons(n: int, sigma: float = 1e-1):
theta = 2 * torch.pi * torch.rand(n)
label = (theta > torch.pi).float()
x = torch.stack((
torch.cos(theta) + label - 1 / 2,
torch.sin(theta) + label / 2 - 1 / 4,
), axis=-1)
return torch.normal(x, sigma), label
samples, labels = two_moons(16384)
samples_tensor = samples.clone().detach()
trainset = data.TensorDataset(samples, labels)
trainloader = data.DataLoader(trainset, batch_size=64, shuffle=True)
model = WrappedNSF()
model.eval()
dummy_input = torch.randn(1, 2)
output = model(dummy_input)
print("Sample output:", output)
try:
torch.onnx.export(model, # Wrapped model instance
dummy_input, # Model input (or a tuple for multiple inputs)
"wrapped_flow_model.onnx", # Output ONNX file path
export_params=True, # Store the trained parameter weights inside the model file
opset_version=17, # ONNX version to export the model to
do_constant_folding=True, # Optimization: constant folding
input_names=['input'], # Model's input names
output_names=['output'], # Model's output names
dynamic_axes={'input': {0: 'batch_size'}, # Variable length axes
'output': {0: 'batch_size'}})
print("Model exported successfully.")
except Exception as e:
print("Failed to export model:", str(e)) But this returns the following error:
Any ideias here or should I open an issue in PyTorch as the error message suggests? |
Beta Was this translation helpful? Give feedback.
-
Hi @francois-rozet , Coming back to this, I implemented a custom-made search-sorted function that is ONNX-friendly, and now it can convert the "custom" NSF model to ONNX. The current implementation is as follows: import torch
from torch import nn, Tensor, LongTensor
import math
from math import pi
from torch.distributions import Transform
from torch.distributions import constraints
from typing import Any, Callable, Dict, Iterable, List, Sequence, Tuple, Union
import torch.nn.functional as F
from zuko.flows import MAF
def broadcast(*tensors: Tensor, ignore: Union[int, Sequence[int]] = 0) -> List[Tensor]:
r"""Broadcasts tensors together.
The term broadcasting describes how PyTorch treats tensors with different shapes
during arithmetic operations. In short, if possible, dimensions that have
different sizes are expanded (without making copies) to be compatible.
Arguments:
tensors: The tensors to broadcast.
ignore: The number(s) of dimensions not to broadcast.
Returns:
The broadcasted tensors.
Example:
>>> x = torch.rand(3, 1, 2)
>>> y = torch.rand(4, 5)
>>> x, y = broadcast(x, y, ignore=1)
>>> x.shape
torch.Size([3, 4, 2])
>>> y.shape
torch.Size([3, 4, 5])
"""
if isinstance(ignore, int):
ignore = [ignore] * len(tensors)
dims = [t.dim() - i for t, i in zip(tensors, ignore)]
common = torch.broadcast_shapes(*(t.shape[:i] for t, i in zip(tensors, dims)))
return [torch.broadcast_to(t, common + t.shape[i:]) for t, i in zip(tensors, dims)]
def onnx_friendly_searchsorted(seq: Tensor, value: Tensor) -> LongTensor:
"""
Custom implementation to replace torch.searchsorted, which is not onnx "friendly" (torch.searchsorted(seq, value).squeeze(dim=-1))
Compatible with ONNX and reproduces the exact behavior and output shapes.
Args:
seq (Tensor): Sorted tensor of shape (..., S).
value (Tensor): Tensor containing values to insert of shape (..., 1).
Returns:
LongTensor: Indices of shape (...), matching torch.searchsorted(seq, value).squeeze(dim=-1).
"""
# Ensure tensors are contiguous
seq = seq.contiguous()
value = value.contiguous()
# Use torch.sum to count the number of elements in seq less than value
# The comparison seq < value results in a boolean tensor of shape (..., S)
# Summing over the last dimension (-1) gives indices of shape (...)
indices = torch.sum(seq < value, dim=-1)
return indices
class MonotonicRQSTransform_(Transform):
r"""Creates a monotonic rational-quadratic spline (RQS) transformation.
References:
| Neural Spline Flows (Durkan et al., 2019)
| https://arxiv.org/abs/1906.04032
Arguments:
widths: The unconstrained bin widths, with shape :math:`(*, K)`.
heights: The unconstrained bin heights, with shape :math:`(*, K)`.
derivatives: The unconstrained knot derivatives, with shape :math:`(*, K - 1)`.
bound: The spline's (co)domain bound :math:`B`.
slope: The minimum slope of the transformation.
"""
domain = constraints.real
codomain = constraints.real
bijective = True
sign = +1
def __init__(
self,
widths: Tensor,
heights: Tensor,
derivatives: Tensor,
bound: float = 5.0,
slope: float = 1e-4,
**kwargs,
):
super().__init__(**kwargs)
widths = widths / (1 + abs(2 * widths / math.log(slope)))
heights = heights / (1 + abs(2 * heights / math.log(slope)))
derivatives = derivatives / (1 + abs(derivatives / math.log(slope)))
widths = F.pad(F.softmax(widths, dim=-1), (1, 0), value=0)
heights = F.pad(F.softmax(heights, dim=-1), (1, 0), value=0)
derivatives = F.pad(derivatives, (1, 1), value=0)
self.horizontal = bound * (2 * torch.cumsum(widths, dim=-1) - 1)
self.vertical = bound * (2 * torch.cumsum(heights, dim=-1) - 1)
self.derivatives = torch.exp(derivatives)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(bins={self.bins})"
@property
def bins(self) -> int:
return self.horizontal.shape[-1] - 1
def bin(self, k: LongTensor) -> Tuple[Tensor, ...]:
mask = torch.logical_and(0 <= k, k < self.bins)
k = k % self.bins
k0_k1 = torch.stack((k, k + 1))
k0_k1, hs, vs, ds = broadcast(
k0_k1[..., None],
self.horizontal,
self.vertical,
self.derivatives,
ignore=1,
)
x0, x1 = hs.gather(-1, k0_k1).squeeze(dim=-1)
y0, y1 = vs.gather(-1, k0_k1).squeeze(dim=-1)
d0, d1 = ds.gather(-1, k0_k1).squeeze(dim=-1)
s = (y1 - y0) / (x1 - x0)
return mask, x0, x1, y0, y1, d0, d1, s
@staticmethod
def searchsorted(seq: Tensor, value: Tensor) -> LongTensor:
seq, value = broadcast(seq, value.unsqueeze(dim=-1), ignore=1)
seq = seq.contiguous()
# uses a non torch implementation of search sorted that enables export to onnx
return onnx_friendly_searchsorted(seq, value)
def _call(self, x: Tensor) -> Tensor:
k = self.searchsorted(self.horizontal, x) - 1
mask, x0, x1, y0, y1, d0, d1, s = self.bin(k)
z = mask * (x - x0) / (x1 - x0)
y = y0 + (y1 - y0) * (s * z**2 + d0 * z * (1 - z)) / (s + (d0 + d1 - 2 * s) * z * (1 - z))
return torch.where(mask, y, x)
def _inverse(self, y: Tensor) -> Tensor:
k = self.searchsorted(self.vertical, y) - 1
mask, x0, x1, y0, y1, d0, d1, s = self.bin(k)
y_ = mask * (y - y0)
a = (y1 - y0) * (s - d0) + y_ * (d0 + d1 - 2 * s)
b = (y1 - y0) * d0 - y_ * (d0 + d1 - 2 * s)
c = -s * y_
z = 2 * c / (-b - (b**2 - 4 * a * c).sqrt())
x = x0 + z * (x1 - x0)
return torch.where(mask, x, y)
def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor:
_, ladj = self.call_and_ladj(x)
return ladj
def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]:
k = self.searchsorted(self.horizontal, x) - 1
mask, x0, x1, y0, y1, d0, d1, s = self.bin(k)
z = mask * (x - x0) / (x1 - x0)
y = y0 + (y1 - y0) * (s * z**2 + d0 * z * (1 - z)) / (s + (d0 + d1 - 2 * s) * z * (1 - z))
jacobian = (
s**2
* (2 * s * z * (1 - z) + d0 * (1 - z) ** 2 + d1 * z**2)
/ (s + (d0 + d1 - 2 * s) * z * (1 - z)) ** 2
)
return torch.where(mask, y, x), mask * jacobian.log()
class NSF_(MAF):
r"""Creates a neural spline flow (NSF) with monotonic rational-quadratic spline
transformations.
By default, transformations are fully autoregressive. Coupling transformations
can be obtained by setting :py:`passes=2`.
Warning:
Spline transformations are defined over the domain :math:`[-5, 5]`. Any feature
outside of this domain is not transformed. It is recommended to standardize
features (zero mean, unit variance) before training.
See also:
:class:`zuko.transforms.MonotonicRQSTransform`
References:
| Neural Spline Flows (Durkan et al., 2019)
| https://arxiv.org/abs/1906.04032
Arguments:
features: The number of features.
context: The number of context features.
bins: The number of bins :math:`K`.
kwargs: Keyword arguments passed to :class:`zuko.flows.autoregressive.MAF`.
"""
def __init__(
self,
features: int,
context: int = 0,
bins: int = 8,
**kwargs,
):
super().__init__(
features=features,
context=context,
univariate=MonotonicRQSTransform_,
shapes=[(bins,), (bins,), (bins - 1,)],
**kwargs,
) And I wrapped the model as I did before, and it works. Here are some performance comparisons between the Zuko model and the "custom" ONNX-friendly NSF: Can I have your opinion on this? I can happily produce more validations/tests and write a more detailed tutorial about exporting it to ONNX. Let me know if I can help with this. |
Beta Was this translation helpful? Give feedback.
Hello @CaioDaumann,
I have never tried, but I think it should be possible. Looking at https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html, you will probably need to wrap the flow as a "pure function", that is something that takes tensors as input and returns tensors as output. This is not the case of the$c$ and $x$ as input and returns $\log p(x | c)$ is probably enough.
Flow
objects which take a tensor as input and returns aDistribution
. A very thin wrapper module that takes bothAlso, I don't know how ONNX handles randomness, but sampling from the flow requires first sampling from the base distribution, and then transforming it. It might be easier to …