diff --git a/RELEASENOTES.md b/RELEASENOTES.md index 3e9c01a83..5c4ddba66 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -1,6 +1,12 @@ ## TorchSharp Release Notes Releases, starting with 9/2/2021, are listed with the most recent release at the top. +# NuGet Version 0.105.2 + +__API Changes__: + +Fix torch.jit.ScriptModule.zero_grad.
+ # NuGet Version 0.105.1 __Bug Fixes__: diff --git a/src/Native/LibTorchSharp/THSJIT.cpp b/src/Native/LibTorchSharp/THSJIT.cpp index a0a4a5d0c..8a40aa75b 100644 --- a/src/Native/LibTorchSharp/THSJIT.cpp +++ b/src/Native/LibTorchSharp/THSJIT.cpp @@ -68,6 +68,23 @@ int THSJIT_Module_is_training(JITModule module) return (*module)->is_training(); } +void THSJIT_Module_zero_grad(const JITModule module, bool set_to_none) +{ + // According to https://github.com/pytorch/pytorch/issues/27144, + // torch::jit::Module has no zero_grad(). + // As a workaround, manually loop over the parameters and zero them out like optimizer does; + // https://github.com/pytorch/pytorch/blob/v2.5.1/torch/csrc/api/src/optim/optimizer.cpp#L123 + for (const auto& p : (*module)->parameters()) { + if (p.mutable_grad().defined()) { + p.mutable_grad().detach_(); + if (set_to_none) + p.mutable_grad().reset(); + else + p.mutable_grad().zero_(); + } + } +} + void THSJIT_Module_train(JITModule module, bool on) { (*module)->train(on); diff --git a/src/Native/LibTorchSharp/THSJIT.h b/src/Native/LibTorchSharp/THSJIT.h index 81e6d51ad..25d7cea32 100644 --- a/src/Native/LibTorchSharp/THSJIT.h +++ b/src/Native/LibTorchSharp/THSJIT.h @@ -44,6 +44,7 @@ EXPORT_API(void) THSJIT_Module_invoke(const JITModule module, const char* name, EXPORT_API(void) THSJIT_CompilationUnit_Invoke(const JITCompilationUnit module, const char* method, const TensorOrScalar* tensorPtrs, const int length, TensorOrScalar* (*allocator)(int32_t idx, size_t length), int8_t* typeCode, int32_t idx); EXPORT_API(int) THSJIT_Module_is_training(JITModule module); +EXPORT_API(void) THSJIT_Module_zero_grad(const JITModule module, bool set_to_none); EXPORT_API(void) THSJIT_Module_train(JITModule module, bool on); EXPORT_API(void) THSJIT_Module_eval(JITModule module); diff --git a/src/TorchSharp/JIT/ScriptModule.cs b/src/TorchSharp/JIT/ScriptModule.cs index 14e5d4773..7166febeb 100644 --- a/src/TorchSharp/JIT/ScriptModule.cs +++ b/src/TorchSharp/JIT/ScriptModule.cs @@ -143,6 +143,23 @@ public override bool training { } } + public override void zero_grad(bool set_to_none = true) + { + THSJIT_Module_zero_grad(handle, set_to_none); + CheckForErrors(); + + foreach (var (_, p) in named_parameters()) { + using var grad = p.grad; + if (grad is not null) { + if (set_to_none) { + p.grad = null; + } else { + grad.zero_(); + } + } + } + } + protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking) { if (device.type != DeviceType.CUDA) { device = new Device(device.type, -1); }; diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs index 074fcc247..4cdc25e82 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs @@ -57,6 +57,9 @@ internal static partial class NativeMethods [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSJIT_Module_is_training(torch.nn.Module.HType module); + [DllImport("LibTorchSharp")] + internal static extern void THSJIT_Module_zero_grad(torch.nn.Module.HType module, [MarshalAs(UnmanagedType.U1)] bool set_to_none); + [DllImport("LibTorchSharp")] internal static extern void THSJIT_Module_to_device(torch.nn.Module.HType module, long deviceType, long deviceIndex);