Skip to content

Commit

Permalink
[Fix] Fix ddp bugs caused by out_type. (open-mmlab#1570)
Browse files Browse the repository at this point in the history
* set out_type to be 'raw'

* update test
  • Loading branch information
fangyixiao18 authored May 17, 2023
1 parent 034919d commit 770eb8e
Show file tree
Hide file tree
Showing 9 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion mmpretrain/models/selfsup/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(self,
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
out_type: str = 'avg_featmap',
out_type: str = 'raw',
frozen_stages: int = -1,
use_abs_pos_emb: bool = False,
use_rel_pos_bias: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion mmpretrain/models/selfsup/cae.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __init__(
bias: bool = 'qv_bias',
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
out_type: str = 'avg_featmap',
out_type: str = 'raw',
frozen_stages: int = -1,
use_abs_pos_emb: bool = True,
use_rel_pos_bias: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion mmpretrain/models/selfsup/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self,
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
out_type: str = 'avg_featmap',
out_type: str = 'raw',
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),
Expand Down
2 changes: 1 addition & 1 deletion mmpretrain/models/selfsup/maskfeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(self,
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
out_type: str = 'avg_featmap',
out_type: str = 'raw',
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_selfsup/test_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_beit_pretrain_vit(self):

# test without mask
fake_outputs = beit_backbone(fake_inputs, None)
assert fake_outputs[0].shape == torch.Size([2, 768])
assert fake_outputs[0].shape == torch.Size([2, 197, 768])

@pytest.mark.skipif(
platform.system() == 'Windows', reason='Windows mem limit')
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_selfsup/test_cae.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_cae_vit():

# test without mask
fake_outputs = cae_backbone(fake_inputs, None)
assert fake_outputs[0].shape == torch.Size([1, 192])
assert fake_outputs[0].shape == torch.Size([1, 197, 192])


@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_selfsup/test_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_mae_vit():

# test without mask
fake_outputs = mae_backbone(fake_inputs, None)
assert fake_outputs[0].shape == torch.Size([2, 768])
assert fake_outputs[0].shape == torch.Size([2, 197, 768])


@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_selfsup/test_maskfeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_maskfeat_vit():

# test without mask
fake_outputs = maskfeat_backbone(fake_inputs, None)
assert fake_outputs[0].shape == torch.Size([2, 768])
assert fake_outputs[0].shape == torch.Size([2, 197, 768])


@pytest.mark.skipif(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_selfsup/test_milan.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_milan_vit():

# test without mask
fake_outputs = milan_backbone(fake_inputs, None)
assert fake_outputs[0].shape == torch.Size([2, 768])
assert fake_outputs[0].shape == torch.Size([2, 197, 768])


@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
Expand Down

0 comments on commit 770eb8e

Please sign in to comment.