Skip to content

Commit

Permalink
Improve error message for too long reactions
Browse files Browse the repository at this point in the history
  • Loading branch information
avaucher committed Sep 19, 2024
1 parent 89df26c commit 01451a8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
8 changes: 8 additions & 0 deletions rxnmapper/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ def convert_batch_to_attns(
return_tensors="pt",
)
parsed_input = {k: v.to(self.device) for k, v in encoded_ids.items()}

max_input_length = parsed_input["input_ids"].shape[1]
max_supported_by_model = self.model.config.max_position_embeddings
if max_input_length > max_supported_by_model:
raise ValueError(
f"Reaction SMILES has {max_input_length} tokens, should be at most {max_supported_by_model}."
)

with torch.no_grad():
output = self.model(**parsed_input)
attentions = output[2]
Expand Down
12 changes: 12 additions & 0 deletions tests/test_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,15 @@ def test_reaction_with_asterisks(rxn_mapper: RXNMapper):

results = rxn_mapper.get_attention_guided_atom_maps(rxns, canonicalize_rxns=False)
assert_correct_maps(results, expected)


def test_too_long_reaction_smiles_produce_exception_with_understandable_error_message(
rxn_mapper: RXNMapper,
):
# dummy reaction with 1 + 3 + 500 * 2 + 3 + 1 = 1008 tokens
rxn = "C=C" + "[C+][C-]" * 500 + ">>CC"

with pytest.raises(ValueError) as excinfo:
_ = rxn_mapper.get_attention_guided_atom_maps([rxn], canonicalize_rxns=False)

assert "Reaction SMILES has 1008 tokens, should be" in str(excinfo.value)

0 comments on commit 01451a8

Please sign in to comment.