Skip to content

Commit

Permalink
update trainer and its args to main
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Oct 12, 2023
1 parent 5a8624e commit a6c53e0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
5 changes: 1 addition & 4 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,17 +213,16 @@ class ORTTrainer(Trainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
args: ORTTrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
onnx_model_path: Union[str, os.PathLike] = None,
):
super().__init__(
model=model,
Expand All @@ -249,8 +248,6 @@ def __init__(

self.model = model

self.onnx_model_path = onnx_model_path
self.exported_with_loss = False
if self.args.local_rank:
torch.cuda.set_device(self.args.local_rank)

Expand Down
20 changes: 11 additions & 9 deletions optimum/onnxruntime/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ def __post_init__(self):
if self.load_best_model_at_end:
if self.evaluation_strategy != self.save_strategy:
raise ValueError(
"--load_best_model_at_end requires the save and eval strategy to match, but found\n- Evaluation "
f"strategy: {self.evaluation_strategy}\n- Save strategy: {self.save_strategy}"
"--load_best_model_at_end requires the saving steps to be a multiple of the evaluation "
"steps, which cannot get guaranteed when mixing ratio and absolute steps for save_steps "
f"{self.save_steps} and eval_steps {self.eval_steps}."
)
if self.evaluation_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0:
if self.eval_steps < 1 or self.save_steps < 1:
Expand Down Expand Up @@ -191,14 +192,15 @@ def __post_init__(self):
self.half_precision_backend = self.fp16_backend

if self.bf16 or self.bf16_full_eval:
if self.no_cuda and not is_torch_bf16_cpu_available():
if self.use_cpu and not is_torch_bf16_cpu_available():
# cpu
raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
elif not self.no_cuda and torch.cuda.is_available() and not is_torch_bf16_gpu_available():
# gpu
raise ValueError(
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
)
elif not self.use_cpu:
if torch.cuda.is_available() and not is_torch_bf16_gpu_available():
# gpu
raise ValueError(
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
)

if self.fp16 and self.bf16:
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
Expand Down Expand Up @@ -307,7 +309,7 @@ def __post_init__(self):
# no need to assert on else

# if training args is specified, it will override the one specified in the accelerate config
if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0:
if self.half_precision_backend != "apex":
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
if self.fp16:
mixed_precision_dtype = "fp16"
Expand Down

0 comments on commit a6c53e0

Please sign in to comment.