Skip to content

Commit

Permalink
fixed model configs not thoroughly passed to the feature extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
keighrim committed Jul 22, 2024
1 parent 7be4b81 commit 371f600
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 19 deletions.
6 changes: 1 addition & 5 deletions modeling/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ def __init__(self, model_stem, logger_name=None):
model_checkpoint = f"{model_stem}.pt"
model_config = yaml.safe_load(open(model_config_file))
self.training_labels = train.pretraining_binned_label(model_config)
self.featurizer = data_loader.FeatureExtractor(
img_enc_name=model_config["img_enc_name"],
pos_enc_dim=model_config.get("pos_enc_dim", 0),
pos_length=model_config.get("pos_length", 0),
pos_unit=model_config.get("pos_unit", 0))
self.featurizer = data_loader.FeatureExtractor(**model_config)
label_count = len(FRAME_TYPES) + 1
if 'bins' in model_config:
label_count = len(model_config['bins'].keys()) + 1
Expand Down
10 changes: 4 additions & 6 deletions modeling/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,26 +56,25 @@ class FeatureExtractor(object):
on a weighed sum of the two.
"""
img_encoder: backbones.ExtractorModel
pos_enc_dim: int
pos_length: int
pos_unit: int
pos_abs_th_front: int
pos_abs_th_end: int
pos_vec_coeff: float
sinusoidal_embeddings: ClassVar[Dict[Tuple[int, int], torch.Tensor]] = {}

def __init__(self, img_enc_name: str,
pos_enc_dim: int = 512,
def __init__(self,
img_enc_name: str,
pos_length: int = 6000000,
pos_unit: int = 60000,
pos_abs_th_front: int = 3,
pos_abs_th_end: int = 10,
pos_vec_coeff: float = 0.5):
pos_vec_coeff: float = 0.5,
**kwargs): # to catch unexpected arguments
"""
Initializes the FeatureExtractor object.
:param img_enc_name: a name of backbone model (e.g. CNN) to use for image vector extraction
:param pos_enc_dim: dimension of positional embedding, when not given use 512
:param pos_length: "width" of positional encoding matrix, actual number of matrix columns is calculated by
pos_length / pos_unit (with default values, that is 100 minutes)
:param pos_unit: unit of positional encoding in milliseconds (e.g., 60000 for minutes, 1000 for seconds)
Expand All @@ -87,7 +86,6 @@ def __init__(self, img_enc_name: str,
raise ValueError("A image vector model must be specified")
else:
self.img_encoder: backbones.ExtractorModel = backbones.model_map[img_enc_name]()
self.pos_enc_dim = pos_enc_dim
self.pos_unit = pos_unit
self.pos_abs_th_front = pos_abs_th_front
self.pos_abs_th_end = pos_abs_th_end
Expand Down
12 changes: 5 additions & 7 deletions modeling/train.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import argparse
import copy
import json
import logging
import os
import platform
import shutil
import time
from pathlib import Path
import copy
from typing import Union

import numpy as np
Expand Down Expand Up @@ -119,12 +120,7 @@ def prepare_datasets(indir, train_guids, validation_guids, configs):
pre_bin_size = len(FRAME_TYPES) + 1
train_vimg = valid_vimg = 0

extractor = data_loader.FeatureExtractor(
img_enc_name=configs.get('img_enc_name'),
pos_unit=configs['pos_unit'] if configs and 'pos_unit' in configs else 3600000,
pos_enc_dim=configs['pos_enc_dim'] if 'pos_enc_dim' in configs else 512,
pos_length=configs.get('pos_length')
)
extractor = data_loader.FeatureExtractor(**config)

for j in Path(indir).glob('*.json'):
guid = j.with_suffix("").name
Expand Down Expand Up @@ -152,6 +148,8 @@ def prepare_datasets(indir, train_guids, validation_guids, configs):


def k_fold_train(indir, outdir, config_file, configs, train_id=time.strftime("%Y%m%d-%H%M%S")):
os.makedirs(outdir, exist_ok=True)

# need to implement "whitelist"?
guids = get_guids(indir)
configs = load_config(configs) if not isinstance(configs, dict) else configs
Expand Down
1 change: 0 additions & 1 deletion test/test_pos_enc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class TestPosAbsTh(unittest.TestCase):
def prep_extractor(th_front, th_end, cols=100):
extractor = data_loader.FeatureExtractor(
img_enc_name="mock_model_name",
pos_enc_dim=256,
pos_length=6000000,
pos_unit=60000,
pos_abs_th_front=th_front,
Expand Down

0 comments on commit 371f600

Please sign in to comment.