Skip to content

Commit

Permalink
Support one big lora, beta warm start
Browse files Browse the repository at this point in the history
  • Loading branch information
phoebeklett committed Feb 20, 2024
1 parent 3dc9cae commit 029afbc
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 142 deletions.
136 changes: 0 additions & 136 deletions experiments/laplace-visuals.py

This file was deleted.

24 changes: 19 additions & 5 deletions experiments/laplace_lora/lora_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self, config: FrozenConfigDict):

self.pretrained_model_name_or_path = config.pretrained_model_name_or_path
self.lr = config.lr
self.betas_warmstart = config.betas_warmstart

# These need to updated with the correct values before calling trainer.fit
self.prior_mean = None
Expand Down Expand Up @@ -119,15 +120,26 @@ def sub_param_to_log_posterior(
return log_lik + log_prior / self.num_data, output

def configure_optimizers(self):
self.opt = AdamW(self.sub_params.values(), lr=self.lr, maximize=True)
self.opt = AdamW(
self.sub_params.values(),
lr=self.lr,
maximize=True,
betas=(0.9 ** self.betas_warmstart[0], 0.999 ** self.betas_warmstart[1]),
)
self.sub_params = tree_map(lambda x: x.to(self.device), self.sub_params)
self.prior_mean = tree_map(lambda x: x.to(self.device), self.prior_mean)
self.prior_sd = tree_map(lambda x: x.to(self.device), self.prior_sd)

def training_step(self, batch, batch_idx):
batch_data = (
{key: torch.cat([d[key] for d in batch], dim=0) for key in batch[0]}
if isinstance(batch, list)
else batch
)

self.opt.zero_grad()

log_post, out = self.sub_param_to_log_posterior(self.sub_params, batch)
log_post, out = self.sub_param_to_log_posterior(self.sub_params, batch_data)
log_post.backward()

self.log("log_post", log_post)
Expand All @@ -152,11 +164,13 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
self.log(f"val_loss_task_{dataloader_idx}", output.loss)

self.log_metrics(
dataloader_idx, {f"val_loss_{dataloader_idx}": output.loss}, step=batch_idx
dataloader_idx,
{f"val_loss_task_{dataloader_idx}": output.loss},
step=batch_idx,
)
return output.loss

def on_validation_epoch_end(self, dataloader_idx=0):
acc = self.val_accuracy.compute()
self.log(f"val_accuracy_{dataloader_idx}", acc)
self.log_metrics(dataloader_idx, {f"val_accuracy_{dataloader_idx}": acc})
self.log("val_accuracy", acc)
self.log_metrics(dataloader_idx, {"val_accuracy": acc})
120 changes: 120 additions & 0 deletions experiments/laplace_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import pandas as pd
import matplotlib.pyplot as plt


# Define a function to apply a moving average to a DataFrame column
def smooth_dataframe_column(df, column_name, window_size):
return (
df[column_name].rolling(window=window_size, center=True, min_periods=1).mean()
)


def produce_plot_A(df_base, df, window_size, save_dir):
# Plot validation loss for all tasks, partitioned by training stage
fig, axs = plt.subplots(3, 1, figsize=(10, 8), sharex=True)

axs[0].plot(
df_base[df_base["metric_name"] == "val_loss_0"]["metric_value"]
.rolling(window=window_size, center=True, min_periods=1)
.mean(),
label="SGD",
)
axs[0].plot(
df[df["metric_name"] == "val_loss_0"]["metric_value"]
.rolling(window=window_size, center=True, min_periods=1)
.mean(),
label="Laplace",
)
axs[0].set_title("Task A")
axs[0].set_ylabel("Val Loss")

axs[1].plot(
df_base[df_base["metric_name"] == "val_loss_1"]["metric_value"]
.rolling(window=window_size, center=True, min_periods=1)
.mean(),
label="SGD",
)
axs[1].plot(
df[df["metric_name"] == "val_loss_1"]["metric_value"]
.rolling(window=window_size, center=True, min_periods=1)
.mean(),
label="Laplace",
)
axs[1].set_title("Task B")
axs[1].set_ylabel("Val Loss")

axs[2].plot(
df_base[df_base["metric_name"] == "val_loss_2"]["metric_value"]
.rolling(window=window_size, center=True, min_periods=1)
.mean(),
label="SGD",
)
axs[2].plot(
df[df["metric_name"] == "val_loss_2"]["metric_value"]
.rolling(window=window_size, center=True, min_periods=1)
.mean(),
label="Laplace",
)
axs[2].set_title("Task C")
axs[2].set_ylabel("Val Loss")
axs[2].set_xlabel("Training epoch")

# Adding vertical lines for training stages
for ax in axs:
ax.axvline(x=45, linestyle="--", color="grey")
ax.axvline(x=135, linestyle="--", color="grey")

# Adding legend
axs[0].legend()

plt.tight_layout()
plt.savefig(f"{save_dir}/plot_A.png", dpi=300) # Save as PNG file with 300 DPI
plt.close()


def produce_plot_B(df_base, df, window_size, save_dir):
# Plot average validation loss over entire training time
losses = ["val_loss_0", "val_loss_1", "val_loss_2"]
plt.plot(
df_base[df_base["metric_name"].isin(losses)]
.groupby(["task", "epoch", "step"])["metric_value"]
.rolling(window=window_size, center=True, min_periods=1)
.mean()
.reset_index()["metric_value"],
label="SGD",
)
plt.plot(
df[df["metric_name"].isin(losses)]
.groupby(["task", "epoch", "step"])["metric_value"]
.rolling(window=window_size, center=True, min_periods=1)
.mean()
.reset_index()["metric_value"],
label="Laplace",
)

plt.xlabel("Training time")
plt.ylabel("Validation Loss")
# Adding vertical lines for training stages
plt.axvline(x=30, linestyle="--", color="grey")
plt.axvline(x=105, linestyle="--", color="grey")

plt.legend()
plt.tight_layout()
plt.savefig(f"{save_dir}/plot_B.png", dpi=300)


# Path to your log file
LAPLACE_LOG_FILE_PATH = "/home/paperspace/Developer/uqlib/experiments/runs/lora_sam/2024-02-20T18-42-16_lora_sam/eval_metrics.txt"
BASELINE_LOG_FILE_PATH = "/home/paperspace/Developer/uqlib/experiments/runs/lora_sam/2024-02-20T19-52-31_lora_sam/eval_metrics.txt"

WINDOW_SIZE = 10
SAVE_DIR = "/home/paperspace/Developer/uqlib/experiments/runs/pictures"


if __name__ == "__main__":
# Read the log file into a pandas DataFrame
df = pd.read_csv(LAPLACE_LOG_FILE_PATH)
df_base = pd.read_csv(BASELINE_LOG_FILE_PATH)

produce_plot_A(df_base=df_base, df=df, window_size=WINDOW_SIZE, save_dir=SAVE_DIR)
produce_plot_B(df_base=df_base, df=df, window_size=WINDOW_SIZE, save_dir=SAVE_DIR)
6 changes: 5 additions & 1 deletion experiments/lora_for_laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@
# Train for MAP
trainer.fit(
model,
train_dataloaders[book_ind],
train_dataloaders=(
train_dataloaders[book_ind]
if config.sequential
else [train_dataloaders[i] for i in range(book_ind + 1)]
),
val_dataloaders=[test_dataloaders[i] for i in range(book_ind + 1)],
)

Expand Down
37 changes: 37 additions & 0 deletions experiments/utils/configs/lora_forget.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# File dirs
base_dir: &base_path "./experiments/"
logs_dir: &logs_path "./experiments/runs/lora_forget/"
data_dir: &data_path "./experiments/lora_forget/data/"

experiment_name: "continuous_lora"
num_tasks: 3
lambda_param: 0.
average_priors: false
sequential: false

# Model
model_config: &model_params
pretrained_model_name_or_path: "meta-llama/Llama-2-7b-hf"
lr: 1e-3
first_prior_sd: 1e10
betas_warmstart: [10,10]

#LoRA
lora_config: &lora_params
target_modules: "last_layer"
r: 8
alpha: 32
dropout: 0.


# Dataset
dataset_path: "./experiments/data/pg19-even-smaller.json"
train_batch_size: 1
laplace_batch_size: 1
drop_last: true
train_proportion: 0.85
shuffle: true
num_workers: 11
tokenizer_pretrained_model_name_or_path: "meta-llama/Llama-2-7b-hf"
stride_length: 4096
stride_overlap: 2048
Loading

0 comments on commit 029afbc

Please sign in to comment.