Skip to content

Commit

Permalink
Add wandb to distill
Browse files Browse the repository at this point in the history
  • Loading branch information
Flova committed Feb 9, 2025
1 parent af70a87 commit e2b2299
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions ddlitlab2024/ml/training/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from torch.utils.data import DataLoader
from tqdm import tqdm


import wandb
from ddlitlab2024.dataset.pytorch import DDLITLab2024Dataset, worker_init_fn
from ddlitlab2024.ml import logger
from ddlitlab2024.ml.model import End2EndDiffusionTransformer
Expand All @@ -24,7 +26,6 @@
if __name__ == "__main__":
logger.info("Starting training")
logger.info(f"Using device {device}")
# TODO wandb

# Parse the command line arguments
parser = argparse.ArgumentParser(description="Distills the multi-step diffusion model into a single-step model")
Expand Down Expand Up @@ -61,6 +62,9 @@
# Flag the student model as distilled
params["distilled_decoder"] = True

# Initialize the weights and biases logging
run = wandb.init(entity="bitbots", project="ddlitlab-2024", config=params)

# Load the dataset (primary for example conditioning)
logger.info("Create dataset objects")
dataset = DDLITLab2024Dataset(
Expand All @@ -76,7 +80,7 @@
use_images=params["use_images"],
use_game_state=params["use_gamestate"],
)
num_workers = 5
num_workers = 32
dataloader = DataLoader(
dataset,
batch_size=params["batch_size"],
Expand Down Expand Up @@ -124,6 +128,9 @@
# Clone the model
student_model = End2EndDiffusionTransformer(**model_config).to(device)

# Log gradients and parameters to wandb
run.watch(student_model)

# Load the same checkpoint into the student model
# I load it from disk do avoid any potential issues when copying the model
logger.info(f"Loading student model from teacher checkpoint")
Expand All @@ -146,7 +153,7 @@
# Iterate over the dataset
for i, batch in enumerate(pbar := tqdm(dataloader)):
# Move the data to the device
batch = {k: v.to(device) for k, v in asdict(batch).items()}
batch = {k: v.to(device, non_blocking=True) for k, v in asdict(batch).items()}

# Extract the target actions
joint_targets = batch["joint_command"]
Expand Down Expand Up @@ -192,9 +199,11 @@
optimizer.step()
lr_scheduler.step()

pbar.set_postfix_str(
f"Epoch {epoch}, Loss: {mean_loss / (i + 1):.05f}, LR: {lr_scheduler.get_last_lr()[0]:0.7f}"
)
if i % 20 == 0:
pbar.set_postfix_str(
f"Epoch {epoch}, Loss: {mean_loss / (i + 1):.05f}, LR: {lr_scheduler.get_last_lr()[0]:0.7f}"
)
run.log({"loss": loss.item(), "lr": lr_scheduler.get_last_lr()[0]}, step=(i + epoch * len(dataloader)))

# Save the model
checkpoint = {
Expand All @@ -205,3 +214,6 @@
"current_epoch": epoch,
}
torch.save(checkpoint, args.output)

# Finish the run cleanly
run.finish()

0 comments on commit e2b2299

Please sign in to comment.