Skip to content

Commit

Permalink
Replace hasattr with getattr in faiar_cair/simmc/simmc_github/mm_dst/…
Browse files Browse the repository at this point in the history
…gpt2_dst/scripts/run_language_modeling.py

Summary:
The pattern
```
X.Y if hasattr(X, "Y") else Z
```
can be replaced with
```
getattr(X, "Y", Z)
```

The [getattr](https://www.w3schools.com/python/ref_func_getattr.asp) function gives more succinct code than the [hasattr](https://www.w3schools.com/python/ref_func_hasattr.asp) function. Please use it when appropriate.

**This diff is very low risk. Green tests indicate that you can safely Accept & Ship.**

Differential Revision: D44886469

fbshipit-source-id: d5819d8d710624b012fbb28139fe42c1aa2f5432
  • Loading branch information
r-barnes authored and facebook-github-bot committed Apr 12, 2023
1 parent 456b85d commit bec604b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions mm_dst/gpt2_dst/scripts/run_language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def collate(examples: List[torch.Tensor]):
else:
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

model = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training
model = getattr(model, "module", model) # Take care of distributed/parallel training
model.resize_token_embeddings(len(tokenizer))

# Prepare optimizer and schedule (linear warmup and decay)
Expand Down Expand Up @@ -370,7 +370,7 @@ def collate(examples: List[torch.Tensor]):
output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
os.makedirs(output_dir, exist_ok=True)
model_to_save = (
model.module if hasattr(model, "module") else model
getattr(model, "module", model)
) # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
Expand Down Expand Up @@ -762,7 +762,7 @@ def main():
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = (
model.module if hasattr(model, "module") else model
getattr(model, "module", model)
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
Expand Down

0 comments on commit bec604b

Please sign in to comment.