Skip to content

Commit

Permalink
feat: add device_target config and set default to Ascend. (#796)
Browse files Browse the repository at this point in the history
  • Loading branch information
IASZHT authored Jul 26, 2024
1 parent 6e2a4c7 commit 0143736
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
6 changes: 5 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def create_parser():
help='Interval for print training log. Unit: step (default=100)')
group.add_argument('--seed', type=int, default=42,
help='Seed value for determining randomness in numpy, random, and mindspore (default=42)')
group.add_argument('--device_target', type=str, default='Ascend',
help='Device target for computing, which can be Ascend, GPU or CPU. (default=Ascend)')

# Dataset parameters
group = parser.add_argument_group('Dataset parameters')
Expand Down Expand Up @@ -94,7 +96,7 @@ def create_parser():
'Example: "randaug-m10-n2-w0-mstd0.5-mmax10-inc0", "autoaug-mstd0.5" or autoaugr-mstd0.5.')
group.add_argument('--aug_splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 3 (currently, only support 3 splits))'
'it should be set with one auto_augment')
'it should be set with one auto_augment')
group.add_argument('--re_prob', type=float, default=0.0,
help='Probability of performing erasing (default=0.0)')
group.add_argument('--re_scale', type=tuple, default=(0.02, 0.33),
Expand Down Expand Up @@ -267,6 +269,8 @@ def create_parser():
help='Whether to shuffle the evaluation data (default=False)')

return parser_config, parser


# fmt: on


Expand Down
9 changes: 6 additions & 3 deletions tests/tasks/test_train_val_imagenet_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_train(mode, val_while_train, model="resnet18"):
DownLoad().download_and_extract_archive(dataset_url, root_dir)

# ---------------- test running train.py using the toy data ---------
device_target = "CPU"
dataset = "imagenet"
num_classes = 2
ckpt_dir = "./tests/ckpt_tmp"
Expand All @@ -48,7 +49,8 @@ def test_train(mode, val_while_train, model="resnet18"):
f"python {train_file} --dataset={dataset} --num_classes={num_classes} --model={model} "
f"--epoch_size={num_epochs} --ckpt_save_interval=2 --lr=0.0001 --num_samples={num_samples} --loss=CE "
f"--weight_decay=1e-6 --ckpt_save_dir={ckpt_dir} {download_str} --train_split=train --batch_size={batch_size} "
f"--pretrained --num_parallel_workers=2 --val_while_train={val_while_train} --val_split=val --val_interval=1"
f"--pretrained --num_parallel_workers=2 --val_while_train={val_while_train} --val_split=val --val_interval=1 "
f"--device_target={device_target}"
)

print(f"Running command: \n{cmd}")
Expand All @@ -57,10 +59,11 @@ def test_train(mode, val_while_train, model="resnet18"):

# --------- Test running validate.py using the trained model ------------- #
# begin_ckpt = os.path.join(ckpt_dir, f'{model}-1_1.ckpt')
end_ckpt = os.path.join(ckpt_dir, f"{model}-{num_epochs}_{num_samples//batch_size}.ckpt")
end_ckpt = os.path.join(ckpt_dir, f"{model}-{num_epochs}_{num_samples // batch_size}.ckpt")
cmd = (
f"python validate.py --model={model} --dataset={dataset} --val_split=val --data_dir={data_dir} "
f"--num_classes={num_classes} --ckpt_path={end_ckpt} --batch_size=40 --num_parallel_workers=2"
f"--num_classes={num_classes} --ckpt_path={end_ckpt} --batch_size=40 --num_parallel_workers=2 "
f"--device_target={device_target}"
)
# ret = subprocess.call(cmd.split(), stdout=sys.stdout, stderr=sys.stderr)
print(f"Running command: \n{cmd}")
Expand Down
1 change: 1 addition & 0 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def check_batch_size(num_samples, ori_batch_size=32, refine=True):


def validate(args):
ms.set_context(device_target=args.device_target)
ms.set_context(mode=args.mode)
if args.mode == ms.GRAPH_MODE:
ms.set_context(jit_config={"jit_level": "O2"})
Expand Down

0 comments on commit 0143736

Please sign in to comment.