File tree Expand file tree Collapse file tree 7 files changed +8
-10
lines changed
torchmultimodal/models/flava Expand file tree Collapse file tree 7 files changed +8
-10
lines changed Original file line number Diff line number Diff line change @@ -60,7 +60,6 @@ def main():
6060 callbacks = [
6161 LearningRateMonitor (logging_interval = "step" ),
6262 ],
63- strategy = "ddp" ,
6463 )
6564 trainer .fit (model , datamodule = datamodule )
6665 trainer .validate (model , datamodule = datamodule )
Original file line number Diff line number Diff line change 88
99import torch
1010from pytorch_lightning import LightningModule
11- from torchmultimodal .models .flava import (
11+ from torchmultimodal .models .flava . flava_model import (
1212 flava_model_for_classification ,
1313 flava_model_for_pretraining ,
1414)
Original file line number Diff line number Diff line change 77import argparse
88
99import torch
10- from torchmultimodal .models .flava import flava_model_for_pretraining
10+ from torchmultimodal .models .flava . flava_model import flava_model_for_pretraining
1111
1212KEY_REPLACEMENTS = {
1313 "image_encoder.module" : "image_encoder" ,
Original file line number Diff line number Diff line change 99import torch
1010from test .test_utils import assert_expected
1111from torch import nn
12- from torchmultimodal .models .flava import (
12+ from torchmultimodal .models .flava . flava_model import (
1313 flava_image_encoder ,
1414 flava_model_for_classification ,
1515 flava_model_for_pretraining ,
Original file line number Diff line number Diff line change 88
99import torch
1010from test .test_utils import assert_expected , set_rng_seed
11- from torchmultimodal .models .flava import flava_image_encoder
1211from torchmultimodal .modules .layers .transformer import (
1312 FLAVASelfAttention ,
1413 FLAVATransformerEncoder ,
@@ -31,11 +30,6 @@ def test_flava_self_attention_value_error(self):
3130 with self .assertRaises (ValueError ):
3231 _ = FLAVASelfAttention (hidden_size = 3 , num_attention_heads = 2 )
3332
34- def test_flava_transformer_without_embeddings_value_error (self ):
35- with self .assertRaises (ValueError ):
36- encoder = flava_image_encoder ()
37- _ = encoder ()
38-
3933 def test_flava_encoder_forward (self ):
4034 output = self .encoder (self .test_input )
4135
Original file line number Diff line number Diff line change 1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
File renamed without changes.
You can’t perform that action at this time.
0 commit comments