Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLX version #1

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

MLX version #1

wants to merge 3 commits into from

Conversation

lukasugar
Copy link

@lukasugar lukasugar commented Jul 18, 2024

MLX port

With small changes it's possible to port the pytorch version to MLX.

Notes:

  • Adam optimizer in MLX doesn't support weight_decay (ref), so I've used AdamW. We'll probably want to use AdamW in pytorch as well?
  • Most of the boilerplate code is copied from the pytorch file. It can be moved to a separate file, but I didn't want to change the original pytorch code.

Potential follow up work:

  • use random = RNG(1337) to somehow initialize the weights
  • optimize parts with @mx.compile
  • plot the losses (I've just copied parts of the training logs)

Installation

Need to install mlx package (works on Apple silicon only):
pip install mlx.

Results

I've gotten pretty similar results compared to the pytorch version.

Pytorch loss:

step 0 | train_loss 3.3022 | val_loss 3.3016 | lr 1.000000e-03
step 200 | train_loss 2.3989 | val_loss 2.4304 | lr 9.999605e-04
step 400 | train_loss 2.3376 | val_loss 2.3581 | lr 9.998421e-04
step 600 | train_loss 2.3081 | val_loss 2.3243 | lr 9.996447e-04
step 800 | train_loss 2.2954 | val_loss 2.2981 | lr 9.993685e-04
step 1000 | train_loss 2.2958 | val_loss 2.2891 | lr 9.990134e-04
...
step 49000 | train_loss 1.9856 | val_loss 2.0581 | lr 9.866358e-07
step 49200 | train_loss 1.9856 | val_loss 2.0581 | lr 6.315217e-07
step 49400 | train_loss 1.9856 | val_loss 2.0581 | lr 3.552637e-07
step 49600 | train_loss 1.9856 | val_loss 2.0581 | lr 1.579054e-07
step 49800 | train_loss 1.9856 | val_loss 2.0581 | lr 3.947790e-08
step 49999 | train_loss 1.9856 | val_loss 2.0581 | lr 9.869605e-13

MLX:

step 0 | train_loss 3.3019 | val_loss 3.3007 | lr 1.000000e-03
step 200 | train_loss 2.3795 | val_loss 2.3943 | lr 9.999605e-04
step 400 | train_loss 2.3325 | val_loss 2.3581 | lr 9.998421e-04
step 600 | train_loss 2.3142 | val_loss 2.3354 | lr 9.996447e-04
step 800 | train_loss 2.3076 | val_loss 2.3212 | lr 9.993685e-04
step 1000 | train_loss 2.3113 | val_loss 2.3150 | lr 9.990134e-04
...
step 49000 | train_loss 2.0180 | val_loss 2.0747 | lr 9.866358e-07
step 49200 | train_loss 2.0180 | val_loss 2.0747 | lr 6.315217e-07
step 49400 | train_loss 2.0180 | val_loss 2.0747 | lr 3.552637e-07
step 49600 | train_loss 2.0180 | val_loss 2.0747 | lr 1.579054e-07
step 49800 | train_loss 2.0180 | val_loss 2.0747 | lr 3.947790e-08
step 49999 | train_loss 2.0180 | val_loss 2.0747 | lr 9.869605e-13

@lukasugar
Copy link
Author

Random seed is used for model initialization. Now runs should produce the exact same results on every run:
image

(Ran for only 1000 steps each time, just to show that the results are the same).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant