From faad8d939fee800aeb03a74d10fe069fb8c160c8 Mon Sep 17 00:00:00 2001 From: ZeguanXiao Date: Sat, 27 Sep 2025 12:54:40 +0800 Subject: [PATCH 1/4] fix: random sampling in ForgetRetainDataset --- src/data/unlearn.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/data/unlearn.py b/src/data/unlearn.py index 0cb0bada..1f6db300 100644 --- a/src/data/unlearn.py +++ b/src/data/unlearn.py @@ -33,14 +33,18 @@ def __len__(self): def __getitem__(self, idx): item = {} + g = torch.Generator() + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + seed = int(torch.empty((), dtype=torch.int64).random_().item() + rank) + g.manual_seed(seed) if self.anchor == "forget": item["forget"] = self.forget[idx] if self.retain: - retain_idx = torch.randint(0, len(self.retain), (1,)).item() + retain_idx = torch.randint(0, len(self.retain), (1,), generator=g).item() item["retain"] = self.retain[retain_idx] elif self.anchor == "retain": item["retain"] = self.retain[idx] if self.forget: - forget_idx = torch.randint(0, len(self.forget), (1,)).item() + forget_idx = torch.randint(0, len(self.forget), (1,), generator=g).item() item["forget"] = self.forget[forget_idx] return item From c079574411df2102b4d89bfd10a9919d2b87ec88 Mon Sep 17 00:00:00 2001 From: ZeguanXiao Date: Wed, 1 Oct 2025 13:05:27 +0800 Subject: [PATCH 2/4] feat: add seed parameter for reproducibility in ForgetRetainDataset --- src/data/__init__.py | 4 ++-- src/data/unlearn.py | 8 +++++--- src/train.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/data/__init__.py b/src/data/__init__.py index c24b0b03..93c092e5 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -46,7 +46,7 @@ def get_datasets(dataset_cfgs: Union[Dict, DictConfig], **kwargs): return dataset -def get_data(data_cfg: DictConfig, mode="train", **kwargs): +def get_data(data_cfg: DictConfig, mode="train", seed=0, **kwargs): data = {} data_cfg = dict(data_cfg) anchor = data_cfg.pop("anchor", "forget") @@ -56,7 +56,7 @@ def get_data(data_cfg: DictConfig, mode="train", **kwargs): return data elif mode == "unlearn": unlearn_splits = {k: v for k, v in data.items() if k not in ("eval", "test")} - unlearn_dataset = ForgetRetainDataset(**unlearn_splits, anchor=anchor) + unlearn_dataset = ForgetRetainDataset(**unlearn_splits, anchor=anchor, seed=seed) data["train"] = unlearn_dataset for split in unlearn_splits: data.pop(split) diff --git a/src/data/unlearn.py b/src/data/unlearn.py index 1f6db300..9fd6e7f5 100644 --- a/src/data/unlearn.py +++ b/src/data/unlearn.py @@ -4,17 +4,19 @@ class ForgetRetainDataset(Dataset): # https://github.com/OPTML-Group/SOUL/blob/main/src/dataset/Base.py - def __init__(self, forget, retain, anchor="forget"): + def __init__(self, forget, retain, anchor="forget", seed=0): """Wraps the forget retain dataset into unlearning dataset. Args: forget (Dataset): Forget Dataset retain (Dataset): Retain Dataset anchor (str, optional): Specifies which dataset to anchor while randomly sampling from the other dataset. Defaults to 'forget'. + seed (int, optional): Random seed for reproducibility. Defaults to 0. """ self.forget = forget self.retain = retain self.anchor = anchor + self.seed = seed def __len__(self): """Ensures the sampled dataset matches the anchor dataset's length.""" @@ -35,8 +37,8 @@ def __getitem__(self, idx): item = {} g = torch.Generator() rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - seed = int(torch.empty((), dtype=torch.int64).random_().item() + rank) - g.manual_seed(seed) + rank_seed = self.seed + rank + g.manual_seed(rank_seed) if self.anchor == "forget": item["forget"] = self.forget[idx] if self.retain: diff --git a/src/train.py b/src/train.py index a2f81c8d..4e6a0224 100644 --- a/src/train.py +++ b/src/train.py @@ -23,7 +23,7 @@ def main(cfg: DictConfig): # Load Dataset data_cfg = cfg.data data = get_data( - data_cfg, mode=mode, tokenizer=tokenizer, template_args=template_args + data_cfg, mode=mode, tokenizer=tokenizer, template_args=template_args, seed=cfg.trainer.args.seed ) # Load collator From 60e099a02ad8cbcb275b200ba86ca5dd96b94e44 Mon Sep 17 00:00:00 2001 From: ZeguanXiao Date: Thu, 9 Oct 2025 00:10:08 +0800 Subject: [PATCH 3/4] refactor: fix lint --- src/data/__init__.py | 4 +++- src/data/unlearn.py | 8 ++++++-- src/train.py | 6 +++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/data/__init__.py b/src/data/__init__.py index 93c092e5..aa39bcbb 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -56,7 +56,9 @@ def get_data(data_cfg: DictConfig, mode="train", seed=0, **kwargs): return data elif mode == "unlearn": unlearn_splits = {k: v for k, v in data.items() if k not in ("eval", "test")} - unlearn_dataset = ForgetRetainDataset(**unlearn_splits, anchor=anchor, seed=seed) + unlearn_dataset = ForgetRetainDataset( + **unlearn_splits, anchor=anchor, seed=seed + ) data["train"] = unlearn_dataset for split in unlearn_splits: data.pop(split) diff --git a/src/data/unlearn.py b/src/data/unlearn.py index 9fd6e7f5..bbf745a1 100644 --- a/src/data/unlearn.py +++ b/src/data/unlearn.py @@ -42,11 +42,15 @@ def __getitem__(self, idx): if self.anchor == "forget": item["forget"] = self.forget[idx] if self.retain: - retain_idx = torch.randint(0, len(self.retain), (1,), generator=g).item() + retain_idx = torch.randint( + 0, len(self.retain), (1,), generator=g + ).item() item["retain"] = self.retain[retain_idx] elif self.anchor == "retain": item["retain"] = self.retain[idx] if self.forget: - forget_idx = torch.randint(0, len(self.forget), (1,), generator=g).item() + forget_idx = torch.randint( + 0, len(self.forget), (1,), generator=g + ).item() item["forget"] = self.forget[forget_idx] return item diff --git a/src/train.py b/src/train.py index 4e6a0224..5e8f6db5 100644 --- a/src/train.py +++ b/src/train.py @@ -23,7 +23,11 @@ def main(cfg: DictConfig): # Load Dataset data_cfg = cfg.data data = get_data( - data_cfg, mode=mode, tokenizer=tokenizer, template_args=template_args, seed=cfg.trainer.args.seed + data_cfg, + mode=mode, + tokenizer=tokenizer, + template_args=template_args, + seed=cfg.trainer.args.seed, ) # Load collator From 7a8b5fda95bea8b24636fcf7fe3adb782c921071 Mon Sep 17 00:00:00 2001 From: ZeguanXiao Date: Thu, 9 Oct 2025 00:20:30 +0800 Subject: [PATCH 4/4] fix: ensure unique random seed per item in ForgetRetainDataset --- src/data/unlearn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data/unlearn.py b/src/data/unlearn.py index bbf745a1..dff81be7 100644 --- a/src/data/unlearn.py +++ b/src/data/unlearn.py @@ -37,7 +37,7 @@ def __getitem__(self, idx): item = {} g = torch.Generator() rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - rank_seed = self.seed + rank + rank_seed = self.seed + rank + idx g.manual_seed(rank_seed) if self.anchor == "forget": item["forget"] = self.forget[idx]