diff --git a/paddleseg/models/backbones/vision_transformer.py b/paddleseg/models/backbones/vision_transformer.py index d9aa79fa98..57741898aa 100644 --- a/paddleseg/models/backbones/vision_transformer.py +++ b/paddleseg/models/backbones/vision_transformer.py @@ -103,7 +103,8 @@ def __init__(self, self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): - N, C = x.shape[1:] + x_shape = paddle.shape(x) + N, C = x_shape[1], x_shape[2] qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).transpose((2, 0, 3, 1, 4))