From e6662c46aa19e4540eccf42c83ecbf07cf9397be Mon Sep 17 00:00:00 2001 From: Marc Verhagen Date: Fri, 3 May 2024 14:51:50 -0400 Subject: [PATCH 1/2] Initial pass at the CLI script, includes retiring the classifier config file --- app.py | 44 +++++------------ cli.py | 86 ++++++++++++++++++++++++++++++++++ metadata.py | 28 +++++++---- modeling/__init__.py | 6 +++ modeling/config/classifier.yml | 32 ------------- modeling/stitch.py | 7 ++- 6 files changed, 127 insertions(+), 76 deletions(-) create mode 100644 cli.py diff --git a/app.py b/app.py index 6b86823..aa168cc 100644 --- a/app.py +++ b/app.py @@ -20,7 +20,6 @@ from modeling import classify, stitch, negative_label, FRAME_TYPES -default_config_fname = Path(__file__).parent / 'modeling/config/classifier.yml' default_model_storage = Path(__file__).parent / 'modeling/models' @@ -28,7 +27,6 @@ class SwtDetection(ClamsApp): def __init__(self, preconf_fname: str = None, log_to_file: bool = False) -> None: super().__init__() - self.preconf = yaml.safe_load(open(preconf_fname)) if log_to_file: fh = logging.FileHandler(f'{self.__class__.__name__}.log') fh.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) @@ -41,8 +39,7 @@ def _appmetadata(self): def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif: # parameters here is a "refined" dict, so hopefully its values are properly # validated and casted at this point. - self.parameters = parameters - self.configs = {**self.preconf, **parameters} + self.configs = parameters self._configure_model() self._configure_postbin() for k, v in self.configs.items(): @@ -76,7 +73,7 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif: return mmif def _configure_model(self): - model_name = self.parameters["modelName"] + model_name = self.configs["modelName"] self.configs['model_file'] = default_model_storage / f'{model_name}.pt' self.configs['model_config_file'] = default_model_storage / f'{model_name}.yml' @@ -101,29 +98,18 @@ def _configure_postbin(self): that the underscore is replaced with a colon. This is not good if we intend there to be a dash. """ - # TODO: this is ugly, but I do not know a better way yet. The default value - # of the map parameter in metadata.py is an empty list. If the user sets those - # parameters during invocation (for example "?map=S:slate&map=B:bar") then in - # the user parameters we have ['S:slate', 'B:bar'] for map and in the refined - # parameters we get {'S': 'slate', 'B': 'bar'}. If the user adds no map - # parameters then there is no map value in the user parameters and the value - # is [] in the refined parameters (which is a bit inconsistent). - # Two experiments: - # 1. What if I set the default to a list like ['S:slate', 'B:bar']? - # Then the map value in refined parameters is that same list, which means - # that I have to turn it into a dictionary before I hand it off. - # 2. What if I set the default to a dictionary like {'S': 'slate', 'B': 'bar'}? - # Then the map value in the refined parameters is a list with one element, - # which is the wanted dictionary as a string: ["{'S': 'slate', 'B': 'bar'}"] - if type(self.parameters['map']) is list: + if type(self.configs['map']) is list: + # This needs to be done because when the default for the map parameters is + # a non-empty list then it will end up in the refined parameters as a list + # if no map parameters were specified when the user invoked the app (whereas + # when the user specifies the map parameter then the map will be a dictionary + # after refinement). newmap = {} - for kv in self.parameters['map']: + for kv in self.configs['map']: k, v = kv.split(':') newmap[k] = v - self.parameters['map'] = newmap self.configs['map'] = newmap - postbin = invert_mappings(self.parameters['map']) - self.configs['postbin'] = postbin + self.configs['postbin'] = invert_mappings(self.configs['map']) def _extract_images(self, video): open_video(video) @@ -159,7 +145,7 @@ def _classify(self, extracted: list, positions: list, total_ms: int): def _new_view(self, annotation_types: list, video, labels: list, mmif): view: View = mmif.new_view() - self.sign_view(view, self.parameters) + self.sign_view(view, self.configs) for annotation_type in annotation_types: view.new_contain( annotation_type, document=video.id, timeUnit='milliseconds', labelset=labels) @@ -197,8 +183,6 @@ def _add_stitcher_results_to_view(self, timeframes: list, view: View): def invert_mappings(mappings: dict) -> dict: - print('-'*80) - print(mappings) inverted_mappings = {} for in_label, out_label in mappings.items(): in_label = restore_colon(in_label) @@ -207,7 +191,7 @@ def invert_mappings(mappings: dict) -> dict: def restore_colon(label_in: str) -> str: - """Replace a dash with a colon.""" + """Replace dashes with colons.""" return label_in.replace('-', ':') @@ -230,13 +214,11 @@ def transform(classification: dict, postbin: dict): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-c", "--config", help="The YAML config file", default=default_config_fname) parser.add_argument("--port", action="store", default="5000", help="set port to listen") parser.add_argument("--production", action="store_true", help="run gunicorn server") - parsed_args = parser.parse_args() - app = SwtDetection(preconf_fname=parsed_args.config, log_to_file=False) + app = SwtDetection(log_to_file=False) http_app = Restifier(app, port=int(parsed_args.port)) # for running the application in production mode diff --git a/cli.py b/cli.py new file mode 100644 index 0000000..9ef549b --- /dev/null +++ b/cli.py @@ -0,0 +1,86 @@ +import yaml +import pprint +import argparse +from pathlib import Path + +from metadata import appmetadata +from clams.app import ClamsApp +from app import SwtDetection + + + +default_config_fname = Path(__file__).parent / 'modeling/config/classifier.yml' + +json_type_map = { + "integer": int, + "number": float, + "string": str, + "boolean": bool, +} + +parameter_names = ( + 'metadata', 'map', 'minFrameCount', 'minFrameScore', 'minTimeframeScore', + 'modelName', 'pretty', 'sampleRate', 'startAt', 'stopAt', 'useStitcher') + + +def get_app(): + app = SwtDetection(preconf_fname=default_config_fname, log_to_file=False) + return app + + +def get_metadata(): + """Gets the metadata from the metadata.py filem, with the universal parameters added.""" + metadata = appmetadata() + for param in ClamsApp.universal_parameters: + metadata.add_parameter(**param) + return metadata + + +def create_argparser(metadata): + parser = argparse.ArgumentParser( + description=f"Command-Line Interface for {metadata.identifier}") + parser.add_argument( + "--metadata", + help="Return the apps metadata and exit", + action="store_true") + parser.add_argument("--input", help="The input MMIF file") + parser.add_argument("--output", help="The output MMIF file") + for parameter in metadata.parameters: + nargs = '*' if parameter.type == 'map' else '?' + parser.add_argument( + f"--{parameter.name}", + help=parameter.description, + nargs=nargs, + type=json_type_map.get(parameter.type), + choices=parameter.choices, + default=parameter.default, + action="store") + return parser + + +def print_parameters(metadata): + for parameter in metadata.parameters: + continue + print(f'\n{parameter.name}') + print(f' type={parameter.type}') + print(f' default={parameter.default}') + print(f' choices={parameter.choices}') + + +if __name__ == '__main__': + + app = get_app() + metadata = get_metadata() + + argparser = create_argparser(metadata) + args = argparser.parse_args() + + print(args) + print() + for arg in vars(args): + value = getattr(args, arg) + print(f'{arg:18s} {str(type(value)):15s} {value}') + + if args.metadata: + print(metadata.jsonify(pretty=args.pretty)) + diff --git a/metadata.py b/metadata.py index 6cafd9d..52c2d70 100644 --- a/metadata.py +++ b/metadata.py @@ -10,7 +10,7 @@ from clams.app import ClamsApp from clams.appmetadata import AppMetadata from mmif import DocumentTypes, AnnotationTypes -from app import default_model_storage, default_config_fname +from app import default_model_storage#, default_config_fname from modeling import FRAME_TYPES @@ -23,9 +23,19 @@ def appmetadata() -> AppMetadata: :return: AppMetadata object holding all necessary information. """ - preconf = yaml.safe_load(open(default_config_fname)) + available_models = default_model_storage.glob('*.pt') + # This was the most frequent label mapping from the now deprecated configuration file, + # which had default mappings for each model. + labelMap = [ + "B:bars", + "S:slate", "S-H:slate", "S-C:slate", "S-D:slate", "S-G:slate", + "W:other_opening", "L:other_opening", "O:other_opening", "M:other_opening", + "I:chyron", "N:chyron", "Y:chyron", + "C:credit", "R:credit", + "E:other_text", "K:other_text", "G:other_text", "T:other_text", "F:other_text" ] + metadata = AppMetadata( name="Scenes-with-text Detection", description="Detects scenes with text, like slates, chyrons and credits.", @@ -46,30 +56,30 @@ def appmetadata() -> AppMetadata: name='stopAt', type='integer', default=sys.maxsize, description='Number of milliseconds into the video to stop processing') metadata.add_parameter( - name='sampleRate', type='integer', default=preconf['sampleRate'], + name='sampleRate', type='integer', default=1000, description='Milliseconds between sampled frames') metadata.add_parameter( - name='minFrameScore', type='number', default=preconf['minFrameScore'], + name='minFrameScore', type='number', default=0.01, description='Minimum score for a still frame to be included in a TimeFrame') metadata.add_parameter( - name='minTimeframeScore', type='number', default=preconf['minTimeframeScore'], + name='minTimeframeScore', type='number', default=0.5, description='Minimum score for a TimeFrame') metadata.add_parameter( - name='minFrameCount', type='integer', default=preconf['minFrameCount'], + name='minFrameCount', type='integer', default=2, description='Minimum number of sampled frames required for a TimeFrame') metadata.add_parameter( name='modelName', type='string', - default=pathlib.Path(preconf['model_file']).stem, + default='20240409-091401.convnext_lg', choices=[m.stem for m in available_models], description='model name to use for classification') metadata.add_parameter( - name='useStitcher', type='boolean', default=preconf['useStitcher'], + name='useStitcher', type='boolean', default=True, description='Use the stitcher after classifying the TimePoints') metadata.add_parameter( # TODO: do we want to use the old default labelMap from the configuration here or # do we truly want an empty mapping and use the pass-through, as hinted at in the # description (which is now not in sync with the code). - name='map', type='map', default=preconf['labelMap'], + name='map', type='map', default=labelMap, description=( 'Mapping of a label in the input annotations to a new label. Must be formatted as ' 'IN_LABEL:OUT_LABEL (with a colon). To pass multiple mappings, use this parameter ' diff --git a/modeling/__init__.py b/modeling/__init__.py index 7491d72..41c04a0 100644 --- a/modeling/__init__.py +++ b/modeling/__init__.py @@ -1,5 +1,11 @@ negative_label = 'NEG' positive_label = 'POS' + # full typology from https://github.com/clamsproject/app-swt-detection/issues/1 FRAME_TYPES = ["B", "S", "S:H", "S:C", "S:D", "S:B", "S:G", "W", "L", "O", "M", "I", "N", "E", "P", "Y", "K", "G", "T", "F", "C", "R"] + +# These are time frames that are typically static (that is, the text does not +# move around or change as with rolling credits). These are frame names after +# the label mapping. +static_frames = ['bars', 'slate', 'chyron'] diff --git a/modeling/config/classifier.yml b/modeling/config/classifier.yml index c36a038..e69de29 100644 --- a/modeling/config/classifier.yml +++ b/modeling/config/classifier.yml @@ -1,32 +0,0 @@ -model_file: "modeling/models/20240409-091401.convnext_lg.pt" -model_config_file: "modeling/models/20240409-091401.convnext_lg.yml" - -# Milliseconds between sampled frames -sampleRate: 1000 - -# Minimum score for a frame to be included in a potential timeframe -minFrameScore: 0.01 - -# Minimum score for a timeframe to be selected -minTimeframeScore: 0.5 - -# Minimum number of sampled frames required for a timeframe to be included -minFrameCount: 2 - -# These are time frames that are typically static (that is, the text does not -# move around or change as with rolling credits). These are frame names after -# the label mapping. -staticFrames: [bars, slate, chyron] - -# Set to False to turn off the stitcher -useStitcher: True - -# This was the most frequent label mapping the previous configuration file, -# which had default mappings for each model. -labelMap: [ - "B:bars", - "S:slate", "S-H:slate", "S-C:slate", "S-D:slate", "S-G:slate", - "W:other_opening", "L:other_opening", "O:other_opening", "M:other_opening", - "I:chyron", "N:chyron", "Y:chyron", - "C:credit", "R:credit", - "E:other_text", "K:other_text", "G:other_text", "T:other_text", "F:other_text" ] diff --git a/modeling/stitch.py b/modeling/stitch.py index 4f3397a..9bb29af 100644 --- a/modeling/stitch.py +++ b/modeling/stitch.py @@ -13,7 +13,7 @@ import yaml -from modeling import train, negative_label +from modeling import train, negative_label, static_frames class Stitcher: @@ -25,7 +25,6 @@ def __init__(self, **config): self.min_frame_score = config.get("minFrameScore") self.min_timeframe_score = config.get("minTimeframeScore") self.min_frame_count = config.get("minFrameCount") - self.static_frames = self.config.get("staticFrames") self.model_label = train.pretraining_binned_label(self.model_config) self.stitch_label = config.get("postbin") self.debug = False @@ -120,7 +119,6 @@ def _score_for_label(self, label: str, prediction): class TimeFrame: def __init__(self, label: str, stitcher: Stitcher): - self.static_frames = stitcher.static_frames self.targets = [] self.label = label self.points = [] @@ -182,11 +180,12 @@ def set_representatives(self): couple of simple heuristics and the frame type.""" representatives = list(zip(self.points, self.scores)) timepoint, max_value = max(representatives, key=operator.itemgetter(1)) - if self.label in self.static_frames: + if self.label in static_frames: # for these just pick the one with the highest score self.representatives = [timepoint] else: # throw out the lower values + # TODO: this may throw out too many time points representatives = [(tp, val) for tp, val in representatives if val >= self.score] # pick every third frame, which corresponds roughly to one every five seconds # (expect when all below-average values bundled together at one end) From 43a3fbb05edc2e09a75ef0b1e728b04fae0a832f Mon Sep 17 00:00:00 2001 From: Marc Verhagen Date: Mon, 6 May 2024 12:25:51 -0400 Subject: [PATCH 2/2] Tweaking the CLI script --- app.py | 2 +- cli.py | 87 ++++++++++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 68 insertions(+), 21 deletions(-) diff --git a/app.py b/app.py index aa168cc..13680bb 100644 --- a/app.py +++ b/app.py @@ -38,7 +38,7 @@ def _appmetadata(self): def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif: # parameters here is a "refined" dict, so hopefully its values are properly - # validated and casted at this point. + # validated and casted at this point. self.configs = parameters self._configure_model() self._configure_postbin() diff --git a/cli.py b/cli.py index 9ef549b..30fae06 100644 --- a/cli.py +++ b/cli.py @@ -1,15 +1,32 @@ +"""cli.py + +Command Line Interface for the SWT app. + +This script can be called with the same arguments as the _annotate method on the app, +except that we add --input, --output and --metadata parameters. + +Example invocation: + +$ python cli.py \ + --modelName 20240409-093229.convnext_tiny + --input example-mmif-local.json + --output out.json + --map B:bars S:slate + --pretty true + +""" + + +import sys import yaml -import pprint import argparse -from pathlib import Path -from metadata import appmetadata +from mmif import Mmif from clams.app import ClamsApp -from app import SwtDetection - +from metadata import appmetadata +from app import SwtDetection -default_config_fname = Path(__file__).parent / 'modeling/config/classifier.yml' json_type_map = { "integer": int, @@ -23,13 +40,8 @@ 'modelName', 'pretty', 'sampleRate', 'startAt', 'stopAt', 'useStitcher') -def get_app(): - app = SwtDetection(preconf_fname=default_config_fname, log_to_file=False) - return app - - def get_metadata(): - """Gets the metadata from the metadata.py filem, with the universal parameters added.""" + """Gets the metadata from the metadata.py file, with the universal parameters added.""" metadata = appmetadata() for param in ClamsApp.universal_parameters: metadata.add_parameter(**param) @@ -67,20 +79,55 @@ def print_parameters(metadata): print(f' choices={parameter.choices}') +def print_args(args): + print(args) + print() + for arg in vars(args): + value = getattr(args, arg) + print(f'{arg:18s} {str(type(value)):15s} {value}') + + +def build_app_parameters(args): + parameters = {} + for arg in vars(args): + if arg in ('input', 'output', 'metadata'): + continue + value = getattr(args, arg) + parameters[arg] = value + return parameters + + +def adjust_parameters(parameters, args): + # Adding the empty directory makes the app code work, but it still won't be able + # to print the parameters as given by the user on the command line. So we loop + # over the arguments to populate the raw parameters dictionary. + parameters[ClamsApp._RAW_PARAMS_KEY] = {} + for arg in sys.argv[1:]: + if arg.startswith('--'): + argname = arg[2:] + argval = vars(args)[argname] + argval = argval if type(argval) is list else [str(argval)] + parameters[ClamsApp._RAW_PARAMS_KEY][argname] = argval + + + if __name__ == '__main__': - app = get_app() + app = SwtDetection() metadata = get_metadata() argparser = create_argparser(metadata) args = argparser.parse_args() - - print(args) - print() - for arg in vars(args): - value = getattr(args, arg) - print(f'{arg:18s} {str(type(value)):15s} {value}') if args.metadata: print(metadata.jsonify(pretty=args.pretty)) - + else: + parameters = build_app_parameters(args) + # Simply calling _annotate() breaks when we try to create the view and copy the + # parameters into it because the CLAMS code expects there to be raw parameters. + # So we first adjust the parameters to match what the CLAMS code expects. + adjust_parameters(parameters, args) + mmif = Mmif(open(args.input).read()) + app._annotate(mmif, **parameters) + with open(args.output, 'w') as fh: + fh.write(mmif.serialize(pretty=args.pretty))