-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Doesn't work with DataParallel #4
Comments
use accelerate, accelerator.prepare(model) |
@SKDDJ
Furthermore it breaks when tracking the gradients using wandb.watch() since the grad.data object send to wandb is None indicating that the gradients dont get backpropagated properly. Im currently using pytorch 2.2.0, can you maybe reference what version you tried? |
@niklasbubeck Hi, my torch version is "torch 2.2.2", maybe you can try the latest torch to see if this works:) |
@niklasbubeck Note that you'd better do this once you've load your model # Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
finetuning_task=data_args.task_name,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
)
model = AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
model, tokenizer = accelerator.prepare(model, tokenizer)
# ... your other code
apply_lora(model) # add lora here after use prepare(model) hope this help you. |
Minimum example
The text was updated successfully, but these errors were encountered: