Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor Baturin committed Dec 24, 2024
1 parent c9726f7 commit 20eac8a
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
6 changes: 3 additions & 3 deletions etna/libs/timesfm/timesfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def __post_init__(self):
torch.cuda.is_available() and self.backend == "gpu") else "cpu")
self._median_index = -1

def _set_horizon(self, horizon):
def _set_horizon(self, horizon): # changed: added to change horizon after initialization
self.horizon_len = horizon

def load_from_checkpoint(
Expand All @@ -216,10 +216,10 @@ def load_from_checkpoint(
"""Loads a checkpoint and compiles the decoder."""
checkpoint_path = checkpoint.path
repo_id = checkpoint.huggingface_repo_id
if not os.path.exists(checkpoint_path):
if not os.path.exists(checkpoint_path): # changed: make loading similar to chronos
checkpoint_path = path.join(snapshot_download(checkpoint_path, cache_dir=checkpoint.local_dir), "torch_model.ckpt")
self._model = ppd.PatchedTimeSeriesDecoder(self._model_config)
loaded_checkpoint = torch.load(checkpoint_path) # TODO weights_only=True
loaded_checkpoint = torch.load(checkpoint_path) # changed: remove weights_only=True due to attribute absence in low torch versions
logging.info("Loading checkpoint from %s", checkpoint_path)
self._model.load_state_dict(loaded_checkpoint)
logging.info("Sending checkpoint to device %s", f"{self._device}")
Expand Down
8 changes: 4 additions & 4 deletions etna/libs/timesfm/timesfm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def moving_average(arr, window_size):

def freq_map(freq: Optional[str]):
"""Returns the frequency map for the given frequency string."""
if freq is None:
if freq is None: # changed: added this case to handle int timestamps during forecasting with exogenous features
warnings.warn("Frequency is None. Mapping it to 0, that can be not optimal. Better to set it to known frequency")
return 0
freq = str.upper(freq)
Expand Down Expand Up @@ -764,7 +764,7 @@ def forecast_on_df(
uids = []
if num_jobs == 1:
if verbose:
logging.info("Processing dataframe with single process.")
logging.info("Processing dataframe with single process.") # changed: replace print
for key, group in df_sorted.groupby("unique_id"):
inp, uid = process_group(
key,
Expand All @@ -778,7 +778,7 @@ def forecast_on_df(
if num_jobs == -1:
num_jobs = multiprocessing.cpu_count()
if verbose:
logging.info("Processing dataframe with multiple processes.")
logging.info("Processing dataframe with multiple processes.") # changed: replace print
with multiprocessing.Pool(processes=num_jobs) as pool:
results = pool.starmap(
process_group,
Expand All @@ -787,7 +787,7 @@ def forecast_on_df(
)
new_inputs, uids = zip(*results)
if verbose:
logging.info("Finished preprocessing dataframe.")
logging.info("Finished preprocessing dataframe.") # changed: replace print
freq_inps = [freq_map(freq)] * len(new_inputs)
_, full_forecast = self.forecast(new_inputs,
freq=freq_inps,
Expand Down
2 changes: 1 addition & 1 deletion etna/models/nn/timesfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def forecast(
else:
if ts.freq is None:
raise NotImplementedError(
"Data with None frequency isn't currently implemented for forecasting without exogenous features."
"Forecasting misaligned data with freq=None without exogenous features isn't currently implemented."
)

target = ts.to_pandas(flatten=True, features=["target"]).dropna()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_models/test_nn/test_timesfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ def expected_ts_increasing_integers():


@pytest.mark.smoke
def test_chronos_url(tmp_path):
def test_url(tmp_path):
model_name = "timesfm-1.0-200m-pytorch.ckpt"
url = f"http://etna-github-prod.cdn-tinkoff.ru/timesfm/{model_name}"
_ = TimesFMModel(path_or_url=url, cache_dir=tmp_path)
assert os.path.exists(tmp_path / model_name)


@pytest.mark.smoke
def test_chronos_custom_cache_dir(tmp_path):
def test_cache_dir(tmp_path):
path_or_url = "google/timesfm-1.0-200m-pytorch"
model_name = path_or_url.split("/")[-1]
_ = TimesFMModel(path_or_url=path_or_url, cache_dir=tmp_path)
Expand All @@ -55,7 +55,7 @@ def test_context_size():


@pytest.mark.smoke
def test_chronos_get_model(example_tsds):
def test_get_model(example_tsds):
model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch")
assert isinstance(model.get_model(), TimesFmTorch)

Expand Down

0 comments on commit 20eac8a

Please sign in to comment.