diff --git a/.gitignore b/.gitignore index 2fbe7c9..5b91735 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ build/* vendor/libtorch *.zip nimcache/* +*.pt diff --git a/flambeau/raw_bindings/serialize.nim b/flambeau/raw_bindings/serialize.nim index 41769d9..a2d6ff7 100644 --- a/flambeau/raw_bindings/serialize.nim +++ b/flambeau/raw_bindings/serialize.nim @@ -5,7 +5,10 @@ # * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). # at your option. This file may not be copied, modified, or distributed except according to those terms. -import ./tensors, ./neural_nets +import + ../cpp/std_cpp, + ./tensors, + ./neural_nets # (Almost) raw bindings to PyTorch serialization # ----------------------------------------------------------------------- @@ -31,4 +34,4 @@ import ./tensors, ./neural_nets # ####################################################################### # libtorch/include/torch/csrc/api/include/torch/optim/optimizer.h -proc save*(module: Module, path: cstring){.sideeffect, importcpp:"torch::save(@)".} +proc save*[T](module: CppSharedPtr[T], path: cstring){.sideeffect, importcpp:"torch::save(@)".} diff --git a/proof_of_concepts/poc09_end_to_end.nim b/proof_of_concepts/poc09_end_to_end.nim index 9615183..89d6b8d 100644 --- a/proof_of_concepts/poc09_end_to_end.nim +++ b/proof_of_concepts/poc09_end_to_end.nim @@ -5,24 +5,14 @@ import ../flambeau, std/[enumerate, strformat] -# Argh, need Linear{nullptr} in the codegen -# so we cheat by inlining C++ -# -# type Net {.pure.} = object of Module -# -# fc1: Linear -# fc2: Linear -# fc3: Linear +# Net is defined in poc_09_end_to_end_types.nim.hpp +# to work around https://github.com/nim-lang/Nim/issues/16664 +# which workarounds https://github.com/nim-lang/Nim/issues/4687 -{.emit:[""" -struct Net: public torch::nn::Module { - torch::nn::Linear fc1{nullptr}; - torch::nn::Linear fc2{nullptr}; - torch::nn::Linear fc3{nullptr}; -}; -"""].} - -type Net{.pure, importcpp.} = object of Module +type Net + {.pure, importcpp, + header:"poc09_end_to_end_types.nim.hpp".} + = object of Module fc1: Linear fc2: Linear fc3: Linear @@ -72,6 +62,6 @@ proc main() = if batch_index mod 100 == 0: echo &"Epoch: {epoch} | Batch: {batch_index} | Loss: {loss.item(float32)}" # Serialize your model periodically as a checkpoint. - net.save("net.pt") + save(net, "net.pt") main() diff --git a/proof_of_concepts/poc09_end_to_end_types.nim.hpp b/proof_of_concepts/poc09_end_to_end_types.nim.hpp new file mode 100644 index 0000000..6586198 --- /dev/null +++ b/proof_of_concepts/poc09_end_to_end_types.nim.hpp @@ -0,0 +1,22 @@ +// We need Linear{nullptr} in the codegen +// so we would like to cheat by inlining C++ +// +// type Net {.pure.} = object of Module +// +// fc1: Linear +// fc2: Linear +// fc3: Linear +// +// https://github.com/nim-lang/Nim/issues/4687 +// +// However +// due to https://github.com/nim-lang/Nim/issues/16664 +// it needs to be in its own file + +#include + +struct Net: public torch::nn::Module { + torch::nn::Linear fc1{nullptr}; + torch::nn::Linear fc2{nullptr}; + torch::nn::Linear fc3{nullptr}; +};