Skip to content

Commit

Permalink
updated requirements file and train script
Browse files Browse the repository at this point in the history
  • Loading branch information
statscol committed May 11, 2022
1 parent ab4f8be commit 9f81466
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
8 changes: 8 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
ipywidgets
datasets
torch
transformers
huggingface_hub
jiwer
torchaudio -f https://download.pytorch.org/whl/torch_stable.html
wandb
13 changes: 9 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def prepare_dataset(batch):
return batch




##apply it for every audio

@dataclass
Expand Down Expand Up @@ -135,19 +137,21 @@ def train_model(tr:float,tst:float):
training_args = TrainingArguments(
output_dir=REPO_OUT,
group_by_length=True,
per_device_train_batch_size=12,
per_device_train_batch_size=18,
gradient_accumulation_steps=2,
evaluation_strategy="steps",
num_train_epochs=10,
gradient_checkpointing=True,
num_train_epochs=20,
fp16=True,
gradient_checkpointing=True,
save_steps=800,
eval_steps=400,
logging_steps=400,
learning_rate=2e-4,
warmup_steps=300,
save_total_limit=30,
push_to_hub=True,
report_to="wandb",
run_name="wav2vec-large-noLM"
)

trainer = Trainer(
Expand All @@ -162,9 +166,10 @@ def train_model(tr:float,tst:float):

###
trainer.train()
trainer.push_to_hub()

if __name__=='__main__':
import wandb
wandb.init(project="wav2vec-spanish")

parser = argparse.ArgumentParser(description = 'ASR Parser')
parser.add_argument('-tr',type=float,help="train sample ratio",dest="tr_size")
Expand Down

0 comments on commit 9f81466

Please sign in to comment.