You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, we are trying to use mup tool to tune Swin Transformer v2 model.
I modified the code of Swin Transformer v2 to adapt mup and executed the "save base shape" and "coordinate check".
The results of "coordinate check" shows that it can not meet the requirements of mup.
Does mup support the Swin Transformer v2 model?
For the code of "swin_transformer_v2.py", I modified the following code (Because Swin Transformer v2 doesn't use "1/sqrt(d) attention scaling", I don't modify it):
replaced the output layper nn.Linear with MuReadout
replaced std normal init with mup normal init
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
# self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
### muP: replace nn.Linear with MuReadout
self.head = MuReadout(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
for bly in self.layers:
bly._init_respostnorm()
def _init_weights(self, m, readout_zero_init=False, query_zero_init=False):
### muP: swap constant std normal init with normal_ from `mup.init`.
### Because `_init_weights` is called in `__init__`, before `infshape` is set,
### we need to manually call `self.apply(self._init_weights)` after calling
### `set_base_shape(model, base)`
if isinstance(m, nn.Linear):
if isinstance(m, MuReadout) and readout_zero_init:
m.weight.data.zero_()
else:
if hasattr(m.weight, 'infshape'):
normal_(m.weight, mean=0.0, std=.02)
else:
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
### End muP
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
For the code of "main.py" of Swin Transformer, I added "save base shape" and "coordinate check" functions.
def main(config, args):
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
model = build_model(config)
logger.info(str(model))
### muP
if args.save_base_shapes:
print(f'saving base shapes at {args.save_base_shapes}')
base_shapes = get_shapes(model)
delta_config = copy.deepcopy(config)
delta_config.defrost()
delta_config.MODEL.SWINV2.EMBED_DIM *= 2 # Modify SwinV2 embed dim
delta_config.MODEL.SWIN.EMBED_DIM *= 2 # Modify Swin embed dim
# delta_config.MODEL.SWIN_MOE.EMBED_DIM *= 2 # Modify Swin_moe embed dim
delta_config.MODEL.SWIN_MLP.EMBED_DIM *= 2 # Modify Swin_mlp embed dim
delta_shapes = get_shapes(
# just need to change whatever dimension(s) we are scaling
build_model(delta_config)
)
make_base_shapes(base_shapes, delta_shapes, savefile=args.save_base_shapes)
print('done and exit')
import sys;
sys.exit()
if args.load_base_shapes:
print(f'loading base shapes from {args.load_base_shapes}')
set_base_shapes(model, args.load_base_shapes)
print('done')
else:
print(f'using own shapes')
set_base_shapes(model, None)
print('done')
### muP
def coord_check(mup, config, lr, optimizer, nsteps, nseeds, args, plotdir='', legend=False):
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
def gen(w, standparam=False):
def f():
delta_config = copy.deepcopy(config)
delta_config.defrost()
delta_config.MODEL.SWINV2.EMBED_DIM = w # Modify SwinV2 embed dim
delta_config.MODEL.SWIN.EMBED_DIM = w # Modify Swin embed dim
# delta_config.MODEL.SWIN_MOE.EMBED_DIM = w # Modify Swin_moe embed dim
delta_config.MODEL.SWIN_MLP.EMBED_DIM = w # Modify Swin_mlp embed dim
model = build_model(delta_config)
if standparam:
set_base_shapes(model, None)
else:
assert args.load_base_shapes, 'load_base_shapes needs to be nonempty'
set_base_shapes(model, args.load_base_shapes)
return model
return f
optimizer = optimizer.replace('mu', '')
widths = (12, 24, 48, 96, 192)
models = {w: gen(w, standparam=not mup) for w in widths}
# train_data = batchify(corpus.train, batch_size, device=args.device)
df = get_coord_data(models, data_loader_train, mup=mup, lr=lr, optimizer=optimizer, flatten_output=True,
nseeds=nseeds, nsteps=nsteps, lossfn='xent')
prm = 'muP' if mup else 'SP'
return plot_coord_data(df, legend=legend,
save_to=os.path.join(plotdir, f'{prm.lower()}_trsfmr_{optimizer}_coord.png'),
suptitle=f'{prm} Transformer {optimizer} lr={lr} nseeds={nseeds}',
face_color='xkcd:light grey' if not mup else None)
The results of "coordinate check" show that there is only a small difference between "mup" and "SP". sorry, I can't upload pictures.
Could you please help us to check if mup can support Swin Transformer v2 model? or there are some other reasons? Thanks a lot.
The text was updated successfully, but these errors were encountered:
@shiyf129 I also think the snippets look reasonable. I have done coord checks on Swin as well, and I attach the plots here. Echoing Edward's suggestion, the widths tested is typically 256, 512, 1024, and 2048. Have you tried larger widths and attaching your coord check plots?
Hi, we are trying to use mup tool to tune Swin Transformer v2 model.
I modified the code of Swin Transformer v2 to adapt mup and executed the "save base shape" and "coordinate check".
The results of "coordinate check" shows that it can not meet the requirements of mup.
Does mup support the Swin Transformer v2 model?
For the code of "swin_transformer_v2.py", I modified the following code (Because Swin Transformer v2 doesn't use "1/sqrt(d) attention scaling", I don't modify it):
For the code of "main.py" of Swin Transformer, I added "save base shape" and "coordinate check" functions.
The results of "coordinate check" show that there is only a small difference between "mup" and "SP". sorry, I can't upload pictures.
Could you please help us to check if mup can support Swin Transformer v2 model? or there are some other reasons? Thanks a lot.
The text was updated successfully, but these errors were encountered: