Skip to content

Commit

Permalink
make style changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ra9hur committed Nov 20, 2024
1 parent be0621d commit 0696597
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 19 deletions.
17 changes: 7 additions & 10 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
BloomDummyPastKeyValuesGenerator,
DummyAudioInputGenerator,
DummyCodegenDecoderTextInputGenerator,
DummyDecoderTextInputGenerator,
DummyDecisionTransformerInputGenerator,
DummyDecoderTextInputGenerator,
DummyEncodecInputGenerator,
DummyFluxTransformerTextInputGenerator,
DummyFluxTransformerVisionInputGenerator,
Expand Down Expand Up @@ -265,19 +265,16 @@ class ImageGPTOnnxConfig(GPT2OnnxConfig):


class DecisionTransformerOnnxConfig(GPT2OnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyDecisionTransformerInputGenerator,
)
DUMMY_INPUT_GENERATOR_CLASSES = (DummyDecisionTransformerInputGenerator,)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:

return {
'timesteps': {0: 'batch_size', 1: 'sequence_length'},
'returns_to_go': {0: 'batch_size', 1: 'sequence_length'},
'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
'actions': {0: 'batch_size', 1: 'sequence_length', 2: 'act_dim'},
'states': {0: 'batch_size', 1: 'sequence_length', 2: 'state_dim'},
"timesteps": {0: "batch_size", 1: "sequence_length"},
"returns_to_go": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"actions": {0: "batch_size", 1: "sequence_length", 2: "act_dim"},
"states": {0: "batch_size", 1: "sequence_length", 2: "state_dim"},
}


Expand Down
4 changes: 3 additions & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ class TasksManager:
"multiple-choice": "AutoModelForMultipleChoice",
"object-detection": "AutoModelForObjectDetection",
"question-answering": "AutoModelForQuestionAnswering",
"reinforcement-learning": ("AutoModel",), # multiple auto model families can be used for reinforcement-learning
"reinforcement-learning": (
"AutoModel",
), # multiple auto model families can be used for reinforcement-learning
"semantic-segmentation": "AutoModelForSemanticSegmentation",
"text-to-audio": ("AutoModelForTextToSpectrogram", "AutoModelForTextToWaveform"),
"text-generation": "AutoModelForCausalLM",
Expand Down
16 changes: 8 additions & 8 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,11 +513,11 @@ class DummyDecisionTransformerInputGenerator(DummyTextInputGenerator):
"""

SUPPORTED_INPUT_NAMES = (
'actions',
'timesteps',
'attention_mask',
'returns_to_go',
'states',
"actions",
"timesteps",
"attention_mask",
"returns_to_go",
"states",
)

def __init__(self, *args, **kwargs):
Expand All @@ -531,15 +531,15 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
shape = [self.batch_size, self.sequence_length, self.state_dim]
elif input_name == "actions":
shape = [self.batch_size, self.sequence_length, self.act_dim]
elif input_name == 'returns_to_go':
elif input_name == "returns_to_go":
shape = [self.batch_size, self.sequence_length, 1]
elif input_name == "attention_mask":
shape = [self.batch_size, self.sequence_length]
elif input_name == 'timesteps':
elif input_name == "timesteps":
shape = [self.batch_size, self.sequence_length]
return self.random_int_tensor(shape=shape, max_value=self.max_ep_len, framework=framework, dtype=int_dtype)

return self.random_float_tensor(shape, min_value=-2., max_value=2., framework=framework, dtype=float_dtype)
return self.random_float_tensor(shape, min_value=-2.0, max_value=2.0, framework=framework, dtype=float_dtype)


class DummySeq2SeqDecoderTextInputGenerator(DummyDecoderTextInputGenerator):
Expand Down

0 comments on commit 0696597

Please sign in to comment.