Skip to content

Demba3/micrograd-rs

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

79 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Building core Deep Learning algorithms in Rust.

It's kinda like the middle child of karpathy/micrograd and geohot/tinygrad.


Contributing

Any type of contribution is welcome as long as it adds value! i.e

  • Bug fixes followed with tests to ensure the bug never resurfaces
  • Increasing code readability, or run-time/memory efficiency
  • Completing a To-Do task

To-Do

We need the convert_state_dict() function to convert PyTorch tensors to lists because micrograd_rs can't unpickle PyTorch tensors. This conversion will allow micrograd_rs to load pytorch models without any issues.

# we need this import to serialize the model in a compatible format 
import pickle

# changes PyTorch generated state dict to micrograd state dict
def convert_state_dict(state_dict):
    new_state_dict = {}
    for name, tensor in state_dict.items():
        new_state_dict[name] = tensor.float().flatten().tolist()
    return new_state_dict

new_state_dict = convert_state_dict(model.state_dict())

# stores new state dict
with open(path, "wb") as f:
    pickle.dump(new_state_dict, f)
// To load model in rust
model.load_state_dict(path);

About

Micrograd but in Rust

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Rust 100.0%