-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEATURE] timm.models.adapt_input_conv: beyond RGB weights #2445
Comments
@adamjstewart So if you had a model with 13 channels, you might want to convert it so those 3 channels are repeated several times and pass it say a 13*4=52 channel image? |
Correct, other than a typo (3 -> 13). Another use case is when a model is trained on top of atmosphere data (13 channels) but we want to use it for surface reflectance data (11 channels) or vice versa. In this case, we know which specific channels were removed, but would be satisfied with a solution that just naively drops the last 2 channels and rescales things. |
@adamjstewart k, I think it's pretty straightforward to support that with an extra arg that covers the 'base' or default channels. Below I added base_chans arg... if you set it to 13 you should get the behaviour described. The 'naive' dropping is the default. For dropping specific channels, an arg could be added with indices to drop. Conceivable you could drop before or after the repeat say if you had 11 of 13 channels you wanted to keep, AND wanted to repeat to have 33 channel input... hmm. def adapt_input_conv(in_chans: int, conv_weight: Tensor, base_chans: int = 3) -> Tensor:
conv_type = conv_weight.dtype
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
O, I, J, K = conv_weight.shape
if in_chans == 1:
if I > base_chans:
assert conv_weight.shape[1] % base_chans == 0
# For models with space2depth stems
conv_weight = conv_weight.reshape(O, I // base_chans, base_chans, J, K)
conv_weight = conv_weight.sum(dim=2, keepdim=False)
else:
conv_weight = conv_weight.sum(dim=1, keepdim=True)
elif in_chans != base_chans:
if I != base_chans:
raise NotImplementedError('Weight format not supported by conversion.')
else:
# NOTE this strategy should be better than random init, but there could be other combinations of
# the original RGB input layer weights that'd work better for specific cases.
repeat = int(math.ceil(in_chans / base_chans))
scale = base_chans / float(in_chans)
if repeat > 1:
conv_weight = conv_weight.repeat(1, repeat, 1, 1)
conv_weight = conv_weight[:, :in_chans, :, :] # drops last channels
conv_weight *= scale
conv_weight = conv_weight.to(conv_type)
return conv_weight |
Would it be simpler to use |
@adamjstewart removing the base_chans arg would require adding a space2depth multiplier arg to resolve that ambiguity to support monochrome use with tresnet models... |
I don't know what most of those words mean so I'll trust you. However, your implementation doesn't seem to support anything other than |
Is your feature request related to a problem? Please describe.
TorchGeo provides a number of model weights pre-trained on non-RGB imagery (e.g., Sentinel-2, 13 channels). Oftentimes, when dealing with time-series data, we would like to stack images along the channel dimension so that we end up with$$B \times TC \times H \times W$$ inputs. However, we don't yet have an easy way to adapt our pre-trained weights to match.
Describe the solution you'd like
timm.models.adapt_input_conv
provides a powerful tool for repeating and scaling weights to adapt to changingin_chans
, but only seems to support 3-channel weights ifin_chans
> 1. I would like to extend this to support any number of channels. Would this be as simple as replacing 3 withI
throughout the function?Describe alternatives you've considered
We could write our own functionality in TorchGeo, but figured this would be useful to the broader timm community.
Additional context
@isaaccorley @keves1 may also be interested in this.
The text was updated successfully, but these errors were encountered: