Skip to content

Unexpected Behavior when Replacing Layers in CustomModel built with Subclass API #19268

Open
@ariG23498

Description

@ariG23498

I have created a CustomModel using the subclass API

class CustomModel(keras.Model):
    def __init__(self, num_classes=10):
        super().__init__()
        self.num_classes = num_classes
        self.stem = keras.layers.Conv2D(32, 3, strides=2, padding='same', name="stem")
        self.head = keras.layers.Dense(num_classes, name="head")
    
    def call(self, inputs):
        x = self.stem(inputs)
        x = keras.layers.Flatten()(x)
        x = self.head(x)
        return x

I wanted to replace (read swap) a layer from the custom model with a new layer.

model = CustomModel()
model.head = keras.layers.Dense(100, name="head")

Now when I hit model.summary I see both the layers added to the Model.

model.summary()

Model: "custom_model_3"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                         ┃ Output ShapeParam # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ stem (Conv2D)                        │ ?                           │     0 (unbuilt) │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ head (Dense)                         │ ?                           │     0 (unbuilt) │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ head (Dense)                         │ ?                           │     0 (unbuilt) │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
 Total params: 0 (0.00 B)
 Trainable params: 0 (0.00 B)
 Non-trainable params: 0 (0.00 B)

When I take a similar approach with torch.nn.Module, the layers get swapped out.

class CustomModel(torch.nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.num_classes = num_classes
        self.stem = torch.nn.Conv2d(32, 32, 3, stride=2)
        self.head = torch.nn.Linear(32, num_classes)
    
    def forward(self, inputs):
        x = self.stem(inputs)
        x = keras.layers.Flatten()(x)
        x = self.head(x)
        return x

model = CustomModel()
model.head = torch.nn.Linear(32, 100)
print(model)
CustomModel(
  (stem): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2))
  (head): Linear(in_features=32, out_features=100, bias=True)
)

I know that I can use the setattr method to set the keras.Model attribute. But here I wanted to check whether this is a supposed workflow.

Any help on this would be grately appreciated.

My opinion: We should warn the user if they want to set the attribute of a model (mainly layers). I feel adding a layer to the model (which is not intuitive at all) introduces a silent bug.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions