Skip to content

Commit

Permalink
Fix export SAX model with polymorphic_seq_len_exclusion and multiple …
Browse files Browse the repository at this point in the history
…methods

PiperOrigin-RevId: 656488542
Change-Id: I331c1f33dc6e1021cc620434daf671622aabd590
  • Loading branch information
miaout17 authored and copybara-github committed Jul 26, 2024
1 parent 18757af commit 0198c25
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion saxml/server/pax/lm/servable_lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,17 @@ def model_fn_input_polymorphic_shape(self) -> pytypes.Nested[str]:
# Do not apply polymorphic seq len to extra inputs.
if self.default_extra_inputs:
polymorphic_seq_len_exclusion |= self.default_extra_inputs.keys()

if polymorphic_seq_len_exclusion:
for key in polymorphic_seq_len_exclusion:
shape_patterns[key] = f'{batch_pattern}, ...'
# Only override the shape pattern if the key already exists.
#
# In the upstream configs, `polymorphic_seq_len_exclusion` is usually
# set at model level instead of methods level. Without this filtering,
# a shape pattern might be added to a method that does not have the
# corresponding input.
if key in shape_patterns:
shape_patterns[key] = f'{batch_pattern}, ...'

return shape_patterns

Expand Down

0 comments on commit 0198c25

Please sign in to comment.