Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_datasets(dataset_cfgs: Union[Dict, DictConfig], **kwargs):
return dataset


def get_data(data_cfg: DictConfig, mode="train", **kwargs):
def get_data(data_cfg: DictConfig, mode="train", seed=0, **kwargs):
data = {}
data_cfg = dict(data_cfg)
anchor = data_cfg.pop("anchor", "forget")
Expand All @@ -56,7 +56,7 @@ def get_data(data_cfg: DictConfig, mode="train", **kwargs):
return data
elif mode == "unlearn":
unlearn_splits = {k: v for k, v in data.items() if k not in ("eval", "test")}
unlearn_dataset = ForgetRetainDataset(**unlearn_splits, anchor=anchor)
unlearn_dataset = ForgetRetainDataset(**unlearn_splits, anchor=anchor, seed=seed)
data["train"] = unlearn_dataset
for split in unlearn_splits:
data.pop(split)
Expand Down
12 changes: 9 additions & 3 deletions src/data/unlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@

class ForgetRetainDataset(Dataset):
# https://github.com/OPTML-Group/SOUL/blob/main/src/dataset/Base.py
def __init__(self, forget, retain, anchor="forget"):
def __init__(self, forget, retain, anchor="forget", seed=0):
"""Wraps the forget retain dataset into unlearning dataset.

Args:
forget (Dataset): Forget Dataset
retain (Dataset): Retain Dataset
anchor (str, optional): Specifies which dataset to anchor while randomly sampling from the other dataset. Defaults to 'forget'.
seed (int, optional): Random seed for reproducibility. Defaults to 0.
"""
self.forget = forget
self.retain = retain
self.anchor = anchor
self.seed = seed

def __len__(self):
"""Ensures the sampled dataset matches the anchor dataset's length."""
Expand All @@ -33,14 +35,18 @@ def __len__(self):

def __getitem__(self, idx):
item = {}
g = torch.Generator()
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
rank_seed = self.seed + rank
g.manual_seed(rank_seed)
if self.anchor == "forget":
item["forget"] = self.forget[idx]
if self.retain:
retain_idx = torch.randint(0, len(self.retain), (1,)).item()
retain_idx = torch.randint(0, len(self.retain), (1,), generator=g).item()
item["retain"] = self.retain[retain_idx]
elif self.anchor == "retain":
item["retain"] = self.retain[idx]
if self.forget:
forget_idx = torch.randint(0, len(self.forget), (1,)).item()
forget_idx = torch.randint(0, len(self.forget), (1,), generator=g).item()
item["forget"] = self.forget[forget_idx]
return item
2 changes: 1 addition & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def main(cfg: DictConfig):
# Load Dataset
data_cfg = cfg.data
data = get_data(
data_cfg, mode=mode, tokenizer=tokenizer, template_args=template_args
data_cfg, mode=mode, tokenizer=tokenizer, template_args=template_args, seed=cfg.trainer.args.seed
)

# Load collator
Expand Down
Loading