Skip to content

Commit

Permalink
Update extract_embeddings.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jackzhu727 authored Aug 5, 2022
1 parent b58070c commit 32c0551
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion feature_extraction/extract_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def load_pretrained(net, model_dir):
feature_extractor = InceptionV4(num_classes=256)
load_pretrained(feature_extractor, args.feature_extractor_dir)
feature_extractor.to('cuda')
featurizer = nn.DataParallel(model, device_ids=device_ids)
feature_extractor = nn.DataParallel(feature_extractor, device_ids=device_ids)

subtype_model = InceptionV4(num_classes=2).to('cuda')
subtype_model.load_state_dict(torch.load(args.subtype_model_dir))
Expand Down

0 comments on commit 32c0551

Please sign in to comment.