Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Fix manual merge for ViT
Browse files Browse the repository at this point in the history
Summary: My manual cherrypick for https://www.internalfb.com/intern/pick/request/permalink/1290074624677854/ introduced errors which I'm fixing here.

Reviewed By: lauragustafson

Differential Revision: D24996390

fbshipit-source-id: de7b7227f055ce3a491ae55a2fe094afa6d424f6
  • Loading branch information
mannatsingh authored and facebook-github-bot committed Nov 16, 2020
1 parent 07d4ffa commit 8bc1903
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion classy_vision/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
"""

import copy
import logging
import math
from collections import OrderedDict
from functools import partial

import torch
import torch.nn as nn
from classy_vision.models import ClassyModel, register_model

from .lecun_normal_init import lecun_normal_init


Expand Down Expand Up @@ -143,7 +145,7 @@ def __init__(
classifier="token",
):
super().__init__()
assert image_size % patch_size == 0, "Input shape indivisble by patch size"
assert image_size % patch_size == 0, "Input shape indivisible by patch size"
assert classifier in ["token", "gap"], "Unexpected classifier mode"
self.image_size = image_size
self.patch_size = patch_size
Expand Down Expand Up @@ -178,6 +180,7 @@ def __init__(
)
self.trunk_output = nn.Identity()

self.seq_length = seq_length
self.init_weights()

def init_weights(self):
Expand Down

0 comments on commit 8bc1903

Please sign in to comment.