Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] timm.models.adapt_input_conv: beyond RGB weights #2445

Open
adamjstewart opened this issue Feb 25, 2025 · 6 comments
Open

[FEATURE] timm.models.adapt_input_conv: beyond RGB weights #2445

adamjstewart opened this issue Feb 25, 2025 · 6 comments
Labels
enhancement New feature or request

Comments

@adamjstewart
Copy link
Contributor

Is your feature request related to a problem? Please describe.

TorchGeo provides a number of model weights pre-trained on non-RGB imagery (e.g., Sentinel-2, 13 channels). Oftentimes, when dealing with time-series data, we would like to stack images along the channel dimension so that we end up with $$B \times TC \times H \times W$$ inputs. However, we don't yet have an easy way to adapt our pre-trained weights to match.

Describe the solution you'd like

timm.models.adapt_input_conv provides a powerful tool for repeating and scaling weights to adapt to changing in_chans, but only seems to support 3-channel weights if in_chans > 1. I would like to extend this to support any number of channels. Would this be as simple as replacing 3 with I throughout the function?

Describe alternatives you've considered

We could write our own functionality in TorchGeo, but figured this would be useful to the broader timm community.

Additional context

@isaaccorley @keves1 may also be interested in this.

@adamjstewart adamjstewart added the enhancement New feature or request label Feb 25, 2025
@rwightman
Copy link
Collaborator

@adamjstewart So if you had a model with 13 channels, you might want to convert it so those 3 channels are repeated several times and pass it say a 13*4=52 channel image?

@adamjstewart
Copy link
Contributor Author

Correct, other than a typo (3 -> 13).

Another use case is when a model is trained on top of atmosphere data (13 channels) but we want to use it for surface reflectance data (11 channels) or vice versa. In this case, we know which specific channels were removed, but would be satisfied with a solution that just naively drops the last 2 channels and rescales things.

@rwightman
Copy link
Collaborator

@adamjstewart k, I think it's pretty straightforward to support that with an extra arg that covers the 'base' or default channels. Below I added base_chans arg... if you set it to 13 you should get the behaviour described. The 'naive' dropping is the default. For dropping specific channels, an arg could be added with indices to drop.

Conceivable you could drop before or after the repeat say if you had 11 of 13 channels you wanted to keep, AND wanted to repeat to have 33 channel input... hmm.

def adapt_input_conv(in_chans: int, conv_weight: Tensor, base_chans: int = 3) -> Tensor:
    conv_type = conv_weight.dtype
    conv_weight = conv_weight.float()  # Some weights are in torch.half, ensure it's float for sum on CPU
    O, I, J, K = conv_weight.shape
    if in_chans == 1:
        if I > base_chans:
            assert conv_weight.shape[1] % base_chans == 0
            # For models with space2depth stems
            conv_weight = conv_weight.reshape(O, I // base_chans, base_chans, J, K)
            conv_weight = conv_weight.sum(dim=2, keepdim=False)
        else:
            conv_weight = conv_weight.sum(dim=1, keepdim=True)
    elif in_chans != base_chans:
        if I != base_chans:
            raise NotImplementedError('Weight format not supported by conversion.')
        else:
            # NOTE this strategy should be better than random init, but there could be other combinations of
            # the original RGB input layer weights that'd work better for specific cases.
            repeat = int(math.ceil(in_chans / base_chans))
            scale = base_chans / float(in_chans)
            if repeat > 1:
                conv_weight = conv_weight.repeat(1, repeat, 1, 1)
            conv_weight = conv_weight[:, :in_chans, :, :]  # drops last channels
            conv_weight *= scale
    conv_weight = conv_weight.to(conv_type)
    return conv_weight

@adamjstewart
Copy link
Contributor Author

Would it be simpler to use base_chans = I? I know your implementation better supports 13 -> 11 -> 33, but I think that specific use case will be uncommon. I'm mostly concerned about 13 -> 26 or 13 -> 11 or 11 -> 13.

@rwightman
Copy link
Collaborator

@adamjstewart removing the base_chans arg would require adding a space2depth multiplier arg to resolve that ambiguity to support monochrome use with tresnet models...

@adamjstewart
Copy link
Contributor Author

I don't know what most of those words mean so I'll trust you. However, your implementation doesn't seem to support anything other than in_chans == 1 or base_chans == I, which is why I'm wondering why we even need a base_chans parameter at all. If base_chans = I, then most if-statements and raises disappear.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants