Skip to content

Commit

Permalink
Fix bugs when collecting results from mp.spawn
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenglongMa committed Sep 22, 2023
1 parent 96eb311 commit d0ba4d4
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 107 deletions.
1 change: 1 addition & 0 deletions recbole/quick_start/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from recbole.quick_start.quick_start import (
run,
run_recbole,
objective_function,
load_data_and_model,
Expand Down
64 changes: 62 additions & 2 deletions recbole/quick_start/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,62 @@
)


def run(
model,
dataset,
config_file_list=None,
config_dict=None,
saved=True,
nproc=1,
world_size=-1,
ip="localhost",
port="5678",
group_offset=0,
):
if nproc == 1 and world_size <= 0:
res = run_recbole(
model=model,
dataset=dataset,
config_file_list=config_file_list,
config_dict=config_dict,
saved=saved,
)
else:
if world_size == -1:
world_size = nproc
import torch.multiprocessing as mp

# Refer to https://discuss.pytorch.org/t/problems-with-torch-multiprocess-spawn-and-simplequeue/69674/2
# https://discuss.pytorch.org/t/return-from-mp-spawn/94302/2
queue = mp.get_context('spawn').SimpleQueue()

config_dict = config_dict or {}
config_dict.update({
"world_size": world_size,
"ip": ip,
"port": port,
"nproc": nproc,
"offset": group_offset,
})
kwargs = {
"config_dict": config_dict,
"queue": queue,
}

mp.spawn(
run_recboles,
args=(model, dataset, config_file_list, kwargs),
nprocs=nproc,
join=True,
)

# Normally, there should be only one item in the queue
res = None if queue.empty() else queue.get()
return res


