From 853bb7957f519f35875000d1301d82cc4fb17584 Mon Sep 17 00:00:00 2001 From: Julien Lefaucheur Date: Wed, 2 Oct 2024 15:26:33 +0000 Subject: [PATCH] Remove debugging leftover --- src/ai_models/model.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/ai_models/model.py b/src/ai_models/model.py index 6d98d8b..e643b5a 100644 --- a/src/ai_models/model.py +++ b/src/ai_models/model.py @@ -43,7 +43,8 @@ def __exit__(self, *args): class ArchiveCollector: - UNIQUE = {"date", "hdate", "time", "referenceDate", "type", "stream", "expver"} + UNIQUE = {"date", "hdate", "time", + "referenceDate", "type", "stream", "expver"} def __init__(self) -> None: self.expect = 0 @@ -55,7 +56,8 @@ def add(self, field): self.request[k].add(str(v)) if k in self.UNIQUE: if len(self.request[k]) > 1: - raise ValueError(f"Field {field} has different values for {k}: {self.request[k]}") + raise ValueError( + f"Field {field} has different values for {k}: {self.request[k]}") class Model: @@ -160,7 +162,8 @@ def json_default(obj): raise TypeError print( - json.dumps(json_requests, separators=(",", ":"), default=json_default, sort_keys=True), + json.dumps(json_requests, separators=( + ",", ":"), default=json_default, sort_keys=True), file=f, ) @@ -170,7 +173,8 @@ def download_assets(self, **kwargs): if not os.path.exists(asset): os.makedirs(os.path.dirname(asset), exist_ok=True) LOG.info("Downloading %s", asset) - download(self.download_url.format(file=file), asset + ".download") + download(self.download_url.format( + file=file), asset + ".download") os.rename(asset + ".download", asset) @property @@ -443,7 +447,8 @@ def _requests(self): def filter_constant(request): # We check for 'sfc' because param 'z' can be ambiguous if request.get("levtype") == "sfc": - param = set(self.constant_fields) & set(request.get("param", [])) + param = set(self.constant_fields) & set( + request.get("param", [])) if param: request["param"] = list(param) return True @@ -454,7 +459,8 @@ def filter_prognostic(request): # TODO: We assume here that prognostic fields are # the ones that are not constant. This may not always be true if request.get("levtype") == "sfc": - param = set(request.get("param", [])) - set(self.constant_fields) + param = set(request.get("param", [])) - \ + set(self.constant_fields) if param: request["param"] = list(param) return True @@ -496,7 +502,8 @@ def peek_into_checkpoint(self, path): def parse_model_args(self, args): if args: - raise NotImplementedError(f"This model does not accept arguments {args}") + raise NotImplementedError( + f"This model does not accept arguments {args}") def provenance(self): from .provenance import gather_provenance_info @@ -542,8 +549,6 @@ def write_input_fields( if ignore is None: ignore = [] - fields.save("input.grib") - with self.timer("Writing step 0"): for field in fields: if field.metadata("shortName") in ignore: @@ -592,7 +597,8 @@ def write_input_fields( """ template = base64.b64decode(template) - accumulations_template = ekd.from_source("memory", template)[0] + accumulations_template = ekd.from_source( + "memory", template)[0] for param in accumulations: self.write(