You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
fromtransformersimportTrainerCallbackclassLossLoggingCallback(TrainerCallback):
defon_log(self, args, state, control, logs=None, **kwargs):
iflogsisnotNone:
loss=logs.get("loss")
iflossisnotNone:
print(f"Step {state.global_step}: Loss: {loss}")
# Optionally, you can store the loss in a file or a list for further processingwithopen("training_loss_log.txt", "a") aslog_file:
log_file.write(f"Step {state.global_step}: Loss: {loss}\n")
Then add the callback to the trainer
fromtransformersimportTrainingArgumentstraining_args=TrainingArguments(
self.experiment_dir,
num_train_epochs=self.epochs,
per_device_train_batch_size=self.batch_size,
save_strategy="no",
**self.train_hyperparameters
)
# Add the custom callback to the trainerloss_logging_callback=LossLoggingCallback()
# Create the trainer with the callbacktabula_trainer=TabulaTrainer(
self.model,
training_args,
train_dataset=tabula_ds,
tokenizer=self.tokenizer,
data_collator=TabulaDataCollator(self.tokenizer),
callbacks=[loss_logging_callback] # Add the callback here
)
You can implement and if it is bug-free, you can create a PR and I will merge it.
No description provided.
The text was updated successfully, but these errors were encountered: