Skip to content

Commit

Permalink
Update model instantiators definitions test file in jax
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 7, 2024
1 parent 722dfa2 commit a7d800f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tests/jax/test_jax_model_instantiators_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_parse_input(capsys):
assert item not in output, f"'{item}' in '{output}'"
# Mixed operation
input = 'OBSERVATIONS["joint"] + concatenate([net * ACTIONS[:, -3:]])'
statement = 'inputs["states"]["joint"] + jnp.concatenate([net * inputs["taken_actions"][:, -3:]], axis=-1)'
statement = 'states["joint"] + jnp.concatenate([net * taken_actions[:, -3:]], axis=-1)'
output = _parse_input(str(input))
assert output.replace("'", '"') == statement, f"'{output}' != '{statement}'"

Expand Down

0 comments on commit a7d800f

Please sign in to comment.