Skip to content

Commit d85e4c0

Browse files
committed
Moving flava model to its own folder
ghstack-source-id: b6f4983 Pull Request resolved: #96
1 parent 3beffd9 commit d85e4c0

File tree

7 files changed

+8
-10
lines changed

7 files changed

+8
-10
lines changed

examples/flava/finetune.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def main():
6060
callbacks=[
6161
LearningRateMonitor(logging_interval="step"),
6262
],
63-
strategy="ddp",
6463
)
6564
trainer.fit(model, datamodule=datamodule)
6665
trainer.validate(model, datamodule=datamodule)

examples/flava/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch
1010
from pytorch_lightning import LightningModule
11-
from torchmultimodal.models.flava import (
11+
from torchmultimodal.models.flava.flava_model import (
1212
flava_model_for_classification,
1313
flava_model_for_pretraining,
1414
)

examples/flava/tools/convert_weights.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import argparse
88

99
import torch
10-
from torchmultimodal.models.flava import flava_model_for_pretraining
10+
from torchmultimodal.models.flava.flava_model import flava_model_for_pretraining
1111

1212
KEY_REPLACEMENTS = {
1313
"image_encoder.module": "image_encoder",

test/models/test_flava.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from test.test_utils import assert_expected
1111
from torch import nn
12-
from torchmultimodal.models.flava import (
12+
from torchmultimodal.models.flava.flava_model import (
1313
flava_image_encoder,
1414
flava_model_for_classification,
1515
flava_model_for_pretraining,

test/modules/layers/test_transformer.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import torch
1010
from test.test_utils import assert_expected, set_rng_seed
11-
from torchmultimodal.models.flava import flava_image_encoder
1211
from torchmultimodal.modules.layers.transformer import (
1312
FLAVASelfAttention,
1413
FLAVATransformerEncoder,
@@ -31,11 +30,6 @@ def test_flava_self_attention_value_error(self):
3130
with self.assertRaises(ValueError):
3231
_ = FLAVASelfAttention(hidden_size=3, num_attention_heads=2)
3332

34-
def test_flava_transformer_without_embeddings_value_error(self):
35-
with self.assertRaises(ValueError):
36-
encoder = flava_image_encoder()
37-
_ = encoder()
38-
3933
def test_flava_encoder_forward(self):
4034
output = self.encoder(self.test_input)
4135

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
File renamed without changes.

0 commit comments

Comments
 (0)