Skip to content

Commit

Permalink
more renaming of config keys and files
Browse files Browse the repository at this point in the history
  • Loading branch information
keighrim committed Nov 20, 2023
1 parent 3234ec6 commit 81e0db2
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 18 deletions.
11 changes: 6 additions & 5 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
from typing import Union

import yaml
from clams import ClamsApp, Restifier
from mmif import Mmif, View, AnnotationTypes, DocumentTypes

Expand All @@ -22,7 +23,7 @@ class SwtDetection(ClamsApp):

def __init__(self, configs):
super().__init__()
self.classifier = classify.Classifier(configs)
self.classifier = classify.Classifier(**configs)

def _appmetadata(self):
# see https://sdk.clams.ai/autodoc/clams.app.html#clams.app.ClamsApp._load_appmetadata
Expand Down Expand Up @@ -60,14 +61,14 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif:
if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", help="The YAML config file")
parser.add_argument("--port", action="store", default="5000", help="set port to listen" )
parser.add_argument("-c", "--config", help="The YAML config file", default='modeling/config/classifier.yaml')
parser.add_argument("--port", action="store", default="5000", help="set port to listen")
parser.add_argument("--production", action="store_true", help="run gunicorn server")

parsed_args = parser.parse_args()
CONFIGS = parsed_args.configs
classifier_configs = yaml.safe_load(parsed_args.configs)

app = SwtDetection(CONFIGS)
app = SwtDetection(classifier_configs)

http_app = Restifier(app, port=int(parsed_args.port))
# for running the application in production mode
Expand Down
7 changes: 0 additions & 7 deletions example-config.yml

This file was deleted.

33 changes: 33 additions & 0 deletions modeling/config/classifier.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
model_file: "modeling/models/20231102-165708.kfold_000.pt"
sample_rate: 1000
minimum_score: 1.01
score_mapping: {0.01: 0, 0.25: 1, 0.50: 2, 0.75: 3, 1.01: 4}

img_enc_name: "convnext_lg"
pos_enc_name: "sinusoidal-concat"
pos_enc_dim: 512
max_input_length: 5640000
position_unit: 60000

feature_dim: 2048
num_layers: 2
dropouts: 0.2

prebin:
- bars
- slate
- chyron
- text-not-chyron
- person-not-chyron
- credits
- other
labels: [ "slate", "chyron", "credit", "other" ]
postbin:
0: 3
1: 0
2: 1
3: 3
4: 3
5: 2
6: 3

File renamed without changes.
18 changes: 12 additions & 6 deletions modeling/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,24 @@ def get_sinusoidal_embeddings(self, n_pos, dim):
self.__class__.sinusoidal_embeddings[(n_pos, dim)] = matrix
return matrix

def get_img_vectors(self, img_vec):
img_vec = self.img_encoder.preprocess(img_vec)
def get_img_vector(self, raw_img, as_numpy=True):
img_vec = self.img_encoder.preprocess(raw_img)
img_vec = img_vec.unsqueeze(0)
if torch.cuda.is_available():
img_vec = img_vec.to('cuda')
self.img_encoder.model.to('cuda')
with torch.no_grad():
feature_vec = self.img_encoder.model(img_vec)
return feature_vec.cpu().numpy()
if as_numpy:
return feature_vec.cpu().numpy()
else:
return feature_vec.cpu()

def encode_position(self, cur_time, tot_time, img_vec):
pos = cur_time / tot_time
if isinstance(img_vec, np.ndarray):
img_vec = torch.from_numpy(img_vec)
img_vec = img_vec.squeeze(0)
if self.pos_encoder is None:
return img_vec
elif self.pos_encoder == 'fractional':
Expand All @@ -131,8 +137,8 @@ def feature_vector_dim(self):
elif self.pos_encoder == 'fractional':
return self.img_encoder.dim + 1

def get_full_feature_vectors(self, img_vec, cur_time, tot_time):
img_vecs = self.get_img_vectors(img_vec)
def get_full_feature_vectors(self, raw_img, cur_time, tot_time):
img_vecs = self.get_img_vector(raw_img, as_numpy=False)
return self.encode_position(cur_time, tot_time, img_vecs)


Expand Down Expand Up @@ -169,7 +175,7 @@ def process_video(self,
frame_metadata['duration'] = frame.total_time

for extractor in self.models:
frame_vecs[extractor.img_encoder.name].append(extractor.get_img_vectors(frame.image))
frame_vecs[extractor.img_encoder.name].append(extractor.get_img_vector(frame.image, as_numpy=True))
frame_dict = {k: v for k, v in frame.__dict__.items() if k != "image" and k != "guid" and k != "total_time"}
frame_dict['vec_idx'] = i
frame_metadata["frames"].append(frame_dict)
Expand Down

0 comments on commit 81e0db2

Please sign in to comment.