Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/flava/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def main():
callbacks=[
LearningRateMonitor(logging_interval="step"),
],
strategy="ddp",
)
trainer.fit(model, datamodule=datamodule)
trainer.validate(model, datamodule=datamodule)
Expand Down
2 changes: 1 addition & 1 deletion examples/flava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
from pytorch_lightning import LightningModule
from torchmultimodal.models.flava import (
from torchmultimodal.models.flava.flava_model import (
flava_model_for_classification,
flava_model_for_pretraining,
)
Expand Down
2 changes: 1 addition & 1 deletion examples/flava/tools/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import argparse

import torch
from torchmultimodal.models.flava import flava_model_for_pretraining
from torchmultimodal.models.flava.flava_model import flava_model_for_pretraining

KEY_REPLACEMENTS = {
"image_encoder.module": "image_encoder",
Expand Down
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace_packages = True
install_types = True

# TODO (T116951827): Remove after fixing FLAVA type check errors
exclude = models/flava.py|modules/losses/flava.py
exclude = models/flava/flava_model.py|modules/losses/flava.py

[mypy-PIL.*]
ignore_missing_imports = True
Expand Down
2 changes: 1 addition & 1 deletion test/models/test_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from test.test_utils import assert_expected
from torch import nn
from torchmultimodal.models.flava import (
from torchmultimodal.models.flava.flava_model import (
flava_image_encoder,
flava_model_for_classification,
flava_model_for_pretraining,
Expand Down
6 changes: 0 additions & 6 deletions test/modules/layers/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch
from test.test_utils import assert_expected, set_rng_seed
from torchmultimodal.models.flava import flava_image_encoder
from torchmultimodal.modules.layers.transformer import (
FLAVASelfAttention,
FLAVATransformerEncoder,
Expand All @@ -31,11 +30,6 @@ def test_flava_self_attention_value_error(self):
with self.assertRaises(ValueError):
_ = FLAVASelfAttention(hidden_size=3, num_attention_heads=2)

def test_flava_transformer_without_embeddings_value_error(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering, why remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its an unrelated test that got checked in by mistake while separating out transformers

with self.assertRaises(ValueError):
encoder = flava_image_encoder()
_ = encoder()

def test_flava_encoder_forward(self):
output = self.encoder(self.test_input)

Expand Down
5 changes: 5 additions & 0 deletions torchmultimodal/models/flava/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.