Skip to content

Commit

Permalink
Fixed test so it should pass on the CI now.
Browse files Browse the repository at this point in the history
  • Loading branch information
finbarrtimbers committed Nov 3, 2024
1 parent 3af8285 commit bef6ef2
Showing 1 changed file with 20 additions and 21 deletions.
41 changes: 20 additions & 21 deletions llamax/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# use_scaled_rope=True,
)

# This is the number of params in Llama 3.2 1B.
# This is the number of params in Llama 3.2 1B, from Huggingface.
NUM_WEIGHTS = 1_498_482_688
MAX_LENGTH = 32

Expand Down Expand Up @@ -91,10 +91,6 @@ def checkpoint_exists():
class IntegrationTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Without the checkpoint, we're not testing anything, so exit early.
if not checkpoint_exists():
return

cls.prompt = "Hello, world!"
cls.model = "meta-llama/Llama-3.1-8B"
cls.tokenizer = transformers.AutoTokenizer.from_pretrained(
Expand All @@ -119,26 +115,29 @@ def setUpClass(cls):
)

cls.torch_model = reference_model_torch.Transformer(cls.config)
checkpoint = torch.load(
CHECKPOINT_PATH,
map_location="cpu",
weights_only=True,
)

# Convert checkpoint to double precision
checkpoint = {k: v.double() for k, v in checkpoint.items()}
if checkpoint_exists():
checkpoint = torch.load(
CHECKPOINT_PATH,
map_location="cpu",
weights_only=True,
)

jax.tree.map(
lambda x, y: np.testing.assert_array_equal(x.shape, y.shape),
dict(cls.torch_model.state_dict()),
checkpoint,
)
cls.torch_model.load_state_dict(checkpoint)
cls.torch_model.double() # Ensure model is in double precision
# Convert checkpoint to double precision
checkpoint = {k: v.double() for k, v in checkpoint.items()}

jax.tree.map(
lambda x, y: np.testing.assert_array_equal(x.shape, y.shape),
dict(cls.torch_model.state_dict()),
checkpoint,
)
cls.torch_model.load_state_dict(checkpoint)

# Convert params to float64
cls.params = model.transformer_params_from_module(cls.torch_model)
# Ensure model is in double precision. Probably not needed?
cls.torch_model.double()

# Finally, create the jax model:
cls.params = model.transformer_params_from_module(cls.torch_model)
cls.flax_model = model.Transformer(cls.config)

def tearDown(self):
Expand Down

0 comments on commit bef6ef2

Please sign in to comment.