def run_recbole(
model=None, dataset=None, config_file_list=None, config_dict=None, saved=True
model=None, dataset=None, config_file_list=None, config_dict=None, saved=True, queue=None
):
r"""A fast running api, which includes the complete process of
training and testing a model on a specified dataset
Expand All @@ -51,6 +105,7 @@ def run_recbole(
config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``.
config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``.
saved (bool, optional): Whether to save the model. Defaults to ``True``.
queue (torch.multiprocessing.Queue, optional): The queue used to pass the result to the main process. Defaults to ``None``.
"""
# configurations initialization
config = Config(
Expand Down Expand Up @@ -104,13 +159,18 @@ def run_recbole(
logger.info(set_color("best valid ", "yellow") + f": {best_valid_result}")
logger.info(set_color("test result", "yellow") + f": {test_result}")

return {
result = {
"best_valid_score": best_valid_score,
"valid_score_bigger": config["valid_metric_bigger"],
"best_valid_result": best_valid_result,
"test_result": test_result,
}

if config["local_rank"] == 0 and queue is not None:
queue.put(result) # for multiprocessing, e.g., mp.spawn

return result # for the single process


def run_recboles(rank, *args):
ip, port, world_size, nproc, offset = args[3:]
Expand Down
36 changes: 11 additions & 25 deletions run_recbole.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
# @Email : [email protected], [email protected], [email protected]

import argparse
from ast import arg

from recbole.quick_start import run_recbole, run_recboles
from recbole.quick_start import run

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -44,26 +43,13 @@
args.config_files.strip().split(" ") if args.config_files else None
)

if args.nproc == 1 and args.world_size <= 0:
run_recbole(
model=args.model, dataset=args.dataset, config_file_list=config_file_list
)
else:
if args.world_size == -1:
args.world_size = args.nproc
import torch.multiprocessing as mp

mp.spawn(
run_recboles,
args=(
args.model,
args.dataset,
config_file_list,
args.ip,
args.port,
args.world_size,
args.nproc,
args.group_offset,
),
nprocs=args.nproc,
)
run(
args.model,
args.dataset,
config_file_list=config_file_list,
nproc=args.nproc,
world_size=args.world_size,
ip=args.ip,
port=args.port,
group_offset=args.group_offset,
)
44 changes: 11 additions & 33 deletions run_recbole_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,10 @@


import argparse
from ast import arg

from recbole.quick_start import run_recbole, run_recboles
from recbole.quick_start import run
from recbole.utils import list_to_latex


def run(args, model, config_file_list):
if args.nproc == 1 and args.world_size <= 0:
res = run_recbole(
model=model,
dataset=args.dataset,
config_file_list=config_file_list,
)
else:
if args.world_size == -1:
args.world_size = args.nproc
import torch.multiprocessing as mp

res = mp.spawn(
run_recboles,
args=(
args.model,
args.dataset,
config_file_list,
args.ip,
args.port,
args.world_size,
args.nproc,
args.group_offset,
),
nprocs=args.nproc,
)
return res


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -92,7 +61,16 @@ def run(args, model, config_file_list):

valid_res_dict = {"Model": model}
test_res_dict = {"Model": model}
result = run(args, model, config_file_list)
result = run(
model,
args.dataset,
config_file_list=config_file_list,
nproc=args.nproc,
world_size=args.world_size,
ip=args.ip,
port=args.port,
group_offset=args.group_offset,
)
valid_res_dict.update(result["best_valid_result"])
test_res_dict.update(result["test_result"])
bigger_flag = result["valid_score_bigger"]
Expand Down
98 changes: 51 additions & 47 deletions significance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,41 @@
# @Email :

import argparse
from ast import arg
import random
import sys
from collections import defaultdict
from scipy import stats

from recbole.quick_start import run_recbole, run_recboles

from scipy import stats

def run(args, seed):
if args.nproc == 1 and args.world_size <= 0:
res = run_recbole(
model=args.model,
dataset=args.dataset,
config_file_list=config_file_list,
from recbole.quick_start import run


def run_test(
model,
dataset,
config_files,
seeds,
nproc,
world_size,
ip,
port,
group_offset,
):
results = defaultdict(list)
for seed in seeds:
res = run(
model,
dataset,
config_files,
config_dict={"seed": seed},
nproc=nproc,
world_size=world_size,
ip=ip,
port=port,
group_offset=group_offset,
)
else:
if args.world_size == -1:
args.world_size = args.nproc
import torch.multiprocessing as mp

res = mp.spawn(
run_recboles,
args=(
args.model,
args.dataset,
config_file_list,
args.ip,
args.port,
args.world_size,
args.nproc,
args.group_offset,
),
nprocs=args.nproc,
)
return res
for _key, _value in res["test_result"].items():
results[_key].append(_value)
return results


if __name__ == "__main__":
Expand Down Expand Up @@ -101,24 +99,30 @@ def run(args, seed):
random.seed(args.st_seed)
random_seeds = [random.randint(0, 2**32 - 1) for _ in range(args.run_times)]

result_ours = defaultdict(list)
result_baseline = defaultdict(list)

config_file_ours, config_file_baseline = config_file_list

args.model = args.model_ours
args.config_file_list = [result_ours]
for seed in random_seeds:
res = run(args, seed)
for key, value in res["test_result"].items():
result_ours[key].append(value)

args.model = args.model_baseline
args.config_file_list = [config_file_baseline]
for seed in random_seeds:
res = run(args, seed)
for key, value in res["test_result"].items():
result_baseline[key].append(value)
result_ours = run_test(
args.model_ours,
args.dataset,
[config_file_ours],
random_seeds,
args.nproc,
args.world_size,
args.ip,
args.port,
args.group_offset,
)
result_baseline = run_test(
args.model_baseline,
args.dataset,
[config_file_baseline],
random_seeds,
args.nproc,
args.world_size,
args.ip,
args.port,
args.group_offset,
)

final_result = {}
for key, value in result_ours.items():
Expand Down

0 comments on commit d0ba4d4

Please sign in to comment.