From 9f894e6d9351e3fdccd55db90c76758969f8e5bb Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Tue, 24 Dec 2024 00:18:18 +0300 Subject: [PATCH 01/20] add timesfm model --- README.md | 3 +- docs/source/api_reference/models.rst | 3 +- docs/source/installation.rst | 3 +- etna/libs/timesfm/__init__.py | 155 +++ etna/libs/timesfm/patched_decoder.py | 948 ++++++++++++++++++ etna/libs/timesfm/timesfm.py | 325 ++++++ etna/libs/timesfm/timesfm_base.py | 806 +++++++++++++++ etna/libs/timesfm/xreg_lib.py | 635 ++++++++++++ etna/models/nn/__init__.py | 3 + etna/models/nn/timesfm.py | 274 +++++ etna/settings.py | 21 + examples/202-NN_examples.ipynb | 190 +++- poetry.lock | 137 ++- pyproject.toml | 11 +- .../test_inference/test_forecast.py | 55 +- .../test_inference/test_predict.py | 10 + tests/test_models/test_nn/test_timesfm.py | 145 +++ 17 files changed, 3702 insertions(+), 22 deletions(-) create mode 100644 etna/libs/timesfm/__init__.py create mode 100644 etna/libs/timesfm/patched_decoder.py create mode 100644 etna/libs/timesfm/timesfm.py create mode 100644 etna/libs/timesfm/timesfm_base.py create mode 100644 etna/libs/timesfm/xreg_lib.py create mode 100644 etna/models/nn/timesfm.py create mode 100644 tests/test_models/test_nn/test_timesfm.py diff --git a/README.md b/README.md index baae3b461..64016e1ad 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,8 @@ Available user extensions are the following: * `auto`: adds AutoML functionality, * `statsforecast`: adds models from [statsforecast](https://nixtla.github.io/statsforecast/), * `classiciation`: adds time series classification functionality, -* `chronos`: adds Chronos-like pretrained models. +* `chronos`: adds Chronos-like pretrained models, +* `timesfm`: adds TimesFM pretrained models. Install extension: ```bash diff --git a/docs/source/api_reference/models.rst b/docs/source/api_reference/models.rst index daea65539..ef06daee4 100644 --- a/docs/source/api_reference/models.rst +++ b/docs/source/api_reference/models.rst @@ -122,4 +122,5 @@ Pretrained neural network models: :template: class.rst nn.ChronosModel - nn.ChronosBoltModel \ No newline at end of file + nn.ChronosBoltModel + nn.TimesFMModel \ No newline at end of file diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 86cd18123..90ac3d214 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -24,7 +24,8 @@ Available user extensions are the following: - ``auto``: adds AutoML functionality, - ``statsforecast``: adds models from `statsforecast `_, - ``classiciation``: adds time series classification functionality, -- ``chronos``: adds Chronos-like pretrained models. +- ``chronos``: adds Chronos-like pretrained models, +- ``timesfm``: adds TimesFM pretrained models. Install extension: diff --git a/etna/libs/timesfm/__init__.py b/etna/libs/timesfm/__init__.py new file mode 100644 index 000000000..2346aa8a2 --- /dev/null +++ b/etna/libs/timesfm/__init__.py @@ -0,0 +1,155 @@ +""" + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + 1. Definitions. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. +""" + + +from etna.libs.timesfm.timesfm import TimesFmTorch +from etna.libs.timesfm.timesfm_base import TimesFmHparams, TimesFmCheckpoint diff --git a/etna/libs/timesfm/patched_decoder.py b/etna/libs/timesfm/patched_decoder.py new file mode 100644 index 000000000..53cd5c6c1 --- /dev/null +++ b/etna/libs/timesfm/patched_decoder.py @@ -0,0 +1,948 @@ +""" + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + 1. Definitions. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. +""" + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: Copied from timesfm repository (https://github.com/google-research/timesfm/blob/154248137ccce29b01f4c3a765e85c3d9e4d92ba/src/timesfm/patched_decoder.py) + +"""Pytorch version of patched decoder.""" + +import dataclasses +import math +from typing import List, Tuple, Optional +import torch +from torch import nn +import torch.nn.functional as F + + +def _create_quantiles() -> List[float]: + return [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + + +@dataclasses.dataclass +class TimesFMConfig: + """Config for initializing timesfm patched_decoder class.""" + + # The number of blocks in the model. + num_layers: int = 20 + # The number of attention heads used in the attention layers of the model. + num_heads: int = 16 + # The number of key-value heads for implementing attention. + num_kv_heads: int = 16 + # The hidden size of the model. + hidden_size: int = 1280 + # The dimension of the MLP representations. + intermediate_size: int = 1280 + # The number of head dimensions. + head_dim: int = 80 + # The epsilon used by the rms normalization layers. + rms_norm_eps: float = 1e-6 + # Patch length + patch_len: int = 32 + # Horizon length + horizon_len: int = 128 + # quantiles + quantiles: List[float] = dataclasses.field(default_factory=_create_quantiles) + # Padding value + pad_val: float = 1123581321.0 + # Tolerance + tolerance: float = 1e-6 + # The dtype of the weights. + dtype: str = "bfloat32" + # use positional embedding + use_positional_embedding: bool = True + + +def _masked_mean_std( + inputs: torch.Tensor, + padding: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculates mean and standard deviation of `inputs` across axis 1. + + It excludes values where `padding` is 1. + + Args: + inputs: A PyTorch tensor of shape [b, n, p]. + padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. + + Returns: + A tuple containing the mean and standard deviation. + We return the statistics of the first patch with more than three non-padded + values. + """ + # Selecting the first patch with more than 3 unpadded values. + pad_sum = torch.sum(1 - padding, dim=2) + + def _get_patch_index(arr: torch.Tensor): + indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) + row_sum = (arr >= 3).to(torch.int32).sum(dim=1) + return torch.where(row_sum == 0, arr.shape[1] - 1, indices) + + patch_indices = _get_patch_index(pad_sum) + bidxs = torch.arange(inputs.shape[0]) + + arr = inputs[bidxs, patch_indices, :] + pad = padding[bidxs, patch_indices, :] + + # Create a mask where padding is 0 + mask = 1 - pad + + # Calculate the number of valid elements + num_valid_elements = torch.sum(mask, dim=1) + num_valid_elements = torch.where( + num_valid_elements == 0, + torch.tensor(1, + dtype=num_valid_elements.dtype, + device=num_valid_elements.device), + num_valid_elements, + ) + + # Calculate the masked sum and squared sum + masked_sum = torch.sum(arr * mask, dim=1) + masked_squared_sum = torch.sum((arr * mask)**2, dim=1) + + # Calculate the masked mean and standard deviation + masked_mean = masked_sum / num_valid_elements + masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 + masked_var = torch.where( + masked_var < 0.0, + torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device), + masked_var, + ) + masked_std = torch.sqrt(masked_var) + + return masked_mean, masked_std + + +def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: + """Shifts rows of seq based on the first 0 in each row of the mask. + + Args: + mask: mask tensor of shape [B, N] + seq: seq tensor of shape [B, N, P] + + Returns: + Returns the shifted sequence. + """ + batch_size, num_seq, feature_dim = seq.shape + + new_mask: torch.BoolTensor = mask == 0 + + # Use argmax to find the first True value in each row + indices = new_mask.to(torch.int32).argmax(dim=1) + + # Handle rows with all zeros + indices[~new_mask.any(dim=1)] = -1 + + # Create index ranges for each sequence in the batch + idx_range = (torch.arange(num_seq).to( + seq.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1, + feature_dim)) + + # Calculate shifted indices for each element in each sequence + shifted_idx = (idx_range - indices[:, None, None]) % num_seq + + # Gather values from seq using shifted indices + shifted_seq = seq.gather(1, shifted_idx) + + return shifted_seq + + +def get_large_negative_number(dtype: torch.dtype) -> torch.Tensor: + """Returns a large negative value for the given dtype.""" + if dtype.is_floating_point: + dtype_max = torch.finfo(dtype).max + else: + dtype_max = torch.iinfo(dtype).max + return torch.tensor(-0.7 * dtype_max, dtype=dtype) + + +def apply_mask_to_logits(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Applies a floating-point mask to a set of logits. + + Args: + logits: A torch.Tensor of logit values. + mask: A torch.Tensor (float32) of mask values with the encoding described + in the function documentation. + + Returns: + Masked logits. + """ + + min_value = get_large_negative_number(logits.dtype) + + return torch.where((mask >= min_value * 0.5), logits, min_value) + + +def convert_paddings_to_mask(paddings: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Converts binary paddings to a logit mask ready to add to attention matrix. + + Args: + paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding + token. + dtype: data type of the input. + + Returns: + A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits. + """ + attention_mask = paddings.detach().clone() + attention_mask = attention_mask[:, None, None, :] # Equivalent to jnp.newaxis + attention_mask *= get_large_negative_number(dtype) + return attention_mask + + +def causal_mask(input_t: torch.Tensor) -> torch.Tensor: + """Computes and returns causal mask. + + Args: + input_t: A torch.Tensor of shape [B, T, D]. + + Returns: + An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has + already been converted to large negative values. + """ + assert input_t.dtype.is_floating_point, input_t.dtype + large_negative_number = get_large_negative_number(input_t.dtype) + t = input_t.shape[1] + col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1) + row_idx = torch.arange(t).unsqueeze(1).repeat(1, t) + mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number + return mask.unsqueeze(0).unsqueeze(0).to(input_t.device) # Equivalent to jnp.newaxis + + +def merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Merges 2 masks. + + logscale mask is expected but 0/1 mask is also fine. + + Args: + a: torch.Tensor of shape [1|B, 1, 1|T, S]. + b: torch.Tensor of shape [1|B, 1, 1|T, S]. + + Returns: + torch.Tensor of shape [1|B, 1, 1|T, S]. + """ + + def expand_t(key_mask): + query_mask = key_mask.transpose(-1, -2) # Equivalent of jnp.transpose + return torch.minimum(query_mask, key_mask) + + if a.shape[2] != b.shape[2]: + if a.shape[2] == 1: + a = expand_t(a) + else: + assert b.shape[2] == 1 + b = expand_t(b) + + assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}." + return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum + + +class ResidualBlock(nn.Module): + """TimesFM residual block.""" + + def __init__( + self, + input_dims, + hidden_dims, + output_dims, + ): + super(ResidualBlock, self).__init__() + self.input_dims = input_dims + self.hidden_dims = hidden_dims + self.output_dims = output_dims + + # Hidden Layer + self.hidden_layer = nn.Sequential( + nn.Linear(input_dims, hidden_dims), + nn.SiLU(), + ) + + # Output Layer + self.output_layer = nn.Linear(hidden_dims, output_dims) + # Residual Layer + self.residual_layer = nn.Linear(input_dims, output_dims) + + def forward(self, x): + hidden = self.hidden_layer(x) + output = self.output_layer(hidden) + residual = self.residual_layer(x) + return output + residual + + +class RMSNorm(torch.nn.Module): + """Pax rms norm in pytorch.""" + + def __init__( + self, + dim: int, + eps: float = 1e-6, + add_unit_offset: bool = False, + ): + super().__init__() + self.eps = eps + self.add_unit_offset = add_unit_offset + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + if self.add_unit_offset: + output = output * (1 + self.weight.float()) + else: + output = output * self.weight.float() + return output.type_as(x) + + +class TransformerMLP(nn.Module): + """Pax transformer MLP in pytorch.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size) + self.down_proj = nn.Linear(intermediate_size, hidden_size) + self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) + + def forward(self, x, paddings=None): + gate_inp = self.layer_norm(x) + gate = self.gate_proj(gate_inp) + gate = F.relu(gate) + outputs = self.down_proj(gate) + if paddings is not None: + outputs = outputs * (1.0 - paddings[:, :, None]) + return outputs + x + + +class TimesFMAttention(nn.Module): + """Implements the attention used in TimesFM.""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + ): + super().__init__() + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.hidden_size = hidden_size + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = nn.Parameter( + torch.empty((self.head_dim,), dtype=torch.float32),) + + self.qkv_proj = nn.Linear( + self.hidden_size, + (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) + + def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor: + # [batch_size, n_local_heads, input_len, head_dim] + r_softplus_0 = 1.442695041 + softplus_func = torch.nn.Softplus() + scale = r_softplus_0 / math.sqrt(self.head_dim) + scale = scale * softplus_func(self.scaling) + return query * scale[None, None, None, :] + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor, + kv_write_indices: Optional[torch.Tensor] = None, + kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + hidden_states_shape = hidden_states.shape + assert len(hidden_states_shape) == 3 + + batch_size, input_len, _ = hidden_states_shape + + qkv = self.qkv_proj(hidden_states) + xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) + xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) + xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) + xq = self._per_dim_scaling(xq) + + # Write new kv cache. + # [batch_size, input_len, n_local_kv_heads, head_dim] + if kv_cache is not None and kv_write_indices is not None: + k_cache, v_cache = kv_cache + k_cache.index_copy_(1, kv_write_indices, xk) + v_cache.index_copy_(1, kv_write_indices, xv) + + key = k_cache + value = v_cache + else: + key = xk + value = xv + if self.num_kv_heads != self.num_heads: + # [batch_size, max_seq_len, n_local_heads, head_dim] + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) + value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) + + # [batch_size, n_local_heads, input_len, head_dim] + q = xq.transpose(1, 2) + # [batch_size, n_local_heads, max_seq_len, head_dim] + k = key.transpose(1, 2) + v = value.transpose(1, 2) + + # [batch_size, n_local_heads, input_len, max_seq_len] + scores = torch.matmul(q, k.transpose(2, 3)) + scores = scores + mask + scores = F.softmax(scores.float(), dim=-1).type_as(q) + + # [batch_size, n_local_heads, input_len, head_dim] + output = torch.matmul(scores, v) + # return scores, output.transpose(1, 2).contiguous() + + # [batch_size, input_len, hidden_dim] + output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) + output = self.o_proj(output) + return scores, output + + +class TimesFMDecoderLayer(nn.Module): + """Transformer layer.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + rms_norm_eps: float = 1e-6, + ): + super().__init__() + self.self_attn = TimesFMAttention( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + ) + self.mlp = TransformerMLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + ) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor, + paddings: torch.Tensor, + kv_write_indices: Optional[torch.Tensor] = None, + kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + scores, hidden_states = self.self_attn( + hidden_states=hidden_states, + mask=mask, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + ) + hidden_states = residual + hidden_states + + # MLP + hidden_states = self.mlp(hidden_states, paddings=paddings) + + return scores, hidden_states + + +class StackedDecoder(nn.Module): + """Stacked transformer layer.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + num_layers: int, + rms_norm_eps: float = 1e-6, + ): + super().__init__() + + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.layers.append( + TimesFMDecoderLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + )) + + def forward( + self, + hidden_states: torch.Tensor, + paddings: torch.Tensor, + kv_write_indices: Optional[torch.Tensor] = None, + kv_caches: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, + ) -> torch.Tensor: + padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype) + atten_mask = causal_mask(hidden_states) + mask = merge_masks(padding_mask, atten_mask) + for i in range(len(self.layers)): + layer = self.layers[i] + kv_cache = kv_caches[i] if kv_caches is not None else None + _, hidden_states = layer( + hidden_states=hidden_states, + mask=mask, + paddings=paddings, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + ) + return hidden_states + + +class PositionalEmbedding(torch.nn.Module): + """Generates position embedding for a given 1-d sequence. + + Attributes: + min_timescale: Start of the geometric index. Determines the periodicity of + the added signal. + max_timescale: End of the geometric index. Determines the frequency of the + added signal. + embedding_dims: Dimension of the embedding to be generated. + """ + + def __init__( + self, + embedding_dims: int, + min_timescale: int = 1, + max_timescale: int = 10_000, + ) -> None: + super().__init__() + self.min_timescale = min_timescale + self.max_timescale = max_timescale + self.embedding_dims = embedding_dims + + def forward(self, seq_length=None, position=None): + """Generates a Tensor of sinusoids with different frequencies. + + Args: + seq_length: an optional Python int defining the output sequence length. + if the `position` argument is specified. + position: [B, seq_length], optional position for each token in the + sequence, only required when the sequence is packed. + + Returns: + [B, seqlen, D] if `position` is specified, else [1, seqlen, D] + """ + if position is None: + assert seq_length is not None + # [1, seqlen] + position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) + else: + assert position.ndim == 2, position.shape + + num_timescales = self.embedding_dims // 2 + log_timescale_increment = math.log( + float(self.max_timescale) / float(self.min_timescale)) / max( + num_timescales - 1, 1) + inv_timescales = self.min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float32) * + -log_timescale_increment) + scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze( + 0) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + # Padding to ensure correct embedding dimension + signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) + return signal + + +class PatchedTimeSeriesDecoder(nn.Module): + """Patched time-series decoder.""" + + def __init__(self, config: TimesFMConfig): + super().__init__() + self.config = config + self.input_ff_layer = ResidualBlock( + input_dims=2 * config.patch_len, + output_dims=config.hidden_size, + hidden_dims=config.intermediate_size, + ) + self.freq_emb = nn.Embedding(num_embeddings=3, + embedding_dim=config.hidden_size) + self.horizon_ff_layer = ResidualBlock( + input_dims=config.hidden_size, + output_dims=config.horizon_len * (1 + len(config.quantiles)), + hidden_dims=config.intermediate_size, + ) + self.stacked_transformer = StackedDecoder( + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size, + num_heads=self.config.num_heads, + num_kv_heads=self.config.num_kv_heads, + head_dim=self.config.head_dim, + num_layers=self.config.num_layers, + rms_norm_eps=self.config.rms_norm_eps, + ) + if self.config.use_positional_embedding: + self.position_emb = PositionalEmbedding(self.config.hidden_size) + + def _forward_transform( + self, inputs: torch.Tensor, patched_pads: torch.Tensor + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Input is of shape [B, N, P].""" + mu, sigma = _masked_mean_std(inputs, patched_pads) + sigma = torch.where( + sigma < self.config.tolerance, + torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), + sigma, + ) + + # Normalize each patch + outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] + outputs = torch.where( + torch.abs(inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(self.config.pad_val, + dtype=outputs.dtype, + device=outputs.device), + outputs, + ) + return outputs, (mu, sigma) + + def _reverse_transform(self, outputs: torch.Tensor, stats: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + """Output is of shape [B, N, P, Q].""" + mu, sigma = stats + return outputs * sigma[:, None, None, None] + mu[:, None, None, None] + + def _preprocess_input( + self, + input_ts: torch.Tensor, + input_padding: torch.Tensor, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + Optional[Tuple[torch.Tensor, torch.Tensor]], + torch.Tensor, + ]: + """Preprocess input for stacked transformer.""" + + # Reshape into patches (using view for efficiency) + bsize = input_ts.shape[0] + patched_inputs = input_ts.view(bsize, -1, self.config.patch_len) + patched_pads = input_padding.view(bsize, -1, self.config.patch_len) + + patched_inputs = torch.where( + torch.abs(patched_pads - 1.0) < self.config.tolerance, + torch.tensor(0.0, + dtype=patched_inputs.dtype, + device=patched_inputs.device), + patched_inputs, + ) + patched_pads = torch.where( + torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), + patched_pads, + ) + patched_inputs, stats = self._forward_transform(patched_inputs, + patched_pads) + + # B x N x D + patched_inputs = patched_inputs * (1.0 - patched_pads) + concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) + model_input = self.input_ff_layer(concat_inputs) + + # A patch should not be padded even if there is at least one zero. + patched_padding = torch.min(patched_pads, + dim=-1)[0] # Get the values from the min result + if self.config.use_positional_embedding: + pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device) + pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0) + pos_emb = _shift_padded_seq(patched_padding, pos_emb) + model_input += pos_emb + + return model_input, patched_padding, stats, patched_inputs + + def _postprocess_output( + self, + model_output: torch.Tensor, + num_outputs: int, + stats: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + """Postprocess output of stacked transformer.""" + + # B x N x (H.Q) + output_ts = self.horizon_ff_layer(model_output) + + # Reshape using view + b, n, _ = output_ts.shape + output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs) + + return self._reverse_transform(output_ts, stats) + + def forward( + self, + input_ts: torch.Tensor, + input_padding: torch.LongTensor, + freq: torch.Tensor, + ) -> torch.Tensor: + num_outputs = len(self.config.quantiles) + 1 + model_input, patched_padding, stats, _ = self._preprocess_input( + input_ts=input_ts, + input_padding=input_padding, + ) + f_emb = self.freq_emb(freq) # B x 1 x D + model_input += f_emb + model_output = self.stacked_transformer(model_input, patched_padding) + + output_ts = self._postprocess_output(model_output, num_outputs, stats) + return output_ts + + def decode( + self, + input_ts: torch.Tensor, + paddings: torch.Tensor, + freq: torch.LongTensor, + horizon_len: int, + output_patch_len: Optional[int] = None, + max_len: int = 512, + return_forecast_on_context: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Auto-regressive decoding without caching. + + Args: + input_ts: input time-series and paddings. Time-series shape B x C. + paddings: padding shape B x (C + H) where H is the prediction length. + freq: frequency shape B x 1 + horizon_len: prediction length. + output_patch_len: output length to be fetched from one step of + auto-regressive decoding. + max_len: maximum training context length. + return_forecast_on_context: whether to return the model forecast on the + context except the first input patch. + + Returns: + Tuple of two forecasting results: + - Point (mean) output predictions as a tensor with shape B x H'. + - Full predictions (mean and quantiles) as a tensor with shape + B x H' x (1 + # quantiles). + In particular, if return_forecast_on_context is True, H' is H plus + the forecastable context length, i.e. context_len - (first) patch_len. + """ + final_out = input_ts + context_len = final_out.shape[1] + full_outputs = [] + if paddings.shape[1] != final_out.shape[1] + horizon_len: + raise ValueError( + "Length of paddings must match length of input + horizon_len:" + f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}") + if output_patch_len is None: + output_patch_len = self.config.horizon_len + num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len + for step_index in range(num_decode_patches): + current_padding = paddings[:, 0:final_out.shape[1]] + input_ts = final_out[:, -max_len:] + input_padding = current_padding[:, -max_len:] + fprop_outputs = self(input_ts, input_padding, freq) + if return_forecast_on_context and step_index == 0: + # For the first decodings step, collect the model forecast on the + # context except the unavailable first input batch forecast. + new_full_ts = fprop_outputs[:, :-1, :self.config.patch_len, :] + new_full_ts = fprop_outputs.view(new_full_ts.size(0), -1, + new_full_ts.size(3)) + + full_outputs.append(new_full_ts) + + # (full batch, last patch, output_patch_len, index of mean forecast = 0) + new_ts = fprop_outputs[:, -1, :output_patch_len, 0] + new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] + # (full batch, last patch, output_patch_len, all output indices) + full_outputs.append(new_full_ts) + final_out = torch.concat([final_out, new_ts], dim=-1) # TODO torch.concatenate(axis) => torch.concat(dim) + + if return_forecast_on_context: + # `full_outputs` indexing starts at after the first input patch. + full_outputs = torch.concat( # TODO torch.concatenate(axis) => torch.concat(dim) + full_outputs, + dim=1)[:, :(context_len - self.config.patch_len + horizon_len), :] + else: + # `full_outputs` indexing starts at the forecast horizon. + full_outputs = torch.concat(full_outputs, dim=1)[:, 0:horizon_len, :] # TODO torch.concatenate(axis) => torch.concat(dim) + + return full_outputs[:, :, 0], full_outputs diff --git a/etna/libs/timesfm/timesfm.py b/etna/libs/timesfm/timesfm.py new file mode 100644 index 000000000..f707879fe --- /dev/null +++ b/etna/libs/timesfm/timesfm.py @@ -0,0 +1,325 @@ +""" + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + 1. Definitions. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. +""" + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: Copied from timesfm repository (https://github.com/google-research/timesfm/blob/154248137ccce29b01f4c3a765e85c3d9e4d92ba/src/timesfm/timesfm_torch.py) +# Add method to change horizon after initialization. +# Minor logic change of loading model. + +"""TimesFM pytorch forecast API for inference.""" + +import logging +from os import path +from typing import Any, Sequence, Optional, Tuple +import os +import numpy as np +import torch +from huggingface_hub import snapshot_download +from etna.libs.timesfm import timesfm_base + +from etna.libs.timesfm import patched_decoder as ppd + +_TOL = 1e-6 + + +class TimesFmTorch(timesfm_base.TimesFmBase): + """TimesFM forecast API for inference.""" + + def __post_init__(self): + self._model_config = ppd.TimesFMConfig( + num_layers=self.num_layers, + num_heads=self.num_heads, + hidden_size=self.model_dims, + intermediate_size=self.model_dims, + patch_len=self.input_patch_len, + horizon_len=self.output_patch_len, + head_dim=self.model_dims // self.num_heads, + quantiles=self.quantiles, + use_positional_embedding=self.use_pos_emb, + ) + self._model = None + self.num_cores = 1 + self.global_batch_size = self.per_core_batch_size + self._device = torch.device("cuda:0" if ( + torch.cuda.is_available() and self.backend == "gpu") else "cpu") + self._median_index = -1 + + def _set_horizon(self, horizon): + self.horizon_len = horizon + + def load_from_checkpoint( + self, + checkpoint: timesfm_base.TimesFmCheckpoint, + ) -> None: + """Loads a checkpoint and compiles the decoder.""" + checkpoint_path = checkpoint.path + repo_id = checkpoint.huggingface_repo_id + if not os.path.exists(checkpoint_path): + checkpoint_path = path.join(snapshot_download(checkpoint_path, cache_dir=checkpoint.local_dir), "torch_model.ckpt") + self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) + loaded_checkpoint = torch.load(checkpoint_path) # TODO weights_only=True + logging.info("Loading checkpoint from %s", checkpoint_path) + self._model.load_state_dict(loaded_checkpoint) + logging.info("Sending checkpoint to device %s", f"{self._device}") + self._model.to(self._device) + self._model.eval() + # TODO: add compilation. + + def _forecast( + self, + inputs: Sequence[Any], + freq: Optional[Sequence[int]] = None, + window_size: Optional[int] = None, + forecast_context_len: Optional[int] = None, + return_forecast_on_context: bool = False, + ) -> Tuple[np.ndarray, np.ndarray]: + """Forecasts on a list of time series. + + Args: + inputs: list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + return_forecast_on_context: True to return the forecast on the context + when available, i.e. after the first input patch. + + Returns: + A tuple for JTensors: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). + + Raises: + ValueError: If the checkpoint is not properly loaded. + """ + if not self._model: + raise ValueError( + "Checkpoint not loaded. Call `load_from_checkpoint` before" + " `forecast`.") + if forecast_context_len is None: + fcontext_len = self.context_len + else: + fcontext_len = forecast_context_len + inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] + + if window_size is not None: + new_inputs = [] + for ts in inputs: + new_inputs.extend(timesfm_base.moving_average(ts, window_size)) + inputs = new_inputs + + if freq is None: + logging.info("No frequency provided via `freq`. Default to high (0).") + freq = [0] * len(inputs) + + input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) + with torch.no_grad(): + mean_outputs = [] + full_outputs = [] + assert input_ts.shape[0] % self.global_batch_size == 0 + for i in range(input_ts.shape[0] // self.global_batch_size): + input_ts_in = torch.from_numpy( + np.array(input_ts[i * self.global_batch_size:(i + 1) * + self.global_batch_size], + dtype=np.float32)).to(self._device) + input_padding_in = torch.from_numpy( + np.array(input_padding[i * self.global_batch_size:(i + 1) * + self.global_batch_size], + dtype=np.float32)).to(self._device) + inp_freq_in = torch.from_numpy( + np.array(inp_freq[ + i * self.global_batch_size:(i + 1) * self.global_batch_size, + :, + ], + dtype=np.int32)).long().to(self._device) + mean_output, full_output = self._model.decode( + input_ts=input_ts_in, + paddings=input_padding_in, + freq=inp_freq_in, + horizon_len=self.horizon_len, + return_forecast_on_context=return_forecast_on_context, + ) + mean_output = mean_output.detach().cpu().numpy() + full_output = full_output.detach().cpu().numpy() + mean_output = np.array(mean_output) + full_output = np.array(full_output) + mean_outputs.append(mean_output) + full_outputs.append(full_output) + + mean_outputs = np.concatenate(mean_outputs, axis=0) + full_outputs = np.concatenate(full_outputs, axis=0) + + if pmap_pad > 0: + mean_outputs = mean_outputs[:-pmap_pad, ...] + full_outputs = full_outputs[:-pmap_pad, ...] + + if window_size is not None: + mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] + full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] + return mean_outputs, full_outputs \ No newline at end of file diff --git a/etna/libs/timesfm/timesfm_base.py b/etna/libs/timesfm/timesfm_base.py new file mode 100644 index 000000000..400cc6bfe --- /dev/null +++ b/etna/libs/timesfm/timesfm_base.py @@ -0,0 +1,806 @@ +""" + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + 1. Definitions. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. +""" + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: Copied from timesfm repository (https://github.com/google-research/timesfm/blob/154248137ccce29b01f4c3a765e85c3d9e4d92ba/src/timesfm/timesfm_base.py) + +"""Base class for TimesFM inference. This will be common to PAX and Pytorch.""" + +import collections +import dataclasses +import logging +import multiprocessing +from typing import Any, Literal, Sequence, Optional, Tuple, List, Dict + +import numpy as np +import pandas as pd + +from utilsforecast.processing import make_future_dataframe + +from etna.libs.timesfm import xreg_lib + +Category = xreg_lib.Category +XRegMode = xreg_lib.XRegMode + +_TOL = 1e-6 +DEFAULT_QUANTILES = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9) + + +def process_group(key, group, value_name, forecast_context_len): + group = group.tail(forecast_context_len) + return np.array(group[value_name], dtype=np.float32), key + + +def moving_average(arr, window_size): + """Calculates the moving average using NumPy's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = np.pad(arr, (window_size - 1, 0), "constant") + smoothed_arr = (np.convolve(arr_padded, np.ones(window_size), "valid") / + window_size) + return [smoothed_arr, arr - smoothed_arr] + + +def freq_map(freq: str): + """Returns the frequency map for the given frequency string.""" + freq = str.upper(freq) + if (freq.endswith("H") or freq.endswith("T") or freq.endswith("MIN") or + freq.endswith("D") or freq.endswith("B") or freq.endswith("U")): + return 0 + elif freq.endswith(("W", "M", "MS")): + return 1 + elif freq.endswith("Y") or freq.endswith("Q"): + return 2 + else: + raise ValueError(f"Invalid frequency: {freq}") + + +# Per time series normalization: forward. +def _normalize(batch): + stats = [ + (np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch + ] + new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)] + return new_batch, stats + + +# Per time series normalization: inverse. +def _renormalize(batch, stats): + return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)] + + +@dataclasses.dataclass() +class TimesFmHparams: + """Hparams used to initialize a TimesFM model for inference. + + These are the sufficient subset of hparams to configure TimesFM inference + agnostic to the checkpoint version, and are not necessarily the same as the + hparams used to train the checkpoint. + + Attributes: + context_len: Largest context length the model allows for each decode call. + This technically can be any large, but practically should set to the + context length the checkpoint was trained with. + horizon_len: Forecast horizon. + input_patch_len: Input patch len. + output_patch_len: Output patch len. How many timepoints is taken from a + single step of autoregressive decoding. Can be set as the training horizon + of the checkpoint. + num_layers: Number of transformer layers in the model. + model_dims: Model dimension. + per_core_batch_size: Batch size on each core for data parallelism. + backend: One of "cpu", "gpu" or "tpu". + quantiles: Which quantiles are output by the model. + """ + + context_len: int = 512 + horizon_len: int = 128 + input_patch_len: int = 32 + output_patch_len: int = 128 + num_layers: int = 20 + num_heads: int = 16 + model_dims: int = 1280 + per_core_batch_size: int = 32 + backend: Literal["cpu", "gpu", "tpu"] = "cpu" + quantiles: Optional[Sequence[float]] = DEFAULT_QUANTILES + use_positional_embedding: bool = True + # Hparams beyond the model. + point_forecast_mode: Literal["mean", "median"] = "median" + + +@dataclasses.dataclass() +class TimesFmCheckpoint: + """Checkpoint used to initialize a TimesFM model for inference. + + Attributes: + version: Version of the checkpoint, e.g. "jax", "torch", "tensorflow", etc. + The factory will create the corresponding TimesFm inference class based on + this version. + path: Path to the checkpoint. + type: If provided, type of the checkpoint used by the specific checkpoint + loader per version. + step: If provided, step of the checkpoint. + """ + + version: str = "jax" + path: Optional[str] = None + huggingface_repo_id: Optional[str] = None + type: Any = None + step: Optional[int] = None + local_dir: Optional[str] = None + + +class TimesFmBase: + """Base TimesFM forecast API for inference. + + This class is the scaffolding for calling TimesFM forecast. To properly use: + 1. Create an instance with the correct hyperparameters of a TimesFM model. + 2. Call `load_from_checkpoint` to load a compatible checkpoint. + 3. Call `forecast` for inference. + """ + + def _logging(self, s): + print(s) + + def __post_init__(self) -> None: + """Additional initialization for subclasses before checkpoint loading.""" + pass + + def __init__(self, hparams: TimesFmHparams, + checkpoint: TimesFmCheckpoint) -> None: + """Initializes the TimesFM forecast API. + + Args: + hparams: Hyperparameters of the model. + checkpoint: Checkpoint to load. Notice `checkpoint.version` will decide + which TimesFM version to use. + """ + self.hparams = hparams + + # Expand hparams for conciseness within the model code. + self.context_len = hparams.context_len + self.horizon_len = hparams.horizon_len + self.input_patch_len = hparams.input_patch_len + self.output_patch_len = hparams.output_patch_len + self.num_layers = hparams.num_layers + self.model_dims = hparams.model_dims + self.backend = hparams.backend + self.quantiles = hparams.quantiles + self.num_heads = hparams.num_heads + self.use_pos_emb = hparams.use_positional_embedding + + # Rewrite these values in __post_init__ for SPMD. + self.num_cores = 1 + self.per_core_batch_size = hparams.per_core_batch_size + self.global_batch_size = hparams.per_core_batch_size + + self._horizon_start = self.context_len - self.input_patch_len + self.__post_init__() + self.load_from_checkpoint(checkpoint) + + def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None: + """Loads a checkpoint and compiles the decoder.""" + raise NotImplementedError("`load_from_checkpoint` is not implemented.") + + def _preprocess( + self, inputs: Sequence[np.ndarray], + freq: Sequence[int]) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]: + """Formats and pads raw inputs to feed into the model. + + This function both pads each time series to match the context length, and + pads the inputs to meet the SPMD shape requirement. + + Args: + inputs: A list of 1d JTensors. Each JTensor is the context time series of + a single forecast task. + freq: list of frequencies + + Returns: + A tuple of: + - the padded input time series to meet the model required context. + - the padding indicator. + - the frequency of each input time series. + - the number of padded examples for SPMD so that each core has the same + number (a multiple of `batch_size`) of examples. + """ + + input_ts, input_padding, inp_freq = [], [], [] + + pmap_pad = ((len(inputs) - 1) // self.global_batch_size + + 1) * self.global_batch_size - len(inputs) + + for i, ts in enumerate(inputs): + input_len = ts.shape[0] + padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) + if input_len < self.context_len: + num_front_pad = self.context_len - input_len + ts = np.concatenate([np.zeros(shape=(num_front_pad,), dtype=float), ts], + axis=0) + padding = np.concatenate( + [np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0) + elif input_len > self.context_len: + ts = ts[-self.context_len:] + padding = padding[-(self.context_len + self.horizon_len):] + + input_ts.append(ts) + input_padding.append(padding) + inp_freq.append(freq[i]) + + # Padding the remainder batch. + for _ in range(pmap_pad): + input_ts.append(input_ts[-1]) + input_padding.append(input_padding[-1]) + inp_freq.append(inp_freq[-1]) + + return ( + np.stack(input_ts, axis=0), + np.stack(input_padding, axis=0), + np.array(inp_freq).astype(np.int32).reshape(-1, 1), + pmap_pad, + ) + + def _forecast( + self, + inputs: Sequence[Any], + freq: Optional[Sequence[int]] = None, + window_size: Optional[int] = None, + forecast_context_len: Optional[int] = None, + return_forecast_on_context: bool = False, + ) -> Tuple[np.ndarray, np.ndarray]: + """Forecasts on a list of time series. + + Args: + inputs: list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + return_forecast_on_context: True to return the forecast on the context + when available, i.e. after the first input patch. + + Returns: + A tuple for np.array: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). + + Raises: + ValueError: If the checkpoint is not properly loaded. + """ + raise NotImplementedError("`_forecast` is not implemented.") + + def forecast( + self, + inputs: Sequence[Any], + freq: Optional[Sequence[int]] = None, + window_size: Optional[int] = None, + forecast_context_len: Optional[int] = None, + return_forecast_on_context: bool = False, + normalize: bool = False, + ) -> Tuple[np.ndarray, np.ndarray]: + """Forecasts on a list of time series. + + Args: + inputs: list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + return_forecast_on_context: True to return the forecast on the context + when available, i.e. after the first input patch. + normalize: If True, then we normalize the inputs before forecasting and + the outputs are then renormalized to the original scale. + + Returns: + A tuple for np.array: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). + + Raises: + ValueError: If the checkpoint is not properly loaded. + """ + stats = None + if normalize: + inputs, stats = _normalize(inputs) + mean_forecast, quantile_forecast = self._forecast( + inputs, + freq, + window_size, + forecast_context_len, + return_forecast_on_context, + ) + if stats is not None: + stats = np.array(stats) + mu = stats[:, 0] + sigma = stats[:, 1] + mean_forecast = mean_forecast * sigma[:, None] + mu[:, None] + quantile_forecast = (quantile_forecast * sigma[:, None, None] + + mu[:, None, None]) + if self.hparams.point_forecast_mode == "mean": + return mean_forecast, quantile_forecast + elif self.hparams.point_forecast_mode == "median": + if self._median_index == -1: + for i, quantile in enumerate(self.quantiles): + if quantile == 0.5: + self._median_index = i + break + if self._median_index == -1: + raise ValueError("Median (0.5) is not found in the model quantiles:" + f" {self.quantiles}. Please check the hparams.") + return ( + quantile_forecast[:, :, 1 + self._median_index], + quantile_forecast, + ) + else: + raise ValueError( + "Unsupported point forecast mode:" + f" {self.hparams.point_forecast_mode}. Use 'mean' or 'median'.") + + def forecast_with_covariates( + self, + inputs: List[Sequence[float]], + dynamic_numerical_covariates: Optional[Dict[str, Sequence[Sequence[float]]]] = None, + dynamic_categorical_covariates: Optional[Dict[str, Sequence[Sequence[Category]]]] = None, + static_numerical_covariates: Optional[Dict[str, Sequence[float]]] = None, + static_categorical_covariates: Optional[Dict[str, Sequence[Category]]]= None, + freq: Optional[Sequence[int]] = None, + window_size: Optional[int] = None, + forecast_context_len: Optional[int]= None, + xreg_mode: XRegMode = "xreg + timesfm", + normalize_xreg_target_per_input: bool = True, + ridge: float = 0.0, + max_rows_per_col: int = 0, + force_on_cpu: bool = False, + ): + """Forecasts on a list of time series with covariates. + + To optimize inference speed, avoid string valued categorical covariates. + + Args: + inputs: A list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + dynamic_numerical_covariates: A dict of dynamic numerical covariates. + dynamic_categorical_covariates: A dict of dynamic categorical covariates. + static_numerical_covariates: A dict of static numerical covariates. + static_categorical_covariates: A dict of static categorical covariates. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + xreg_mode: one of "xreg + timesfm" or "timesfm + xreg". "xreg + timesfm" + fits a model on the residuals of the TimesFM forecast. "timesfm + xreg" + fits a model on the targets then forecasts on the residuals via TimesFM. + normalize_xreg_target_per_input: whether to normalize the xreg target per + input in the given batch. + ridge: ridge penalty for the linear model. + max_rows_per_col: max number of rows per column for the linear model. + force_on_cpu: whether to force running on cpu for the linear model. + + Returns: + A tuple of two lists. The first is the outputs of the model. The second is + the outputs of the xreg. + """ + + # Verify and bookkeep covariates. + if not (dynamic_numerical_covariates or dynamic_categorical_covariates or + static_numerical_covariates or static_categorical_covariates): + raise ValueError( + "At least one of dynamic_numerical_covariates," + " dynamic_categorical_covariates, static_numerical_covariates," + " static_categorical_covariates must be set.") + + # Track the lengths of (1) each input, (2) the part that can be used in the + # linear model, and (3) the horizon. + input_lens, train_lens, test_lens = [], [], [] + + for i, input_ts in enumerate(inputs): + input_len = len(input_ts) + input_lens.append(input_len) + + if xreg_mode == "timesfm + xreg": + # For fitting residuals, no TimesFM forecast on the first patch. + train_lens.append(max(0, input_len - self.input_patch_len)) + elif xreg_mode == "xreg + timesfm": + train_lens.append(input_len) + else: + raise ValueError(f"Unsupported mode: {xreg_mode}") + + if dynamic_numerical_covariates: + test_lens.append( + len(list(dynamic_numerical_covariates.values())[0][i]) - input_len) + elif dynamic_categorical_covariates: + test_lens.append( + len(list(dynamic_categorical_covariates.values())[0][i]) - + input_len) + else: + test_lens.append(self.horizon_len) + + if test_lens[-1] > self.horizon_len: + raise ValueError( + "Forecast requested longer horizon than the model definition " + f"supports: {test_lens[-1]} vs {self.horizon_len}.") + + # Prepare the covariates into train and test. + train_dynamic_numerical_covariates = collections.defaultdict(list) + test_dynamic_numerical_covariates = collections.defaultdict(list) + train_dynamic_categorical_covariates = collections.defaultdict(list) + test_dynamic_categorical_covariates = collections.defaultdict(list) + for covariates, train_covariates, test_covariates in ( + ( + dynamic_numerical_covariates, + train_dynamic_numerical_covariates, + test_dynamic_numerical_covariates, + ), + ( + dynamic_categorical_covariates, + train_dynamic_categorical_covariates, + test_dynamic_categorical_covariates, + ), + ): + if not covariates: + continue + for covariate_name, covariate_values in covariates.items(): + for input_len, train_len, covariate_value in zip( + input_lens, train_lens, covariate_values): + train_covariates[covariate_name].append( + covariate_value[(input_len - train_len):input_len]) + test_covariates[covariate_name].append(covariate_value[input_len:]) + + # Fit models. + if xreg_mode == "timesfm + xreg": + # Forecast via TimesFM then fit a model on the residuals. + mean_outputs, _ = self.forecast( + inputs, + freq, + window_size, + forecast_context_len, + return_forecast_on_context=True, + ) + targets = [ + (np.array(input_ts)[-train_len:] - + mean_output[(self._horizon_start - train_len):self._horizon_start]) + for input_ts, mean_output, train_len in zip(inputs, mean_outputs, + train_lens) + ] + per_instance_stats = None + if normalize_xreg_target_per_input: + targets, per_instance_stats = _normalize(targets) + xregs = xreg_lib.BatchedInContextXRegLinear( + targets=targets, + train_lens=train_lens, + test_lens=test_lens, + train_dynamic_numerical_covariates=train_dynamic_numerical_covariates, + test_dynamic_numerical_covariates=test_dynamic_numerical_covariates, + train_dynamic_categorical_covariates= + train_dynamic_categorical_covariates, + test_dynamic_categorical_covariates= + test_dynamic_categorical_covariates, + static_numerical_covariates=static_numerical_covariates, + static_categorical_covariates=static_categorical_covariates, + ).fit( + ridge=ridge, + one_hot_encoder_drop=None if ridge > 0 else "first", + max_rows_per_col=max_rows_per_col, + force_on_cpu=force_on_cpu, + debug_info=False, + assert_covariates=True, + assert_covariate_shapes=True, + ) + if normalize_xreg_target_per_input: + xregs = _renormalize(xregs, per_instance_stats) + outputs = [ + (mean_output[self._horizon_start:(self._horizon_start + test_len)] + + xreg) + for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs) + ] + + else: + # Fit a model on the targets then forecast on the residuals via TimesFM. + targets = [ + np.array(input_ts)[-train_len:] + for input_ts, train_len in zip(inputs, train_lens) + ] + per_instance_stats = None + if normalize_xreg_target_per_input: + targets, per_instance_stats = _normalize(targets) + xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear( + targets=targets, + train_lens=train_lens, + test_lens=test_lens, + train_dynamic_numerical_covariates=train_dynamic_numerical_covariates, + test_dynamic_numerical_covariates=test_dynamic_numerical_covariates, + train_dynamic_categorical_covariates= + train_dynamic_categorical_covariates, + test_dynamic_categorical_covariates= + test_dynamic_categorical_covariates, + static_numerical_covariates=static_numerical_covariates, + static_categorical_covariates=static_categorical_covariates, + ).fit( + ridge=ridge, + one_hot_encoder_drop=None if ridge > 0 else "first", + max_rows_per_col=max_rows_per_col, + force_on_cpu=force_on_cpu, + debug_info=True, + assert_covariates=True, + assert_covariate_shapes=True, + ) + mean_outputs, _ = self.forecast( + [ + target - xreg_on_context + for target, xreg_on_context in zip(targets, xregs_on_context) + ], + freq, + window_size, + forecast_context_len, + return_forecast_on_context=True, + ) + outputs = [ + (mean_output[self._horizon_start:(self._horizon_start + test_len)] + + xreg) + for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs) + ] + if normalize_xreg_target_per_input: + outputs = _renormalize(outputs, per_instance_stats) + + return outputs, xregs + + def forecast_on_df( + self, + inputs: pd.DataFrame, + freq: str, + forecast_context_len: int = 0, + value_name: str = "values", + model_name: str = "timesfm", + window_size: Optional[int] = None, + num_jobs: int = 1, + verbose: bool = True, + ) -> pd.DataFrame: + """Forecasts on a list of time series. + + Args: + inputs: A pd.DataFrame of all time series. The dataframe should have a + `unique_id` column for identifying the time series, a `ds` column for + timestamps and a value column for the time series values. + freq: string valued `freq` of data. Notice this is different from the + `freq` required by `forecast`. See `freq_map` for allowed values. + forecast_context_len: If provided none zero, we take the last + `forecast_context_len` time-points from each series as the forecast + context instead of the `context_len` set by the model. + value_name: The name of the value column. + model_name: name of the model to be written into future df. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + num_jobs: number of parallel processes to use for dataframe processing. + verbose: output model states in terminal. + + Returns: + Future forecasts dataframe. + """ + if not ("unique_id" in inputs.columns and "ds" in inputs.columns and + value_name in inputs.columns): + raise ValueError( + f"DataFrame must have unique_id, ds and {value_name} columns.") + if not forecast_context_len: + forecast_context_len = self.context_len + logging.info("Preprocessing dataframe.") + df_sorted = inputs.sort_values(by=["unique_id", "ds"]) + new_inputs = [] + uids = [] + if num_jobs == 1: + if verbose: + print("Processing dataframe with single process.") + for key, group in df_sorted.groupby("unique_id"): + inp, uid = process_group( + key, + group, + value_name, + forecast_context_len, + ) + new_inputs.append(inp) + uids.append(uid) + else: + if num_jobs == -1: + num_jobs = multiprocessing.cpu_count() + if verbose: + print("Processing dataframe with multiple processes.") + with multiprocessing.Pool(processes=num_jobs) as pool: + results = pool.starmap( + process_group, + [(key, group, value_name, forecast_context_len) + for key, group in df_sorted.groupby("unique_id")], + ) + new_inputs, uids = zip(*results) + if verbose: + print("Finished preprocessing dataframe.") + freq_inps = [freq_map(freq)] * len(new_inputs) + _, full_forecast = self.forecast(new_inputs, + freq=freq_inps, + window_size=window_size) + if verbose: + print("Finished forecasting.") + fcst_df = make_future_dataframe( + uids=uids, + last_times=df_sorted.groupby("unique_id")["ds"].tail(1), + h=self.horizon_len, + freq=freq, + ) + fcst_df[model_name] = full_forecast[:, 0:self.horizon_len, 0].reshape(-1, 1) + + for i, q in enumerate(self.quantiles): + q_col = f"{model_name}-q-{q}" + fcst_df[q_col] = full_forecast[:, 0:self.horizon_len, + 1 + i].reshape(-1, 1) + if q == 0.5: + fcst_df[model_name] = fcst_df[q_col] + logging.info("Finished creating output dataframe.") + return fcst_df \ No newline at end of file diff --git a/etna/libs/timesfm/xreg_lib.py b/etna/libs/timesfm/xreg_lib.py new file mode 100644 index 000000000..1cbb39781 --- /dev/null +++ b/etna/libs/timesfm/xreg_lib.py @@ -0,0 +1,635 @@ +""" + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + 1. Definitions. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. +""" + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: Copied from timesfm repository (https://github.com/google-research/timesfm/blob/154248137ccce29b01f4c3a765e85c3d9e4d92ba/src/timesfm/xreg_lib.py) + +"""Helper functions for in-context covariates and regression.""" + +import itertools +import math +from typing import Any, Iterable, Literal, Mapping, Sequence, Union, Optional, Tuple, List + +import jax +import jax.numpy as jnp +import numpy as np +from sklearn import preprocessing + +Category = Union[int, str] + +_TOL = 1e-6 +XRegMode = Literal["timesfm + xreg", "xreg + timesfm"] + + +def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray: + return np.array(list(itertools.chain.from_iterable(nested))) + + +def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray: + return np.array( + list( + itertools.chain.from_iterable(map(itertools.repeat, elements, + counts)))) + + +def _to_padded_jax_array(x: np.ndarray) -> jax.Array: + if x.ndim == 1: + (i,) = x.shape + di = 2**math.ceil(math.log2(i)) - i + return jnp.pad(x, ((0, di),), mode="constant", constant_values=0.0) + elif x.ndim == 2: + i, j = x.shape + di = 2**math.ceil(math.log2(i)) - i + dj = 2**math.ceil(math.log2(j)) - j + return jnp.pad(x, ((0, di), (0, dj)), mode="constant", constant_values=0.0) + else: + raise ValueError(f"Unsupported array shape: {x.shape}") + + +class BatchedInContextXRegBase: + """Helper class for in-context regression covariate formatting. + + Attributes: + targets: List of targets (responses) of the in-context regression. + train_lens: List of lengths of each target vector from the context. + test_lens: List of lengths of each forecast horizon. + train_dynamic_numerical_covariates: Dict of covariate names mapping to the + dynamic numerical covariates of each forecast task on the context. Their + lengths should match the corresponding lengths in `train_lens`. + train_dynamic_categorical_covariates: Dict of covariate names mapping to the + dynamic categorical covariates of each forecast task on the context. Their + lengths should match the corresponding lengths in `train_lens`. + test_dynamic_numerical_covariates: Dict of covariate names mapping to the + dynamic numerical covariates of each forecast task on the horizon. Their + lengths should match the corresponding lengths in `test_lens`. + test_dynamic_categorical_covariates: Dict of covariate names mapping to the + dynamic categorical covariates of each forecast task on the horizon. Their + lengths should match the corresponding lengths in `test_lens`. + static_numerical_covariates: Dict of covariate names mapping to the static + numerical covariates of each forecast task. + static_categorical_covariates: Dict of covariate names mapping to the static + categorical covariates of each forecast task. + """ + + def __init__( + self, + targets: Sequence[Sequence[float]], + train_lens: Sequence[int], + test_lens: Sequence[int], + train_dynamic_numerical_covariates: Optional[Mapping[str, Sequence[Sequence[float]]]] = None, + train_dynamic_categorical_covariates: Optional[Mapping[str, Sequence[Sequence[Category]]]] = None, + test_dynamic_numerical_covariates: Optional[Mapping[str, Sequence[Sequence[float]]]] = None, + test_dynamic_categorical_covariates: Optional[Mapping[str, Sequence[Sequence[Category]]]] = None, + static_numerical_covariates: Optional[Mapping[str, Sequence[float]]] = None, + static_categorical_covariates: Optional[Mapping[str, Sequence[Category]]] = None, + ) -> None: + """Initializes with the exogenous covariate inputs. + + Here we use model fitting language to refer to the context as 'train' and + the horizon as 'test'. We assume batched inputs. To properly format the + request: + + - `train_lens` represents the contexts in the batch. Targets and all train + dynamic covariates should have the same lengths as the corresponding + elements + in `train_lens`. Notice each `train_len` can be different from the exact + length of the corresponding context depending on how much of the context is + used for fitting the in-context model. + - `test_lens` represents the horizon lengths in the batch. All tesdt + dynamic + covariates should have the same lengths as the corresponding elements in + `test_lens`. + - Static covariates should be one for each input. + - For train and test dynamic covariates, they should have the same + covariate + names. + + Pass an empty dict {} for a covariate type if it is not present. + + Example: + Here is a set of valid inputs whose schema can be used for reference. + ``` + targets = [ + [0.0, 0.1, 0.2], + [0.0, 0.1, 0.2, 0.3], + ] # Two inputs in this batch. + train_lens = [3, 4] + test_lens = [2, 5] # Forecast horizons 2 and 5 respectively. + train_dynamic_numerical_covariates = { + "cov_1_dn": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]], + "cov_2_dn": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]], + } # Each train dynamic covariate has 3 and 4 elements respectively. + test_dynamic_numerical_covariates = { + "cov_1_dn": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]], + "cov_2_dn": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]], + } # Each test dynamic covariate has 2 and 5 elements respectively. + train_dynamic_categorical_covariates = { + "cov_1_dc": [[0, 1, 0], [0, 1, 2, 3]], + "cov_2_dc": [["good", "bad", "good"], ["good", "good", "bad", + "bad"]], + } + test_dynamic_categorical_covariates = { + "cov_1_dc": [[1, 0], [1, 0, 2, 3, 1]], + "cov_2_dc": [["bad", "good"], ["bad", "bad", "bad", "bad", "bad"]], + } + static_numerical_covariates = { + "cov_1_sn": [0.0, 3.0], + "cov_2_sn": [2.0, 1.0], + "cov_3_sn": [1.0, 2.0], + } # Each static covariate has 1 element for each input. + static_categorical_covariates = { + "cov_1_sc": ["apple", "orange"], + "cov_2_sc": [2, 3], + } + ``` + + Args: + targets: List of targets (responses) of the in-context regression. + train_lens: List of lengths of each target vector from the context. + test_lens: List of lengths of each forecast horizon. + train_dynamic_numerical_covariates: Dict of covariate names mapping to the + dynamic numerical covariates of each forecast task on the context. Their + lengths should match the corresponding lengths in `train_lens`. + train_dynamic_categorical_covariates: Dict of covariate names mapping to + the dynamic categorical covariates of each forecast task on the context. + Their lengths should match the corresponding lengths in `train_lens`. + test_dynamic_numerical_covariates: Dict of covariate names mapping to the + dynamic numerical covariates of each forecast task on the horizon. Their + lengths should match the corresponding lengths in `test_lens`. + test_dynamic_categorical_covariates: Dict of covariate names mapping to + the dynamic categorical covariates of each forecast task on the horizon. + Their lengths should match the corresponding lengths in `test_lens`. + static_numerical_covariates: Dict of covariate names mapping to the static + numerical covariates of each forecast task. + static_categorical_covariates: Dict of covariate names mapping to the + static categorical covariates of each forecast task. + """ + self.targets = targets + self.train_lens = train_lens + self.test_lens = test_lens + self.train_dynamic_numerical_covariates = ( + train_dynamic_numerical_covariates or {}) + self.train_dynamic_categorical_covariates = ( + train_dynamic_categorical_covariates or {}) + self.test_dynamic_numerical_covariates = (test_dynamic_numerical_covariates + or {}) + self.test_dynamic_categorical_covariates = ( + test_dynamic_categorical_covariates or {}) + self.static_numerical_covariates = static_numerical_covariates or {} + self.static_categorical_covariates = static_categorical_covariates or {} + + def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None: + """Verifies the validity of the covariate inputs.""" + + # Check presence. + if (self.train_dynamic_numerical_covariates and + not self.test_dynamic_numerical_covariates) or ( + not self.train_dynamic_numerical_covariates and + self.test_dynamic_numerical_covariates): + raise ValueError( + "train_dynamic_numerical_covariates and" + " test_dynamic_numerical_covariates must be both present or both" + " absent.") + + if (self.train_dynamic_categorical_covariates and + not self.test_dynamic_categorical_covariates) or ( + not self.train_dynamic_categorical_covariates and + self.test_dynamic_categorical_covariates): + raise ValueError( + "train_dynamic_categorical_covariates and" + " test_dynamic_categorical_covariates must be both present or both" + " absent.") + + # Check keys. + for dict_a, dict_b, dict_a_name, dict_b_name in ( + ( + self.train_dynamic_numerical_covariates, + self.test_dynamic_numerical_covariates, + "train_dynamic_numerical_covariates", + "test_dynamic_numerical_covariates", + ), + ( + self.train_dynamic_categorical_covariates, + self.test_dynamic_categorical_covariates, + "train_dynamic_categorical_covariates", + "test_dynamic_categorical_covariates", + ), + ): + if w := set(dict_a.keys()) - set(dict_b.keys()): + raise ValueError( + f"{dict_a_name} has keys not present in {dict_b_name}: {w}") + if w := set(dict_b.keys()) - set(dict_a.keys()): + raise ValueError( + f"{dict_b_name} has keys not present in {dict_a_name}: {w}") + + # Check shapes. + if assert_covariate_shapes: + if len(self.targets) != len(self.train_lens): + raise ValueError( + "targets and train_lens must have the same number of elements.") + + if len(self.train_lens) != len(self.test_lens): + raise ValueError( + "train_lens and test_lens must have the same number of elements.") + + for i, (target, train_len) in enumerate(zip(self.targets, + self.train_lens)): + if len(target) != train_len: + raise ValueError( + f"targets[{i}] has length {len(target)} != expected {train_len}.") + + for key, values in self.static_numerical_covariates.items(): + if len(values) != len(self.train_lens): + raise ValueError( + f"static_numerical_covariates has key {key} with number of" + f" examples {len(values)} != expected {len(self.train_lens)}.") + + for key, values in self.static_categorical_covariates.items(): + if len(values) != len(self.train_lens): + raise ValueError( + f"static_categorical_covariates has key {key} with number of" + f" examples {len(values)} != expected {len(self.train_lens)}.") + + for lens, dict_cov, dict_cov_name in ( + ( + self.train_lens, + self.train_dynamic_numerical_covariates, + "train_dynamic_numerical_covariates", + ), + ( + self.train_lens, + self.train_dynamic_categorical_covariates, + "train_dynamic_categorical_covariates", + ), + ( + self.test_lens, + self.test_dynamic_numerical_covariates, + "test_dynamic_numerical_covariates", + ), + ( + self.test_lens, + self.test_dynamic_categorical_covariates, + "test_dynamic_categorical_covariates", + ), + ): + for key, cov_values in dict_cov.items(): + if len(cov_values) != len(lens): + raise ValueError( + f"{dict_cov_name} has key {key} with number of examples" + f" {len(cov_values)} != expected {len(lens)}.") + for i, cov_value in enumerate(cov_values): + if len(cov_value) != lens[i]: + raise ValueError( + f"{dict_cov_name} has key {key} with its {i}-th example" + f" length {len(cov_value)} != expected {lens[i]}.") + + def create_covariate_matrix( + self, + one_hot_encoder_drop: Optional[str]= "first", + use_intercept: bool = True, + assert_covariates: bool = False, + assert_covariate_shapes: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Creates target vector and covariate matrices for in context regression. + + Here we use model fitting language to refer to the context as 'train' and + the horizon as 'test'. + + Args: + one_hot_encoder_drop: Which drop strategy to use for the one hot encoder. + use_intercept: Whether to prepare an intercept (all 1) column in the + matrices. + assert_covariates: Whether to assert the validity of the covariate inputs. + assert_covariate_shapes: Whether to assert the shapes of the covariate + inputs when `assert_covariates` is True. + + Returns: + A tuple of the target vector, the covariate matrix for the context, and + the covariate matrix for the horizon. + """ + if assert_covariates: + self._assert_covariates(assert_covariate_shapes) + + x_train, x_test = [], [] + + # Numerical features. + for name in sorted(self.train_dynamic_numerical_covariates): + x_train.append( + _unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis]) + x_test.append( + _unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis]) + + for covs in self.static_numerical_covariates.values(): + x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis]) + x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis]) + + if x_train: + x_train = np.concatenate(x_train, axis=1) + x_test = np.concatenate(x_test, axis=1) + + # Normalize for robustness. + x_mean = np.mean(x_train, axis=0, keepdims=True) + x_std = np.where((w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w, + 1.0) + x_train = [(x_train - x_mean) / x_std] + x_test = [(x_test - x_mean) / x_std] + + # Categorical features. Encode one by one. + one_hot_encoder = preprocessing.OneHotEncoder( + drop=one_hot_encoder_drop, + sparse_output=False, + handle_unknown="ignore", + ) + for name in sorted(self.train_dynamic_categorical_covariates.keys()): + ohe_train = _unnest( + self.train_dynamic_categorical_covariates[name])[:, np.newaxis] + ohe_test = _unnest( + self.test_dynamic_categorical_covariates[name])[:, np.newaxis] + x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train))) + x_test.append(np.array(one_hot_encoder.transform(ohe_test))) + + for covs in self.static_categorical_covariates.values(): + ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis]) + x_train.append(_repeat(ohe, self.train_lens)) + x_test.append(_repeat(ohe, self.test_lens)) + + x_train = np.concatenate(x_train, axis=1) + x_test = np.concatenate(x_test, axis=1) + + if use_intercept: + x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0) + x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0) + + return _unnest(self.targets), x_train, x_test + + def fit(self) -> Any: + raise NotImplementedError("Fit is not implemented.") + + +class BatchedInContextXRegLinear(BatchedInContextXRegBase): + """Linear in-context regression model.""" + + def fit( + self, + ridge: float = 0.0, + one_hot_encoder_drop: Optional[str] = "first", + use_intercept: bool = True, + force_on_cpu: bool = False, + max_rows_per_col: int = 0, + max_rows_per_col_sample_seed: int = 42, + debug_info: bool = False, + assert_covariates: bool = False, + assert_covariate_shapes: bool = False, + ) -> Union[List[np.ndarray], Tuple[List[np.ndarray], List[np.ndarray], jax.Array, jax.Array, jax.Array]]: + """Fits a linear model for in-context regression. + + Args: + ridge: A non-negative value for specifying the ridge regression penalty. + If 0 is provided, fallback to ordinary least squares. Note this penalty + is added to the normalized covariate matrix. + one_hot_encoder_drop: Which drop strategy to use for the one hot encoder. + use_intercept: Whether to prepare an intercept (all 1) column in the + matrices. + force_on_cpu: Whether to force execution on cpu for accelerator machines. + max_rows_per_col: How many rows to subsample per column. 0 for no + subsampling. This is for speeding up model fitting. + max_rows_per_col_sample_seed: The seed for the subsampling if needed by + `max_rows_per_col`. + debug_info: Whether to return debug info. + assert_covariates: Whether to assert the validity of the covariate inputs. + assert_covariate_shapes: Whether to assert the shapes of the covariate + inputs when `assert_covariates` is True. + + Returns: + If `debug_info` is False: + The linear fits on the horizon. + If `debug_info` is True: + A tuple of: + - the linear fits on the horizon, + - the linear fits on the context, + - the flattened target vector, + - the covariate matrix for the context, and + - the covariate matrix for the horizon. + """ + flat_targets, x_train_raw, x_test = self.create_covariate_matrix( + one_hot_encoder_drop=one_hot_encoder_drop, + use_intercept=use_intercept, + assert_covariates=assert_covariates, + assert_covariate_shapes=assert_covariate_shapes, + ) + + x_train = x_train_raw.copy() + if max_rows_per_col: + nrows, ncols = x_train.shape + if nrows > (w := ncols * max_rows_per_col): + subsample = jax.random.choice( + jax.random.PRNGKey(max_rows_per_col_sample_seed), + nrows, + (w,), + replace=False, + ) + x_train = x_train[subsample] + flat_targets = flat_targets[subsample] + + device = jax.devices("cpu")[0] if force_on_cpu else None + # Runs jitted version of the solvers which are quicker at the cost of + # running jitting during the first time calling. Re-jitting happens whenever + # new (padded) shapes are encountered. + # Ocassionally it helps with the speed and the accuracy if we force single + # thread execution on cpu for accelerator machines: + # 1. Avoid moving data to accelarator memory. + # 2. Avoid precision loss if any. + with jax.default_device(device): + x_train_raw = _to_padded_jax_array(x_train_raw) + x_train = _to_padded_jax_array(x_train) + flat_targets = _to_padded_jax_array(flat_targets) + x_test = _to_padded_jax_array(x_test) + beta_hat = (jnp.linalg.pinv( + x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]), + hermitian=True, + ) @ x_train.T @ flat_targets) + y_hat = x_test @ beta_hat + y_hat_context = x_train_raw @ beta_hat if debug_info else None + + outputs = [] + outputs_context = [] + + # Reconstruct the ragged 2-dim batched forecasts from flattened linear fits. + train_index, test_index = 0, 0 + for train_index_delta, test_index_delta in zip(self.train_lens, + self.test_lens): + outputs.append(np.array(y_hat[test_index:(test_index + + test_index_delta)])) + if debug_info: + outputs_context.append( + np.array(y_hat_context[train_index:(train_index + + train_index_delta)])) + train_index += train_index_delta + test_index += test_index_delta + + if debug_info: + return outputs, outputs_context, flat_targets, x_train, x_test + else: + return outputs \ No newline at end of file diff --git a/etna/models/nn/__init__.py b/etna/models/nn/__init__.py index b972e2aab..4512ac420 100644 --- a/etna/models/nn/__init__.py +++ b/etna/models/nn/__init__.py @@ -16,3 +16,6 @@ if SETTINGS.chronos_required: from etna.models.nn.chronos import ChronosBoltModel from etna.models.nn.chronos import ChronosModel + +if SETTINGS.timesfm_required: + from etna.models.nn.timesfm import TimesFMModel diff --git a/etna/models/nn/timesfm.py b/etna/models/nn/timesfm.py new file mode 100644 index 000000000..c687d371d --- /dev/null +++ b/etna/models/nn/timesfm.py @@ -0,0 +1,274 @@ +import os +import warnings +from pathlib import Path +from typing import Dict +from typing import List +from urllib import request + +import pandas as pd + +from etna import SETTINGS +from etna.datasets import TSDataset +from etna.distributions import BaseDistribution +from etna.models.base import NonPredictionIntervalContextRequiredAbstractModel + +if SETTINGS.timesfm_required: + from etna.libs.timesfm import TimesFmCheckpoint + from etna.libs.timesfm import TimesFmHparams + from etna.libs.timesfm import TimesFmTorch + +_DOWNLOAD_PATH = Path.home() / ".etna" / "timesfm" + + +class TimesFMModel(NonPredictionIntervalContextRequiredAbstractModel): + """ + Class for pretrained timesfm models. + + This model is only for zero-shot forecasting: it doesn't support training on data during ``fit``. + + Official implementation: https://github.com/google-research/timesfm + + Warning + ------- + This model doesn't support forecasting on misaligned data with `freq=None`. + + Note + ---- + This model requires ``timesfm`` extension to be installed. + Read more about this at :ref:`installation page `. + """ + + def __init__( + self, + path_or_url: str, + encoder_length: int = 512, + device: str = "cpu", + batch_size: int = 128, + cache_dir: Path = _DOWNLOAD_PATH, + ): + """ + Init TimesFM model. + + Parameters + ---------- + path_or_url: + Path to the model. It can be huggingface repository, local path or external url. + + - If huggingface repository, the available models are: + + - 'google/timesfm-1.0-200m-pytorch'. + During the first initialization model is downloaded from huggingface and saved to local ``cache_dir``. + All following initializations model will be loaded from ``cache_dir``. + - If local path, it should be a file with model weights, that can be loaded by :py:func:`torch.load`. + - If external url, it must be a file with model weights, that can be loaded by :py:func:`torch.load`. Model will be downloaded to ``cache_dir``. + device: + Device type. Can be "cpu", "gpu". + encoder_length: + Number of last timestamps to use as a context. It needs to be a multiplier of 32. + batch_size: + Batch size. It can be useful when inference is done on gpu. + cache_dir: + Local path to save model from huggingface during first model initialization. All following class initializations appropriate model version will be downloaded from this path. + """ + self.path_or_url = path_or_url + self.encoder_length = encoder_length + self.device = device + self.batch_size = batch_size + self.cache_dir = cache_dir + + self._set_pipeline() + + def _set_pipeline(self): + """Set ``tfm`` attribute.""" + if self._is_url(): + full_model_path = self._download_model_from_url() + self.tfm = TimesFmTorch( + hparams=TimesFmHparams( + context_len=self.encoder_length, per_core_batch_size=self.batch_size, backend=self.device + ), + checkpoint=TimesFmCheckpoint(path=full_model_path), + ) + else: + self.tfm = TimesFmTorch( + hparams=TimesFmHparams( + context_len=self.encoder_length, per_core_batch_size=self.batch_size, backend=self.device + ), + checkpoint=TimesFmCheckpoint(path=self.path_or_url, local_dir=self.cache_dir), + ) + + def _is_url(self): + """Check whether ``path_or_url`` is url.""" + return self.path_or_url.startswith("https://") or self.path_or_url.startswith("http://") + + def _download_model_from_url(self) -> str: + """Download model from url to local cache_dir.""" + model_file = self.path_or_url.split("/")[-1] + full_model_path = f"{self.cache_dir}/{model_file}" + if not os.path.exists(full_model_path): + request.urlretrieve(url=self.path_or_url, filename=full_model_path) + return full_model_path + + @property + def context_size(self) -> int: + """Context size for model.""" + return self.encoder_length + + def get_model(self) -> TimesFmTorch: + """Get model.""" + return self.tfm + + def fit(self, ts: TSDataset): + """Fit model. + + For this model, fit does nothing. + + Parameters + ---------- + ts: + Dataset with features. + + Returns + ------- + : + Model after fit + """ + return self + + def predict( + self, + ts: TSDataset, + prediction_size: int, + return_components: bool = False, + ) -> TSDataset: + """Make predictions using true values as autoregression context (teacher forcing). + + Parameters + ---------- + ts: + Dataset with features. + prediction_size: + Number of last timestamps to leave after making prediction. + Previous timestamps will be used as a context. + return_components: + If True additionally returns forecast components. + + Returns + ------- + : + Dataset with predictions. + """ + raise NotImplementedError("Method predict isn't currently implemented!") + + def forecast( + self, + ts: TSDataset, + prediction_size: int, + return_components: bool = False, + ) -> TSDataset: + """Make autoregressive forecasts. + + Parameters + ---------- + ts: + Dataset with features. + prediction_size: + Number of last timestamps to leave after making prediction. + Previous timestamps will be used as a context. + return_components: + If True additionally returns forecast components. + + Returns + ------- + : + Dataset with predictions. + + Raises + ------ + NotImplementedError: + if return_components mode is used. + ValueError: + if dataset doesn't have any context timestamps. + NotImplementedError: + if data with None frequency is used. + """ + if ts.freq is None: + raise NotImplementedError("Data with None frequency isn't currently implemented!") + + if return_components: + raise NotImplementedError("This mode isn't currently implemented!") + + max_context_size = len(ts.index) - prediction_size + if max_context_size <= 0: + raise ValueError("Dataset doesn't have any context timestamps.") + + if max_context_size < self.context_size: + warnings.warn("Actual length of a dataset is less that context size. All history will be used as context.") + available_context_size = min(max_context_size, self.context_size) + + self.tfm._set_horizon(prediction_size) + + target = ts.df.loc[ts.index[:-prediction_size], pd.IndexSlice[:, "target"]] + target = target.iloc[-available_context_size:] + df = TSDataset.to_flatten(target).dropna() + df = df.rename(columns={"segment": "unique_id", "timestamp": "ds"}) + + predictions = self.tfm.forecast_on_df(df, freq=ts.freq, value_name="target") + + end_idx = len(ts.index) + future_ts = ts.tsdataset_idx_slice(start_idx=end_idx - prediction_size, end_idx=end_idx) + predictions = predictions.rename(columns={"unique_id": "segment", "ds": "timestamp", "timesfm": "target"}) + predictions = TSDataset.to_dataset(predictions) + future_ts.df.loc[:, pd.IndexSlice[:, "target"]] = predictions.loc[ + :, pd.IndexSlice[:, "target"] + ].values # .values is needed to cast predictions type of initial target type in ts + + return future_ts + + @staticmethod + def list_models() -> List[str]: + """ + Return a list of available pretrained timesfm models. + + Returns + ------- + : + List of available pretrained chronos models. + """ + return ["google/timesfm-1.0-200m-pytorch"] + + def save(self, path: Path): + """Save the model. This method doesn't save model's weights. + + During ``load`` weights are loaded from the path where they were saved during ``init`` + + Parameters + ---------- + path: + Path to save object to. + """ + self._save(path=path, skip_attributes=["tfm"]) + + @classmethod + def load(cls, path: Path): + """Load the model. + + Parameters + ---------- + path: + Path to load object from. + """ + obj: TimesFMModel = super().load(path=path) + obj._set_pipeline() + return obj + + def params_to_tune(self) -> Dict[str, BaseDistribution]: + """Get default grid for tuning hyperparameters. + + This grid is empty. + + Returns + ------- + : + Grid to tune. + """ + return {} diff --git a/etna/settings.py b/etna/settings.py index 987525548..afbcc5d8a 100644 --- a/etna/settings.py +++ b/etna/settings.py @@ -52,6 +52,21 @@ def _is_chronos_available(): return False +def _is_timesfm_available(): + true_case = ( + _module_available("torch") + & _module_available("jax") + & _module_available("jaxlib") + & _module_available("huggingface_hub") + & _module_available("utilsforecast") + ) + if true_case: + return True + else: + warnings.warn("etna[timesfm] is not available, to install it, run `pip install etna[timesfm]`") + return False + + def _is_wandb_available(): if _module_available("wandb"): return True @@ -112,6 +127,7 @@ def __init__( # noqa: D107 self, torch_required: Optional[bool] = None, chronos_required: Optional[bool] = None, + timesfm_required: Optional[bool] = None, prophet_required: Optional[bool] = None, wandb_required: Optional[bool] = None, classification_required: Optional[bool] = None, @@ -131,6 +147,11 @@ def __init__( # noqa: D107 _is_chronos_available, "etna[chronos] is not available, to install it, run `pip install etna[chronos]`.", ) + self.timesfm_required: bool = _get_optional_value( + timesfm_required, + _is_timesfm_available, + "etna[timesfm] is not available, to install it, run `pip install etna[timesfm]`.", + ) self.wandb_required: bool = _get_optional_value( wandb_required, _is_wandb_available, "wandb is not available, to install it, " "run `pip install wandb`." ) diff --git a/examples/202-NN_examples.ipynb b/examples/202-NN_examples.ipynb index 29bdeee07..d8ad46123 100644 --- a/examples/202-NN_examples.ipynb +++ b/examples/202-NN_examples.ipynb @@ -33,7 +33,8 @@ " * [N-BEATS Model](#section_3_9)\n", " * [PatchTS Model](#section_3_10)\n", " * [Chronos Model](#section_3_11)\n", - " * [Chronos Bolt Model](#section_3_12)" + " * [Chronos Bolt Model](#section_3_12)\n", + " * [TimesFM Model](#section_3_13)" ] }, { @@ -43,12 +44,12 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install \"etna[torch,chronos]\" -q" + "!pip install \"etna[torch,chronos,timesfm]\" -q" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "0eb5e69e", "metadata": { "tags": [] @@ -62,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "a1a1a571", "metadata": { "pycharm": { @@ -95,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "7851ddcc", "metadata": { "tags": [] @@ -128,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "bfe220cb", "metadata": { "tags": [] @@ -204,7 +205,7 @@ "4 2019-01-05 segment_a 279" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -224,7 +225,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "c1f7de68", "metadata": { "tags": [] @@ -326,7 +327,7 @@ "2019-01-05 279 137 104 384" ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -369,7 +370,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "f3ae2de0", "metadata": { "tags": [] @@ -381,7 +382,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "id": "af6a6035", "metadata": { "tags": [] @@ -444,7 +445,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "fbb2c279-f505-4f1b-b0e3-5e94369f9673", "metadata": { "tags": [] @@ -5035,6 +5036,169 @@ "score = metrics_chronos_bolt[\"SMAPE\"].mean()\n", "print(f\"Average SMAPE for Chronos Bolt small with long context: {score:.3f}\")" ] + }, + { + "cell_type": "markdown", + "id": "708ecd8f", + "metadata": {}, + "source": [ + "### 3.13 TimesFm Model \n", + "\n", + "`TimesFMModel` is one more pretrained model for zero-shot forecasting. It has similar interface to `ChronosBoltModel` and `ChronosModel`." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "95832e68", + "metadata": {}, + "outputs": [], + "source": [ + "from etna.models.nn import TimesFMModel" + ] + }, + { + "cell_type": "markdown", + "id": "4130cab8", + "metadata": {}, + "source": [ + "Now only one model is available." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "24a4a369", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['google/timesfm-1.0-200m-pytorch']" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "TimesFMModel.list_models()" + ] + }, + { + "cell_type": "markdown", + "id": "2173dc9f", + "metadata": {}, + "source": [ + "Be careful. `encoder_length` needs to be a multiplier of 32." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "0473e740", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cbe5ceeb717c4b25a6261ac67bb69a4f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 3 files: 0%| | 0/3 [00:00=0.5.2)", "pipreqs", "requirementslib" plugins = ["setuptools"] requirements-deprecated-finder = ["pip-api", "pipreqs"] +[[package]] +name = "jax" +version = "0.4.13" +description = "Differentiate, compile, and transform Numpy code." +optional = true +python-versions = ">=3.8" +files = [ + {file = "jax-0.4.13.tar.gz", hash = "sha256:03bfe6749dfe647f16f15f6616638adae6c4a7ca7167c75c21961ecfd3a3baaa"}, +] + +[package.dependencies] +importlib_metadata = {version = ">=4.6", markers = "python_version < \"3.10\""} +ml_dtypes = ">=0.1.0" +numpy = ">=1.21" +opt_einsum = "*" +scipy = ">=1.7" + +[package.extras] +australis = ["protobuf (>=3.13,<4)"] +ci = ["jaxlib (==0.4.12)"] +cpu = ["jaxlib (==0.4.13)"] +cuda = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-cudnn86 = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-local = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-pip = ["jaxlib (==0.4.13+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-local = ["jaxlib (==0.4.13+cuda12.cudnn89)"] +cuda12-pip = ["jaxlib (==0.4.13+cuda12.cudnn89)", "nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] +minimum-jaxlib = ["jaxlib (==0.4.11)"] +tpu = ["jaxlib (==0.4.13)", "libtpu-nightly (==0.1.dev20230622)"] + +[[package]] +name = "jaxlib" +version = "0.4.13" +description = "XLA library for JAX" +optional = true +python-versions = ">=3.8" +files = [ + {file = "jaxlib-0.4.13-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:532ebc4fb11386282ad63b83941d4557f4038c1144acf026f1f8565f64c7e9c0"}, + {file = "jaxlib-0.4.13-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a259bb35429bfbd3b76e43019dfc8f7d6ea94bb217400b78f7d0824ce07a58ac"}, + {file = "jaxlib-0.4.13-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:ea1bc9811ef7d73a15e3213115e88fe7f5d14b59d95027bea9fccc98e5a14af8"}, + {file = "jaxlib-0.4.13-cp310-cp310-win_amd64.whl", hash = "sha256:fde66a93e9be89d99e5792f677ed8e319667d6b2396865b1c52c1312844c47f9"}, + {file = "jaxlib-0.4.13-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:49690fcdd26560515fd15399fc3a44777e0bfc5db5c48fe76ff7bc7228e8b2fb"}, + {file = "jaxlib-0.4.13-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f4e9e34e5d8a6556f62fead14aee0b1614c2c6296f0078d8e6139d6aff109649"}, + {file = "jaxlib-0.4.13-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:8000c0d15c107328e8f7b7b3ac91dd822f5c287a80231882b620503ed141fa89"}, + {file = "jaxlib-0.4.13-cp311-cp311-win_amd64.whl", hash = "sha256:19ae4c316b17a49342432c69f7f89f190b975333f3f9e9e175f686a651bc7347"}, + {file = "jaxlib-0.4.13-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:522635d5e159401a386c79f1236c218c1f68fbb4ca6648115c3ad3c2c3f518ab"}, + {file = "jaxlib-0.4.13-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:411334d903df07dc1ace8d52fc53c17f6bc1d55aff7f6e0e5cf61ec149f758a0"}, + {file = "jaxlib-0.4.13-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:839173b2e9593f5e9a6d3c42852cd15070fe80a939246efbb5cf40eec815de89"}, + {file = "jaxlib-0.4.13-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:c230ef85712e608d0f048869766a5a63afeb2e72309943db0df9f959ab17307f"}, + {file = "jaxlib-0.4.13-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d19c05c15f962e098d49b45e2758aacf19330d192ec5395f9ef136f62db90edc"}, + {file = "jaxlib-0.4.13-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:b5c0a9737efd95fe18fd7715ce30dfce476546705ea8934aad6731777a9631a5"}, + {file = "jaxlib-0.4.13-cp39-cp39-win_amd64.whl", hash = "sha256:bebb4cf001f180dc431f9604daf930c2d9cc778e4dda26f401ac939b7bac912e"}, +] + +[package.dependencies] +ml-dtypes = ">=0.1.0" +numpy = ">=1.21" +scipy = ">=1.7" + +[package.extras] +cuda11-pip = ["nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-pip = ["nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] + [[package]] name = "jedi" version = "0.18.2" @@ -2673,6 +2736,41 @@ files = [ {file = "mistune-2.0.5.tar.gz", hash = "sha256:0246113cb2492db875c6be56974a7c893333bf26cd92891c85f63151cee09d34"}, ] +[[package]] +name = "ml-dtypes" +version = "0.2.0" +description = "" +optional = true +python-versions = ">=3.7" +files = [ + {file = "ml_dtypes-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:df6a76e1c8adf484feb138ed323f9f40a7b6c21788f120f7c78bec20ac37ee81"}, + {file = "ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc29a0524ef5e23a7fbb8d881bdecabeb3fc1d19d9db61785d077a86cb94fab2"}, + {file = "ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f08c391c2794f2aad358e6f4c70785a9a7b1df980ef4c232b3ccd4f6fe39f719"}, + {file = "ml_dtypes-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:75015818a7fccf99a5e8ed18720cb430f3e71a8838388840f4cdf225c036c983"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e70047ec2c83eaee01afdfdabee2c5b0c133804d90d0f7db4dd903360fcc537c"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36d28b8861a8931695e5a31176cad5ae85f6504906650dea5598fbec06c94606"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e85ba8e24cf48d456e564688e981cf379d4c8e644db0a2f719b78de281bac2ca"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:832a019a1b6db5c4422032ca9940a990fa104eee420f643713241b3a518977fa"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8faaf0897942c8253dd126662776ba45f0a5861968cf0f06d6d465f8a7bc298a"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35b984cddbe8173b545a0e3334fe56ea1a5c3eb67c507f60d0cfde1d3fa8f8c2"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:022d5a4ee6be14569c2a9d1549e16f1ec87ca949681d0dca59995445d5fcdd5b"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:50845af3e9a601810751b55091dee6c2562403fa1cb4e0123675cf3a4fc2c17a"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f00c71c8c63e03aff313bc6a7aeaac9a4f1483a921a6ffefa6d4404efd1af3d0"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80d304c836d73f10605c58ccf7789c171cc229bfb678748adfb7cea2510dfd0e"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32107e7fa9f62db9a5281de923861325211dfff87bd23faefb27b303314635ab"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:1749b60348da71fd3c2ab303fdbc1965958dc50775ead41f5669c932a341cafd"}, + {file = "ml_dtypes-0.2.0.tar.gz", hash = "sha256:6488eb642acaaf08d8020f6de0a38acee7ac324c1e6e92ee0c0fea42422cb797"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.21.2", markers = "python_version > \"3.9\""}, + {version = ">1.20", markers = "python_version <= \"3.9\""}, +] + +[package.extras] +dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] + [[package]] name = "msgpack" version = "1.0.5" @@ -3215,6 +3313,17 @@ files = [ antlr4-python3-runtime = "==4.9.*" PyYAML = ">=5.1.0" +[[package]] +name = "opt-einsum" +version = "3.4.0" +description = "Path optimization of einsum functions." +optional = true +python-versions = ">=3.8" +files = [ + {file = "opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd"}, + {file = "opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac"}, +] + [[package]] name = "optuna" version = "2.10.1" @@ -6031,6 +6140,27 @@ brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +[[package]] +name = "utilsforecast" +version = "0.2.10" +description = "Forecasting utilities" +optional = true +python-versions = ">=3.8" +files = [ + {file = "utilsforecast-0.2.10-py3-none-any.whl", hash = "sha256:ee7860f18a6df5dd695b51e603f3866a00dd0da0eb6c12d07f052e54390cf1a7"}, + {file = "utilsforecast-0.2.10.tar.gz", hash = "sha256:6058ca1a00b7e9dc02346a071a8e3f4dabe2a01f6f6b5a563c6c849754a86d3e"}, +] + +[package.dependencies] +numpy = "*" +packaging = "*" +pandas = ">=1.1.1" + +[package.extras] +dev = ["black", "datasetsforecast (==0.0.8)", "nbdev (<2.3.26)", "numba (>=0.58.0)", "pandas[plot]", "plotly", "plotly-resampler", "polars[numpy]", "pyarrow", "scipy"] +plotting = ["pandas[plot]", "plotly", "plotly-resampler"] +polars = ["polars[numpy]"] + [[package]] name = "wandb" version = "0.12.21" @@ -6348,8 +6478,8 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] [extras] -all = ["accelerate", "einops", "huggingface-hub", "optuna", "prophet", "pytorch-forecasting", "pytorch-lightning", "pyts", "sqlalchemy", "statsforecast", "torch", "transformers", "tsfresh", "wandb"] -all-dev = ["GitPython", "Sphinx", "accelerate", "black", "click", "click", "codespell", "einops", "flake8", "flake8-bugbear", "flake8-comprehensions", "flake8-docstrings", "huggingface-hub", "ipywidgets", "isort", "jupyter", "mypy", "myst-parser", "nbconvert", "nbqa", "nbsphinx", "optuna", "pep8-naming", "prophet", "pydata-sphinx-theme", "pytest", "pytest-cov", "pytest-shard", "pytorch-forecasting", "pytorch-lightning", "pyts", "semver", "semver", "sphinx-design", "sphinx-mathjax-offline", "sqlalchemy", "statsforecast", "torch", "transformers", "tsfresh", "types-PyYAML", "types-setuptools", "wandb"] +all = ["accelerate", "einops", "huggingface-hub", "jax", "jaxlib", "optuna", "prophet", "pytorch-forecasting", "pytorch-lightning", "pyts", "sqlalchemy", "statsforecast", "torch", "transformers", "tsfresh", "utilsforecast", "wandb"] +all-dev = ["GitPython", "Sphinx", "accelerate", "black", "click", "click", "codespell", "einops", "flake8", "flake8-bugbear", "flake8-comprehensions", "flake8-docstrings", "huggingface-hub", "ipywidgets", "isort", "jax", "jaxlib", "jupyter", "mypy", "myst-parser", "nbconvert", "nbqa", "nbsphinx", "optuna", "pep8-naming", "prophet", "pydata-sphinx-theme", "pytest", "pytest-cov", "pytest-shard", "pytorch-forecasting", "pytorch-lightning", "pyts", "semver", "semver", "sphinx-design", "sphinx-mathjax-offline", "sqlalchemy", "statsforecast", "torch", "transformers", "tsfresh", "types-PyYAML", "types-setuptools", "utilsforecast", "wandb"] auto = ["optuna", "sqlalchemy"] chronos = ["accelerate", "huggingface-hub", "torch", "transformers"] classification = ["pyts", "tsfresh"] @@ -6360,10 +6490,11 @@ release = ["click", "semver"] statsforecast = ["statsforecast"] style = ["black", "codespell", "flake8", "flake8-bugbear", "flake8-comprehensions", "flake8-docstrings", "isort", "mypy", "nbqa", "pep8-naming", "types-PyYAML", "types-setuptools"] tests = ["pytest", "pytest-cov", "pytest-shard"] +timesfm = ["huggingface-hub", "jax", "jaxlib", "torch", "utilsforecast"] torch = ["einops", "pytorch-forecasting", "pytorch-lightning", "torch"] wandb = ["wandb"] [metadata] lock-version = "2.0" python-versions = ">=3.8.0, <3.11.0" -content-hash = "e77d4814a86e46ab4009fd8d0ba2900a3099263a8aad209befbdd764447e4691" +content-hash = "560122b4e3d4cbbe17fa501f1eb961c3a9605cf52a40b1f9df7129d3ab58f5cf" diff --git a/pyproject.toml b/pyproject.toml index c00f004ac..2e4cd2239 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,10 @@ transformers = {version = "<5", optional = true} accelerate = {version = "<1", optional = true} huggingface-hub = {version = ">=0.21,<0.23", optional = true} +jax = {version = "<1", optional = true} +jaxlib = {version = "<1", optional = true} +utilsforecast = {version = ">=0.1.10,<1", optional = true} + sphinx-mathjax-offline = {version = "^0.0.2", optional = true} nbsphinx = {version = "^0.9.0", optional = true} Sphinx = {version = "^6.2", optional = true} @@ -132,6 +136,7 @@ auto = ["optuna", "sqlalchemy"] classification = ["pyts", "tsfresh"] statsforecast = ["statsforecast"] chronos = ["torch", "transformers", "accelerate", "huggingface-hub"] +timesfm = ["torch", "jax", "jaxlib", "huggingface-hub", "utilsforecast"] # dev deps release = ["click", "semver"] @@ -147,7 +152,8 @@ all = [ "optuna", "sqlalchemy", "pyts", "tsfresh", "statsforecast", - "transformers", "accelerate", "huggingface-hub" + "transformers", "accelerate", "huggingface-hub", + "jax", "jaxlib", "utilsforecast" ] all-dev = [ @@ -163,7 +169,8 @@ all-dev = [ "jupyter", "nbconvert", "ipywidgets", "pyts", "tsfresh", "statsforecast", - "transformers", "accelerate", "huggingface-hub" + "transformers", "accelerate", "huggingface-hub", + "jax", "jaxlib", "utilsforecast" ] [tool.poetry.scripts] diff --git a/tests/test_models/test_inference/test_forecast.py b/tests/test_models/test_inference/test_forecast.py index bc3b9da02..7c9120114 100644 --- a/tests/test_models/test_inference/test_forecast.py +++ b/tests/test_models/test_inference/test_forecast.py @@ -46,6 +46,7 @@ from etna.models.nn import RNNModel from etna.models.nn import TFTModel from etna.models.nn import TFTNativeModel +from etna.models.nn import TimesFMModel from etna.models.nn.deepstate import CompositeSSM from etna.models.nn.deepstate import WeeklySeasonalitySSM from etna.transforms import LagTransform @@ -245,6 +246,17 @@ def test_forecast_in_sample_full_no_target_failed_chronos(self, model, transform with pytest.raises(ValueError, match="Dataset doesn't have any context timestamps."): self._test_forecast_in_sample_full_no_target(ts, model, transforms) + @pytest.mark.parametrize( + "model, transforms, dataset_name", + [ + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + ], + ) + def test_forecast_in_sample_full_no_target_failed_timesfm(self, model, transforms, dataset_name, request): + ts = request.getfixturevalue(dataset_name) + with pytest.raises(ValueError, match="Dataset doesn't have any context timestamps."): + self._test_forecast_in_sample_full_no_target(ts, model, transforms) + class TestForecastInSampleFull: """Test forecast on full train dataset. @@ -402,7 +414,18 @@ def test_forecast_in_sample_full_not_implemented(self, model, transforms, datase (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), ], ) - def test_forecast_in_sample_full_no_target_failed_chronos(self, model, transforms, dataset_name, request): + def test_forecast_in_sample_full_failed_chronos(self, model, transforms, dataset_name, request): + ts = request.getfixturevalue(dataset_name) + with pytest.raises(ValueError, match="Dataset doesn't have any context timestamps."): + _test_prediction_in_sample_full(ts, model, transforms, method_name="forecast") + + @pytest.mark.parametrize( + "model, transforms, dataset_name", + [ + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + ], + ) + def test_forecast_in_sample_full_failed_timesfm(self, model, transforms, dataset_name, request): ts = request.getfixturevalue(dataset_name) with pytest.raises(ValueError, match="Dataset doesn't have any context timestamps."): _test_prediction_in_sample_full(ts, model, transforms, method_name="forecast") @@ -493,6 +516,7 @@ def _test_forecast_in_sample_suffix_no_target(ts, model, transforms, num_skip_po (NBeatsGenericModel(input_size=7, output_size=50, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), ], ) def test_forecast_in_sample_suffix_no_target(self, model, transforms, dataset_name, request): @@ -617,6 +641,7 @@ class TestForecastInSampleSuffix: (NBeatsGenericModel(input_size=7, output_size=50, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), ], ) def test_forecast_in_sample_suffix(self, model, transforms, dataset_name, request): @@ -792,6 +817,7 @@ def _test_forecast_out_sample(ts, model, transforms, prediction_size=5): (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), ], ) def test_forecast_out_sample_datetime_timestamp(self, model, transforms, dataset_name, request): @@ -891,6 +917,18 @@ def test_forecast_out_sample_int_timestamp(self, model, transforms, dataset_name ts_int_timestamp = convert_ts_to_int_timestamp(ts, shift=10) self._test_forecast_out_sample(ts_int_timestamp, model, transforms) + @pytest.mark.parametrize( + "model, transforms, dataset_name", + [ + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + ], + ) + def test_forecast_out_sample_int_timestamp_failed_timesfm(self, model, transforms, dataset_name, request): + ts = request.getfixturevalue(dataset_name) + ts_int_timestamp = convert_ts_to_int_timestamp(ts, shift=10) + with pytest.raises(NotImplementedError, match="Data with None frequency isn't currently implemented!"): + self._test_forecast_out_sample(ts_int_timestamp, model, transforms) + @pytest.mark.parametrize( "model, transforms, dataset_name", [ @@ -1047,6 +1085,7 @@ def _test_forecast_out_sample_prefix(ts, model, transforms, full_prediction_size (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), ], ) def test_forecast_out_sample_prefix(self, model, transforms, dataset_name, request): @@ -1255,6 +1294,17 @@ def test_forecast_out_sample_suffix_failed_chronos(self, model, transforms, data with pytest.raises(AssertionError): self._test_forecast_out_sample_suffix(ts, model, transforms) + @pytest.mark.parametrize( + "model, transforms, dataset_name", + [ + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + ], + ) + def test_forecast_out_sample_suffix_failed_timesfm(self, model, transforms, dataset_name, request): + ts = request.getfixturevalue(dataset_name) + with pytest.raises(AssertionError): + self._test_forecast_out_sample_suffix(ts, model, transforms) + @to_be_fixed( raises=NotImplementedError, match="This model can't make forecast on out-of-sample data that goes after training data with a gap", @@ -1395,6 +1445,7 @@ def _test_forecast_mixed_in_out_sample(ts, model, transforms, num_skip_points=50 (NBeatsGenericModel(input_size=7, output_size=55, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), ], ) def test_forecast_mixed_in_out_sample(self, model, transforms, dataset_name, request): @@ -1557,6 +1608,7 @@ def _test_forecast_subset_segments(self, ts, model, transforms, segments, predic ), (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), ], ) def test_forecast_subset_segments(self, model, transforms, dataset_name, request): @@ -1734,6 +1786,7 @@ def _test_forecast_new_segments(self, ts, model, transforms, train_segments, pre (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), ], ) def test_forecast_new_segments(self, model, transforms, dataset_name, request): diff --git a/tests/test_models/test_inference/test_predict.py b/tests/test_models/test_inference/test_predict.py index f53019df5..601703b36 100644 --- a/tests/test_models/test_inference/test_predict.py +++ b/tests/test_models/test_inference/test_predict.py @@ -44,6 +44,7 @@ from etna.models.nn import RNNModel from etna.models.nn import TFTModel from etna.models.nn import TFTNativeModel +from etna.models.nn import TimesFMModel from etna.models.nn.deepstate import CompositeSSM from etna.models.nn.deepstate import WeeklySeasonalitySSM from etna.transforms import LagTransform @@ -195,6 +196,7 @@ def test_predict_in_sample_full_failed_not_enough_context(self, model, transform (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), ], ) def test_predict_in_sample_full_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -326,6 +328,7 @@ def test_predict_in_sample_suffix_datetime_timestamp(self, model, transforms, da (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), ], ) def test_predict_in_sample_suffix_datetime_timestamp_failed_not_implemented_predict( @@ -472,6 +475,7 @@ def test_predict_in_sample_suffix_int_timestamp_failed(self, model, transforms, (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), ], ) def test_predict_in_sample_suffix_int_timestamp_failed_not_implemented_predict( @@ -614,6 +618,7 @@ def test_predict_out_sample(self, model, transforms, dataset_name, request): (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), ], ) def test_predict_out_sample_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -773,6 +778,7 @@ def test_predict_out_sample_prefix(self, model, transforms, dataset_name, reques (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), ], ) def test_predict_out_sample_prefix_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -946,6 +952,7 @@ def test_predict_out_sample_suffix(self, model, transforms, dataset_name, reques (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), ], ) def test_predict_out_sample_suffix_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -1124,6 +1131,7 @@ def test_predict_mixed_in_out_sample(self, model, transforms, dataset_name, requ (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), ], ) def test_predict_mixed_in_out_sample_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -1290,6 +1298,7 @@ def test_predict_subset_segments(self, model, transforms, dataset_name, request) (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), ], ) def test_predict_subset_segments_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -1421,6 +1430,7 @@ def test_predict_new_segments(self, model, transforms, dataset_name, request): (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), ], ) def test_predict_new_segments_failed_not_implemented_predict(self, model, transforms, dataset_name, request): diff --git a/tests/test_models/test_nn/test_timesfm.py b/tests/test_models/test_nn/test_timesfm.py new file mode 100644 index 000000000..cfbb88b87 --- /dev/null +++ b/tests/test_models/test_nn/test_timesfm.py @@ -0,0 +1,145 @@ +import os +from pathlib import Path + +import pytest +from pandas.testing import assert_frame_equal + +from etna.datasets import TSDataset +from etna.datasets import generate_ar_df +from etna.libs.timesfm import TimesFmTorch +from etna.models.nn import TimesFMModel +from etna.pipeline import Pipeline + + +@pytest.fixture +def ts_increasing_integers(): + n = 128 + df = generate_ar_df(start_time="2001-01-01", periods=n, n_segments=2) + df["target"] = list(range(n)) + list(range(100, 100 + n)) + ts = TSDataset(df, freq="D") + return ts + + +@pytest.fixture +def expected_ts_increasing_integers(): + df = generate_ar_df(start_time="2001-03-06", periods=1, n_segments=2) + df["target"] = [128.0] + [228.0] + ts = TSDataset(df, freq="D") + return ts + + +@pytest.mark.smoke +def test_chronos_url(tmp_path): + model_name = "timesfm-1.0-200m-pytorch.ckpt" + url = f"http://etna-github-prod.cdn-tinkoff.ru/timesfm/{model_name}" + _ = TimesFMModel(path_or_url=url, cache_dir=tmp_path) + assert os.path.exists(tmp_path / model_name) + + +@pytest.mark.smoke +def test_chronos_custom_cache_dir(tmp_path): + path_or_url = "google/timesfm-1.0-200m-pytorch" + model_name = path_or_url.split("/")[-1] + _ = TimesFMModel(path_or_url=path_or_url, cache_dir=tmp_path) + assert os.path.exists(tmp_path / f"models--google--{model_name}") + + +@pytest.mark.smoke +def test_context_size(): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=10) + assert model.context_size == 10 + + +@pytest.mark.smoke +def test_chronos_get_model(example_tsds): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") + assert isinstance(model.get_model(), TimesFmTorch) + + +@pytest.mark.smoke +def test_fit(example_tsds): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") + model.fit(example_tsds) + + +@pytest.mark.smoke +def test_predict(example_tsds): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") + with pytest.raises(NotImplementedError, match="Method predict isn't currently implemented!"): + model.predict(ts=example_tsds, prediction_size=1) + + +@pytest.mark.smoke +def test_forecast_warns_big_context_size(ts_increasing_integers): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=512) + pipeline = Pipeline(model=model, horizon=1) + pipeline.fit(ts_increasing_integers) + with pytest.warns(UserWarning, match="Actual length of a dataset is less that context size."): + _ = pipeline.forecast() + + +@pytest.mark.parametrize("encoder_length", [32, 64, 128]) +def test_forecast(ts_increasing_integers, expected_ts_increasing_integers, encoder_length): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=encoder_length) + pipeline = Pipeline(model=model, horizon=1) + pipeline.fit(ts_increasing_integers) + forecast = pipeline.forecast() + assert_frame_equal(forecast.df, expected_ts_increasing_integers.df, atol=2) + + +@pytest.mark.smoke +def test_forecast_failed_int_timestamps(example_tsds_int_timestamp): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") + pipeline = Pipeline(model=model, horizon=1) + pipeline.fit(example_tsds_int_timestamp) + with pytest.raises(NotImplementedError, match="Data with None frequency isn't currently implemented!"): + _ = pipeline.forecast() + + +@pytest.mark.smoke +@pytest.mark.parametrize("encoder_length", [16, 33]) +def test_forecast_wrong_context_len(ts_increasing_integers, encoder_length): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=encoder_length) + pipeline = Pipeline(model=model, horizon=1) + pipeline.fit(ts_increasing_integers) + with pytest.raises(RuntimeError, match=r"shape .+ is invalid for input of size \d+"): + _ = pipeline.forecast() + + +@pytest.mark.smoke +def test_forecast_without_fit(example_tsds): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32) + pipeline = Pipeline(model=model, horizon=1) + _ = pipeline.forecast(example_tsds) + + +@pytest.mark.smoke +def test_forecast_fails_components(example_tsds): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") + pipeline = Pipeline(model=model, horizon=1) + with pytest.raises(NotImplementedError, match="This mode isn't currently implemented!"): + pipeline.forecast(ts=example_tsds, return_components=True) + + +@pytest.mark.smoke +def test_list_models(): + assert TimesFMModel.list_models() == ["google/timesfm-1.0-200m-pytorch"] + + +@pytest.mark.smoke +def test_save_load(tmp_path, ts_increasing_integers): + path = Path(tmp_path) / "tmp.zip" + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32) + model.save(path) + loaded_model = TimesFMModel.load(path) + + pipeline = Pipeline(model=loaded_model, horizon=1) + pipeline.fit(ts_increasing_integers) + _ = pipeline.forecast() + assert isinstance(loaded_model, TimesFMModel) + + +@pytest.mark.smoke +def test_params_to_tune(): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") + assert len(model.params_to_tune()) == 0 From 2d7d02cc52cd193c53d994138f35f3fd8f7e729f Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Tue, 24 Dec 2024 12:56:58 +0300 Subject: [PATCH 02/20] add exogenous features --- etna/libs/timesfm/timesfm_base.py | 24 +-- etna/libs/timesfm/xreg_lib.py | 12 +- etna/models/nn/timesfm.py | 101 ++++++++++--- examples/202-NN_examples.ipynb | 169 ++++++++++++++++------ tests/test_models/test_nn/test_timesfm.py | 60 +++++++- 5 files changed, 284 insertions(+), 82 deletions(-) diff --git a/etna/libs/timesfm/timesfm_base.py b/etna/libs/timesfm/timesfm_base.py index 400cc6bfe..ade80e6c3 100644 --- a/etna/libs/timesfm/timesfm_base.py +++ b/etna/libs/timesfm/timesfm_base.py @@ -165,6 +165,9 @@ # limitations under the License. # Note: Copied from timesfm repository (https://github.com/google-research/timesfm/blob/154248137ccce29b01f4c3a765e85c3d9e4d92ba/src/timesfm/timesfm_base.py) +# replace print with logging + +import warnings """Base class for TimesFM inference. This will be common to PAX and Pytorch.""" @@ -172,8 +175,8 @@ import dataclasses import logging import multiprocessing -from typing import Any, Literal, Sequence, Optional, Tuple, List, Dict - +from typing import Any, Literal, Sequence, Optional, Tuple, List, Dict, Union +from pathlib import Path import numpy as np import pandas as pd @@ -202,8 +205,11 @@ def moving_average(arr, window_size): return [smoothed_arr, arr - smoothed_arr] -def freq_map(freq: str): +def freq_map(freq: Optional[str]): """Returns the frequency map for the given frequency string.""" + if freq is None: + warnings.warn("Frequency is None. Mapping it to 0, that can be not optimal. Better to set it to known frequency") + return 0 freq = str.upper(freq) if (freq.endswith("H") or freq.endswith("T") or freq.endswith("MIN") or freq.endswith("D") or freq.endswith("B") or freq.endswith("U")): @@ -284,11 +290,11 @@ class TimesFmCheckpoint: """ version: str = "jax" - path: Optional[str] = None + path: Optional[Union[str, Path]] = None huggingface_repo_id: Optional[str] = None type: Any = None step: Optional[int] = None - local_dir: Optional[str] = None + local_dir: Optional[Union[str, Path]] = None class TimesFmBase: @@ -758,7 +764,7 @@ def forecast_on_df( uids = [] if num_jobs == 1: if verbose: - print("Processing dataframe with single process.") + logging.info("Processing dataframe with single process.") for key, group in df_sorted.groupby("unique_id"): inp, uid = process_group( key, @@ -772,7 +778,7 @@ def forecast_on_df( if num_jobs == -1: num_jobs = multiprocessing.cpu_count() if verbose: - print("Processing dataframe with multiple processes.") + logging.info("Processing dataframe with multiple processes.") with multiprocessing.Pool(processes=num_jobs) as pool: results = pool.starmap( process_group, @@ -781,13 +787,13 @@ def forecast_on_df( ) new_inputs, uids = zip(*results) if verbose: - print("Finished preprocessing dataframe.") + logging.info("Finished preprocessing dataframe.") freq_inps = [freq_map(freq)] * len(new_inputs) _, full_forecast = self.forecast(new_inputs, freq=freq_inps, window_size=window_size) if verbose: - print("Finished forecasting.") + logging.info("Finished forecasting.") fcst_df = make_future_dataframe( uids=uids, last_times=df_sorted.groupby("unique_id")["ds"].tail(1), diff --git a/etna/libs/timesfm/xreg_lib.py b/etna/libs/timesfm/xreg_lib.py index 1cbb39781..ba7ab69ce 100644 --- a/etna/libs/timesfm/xreg_lib.py +++ b/etna/libs/timesfm/xreg_lib.py @@ -165,7 +165,7 @@ # limitations under the License. # Note: Copied from timesfm repository (https://github.com/google-research/timesfm/blob/154248137ccce29b01f4c3a765e85c3d9e4d92ba/src/timesfm/xreg_lib.py) - +# add check of sklearn version for OHE """Helper functions for in-context covariates and regression.""" import itertools @@ -176,6 +176,7 @@ import jax.numpy as jnp import numpy as np from sklearn import preprocessing +from sklearn import __version__ as sklearn_version Category = Union[int, str] @@ -496,11 +497,18 @@ def create_covariate_matrix( x_train = [(x_train - x_mean) / x_std] x_test = [(x_test - x_mean) / x_std] + sklearn_version_tuple = tuple(map(int, sklearn_version.split("."))) + encoder_params = {} + if sklearn_version_tuple < (1, 2): + encoder_params["sparse"] = False + else: + encoder_params["sparse_output"] = False + # Categorical features. Encode one by one. one_hot_encoder = preprocessing.OneHotEncoder( drop=one_hot_encoder_drop, - sparse_output=False, handle_unknown="ignore", + **encoder_params ) for name in sorted(self.train_dynamic_categorical_covariates.keys()): ohe_train = _unnest( diff --git a/etna/models/nn/timesfm.py b/etna/models/nn/timesfm.py index c687d371d..abf018ca1 100644 --- a/etna/models/nn/timesfm.py +++ b/etna/models/nn/timesfm.py @@ -3,8 +3,11 @@ from pathlib import Path from typing import Dict from typing import List +from typing import Literal +from typing import Optional from urllib import request +import numpy as np import pandas as pd from etna import SETTINGS @@ -16,6 +19,7 @@ from etna.libs.timesfm import TimesFmCheckpoint from etna.libs.timesfm import TimesFmHparams from etna.libs.timesfm import TimesFmTorch + from etna.libs.timesfm.timesfm_base import freq_map _DOWNLOAD_PATH = Path.home() / ".etna" / "timesfm" @@ -32,6 +36,8 @@ class TimesFMModel(NonPredictionIntervalContextRequiredAbstractModel): ------- This model doesn't support forecasting on misaligned data with `freq=None`. + Use :py:class:`~etna.transforms.TimeSeriesImputerTransform` to fill NaNs for stable behaviour. + Note ---- This model requires ``timesfm`` extension to be installed. @@ -42,8 +48,12 @@ def __init__( self, path_or_url: str, encoder_length: int = 512, - device: str = "cpu", + device: Literal["cpu", "gpu"] = "cpu", batch_size: int = 128, + static_reals: Optional[List[str]] = None, + static_categoricals: Optional[List[str]] = None, + time_varying_reals: Optional[List[str]] = None, + time_varying_categoricals: Optional[List[str]] = None, cache_dir: Path = _DOWNLOAD_PATH, ): """ @@ -62,11 +72,19 @@ def __init__( - If local path, it should be a file with model weights, that can be loaded by :py:func:`torch.load`. - If external url, it must be a file with model weights, that can be loaded by :py:func:`torch.load`. Model will be downloaded to ``cache_dir``. device: - Device type. Can be "cpu", "gpu". + Device type. Can be "cpu" or "gpu". encoder_length: Number of last timestamps to use as a context. It needs to be a multiplier of 32. batch_size: Batch size. It can be useful when inference is done on gpu. + static_reals: + Continuous features that have one unique feature value for the whole series. The first value in the series will be used for each feature. + static_categoricals: + Categorical features that have one unique feature value for the whole series. The first value in the series will be used for each feature. + time_varying_reals: + Time varying continuous features known for future. + time_varying_categoricals: + Time varying categorical features known for future. cache_dir: Local path to save model from huggingface during first model initialization. All following class initializations appropriate model version will be downloaded from this path. """ @@ -74,6 +92,10 @@ def __init__( self.encoder_length = encoder_length self.device = device self.batch_size = batch_size + self.static_reals = static_reals + self.static_categoricals = static_categoricals + self.time_varying_reals = time_varying_reals + self.time_varying_categoricals = time_varying_categoricals self.cache_dir = cache_dir self._set_pipeline() @@ -189,11 +211,8 @@ def forecast( ValueError: if dataset doesn't have any context timestamps. NotImplementedError: - if data with None frequency is used. + if forecasting is done without exogenous features and dataset has None frequency. """ - if ts.freq is None: - raise NotImplementedError("Data with None frequency isn't currently implemented!") - if return_components: raise NotImplementedError("This mode isn't currently implemented!") @@ -203,25 +222,71 @@ def forecast( if max_context_size < self.context_size: warnings.warn("Actual length of a dataset is less that context size. All history will be used as context.") - available_context_size = min(max_context_size, self.context_size) self.tfm._set_horizon(prediction_size) - target = ts.df.loc[ts.index[:-prediction_size], pd.IndexSlice[:, "target"]] - target = target.iloc[-available_context_size:] - df = TSDataset.to_flatten(target).dropna() - df = df.rename(columns={"segment": "unique_id", "timestamp": "ds"}) + end_idx = len(ts.index) - predictions = self.tfm.forecast_on_df(df, freq=ts.freq, value_name="target") + static_reals_dict = ( + {column: ts.df.loc[ts.index[0], pd.IndexSlice[:, column]].values.tolist() for column in self.static_reals} + if self.static_reals is not None + else None + ) + static_categoricals_dict = ( + { + column: ts.df.loc[ts.index[0], pd.IndexSlice[:, column]].values.tolist() + for column in self.static_categoricals + } + if self.static_categoricals is not None + else None + ) + time_varying_reals_dict = ( + { + column: ts.df.loc[:, pd.IndexSlice[:, column]].values.swapaxes(1, 0).tolist() + for column in self.time_varying_reals + } + if self.time_varying_reals is not None + else None + ) + time_varying_categoricals_dict = ( + { + column: ts.df.loc[:, pd.IndexSlice[:, column]].values.swapaxes(1, 0).tolist() + for column in self.time_varying_categoricals + } + if self.time_varying_categoricals is not None + else None + ) - end_idx = len(ts.index) future_ts = ts.tsdataset_idx_slice(start_idx=end_idx - prediction_size, end_idx=end_idx) - predictions = predictions.rename(columns={"unique_id": "segment", "ds": "timestamp", "timesfm": "target"}) - predictions = TSDataset.to_dataset(predictions) - future_ts.df.loc[:, pd.IndexSlice[:, "target"]] = predictions.loc[ - :, pd.IndexSlice[:, "target"] - ].values # .values is needed to cast predictions type of initial target type in ts + if static_reals_dict or static_categoricals_dict or time_varying_reals_dict or time_varying_categoricals_dict: + target = ts.df.loc[:, pd.IndexSlice[:, "target"]].dropna().values.swapaxes(1, 0).tolist() + + complex_forecast, _ = self.tfm.forecast_with_covariates( + inputs=target, + dynamic_numerical_covariates=time_varying_reals_dict, + dynamic_categorical_covariates=time_varying_categoricals_dict, + static_numerical_covariates=static_reals_dict, + static_categorical_covariates=static_categoricals_dict, + freq=[freq_map(ts.freq)] * len(ts.segments), + ) + future_ts.df.loc[:, pd.IndexSlice[:, "target"]] = np.vstack(complex_forecast).swapaxes(1, 0) + else: + if ts.freq is None: + raise NotImplementedError( + "Data with None frequency isn't currently implemented for forecasting without exogenous features." + ) + + target = ts.to_pandas(flatten=True, features=["target"]).dropna() + target = target.rename(columns={"segment": "unique_id", "timestamp": "ds"}) + + predictions = self.tfm.forecast_on_df(target, freq=ts.freq, value_name="target") + + predictions = predictions.rename(columns={"unique_id": "segment", "ds": "timestamp", "timesfm": "target"}) + predictions = TSDataset.to_dataset(predictions) + future_ts.df.loc[:, pd.IndexSlice[:, "target"]] = predictions.loc[ + :, pd.IndexSlice[:, "target"] + ].values # .values is needed to cast predictions type of initial target type in ts return future_ts @staticmethod diff --git a/examples/202-NN_examples.ipynb b/examples/202-NN_examples.ipynb index d8ad46123..bdedfdc57 100644 --- a/examples/202-NN_examples.ipynb +++ b/examples/202-NN_examples.ipynb @@ -5039,7 +5039,7 @@ }, { "cell_type": "markdown", - "id": "708ecd8f", + "id": "3d1c5cdc", "metadata": {}, "source": [ "### 3.13 TimesFm Model \n", @@ -5049,8 +5049,8 @@ }, { "cell_type": "code", - "execution_count": 10, - "id": "95832e68", + "execution_count": 9, + "id": "b1cddb5d", "metadata": {}, "outputs": [], "source": [ @@ -5059,7 +5059,7 @@ }, { "cell_type": "markdown", - "id": "4130cab8", + "id": "62e2a14a", "metadata": {}, "source": [ "Now only one model is available." @@ -5067,8 +5067,8 @@ }, { "cell_type": "code", - "execution_count": 12, - "id": "24a4a369", + "execution_count": 10, + "id": "c9728978", "metadata": {}, "outputs": [ { @@ -5077,7 +5077,7 @@ "['google/timesfm-1.0-200m-pytorch']" ] }, - "execution_count": 12, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -5088,7 +5088,7 @@ }, { "cell_type": "markdown", - "id": "2173dc9f", + "id": "2cff6154", "metadata": {}, "source": [ "Be careful. `encoder_length` needs to be a multiplier of 32." @@ -5096,14 +5096,14 @@ }, { "cell_type": "code", - "execution_count": 13, - "id": "0473e740", + "execution_count": 11, + "id": "4482e56d", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "cbe5ceeb717c4b25a6261ac67bb69a4f", + "model_id": "f01e125a40134bd680b31057fb7fa78e", "version_major": 2, "version_minor": 0 }, @@ -5120,47 +5120,105 @@ "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.3s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s finished\n", - "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n" + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.4s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.7s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.7s finished\n", + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.5s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.7s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.7s finished\n", + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s finished\n" ] - }, + } + ], + "source": [ + "set_seed()\n", + "\n", + "model_timesfm = TimesFMModel(path_or_url=\"google/timesfm-1.0-200m-pytorch\", encoder_length=32)\n", + "\n", + "pipeline_timesfm = Pipeline(model=model_timesfm, horizon=HORIZON, transforms=[])\n", + "\n", + "metrics_timesfm, forecast_timesfm, fold_info_timesfm = pipeline_timesfm.backtest(\n", + " ts, metrics=metrics, n_folds=3, n_jobs=1\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "04f62aad", + "metadata": {}, + "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Processing dataframe with single process.\n", - "Finished preprocessing dataframe.\n", - "Finished forecasting.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.2s remaining: 0.0s\n" + "Average SMAPE for TimesFM: 5.249\n" ] - }, + } + ], + "source": [ + "score = metrics_timesfm[\"SMAPE\"].mean()\n", + "print(f\"Average SMAPE for TimesFM: {score:.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "149a91eb", + "metadata": {}, + "outputs": [], + "source": [ + "plot_backtest(forecast_timesfm, ts, history_len=20)" + ] + }, + { + "cell_type": "markdown", + "id": "bb035a81", + "metadata": {}, + "source": [ + "Model can work with exogenous features." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "c0588f39", + "metadata": {}, + "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Processing dataframe with single process.\n", - "Finished preprocessing dataframe.\n", - "Finished forecasting.\n", - "Processing dataframe with single process.\n", - "Finished preprocessing dataframe.\n", - "Finished forecasting.\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "20fed636fead4ca6b8f533edcde1a686", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 3 files: 0%| | 0/3 [00:00 Date: Tue, 24 Dec 2024 13:13:59 +0300 Subject: [PATCH 03/20] chronos minor fix --- etna/models/nn/chronos/base.py | 5 +- examples/202-NN_examples.ipynb | 243 +++++++++++++++++++++------------ 2 files changed, 159 insertions(+), 89 deletions(-) diff --git a/etna/models/nn/chronos/base.py b/etna/models/nn/chronos/base.py index a20514a1a..7505b150c 100644 --- a/etna/models/nn/chronos/base.py +++ b/etna/models/nn/chronos/base.py @@ -202,10 +202,9 @@ def _forecast( if max_context_size < self.context_size: warnings.warn("Actual length of a dataset is less that context size. All history will be used as context.") - available_context_size = min(max_context_size, self.context_size) - target = ts.df.loc[:, pd.IndexSlice[:, "target"]] - context = torch.tensor(target.values.T[:, :available_context_size]) + target = ts.df.loc[:, pd.IndexSlice[:, "target"]].dropna() + context = torch.tensor(target.values.T) if prediction_interval: quantiles_forecast, target_forecast = self.pipeline.predict_quantiles( diff --git a/examples/202-NN_examples.ipynb b/examples/202-NN_examples.ipynb index bdedfdc57..1a03ef37e 100644 --- a/examples/202-NN_examples.ipynb +++ b/examples/202-NN_examples.ipynb @@ -37,6 +37,16 @@ " * [TimesFM Model](#section_3_13)" ] }, + { + "cell_type": "code", + "execution_count": 1, + "id": "dfdc181e", + "metadata": {}, + "outputs": [], + "source": [ + "a = 1" + ] + }, { "cell_type": "code", "execution_count": 1, @@ -49,7 +59,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "0eb5e69e", "metadata": { "tags": [] @@ -63,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "a1a1a571", "metadata": { "pycharm": { @@ -96,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "7851ddcc", "metadata": { "tags": [] @@ -129,7 +139,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "id": "bfe220cb", "metadata": { "tags": [] @@ -205,7 +215,7 @@ "4 2019-01-05 segment_a 279" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -225,7 +235,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "c1f7de68", "metadata": { "tags": [] @@ -327,7 +337,7 @@ "2019-01-05 279 137 104 384" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -370,7 +380,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "f3ae2de0", "metadata": { "tags": [] @@ -382,7 +392,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "af6a6035", "metadata": { "tags": [] @@ -445,7 +455,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "fbb2c279-f505-4f1b-b0e3-5e94369f9673", "metadata": { "tags": [] @@ -4651,6 +4661,64 @@ "Chronos is pretrained model for zero-shot forecasting." ] }, + { + "cell_type": "code", + "execution_count": 80, + "id": "13f5f221", + "metadata": {}, + "outputs": [], + "source": [ + "a = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "17d7b955", + "metadata": {}, + "outputs": [], + "source": [ + "a = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "1c214441", + "metadata": {}, + "outputs": [], + "source": [ + "a = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "id": "085b45de", + "metadata": {}, + "outputs": [], + "source": [ + "a = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "934b244c", + "metadata": {}, + "outputs": [], + "source": [ + "a = 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a247b7ef", + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": 84, @@ -4718,15 +4786,15 @@ "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.2s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.3s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.5s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.5s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s finished\n" ] } ], @@ -4816,15 +4884,15 @@ "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.7s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.5s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.2s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.2s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.4s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s finished\n" ] } ], @@ -4878,20 +4946,20 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.3s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.3s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.4s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.4s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 3.7s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 7.5s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 11.4s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 11.4s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.6s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.9s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.9s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s finished\n" ] } ], @@ -4989,20 +5057,20 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.2s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.2s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.3s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.3s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.3s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.4s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.4s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s finished\n" ] } ], @@ -5039,7 +5107,7 @@ }, { "cell_type": "markdown", - "id": "3d1c5cdc", + "id": "83d6d047", "metadata": {}, "source": [ "### 3.13 TimesFm Model \n", @@ -5049,8 +5117,8 @@ }, { "cell_type": "code", - "execution_count": 9, - "id": "b1cddb5d", + "execution_count": 98, + "id": "eee62c5e", "metadata": {}, "outputs": [], "source": [ @@ -5059,7 +5127,7 @@ }, { "cell_type": "markdown", - "id": "62e2a14a", + "id": "999d5479", "metadata": {}, "source": [ "Now only one model is available." @@ -5067,8 +5135,8 @@ }, { "cell_type": "code", - "execution_count": 10, - "id": "c9728978", + "execution_count": 99, + "id": "d19f4212", "metadata": {}, "outputs": [ { @@ -5077,7 +5145,7 @@ "['google/timesfm-1.0-200m-pytorch']" ] }, - "execution_count": 10, + "execution_count": 99, "metadata": {}, "output_type": "execute_result" } @@ -5088,7 +5156,7 @@ }, { "cell_type": "markdown", - "id": "2cff6154", + "id": "ae998579", "metadata": {}, "source": [ "Be careful. `encoder_length` needs to be a multiplier of 32." @@ -5096,14 +5164,14 @@ }, { "cell_type": "code", - "execution_count": 11, - "id": "4482e56d", + "execution_count": 100, + "id": "8697927d", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "f01e125a40134bd680b31057fb7fa78e", + "model_id": "20c4816d7a1f43fd86b152a6257ea4cf", "version_major": 2, "version_minor": 0 }, @@ -5120,14 +5188,14 @@ "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.4s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.7s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.7s finished\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.2s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.5s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.7s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.7s finished\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.4s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s remaining: 0.0s\n", @@ -5150,8 +5218,8 @@ }, { "cell_type": "code", - "execution_count": 12, - "id": "04f62aad", + "execution_count": 101, + "id": "83422d02", "metadata": {}, "outputs": [ { @@ -5169,17 +5237,28 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "149a91eb", + "execution_count": 102, + "id": "391044d7", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "plot_backtest(forecast_timesfm, ts, history_len=20)" ] }, { "cell_type": "markdown", - "id": "bb035a81", + "id": "334b9a6b", "metadata": {}, "source": [ "Model can work with exogenous features." @@ -5187,14 +5266,14 @@ }, { "cell_type": "code", - "execution_count": 22, - "id": "c0588f39", + "execution_count": 103, + "id": "de5f4d5d", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "20fed636fead4ca6b8f533edcde1a686", + "model_id": "eb26373760e0483abd16481719195154", "version_major": 2, "version_minor": 0 }, @@ -5212,13 +5291,13 @@ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.2s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s finished\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.3s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.3s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.3s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.5s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.6s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.9s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.9s finished\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.8s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.8s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s remaining: 0.0s\n", @@ -5233,14 +5312,14 @@ "transforms = [\n", " SegmentEncoderTransform(),\n", " LagTransform(in_column=\"target\", lags=[HORIZON], out_column=\"lag\"),\n", - " DateFlagsTransform(day_number_in_week=True, day_number_in_month=False, is_weekend=False, out_column=\"flag\"),\n", + " DateFlagsTransform(day_number_in_week=True, day_number_in_month=False, is_weekend=False, out_column=\"dateflag\"),\n", "]\n", "model_timesfm = TimesFMModel(\n", " path_or_url=\"google/timesfm-1.0-200m-pytorch\",\n", " encoder_length=32,\n", " static_categoricals=[\"segment_code\"],\n", " time_varying_reals=[f\"lag_{HORIZON}\"],\n", - " time_varying_categoricals=[\"flag_day_number_in_week\"],\n", + " time_varying_categoricals=[\"dateflag_day_number_in_week\"],\n", ")\n", "\n", "pipeline_timesfm = Pipeline(model=model_timesfm, horizon=HORIZON, transforms=transforms)\n", @@ -5252,15 +5331,15 @@ }, { "cell_type": "code", - "execution_count": 23, - "id": "4467a8c0", + "execution_count": 104, + "id": "daeddd44", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Average SMAPE for TimesFM: 6.784\n" + "Average SMAPE for TimesFM with exogenous features: 6.784\n" ] } ], @@ -5268,14 +5347,6 @@ "score = metrics_timesfm[\"SMAPE\"].mean()\n", "print(f\"Average SMAPE for TimesFM with exogenous features: {score:.3f}\")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6d6a8774", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { From c609ab7b038e13826c5954f4ca9187a52c3abff4 Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Tue, 24 Dec 2024 13:19:40 +0300 Subject: [PATCH 04/20] update changelog --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b7d0a3b9a..0d28abf23 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,8 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add `ChronosModel` ([#511](https://github.com/etna-team/etna/pull/511)) - Add `ChronosBoltModel` ([#511](https://github.com/etna-team/etna/pull/511)) - Add usage example of `ChronosModel` and `ChronosBoltModel` in `202-NN_examples` notebook ([#511](https://github.com/etna-team/etna/pull/511)) -- -- +- Add `TimesFMModel` ([#544](https://github.com/etna-team/etna/pull/544)) +- Add usage example of `TimesFMModel` in `202-NN_examples` notebook ([#544](https://github.com/etna-team/etna/pull/544)) - - - Add `MissingCounter` metric ([#520](https://github.com/etna-team/etna/pull/520)) From 2faa926ed4c5faa7d7c5e781d8ee5bba869abdca Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Tue, 24 Dec 2024 13:25:50 +0300 Subject: [PATCH 05/20] minor docstring fix --- etna/models/nn/timesfm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/etna/models/nn/timesfm.py b/etna/models/nn/timesfm.py index abf018ca1..0f6bdf396 100644 --- a/etna/models/nn/timesfm.py +++ b/etna/models/nn/timesfm.py @@ -34,7 +34,7 @@ class TimesFMModel(NonPredictionIntervalContextRequiredAbstractModel): Warning ------- - This model doesn't support forecasting on misaligned data with `freq=None`. + This model doesn't support forecasting on misaligned data with `freq=None` without exogenous features. Use :py:class:`~etna.transforms.TimeSeriesImputerTransform` to fill NaNs for stable behaviour. From 5d537b0cb85382b3f68ebde6b86411bf781c5af1 Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Tue, 24 Dec 2024 13:34:01 +0300 Subject: [PATCH 06/20] lints --- etna/libs/timesfm/xreg_lib.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/etna/libs/timesfm/xreg_lib.py b/etna/libs/timesfm/xreg_lib.py index ba7ab69ce..4521009bf 100644 --- a/etna/libs/timesfm/xreg_lib.py +++ b/etna/libs/timesfm/xreg_lib.py @@ -605,9 +605,9 @@ def fit( # Runs jitted version of the solvers which are quicker at the cost of # running jitting during the first time calling. Re-jitting happens whenever # new (padded) shapes are encountered. - # Ocassionally it helps with the speed and the accuracy if we force single + # Occasionally it helps with the speed and the accuracy if we force single # thread execution on cpu for accelerator machines: - # 1. Avoid moving data to accelarator memory. + # 1. Avoid moving data to accelerator memory. # 2. Avoid precision loss if any. with jax.default_device(device): x_train_raw = _to_padded_jax_array(x_train_raw) From c9726f711b113941149fe35de6bdc9be1468f58d Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Tue, 24 Dec 2024 13:38:45 +0300 Subject: [PATCH 07/20] minor notebook fix --- examples/202-NN_examples.ipynb | 90 +++++----------------------------- 1 file changed, 11 insertions(+), 79 deletions(-) diff --git a/examples/202-NN_examples.ipynb b/examples/202-NN_examples.ipynb index 1a03ef37e..7af8f9d14 100644 --- a/examples/202-NN_examples.ipynb +++ b/examples/202-NN_examples.ipynb @@ -37,16 +37,6 @@ " * [TimesFM Model](#section_3_13)" ] }, - { - "cell_type": "code", - "execution_count": 1, - "id": "dfdc181e", - "metadata": {}, - "outputs": [], - "source": [ - "a = 1" - ] - }, { "cell_type": "code", "execution_count": 1, @@ -4661,64 +4651,6 @@ "Chronos is pretrained model for zero-shot forecasting." ] }, - { - "cell_type": "code", - "execution_count": 80, - "id": "13f5f221", - "metadata": {}, - "outputs": [], - "source": [ - "a = 1" - ] - }, - { - "cell_type": "code", - "execution_count": 81, - "id": "17d7b955", - "metadata": {}, - "outputs": [], - "source": [ - "a = 1" - ] - }, - { - "cell_type": "code", - "execution_count": 82, - "id": "1c214441", - "metadata": {}, - "outputs": [], - "source": [ - "a = 1" - ] - }, - { - "cell_type": "code", - "execution_count": 83, - "id": "085b45de", - "metadata": {}, - "outputs": [], - "source": [ - "a = 1" - ] - }, - { - "cell_type": "code", - "execution_count": 79, - "id": "934b244c", - "metadata": {}, - "outputs": [], - "source": [ - "a = 1" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a247b7ef", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": 84, @@ -5107,7 +5039,7 @@ }, { "cell_type": "markdown", - "id": "83d6d047", + "id": "e39ebb25", "metadata": {}, "source": [ "### 3.13 TimesFm Model \n", @@ -5118,7 +5050,7 @@ { "cell_type": "code", "execution_count": 98, - "id": "eee62c5e", + "id": "f7cae59a", "metadata": {}, "outputs": [], "source": [ @@ -5127,7 +5059,7 @@ }, { "cell_type": "markdown", - "id": "999d5479", + "id": "2e999f64", "metadata": {}, "source": [ "Now only one model is available." @@ -5136,7 +5068,7 @@ { "cell_type": "code", "execution_count": 99, - "id": "d19f4212", + "id": "a268ee77", "metadata": {}, "outputs": [ { @@ -5156,7 +5088,7 @@ }, { "cell_type": "markdown", - "id": "ae998579", + "id": "e57e1f23", "metadata": {}, "source": [ "Be careful. `encoder_length` needs to be a multiplier of 32." @@ -5165,7 +5097,7 @@ { "cell_type": "code", "execution_count": 100, - "id": "8697927d", + "id": "cdc514fa", "metadata": {}, "outputs": [ { @@ -5219,7 +5151,7 @@ { "cell_type": "code", "execution_count": 101, - "id": "83422d02", + "id": "625ef289", "metadata": {}, "outputs": [ { @@ -5238,7 +5170,7 @@ { "cell_type": "code", "execution_count": 102, - "id": "391044d7", + "id": "32f13834", "metadata": {}, "outputs": [ { @@ -5258,7 +5190,7 @@ }, { "cell_type": "markdown", - "id": "334b9a6b", + "id": "095326bd", "metadata": {}, "source": [ "Model can work with exogenous features." @@ -5267,7 +5199,7 @@ { "cell_type": "code", "execution_count": 103, - "id": "de5f4d5d", + "id": "10dede77", "metadata": {}, "outputs": [ { @@ -5332,7 +5264,7 @@ { "cell_type": "code", "execution_count": 104, - "id": "daeddd44", + "id": "379309bd", "metadata": {}, "outputs": [ { From 20eac8aa83e2523276efaf5d45932ffb45eb6643 Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Tue, 24 Dec 2024 17:34:39 +0300 Subject: [PATCH 08/20] minor fix --- etna/libs/timesfm/timesfm.py | 6 +++--- etna/libs/timesfm/timesfm_base.py | 8 ++++---- etna/models/nn/timesfm.py | 2 +- tests/test_models/test_nn/test_timesfm.py | 6 +++--- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/etna/libs/timesfm/timesfm.py b/etna/libs/timesfm/timesfm.py index f707879fe..d46782e3c 100644 --- a/etna/libs/timesfm/timesfm.py +++ b/etna/libs/timesfm/timesfm.py @@ -206,7 +206,7 @@ def __post_init__(self): torch.cuda.is_available() and self.backend == "gpu") else "cpu") self._median_index = -1 - def _set_horizon(self, horizon): + def _set_horizon(self, horizon): # changed: added to change horizon after initialization self.horizon_len = horizon def load_from_checkpoint( @@ -216,10 +216,10 @@ def load_from_checkpoint( """Loads a checkpoint and compiles the decoder.""" checkpoint_path = checkpoint.path repo_id = checkpoint.huggingface_repo_id - if not os.path.exists(checkpoint_path): + if not os.path.exists(checkpoint_path): # changed: make loading similar to chronos checkpoint_path = path.join(snapshot_download(checkpoint_path, cache_dir=checkpoint.local_dir), "torch_model.ckpt") self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) - loaded_checkpoint = torch.load(checkpoint_path) # TODO weights_only=True + loaded_checkpoint = torch.load(checkpoint_path) # changed: remove weights_only=True due to attribute absence in low torch versions logging.info("Loading checkpoint from %s", checkpoint_path) self._model.load_state_dict(loaded_checkpoint) logging.info("Sending checkpoint to device %s", f"{self._device}") diff --git a/etna/libs/timesfm/timesfm_base.py b/etna/libs/timesfm/timesfm_base.py index ade80e6c3..14755cbb3 100644 --- a/etna/libs/timesfm/timesfm_base.py +++ b/etna/libs/timesfm/timesfm_base.py @@ -207,7 +207,7 @@ def moving_average(arr, window_size): def freq_map(freq: Optional[str]): """Returns the frequency map for the given frequency string.""" - if freq is None: + if freq is None: # changed: added this case to handle int timestamps during forecasting with exogenous features warnings.warn("Frequency is None. Mapping it to 0, that can be not optimal. Better to set it to known frequency") return 0 freq = str.upper(freq) @@ -764,7 +764,7 @@ def forecast_on_df( uids = [] if num_jobs == 1: if verbose: - logging.info("Processing dataframe with single process.") + logging.info("Processing dataframe with single process.") # changed: replace print for key, group in df_sorted.groupby("unique_id"): inp, uid = process_group( key, @@ -778,7 +778,7 @@ def forecast_on_df( if num_jobs == -1: num_jobs = multiprocessing.cpu_count() if verbose: - logging.info("Processing dataframe with multiple processes.") + logging.info("Processing dataframe with multiple processes.") # changed: replace print with multiprocessing.Pool(processes=num_jobs) as pool: results = pool.starmap( process_group, @@ -787,7 +787,7 @@ def forecast_on_df( ) new_inputs, uids = zip(*results) if verbose: - logging.info("Finished preprocessing dataframe.") + logging.info("Finished preprocessing dataframe.") # changed: replace print freq_inps = [freq_map(freq)] * len(new_inputs) _, full_forecast = self.forecast(new_inputs, freq=freq_inps, diff --git a/etna/models/nn/timesfm.py b/etna/models/nn/timesfm.py index 0f6bdf396..8a3bcc155 100644 --- a/etna/models/nn/timesfm.py +++ b/etna/models/nn/timesfm.py @@ -274,7 +274,7 @@ def forecast( else: if ts.freq is None: raise NotImplementedError( - "Data with None frequency isn't currently implemented for forecasting without exogenous features." + "Forecasting misaligned data with freq=None without exogenous features isn't currently implemented." ) target = ts.to_pandas(flatten=True, features=["target"]).dropna() diff --git a/tests/test_models/test_nn/test_timesfm.py b/tests/test_models/test_nn/test_timesfm.py index 4a91630c0..739d51afc 100644 --- a/tests/test_models/test_nn/test_timesfm.py +++ b/tests/test_models/test_nn/test_timesfm.py @@ -33,7 +33,7 @@ def expected_ts_increasing_integers(): @pytest.mark.smoke -def test_chronos_url(tmp_path): +def test_url(tmp_path): model_name = "timesfm-1.0-200m-pytorch.ckpt" url = f"http://etna-github-prod.cdn-tinkoff.ru/timesfm/{model_name}" _ = TimesFMModel(path_or_url=url, cache_dir=tmp_path) @@ -41,7 +41,7 @@ def test_chronos_url(tmp_path): @pytest.mark.smoke -def test_chronos_custom_cache_dir(tmp_path): +def test_cache_dir(tmp_path): path_or_url = "google/timesfm-1.0-200m-pytorch" model_name = path_or_url.split("/")[-1] _ = TimesFMModel(path_or_url=path_or_url, cache_dir=tmp_path) @@ -55,7 +55,7 @@ def test_context_size(): @pytest.mark.smoke -def test_chronos_get_model(example_tsds): +def test_get_model(example_tsds): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") assert isinstance(model.get_model(), TimesFmTorch) From 2b89f55bed543e2635a1521020de9b9c040dd437 Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Tue, 24 Dec 2024 17:45:42 +0300 Subject: [PATCH 09/20] minor fix tests --- tests/test_models/test_nn/test_timesfm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models/test_nn/test_timesfm.py b/tests/test_models/test_nn/test_timesfm.py index 739d51afc..4131f5d73 100644 --- a/tests/test_models/test_nn/test_timesfm.py +++ b/tests/test_models/test_nn/test_timesfm.py @@ -118,7 +118,7 @@ def test_forecast_only_target_failed_int_timestamps(example_tsds_int_timestamp): pipeline.fit(example_tsds_int_timestamp) with pytest.raises( NotImplementedError, - match="Data with None frequency isn't currently implemented for forecasting without exogenous features.", + match="Forecasting misaligned data with freq=None without exogenous features isn't currently implemented.", ): _ = pipeline.forecast() From 10d60b4fb807ce93b0e0639a271cf2e4d18ff40d Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Wed, 25 Dec 2024 09:59:05 +0300 Subject: [PATCH 10/20] fix notebook, tests --- examples/202-NN_examples.ipynb | 164 +++++++++++++----- .../test_inference/test_forecast.py | 1 + tests/test_models/test_nn/test_timesfm.py | 2 - 3 files changed, 117 insertions(+), 50 deletions(-) diff --git a/examples/202-NN_examples.ipynb b/examples/202-NN_examples.ipynb index 7af8f9d14..71d116237 100644 --- a/examples/202-NN_examples.ipynb +++ b/examples/202-NN_examples.ipynb @@ -37,6 +37,16 @@ " * [TimesFM Model](#section_3_13)" ] }, + { + "cell_type": "code", + "execution_count": 1, + "id": "05682333", + "metadata": {}, + "outputs": [], + "source": [ + "a = 1" + ] + }, { "cell_type": "code", "execution_count": 1, @@ -4651,6 +4661,64 @@ "Chronos is pretrained model for zero-shot forecasting." ] }, + { + "cell_type": "code", + "execution_count": 80, + "id": "d37a75b8", + "metadata": {}, + "outputs": [], + "source": [ + "a = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "17322770", + "metadata": {}, + "outputs": [], + "source": [ + "a = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "3b147434", + "metadata": {}, + "outputs": [], + "source": [ + "a = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "id": "7e37a4b5", + "metadata": {}, + "outputs": [], + "source": [ + "a = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "e596bf7e", + "metadata": {}, + "outputs": [], + "source": [ + "a = 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1be903d8", + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": 84, @@ -4720,8 +4788,8 @@ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.2s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s finished\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.3s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.3s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s remaining: 0.0s\n", @@ -4761,27 +4829,6 @@ "print(f\"Average SMAPE for Chronos tiny: {score:.3f}\")" ] }, - { - "cell_type": "code", - "execution_count": 88, - "id": "8334cd6a", - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "plot_backtest(forecast_chronos, ts, history_len=20)" - ] - }, { "cell_type": "markdown", "id": "e78ef2ad", @@ -4792,7 +4839,7 @@ }, { "cell_type": "code", - "execution_count": 89, + "execution_count": 88, "id": "cfff335c", "metadata": {}, "outputs": [], @@ -4802,7 +4849,7 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 89, "id": "5b07ab78", "metadata": {}, "outputs": [ @@ -4842,7 +4889,7 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": 90, "id": "cd7238b2", "metadata": {}, "outputs": [ @@ -4869,7 +4916,7 @@ }, { "cell_type": "code", - "execution_count": 92, + "execution_count": 91, "id": "bc2fde04", "metadata": {}, "outputs": [ @@ -4878,13 +4925,13 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.3s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.3s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.4s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.4s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.6s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.3s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.9s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.9s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", @@ -4909,7 +4956,7 @@ }, { "cell_type": "code", - "execution_count": 93, + "execution_count": 92, "id": "7b5d7d58", "metadata": {}, "outputs": [ @@ -4926,6 +4973,27 @@ "print(f\"Average SMAPE for Chronos small with long context: {score:.3f}\")" ] }, + { + "cell_type": "code", + "execution_count": 93, + "id": "1666d622", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_backtest(forecast_chronos, ts, history_len=20)" + ] + }, { "cell_type": "markdown", "id": "6bbea6f5", @@ -4989,10 +5057,10 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.3s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.3s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.3s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", @@ -5039,7 +5107,7 @@ }, { "cell_type": "markdown", - "id": "e39ebb25", + "id": "2337e9ce", "metadata": {}, "source": [ "### 3.13 TimesFm Model \n", @@ -5050,7 +5118,7 @@ { "cell_type": "code", "execution_count": 98, - "id": "f7cae59a", + "id": "dfea3ca5", "metadata": {}, "outputs": [], "source": [ @@ -5059,7 +5127,7 @@ }, { "cell_type": "markdown", - "id": "2e999f64", + "id": "5a30860f", "metadata": {}, "source": [ "Now only one model is available." @@ -5068,7 +5136,7 @@ { "cell_type": "code", "execution_count": 99, - "id": "a268ee77", + "id": "876014e9", "metadata": {}, "outputs": [ { @@ -5088,7 +5156,7 @@ }, { "cell_type": "markdown", - "id": "e57e1f23", + "id": "85dc8d64", "metadata": {}, "source": [ "Be careful. `encoder_length` needs to be a multiplier of 32." @@ -5097,7 +5165,7 @@ { "cell_type": "code", "execution_count": 100, - "id": "cdc514fa", + "id": "4cea77dd", "metadata": {}, "outputs": [ { @@ -5151,7 +5219,7 @@ { "cell_type": "code", "execution_count": 101, - "id": "625ef289", + "id": "92fd49c4", "metadata": {}, "outputs": [ { @@ -5170,7 +5238,7 @@ { "cell_type": "code", "execution_count": 102, - "id": "32f13834", + "id": "ba48b07d", "metadata": {}, "outputs": [ { @@ -5190,7 +5258,7 @@ }, { "cell_type": "markdown", - "id": "095326bd", + "id": "6a6d6cd7", "metadata": {}, "source": [ "Model can work with exogenous features." @@ -5199,7 +5267,7 @@ { "cell_type": "code", "execution_count": 103, - "id": "10dede77", + "id": "c92e1d41", "metadata": {}, "outputs": [ { @@ -5264,7 +5332,7 @@ { "cell_type": "code", "execution_count": 104, - "id": "379309bd", + "id": "1a98104a", "metadata": {}, "outputs": [ { diff --git a/tests/test_models/test_inference/test_forecast.py b/tests/test_models/test_inference/test_forecast.py index 7c9120114..573ea4ce7 100644 --- a/tests/test_models/test_inference/test_forecast.py +++ b/tests/test_models/test_inference/test_forecast.py @@ -1301,6 +1301,7 @@ def test_forecast_out_sample_suffix_failed_chronos(self, model, transforms, data ], ) def test_forecast_out_sample_suffix_failed_timesfm(self, model, transforms, dataset_name, request): + """This test is expected to fail due to patch strategy in TimesFM""" ts = request.getfixturevalue(dataset_name) with pytest.raises(AssertionError): self._test_forecast_out_sample_suffix(ts, model, transforms) diff --git a/tests/test_models/test_nn/test_timesfm.py b/tests/test_models/test_nn/test_timesfm.py index 4131f5d73..af5608dbc 100644 --- a/tests/test_models/test_nn/test_timesfm.py +++ b/tests/test_models/test_nn/test_timesfm.py @@ -73,7 +73,6 @@ def test_predict(example_tsds): model.predict(ts=example_tsds, prediction_size=1) -@pytest.mark.smoke def test_forecast_warns_big_context_size(ts_increasing_integers): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=512) pipeline = Pipeline(model=model, horizon=1) @@ -142,7 +141,6 @@ def test_forecast_exog_int_timestamps(example_tsds_int_timestamp): _ = pipeline.forecast() -@pytest.mark.smoke @pytest.mark.parametrize("encoder_length", [16, 33]) def test_forecast_wrong_context_len(ts_increasing_integers, encoder_length): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=encoder_length) From e99e6872adceb82e11e2f30c133743142f4c2cbd Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Thu, 26 Dec 2024 14:19:13 +0300 Subject: [PATCH 11/20] raise exception when NaNs are in features --- etna/models/nn/timesfm.py | 95 ++++++++++------- examples/202-NN_examples.ipynb | 118 +++++----------------- tests/test_models/test_nn/test_timesfm.py | 74 ++++++++++++-- 3 files changed, 149 insertions(+), 138 deletions(-) diff --git a/etna/models/nn/timesfm.py b/etna/models/nn/timesfm.py index 8a3bcc155..a869ec491 100644 --- a/etna/models/nn/timesfm.py +++ b/etna/models/nn/timesfm.py @@ -30,14 +30,12 @@ class TimesFMModel(NonPredictionIntervalContextRequiredAbstractModel): This model is only for zero-shot forecasting: it doesn't support training on data during ``fit``. - Official implementation: https://github.com/google-research/timesfm - - Warning - ------- This model doesn't support forecasting on misaligned data with `freq=None` without exogenous features. Use :py:class:`~etna.transforms.TimeSeriesImputerTransform` to fill NaNs for stable behaviour. + Official implementation: https://github.com/google-research/timesfm + Note ---- This model requires ``timesfm`` extension to be installed. @@ -181,6 +179,14 @@ def predict( """ raise NotImplementedError("Method predict isn't currently implemented!") + def _get_exog_features(self) -> List[str]: + static_reals = [] if self.static_reals is None else self.static_reals + static_categoricals = [] if self.static_categoricals is None else self.static_categoricals + time_varying_reals = [] if self.time_varying_reals is None else self.time_varying_reals + time_varying_categoricals = [] if self.time_varying_categoricals is None else self.time_varying_categoricals + + return static_reals + static_categoricals + time_varying_reals + time_varying_categoricals + def forecast( self, ts: TSDataset, @@ -210,6 +216,8 @@ def forecast( if return_components mode is used. ValueError: if dataset doesn't have any context timestamps. + ValueError: + if there are NaNs in the middle or end of the time series. NotImplementedError: if forecasting is done without exogenous features and dataset has None frequency. """ @@ -227,40 +235,55 @@ def forecast( end_idx = len(ts.index) - static_reals_dict = ( - {column: ts.df.loc[ts.index[0], pd.IndexSlice[:, column]].values.tolist() for column in self.static_reals} - if self.static_reals is not None - else None - ) - static_categoricals_dict = ( - { - column: ts.df.loc[ts.index[0], pd.IndexSlice[:, column]].values.tolist() - for column in self.static_categoricals - } - if self.static_categoricals is not None - else None - ) - time_varying_reals_dict = ( - { - column: ts.df.loc[:, pd.IndexSlice[:, column]].values.swapaxes(1, 0).tolist() - for column in self.time_varying_reals - } - if self.time_varying_reals is not None - else None - ) - time_varying_categoricals_dict = ( - { - column: ts.df.loc[:, pd.IndexSlice[:, column]].values.swapaxes(1, 0).tolist() - for column in self.time_varying_categoricals - } - if self.time_varying_categoricals is not None - else None - ) + all_exog = self._get_exog_features() + df_slice = ts.df.loc[:, pd.IndexSlice[:, all_exog + ["target"]]] + first_valid_index = df_slice.isna().any(axis=1).idxmin() + + target_df = df_slice.loc[first_valid_index : ts.index[-prediction_size - 1], pd.IndexSlice[:, "target"]] + if target_df.isna().any().any(): + raise ValueError("There are NaNs in the middle or end of the time series.") future_ts = ts.tsdataset_idx_slice(start_idx=end_idx - prediction_size, end_idx=end_idx) - if static_reals_dict or static_categoricals_dict or time_varying_reals_dict or time_varying_categoricals_dict: - target = ts.df.loc[:, pd.IndexSlice[:, "target"]].dropna().values.swapaxes(1, 0).tolist() + if len(all_exog) > 0: + target = target_df.values.swapaxes(1, 0).tolist() + + exog_df = df_slice.loc[first_valid_index:, pd.IndexSlice[:, all_exog]] + if exog_df.isna().any().any(): + raise ValueError("There are NaNs in the middle or end of the exogenous features.") + + static_reals_dict = ( + { + column: exog_df.loc[exog_df.index[0], pd.IndexSlice[:, column]].values.tolist() + for column in self.static_reals + } + if self.static_reals is not None + else None + ) + static_categoricals_dict = ( + { + column: exog_df.loc[exog_df.index[0], pd.IndexSlice[:, column]].values.tolist() + for column in self.static_categoricals + } + if self.static_categoricals is not None + else None + ) + time_varying_reals_dict = ( + { + column: exog_df.loc[:, pd.IndexSlice[:, column]].values.swapaxes(1, 0).tolist() + for column in self.time_varying_reals + } + if self.time_varying_reals is not None + else None + ) + time_varying_categoricals_dict = ( + { + column: exog_df.loc[:, pd.IndexSlice[:, column]].values.swapaxes(1, 0).tolist() + for column in self.time_varying_categoricals + } + if self.time_varying_categoricals is not None + else None + ) complex_forecast, _ = self.tfm.forecast_with_covariates( inputs=target, @@ -277,7 +300,7 @@ def forecast( "Forecasting misaligned data with freq=None without exogenous features isn't currently implemented." ) - target = ts.to_pandas(flatten=True, features=["target"]).dropna() + target = TSDataset.to_flatten(df=target_df) target = target.rename(columns={"segment": "unique_id", "timestamp": "ds"}) predictions = self.tfm.forecast_on_df(target, freq=ts.freq, value_name="target") diff --git a/examples/202-NN_examples.ipynb b/examples/202-NN_examples.ipynb index 71d116237..d6a53d551 100644 --- a/examples/202-NN_examples.ipynb +++ b/examples/202-NN_examples.ipynb @@ -37,16 +37,6 @@ " * [TimesFM Model](#section_3_13)" ] }, - { - "cell_type": "code", - "execution_count": 1, - "id": "05682333", - "metadata": {}, - "outputs": [], - "source": [ - "a = 1" - ] - }, { "cell_type": "code", "execution_count": 1, @@ -4661,64 +4651,6 @@ "Chronos is pretrained model for zero-shot forecasting." ] }, - { - "cell_type": "code", - "execution_count": 80, - "id": "d37a75b8", - "metadata": {}, - "outputs": [], - "source": [ - "a = 1" - ] - }, - { - "cell_type": "code", - "execution_count": 81, - "id": "17322770", - "metadata": {}, - "outputs": [], - "source": [ - "a = 1" - ] - }, - { - "cell_type": "code", - "execution_count": 82, - "id": "3b147434", - "metadata": {}, - "outputs": [], - "source": [ - "a = 1" - ] - }, - { - "cell_type": "code", - "execution_count": 83, - "id": "7e37a4b5", - "metadata": {}, - "outputs": [], - "source": [ - "a = 1" - ] - }, - { - "cell_type": "code", - "execution_count": 79, - "id": "e596bf7e", - "metadata": {}, - "outputs": [], - "source": [ - "a = 1" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1be903d8", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": 84, @@ -4976,7 +4908,7 @@ { "cell_type": "code", "execution_count": 93, - "id": "1666d622", + "id": "041c767b", "metadata": {}, "outputs": [ { @@ -5107,7 +5039,7 @@ }, { "cell_type": "markdown", - "id": "2337e9ce", + "id": "d438f154", "metadata": {}, "source": [ "### 3.13 TimesFm Model \n", @@ -5118,7 +5050,7 @@ { "cell_type": "code", "execution_count": 98, - "id": "dfea3ca5", + "id": "d2d30ce1", "metadata": {}, "outputs": [], "source": [ @@ -5127,7 +5059,7 @@ }, { "cell_type": "markdown", - "id": "5a30860f", + "id": "7978edc4", "metadata": {}, "source": [ "Now only one model is available." @@ -5136,7 +5068,7 @@ { "cell_type": "code", "execution_count": 99, - "id": "876014e9", + "id": "215231c6", "metadata": {}, "outputs": [ { @@ -5156,7 +5088,7 @@ }, { "cell_type": "markdown", - "id": "85dc8d64", + "id": "32806e28", "metadata": {}, "source": [ "Be careful. `encoder_length` needs to be a multiplier of 32." @@ -5165,13 +5097,13 @@ { "cell_type": "code", "execution_count": 100, - "id": "4cea77dd", + "id": "cf12e1ca", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "20c4816d7a1f43fd86b152a6257ea4cf", + "model_id": "31c00b3df9ce47b9bf5a43dfdb29b79d", "version_major": 2, "version_minor": 0 }, @@ -5188,15 +5120,15 @@ "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.2s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s finished\n", - "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.2s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.4s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.4s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.7s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.0s finished\n", + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s remaining: 0.0s\n", @@ -5219,7 +5151,7 @@ { "cell_type": "code", "execution_count": 101, - "id": "92fd49c4", + "id": "8bf266f6", "metadata": {}, "outputs": [ { @@ -5238,7 +5170,7 @@ { "cell_type": "code", "execution_count": 102, - "id": "ba48b07d", + "id": "9ac4f029", "metadata": {}, "outputs": [ { @@ -5258,7 +5190,7 @@ }, { "cell_type": "markdown", - "id": "6a6d6cd7", + "id": "6baa3066", "metadata": {}, "source": [ "Model can work with exogenous features." @@ -5267,13 +5199,13 @@ { "cell_type": "code", "execution_count": 103, - "id": "c92e1d41", + "id": "fcc1d4c4", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "eb26373760e0483abd16481719195154", + "model_id": "915baa91dd90481b93504302699e6d67", "version_major": 2, "version_minor": 0 }, @@ -5291,13 +5223,13 @@ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.2s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.3s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.3s finished\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.5s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.6s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.8s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.8s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 1.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.3s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.5s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.5s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s remaining: 0.0s\n", @@ -5332,7 +5264,7 @@ { "cell_type": "code", "execution_count": 104, - "id": "1a98104a", + "id": "a38f407f", "metadata": {}, "outputs": [ { diff --git a/tests/test_models/test_nn/test_timesfm.py b/tests/test_models/test_nn/test_timesfm.py index af5608dbc..27f0629ef 100644 --- a/tests/test_models/test_nn/test_timesfm.py +++ b/tests/test_models/test_nn/test_timesfm.py @@ -1,6 +1,7 @@ import os from pathlib import Path +import numpy as np import pandas as pd import pytest from pandas.testing import assert_frame_equal @@ -15,11 +16,32 @@ from etna.transforms import SegmentEncoderTransform -@pytest.fixture -def ts_increasing_integers(): +def generate_increasing_df(): n = 128 df = generate_ar_df(start_time="2001-01-01", periods=n, n_segments=2) df["target"] = list(range(n)) + list(range(100, 100 + n)) + return df + + +@pytest.fixture +def ts_increasing_integers(): + df = generate_increasing_df() + ts = TSDataset(df, freq="D") + return ts + + +@pytest.fixture +def ts_nan_start(): + df = generate_increasing_df() + df.loc[0, "target"] = np.NaN + ts = TSDataset(df, freq="D") + return ts + + +@pytest.fixture +def ts_nan_middle(): + df = generate_increasing_df() + df.loc[120, "target"] = np.NaN ts = TSDataset(df, freq="D") return ts @@ -82,34 +104,68 @@ def test_forecast_warns_big_context_size(ts_increasing_integers): @pytest.mark.parametrize("encoder_length", [32, 64, 128]) -def test_forecast(ts_increasing_integers, expected_ts_increasing_integers, encoder_length): +@pytest.mark.parametrize("ts", ["ts_increasing_integers", "ts_nan_start"]) +def test_forecast(ts, expected_ts_increasing_integers, encoder_length, request): + ts = request.getfixturevalue(ts) model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=encoder_length) pipeline = Pipeline(model=model, horizon=2) - pipeline.fit(ts_increasing_integers) + pipeline.fit(ts) forecast = pipeline.forecast() assert_frame_equal(forecast.df, expected_ts_increasing_integers.df, atol=1) -def test_forecast_exogenous_features(ts_increasing_integers, expected_ts_increasing_integers): +def test_forecast_failed_nan_middle_target(ts_nan_middle): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=128) + pipeline = Pipeline(model=model, horizon=2) + pipeline.fit(ts_nan_middle) + with pytest.raises(ValueError, match="There are NaNs in the middle or end of the time series."): + _ = pipeline.forecast() + + +@pytest.mark.parametrize("encoder_length", [32, 64, 128]) +@pytest.mark.parametrize("ts", ["ts_increasing_integers", "ts_nan_start"]) +def test_forecast_exogenous_features(ts, expected_ts_increasing_integers, encoder_length, request): + ts = request.getfixturevalue(ts) + horizon = 2 transforms = [ SegmentEncoderTransform(), - LagTransform(in_column="target", lags=[horizon], out_column="lag"), + LagTransform(in_column="target", lags=[horizon, horizon + 1], out_column="lag"), DateFlagsTransform(day_number_in_week=True, day_number_in_month=False, is_weekend=False, out_column="flag"), ] model = TimesFMModel( path_or_url="google/timesfm-1.0-200m-pytorch", - encoder_length=32, + encoder_length=encoder_length, static_categoricals=["segment_code"], - time_varying_reals=[f"lag_{horizon}"], + time_varying_reals=[f"lag_{horizon}", f"lag_{horizon+1}"], time_varying_categoricals=["flag_day_number_in_week"], ) pipeline = Pipeline(model=model, transforms=transforms, horizon=horizon) - pipeline.fit(ts_increasing_integers) + pipeline.fit(ts) forecast = pipeline.forecast() assert_frame_equal(forecast.df.loc[:, pd.IndexSlice[:, "target"]], expected_ts_increasing_integers.df, atol=1) +def test_forecast_exog_features_failed_nan_middle_target(ts_nan_middle): + horizon = 2 + transforms = [ + SegmentEncoderTransform(), + LagTransform(in_column="target", lags=[horizon, horizon + 1], out_column="lag"), + DateFlagsTransform(day_number_in_week=True, day_number_in_month=False, is_weekend=False, out_column="flag"), + ] + model = TimesFMModel( + path_or_url="google/timesfm-1.0-200m-pytorch", + encoder_length=128, + static_categoricals=["segment_code"], + time_varying_reals=[f"lag_{horizon}", f"lag_{horizon+1}"], + time_varying_categoricals=["flag_day_number_in_week"], + ) + pipeline = Pipeline(model=model, transforms=transforms, horizon=horizon) + pipeline.fit(ts_nan_middle) + with pytest.raises(ValueError, match="There are NaNs in the middle or end of the time series."): + _ = pipeline.forecast() + + @pytest.mark.smoke def test_forecast_only_target_failed_int_timestamps(example_tsds_int_timestamp): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32) From 85fbe3c5b70c149dbe6aa3ebfd90231766debe16 Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Thu, 26 Dec 2024 16:05:48 +0300 Subject: [PATCH 12/20] minor fixes --- etna/models/nn/timesfm.py | 30 +++++++++++---- tests/test_models/test_nn/test_timesfm.py | 47 ++++++++++++++++++++++- 2 files changed, 67 insertions(+), 10 deletions(-) diff --git a/etna/models/nn/timesfm.py b/etna/models/nn/timesfm.py index a869ec491..0fc504360 100644 --- a/etna/models/nn/timesfm.py +++ b/etna/models/nn/timesfm.py @@ -1,4 +1,5 @@ import os +import reprlib import warnings from pathlib import Path from typing import Dict @@ -32,7 +33,8 @@ class TimesFMModel(NonPredictionIntervalContextRequiredAbstractModel): This model doesn't support forecasting on misaligned data with `freq=None` without exogenous features. - Use :py:class:`~etna.transforms.TimeSeriesImputerTransform` to fill NaNs for stable behaviour. + This model doesn't support NaN in the middle or at the end of time series. + Use :py:class:`~etna.transforms.TimeSeriesImputerTransform` to fill them. Official implementation: https://github.com/google-research/timesfm @@ -179,7 +181,7 @@ def predict( """ raise NotImplementedError("Method predict isn't currently implemented!") - def _get_exog_features(self) -> List[str]: + def _exog_сolumns(self) -> List[str]: static_reals = [] if self.static_reals is None else self.static_reals static_categoricals = [] if self.static_categoricals is None else self.static_categoricals time_varying_reals = [] if self.time_varying_reals is None else self.time_varying_reals @@ -235,13 +237,20 @@ def forecast( end_idx = len(ts.index) - all_exog = self._get_exog_features() + all_exog = self._exog_сolumns() df_slice = ts.df.loc[:, pd.IndexSlice[:, all_exog + ["target"]]] - first_valid_index = df_slice.isna().any(axis=1).idxmin() + first_valid_index = ( + df_slice.isna().any(axis=1).idxmin() + ) # If all timestamps contains NaNs, idxmin() returns the first timestamp target_df = df_slice.loc[first_valid_index : ts.index[-prediction_size - 1], pd.IndexSlice[:, "target"]] - if target_df.isna().any().any(): - raise ValueError("There are NaNs in the middle or end of the time series.") + + nan_segment_mask = target_df.isna().any() + if nan_segment_mask.any(): + nan_segments = nan_segment_mask.loc[:, nan_segment_mask].index.get_level_values(0).tolist() + raise ValueError( + f"There are NaNs in the middle or at the end of target. Segments with NaNs: {reprlib.repr(nan_segments)}." + ) future_ts = ts.tsdataset_idx_slice(start_idx=end_idx - prediction_size, end_idx=end_idx) @@ -249,8 +258,13 @@ def forecast( target = target_df.values.swapaxes(1, 0).tolist() exog_df = df_slice.loc[first_valid_index:, pd.IndexSlice[:, all_exog]] - if exog_df.isna().any().any(): - raise ValueError("There are NaNs in the middle or end of the exogenous features.") + + nan_segment_mask = exog_df.isna().any() + if nan_segment_mask.any(): + nan_segments = nan_segment_mask.loc[:, nan_segment_mask].index.get_level_values(0).tolist() + raise ValueError( + f"There are NaNs in the middle or at the end of exogenous features. Segments with NaNs: {reprlib.repr(nan_segments)}." + ) static_reals_dict = ( { diff --git a/tests/test_models/test_nn/test_timesfm.py b/tests/test_models/test_nn/test_timesfm.py index 27f0629ef..3778c432b 100644 --- a/tests/test_models/test_nn/test_timesfm.py +++ b/tests/test_models/test_nn/test_timesfm.py @@ -23,6 +23,13 @@ def generate_increasing_df(): return df +def generate_exog(): + n = 128 + df_exog = generate_ar_df(start_time="2001-01-01", periods=n + 2, n_segments=2) + df_exog.rename(columns={"target": "exog"}, inplace=True) + return df_exog + + @pytest.fixture def ts_increasing_integers(): df = generate_increasing_df() @@ -54,6 +61,24 @@ def expected_ts_increasing_integers(): return ts +@pytest.fixture +def ts_exog_middle_nan(): + df = generate_increasing_df() + df_exog = generate_exog() + df_exog.loc[120, "exog"] = np.NaN + ts = TSDataset(df, df_exog=df_exog, freq="D", known_future="all") + return ts + + +@pytest.fixture +def ts_exog_all_nan(): + df = generate_increasing_df() + df_exog = generate_exog() + df_exog["exog"] = np.NaN + ts = TSDataset(df, df_exog=df_exog, freq="D", known_future="all") + return ts + + @pytest.mark.smoke def test_url(tmp_path): model_name = "timesfm-1.0-200m-pytorch.ckpt" @@ -118,7 +143,7 @@ def test_forecast_failed_nan_middle_target(ts_nan_middle): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=128) pipeline = Pipeline(model=model, horizon=2) pipeline.fit(ts_nan_middle) - with pytest.raises(ValueError, match="There are NaNs in the middle or end of the time series."): + with pytest.raises(ValueError, match=r"There are NaNs in the middle or at the end of target. Segments with NaNs:"): _ = pipeline.forecast() @@ -162,7 +187,25 @@ def test_forecast_exog_features_failed_nan_middle_target(ts_nan_middle): ) pipeline = Pipeline(model=model, transforms=transforms, horizon=horizon) pipeline.fit(ts_nan_middle) - with pytest.raises(ValueError, match="There are NaNs in the middle or end of the time series."): + with pytest.raises(ValueError, match="There are NaNs in the middle or at the end of target. Segments with NaNs:"): + _ = pipeline.forecast() + + +@pytest.mark.parametrize("ts", ["ts_exog_middle_nan", "ts_exog_all_nan"]) +def test_forecast_exog_features_failed_exog_nan(ts, request): + ts = request.getfixturevalue(ts) + + horizon = 2 + model = TimesFMModel( + path_or_url="google/timesfm-1.0-200m-pytorch", + encoder_length=128, + time_varying_reals=["exog"], + ) + pipeline = Pipeline(model=model, transforms=[], horizon=horizon) + pipeline.fit(ts) + with pytest.raises( + ValueError, match="There are NaNs in the middle or at the end of exogenous features. Segments with NaNs:" + ): _ = pipeline.forecast() From 36c3f2b5f43fd2e7d4125064b9bd274ef65b226c Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Thu, 26 Dec 2024 19:16:20 +0300 Subject: [PATCH 13/20] fix --- etna/models/nn/timesfm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/etna/models/nn/timesfm.py b/etna/models/nn/timesfm.py index 0fc504360..160e1f6c7 100644 --- a/etna/models/nn/timesfm.py +++ b/etna/models/nn/timesfm.py @@ -33,7 +33,7 @@ class TimesFMModel(NonPredictionIntervalContextRequiredAbstractModel): This model doesn't support forecasting on misaligned data with `freq=None` without exogenous features. - This model doesn't support NaN in the middle or at the end of time series. + This model doesn't support NaN in the middle or at the end of target and exogenous features. Use :py:class:`~etna.transforms.TimeSeriesImputerTransform` to fill them. Official implementation: https://github.com/google-research/timesfm @@ -181,7 +181,7 @@ def predict( """ raise NotImplementedError("Method predict isn't currently implemented!") - def _exog_сolumns(self) -> List[str]: + def _exog_columns(self) -> List[str]: static_reals = [] if self.static_reals is None else self.static_reals static_categoricals = [] if self.static_categoricals is None else self.static_categoricals time_varying_reals = [] if self.time_varying_reals is None else self.time_varying_reals @@ -247,7 +247,7 @@ def forecast( nan_segment_mask = target_df.isna().any() if nan_segment_mask.any(): - nan_segments = nan_segment_mask.loc[:, nan_segment_mask].index.get_level_values(0).tolist() + nan_segments = nan_segment_mask.loc[:, nan_segment_mask].index.get_level_values(0).unique().tolist() raise ValueError( f"There are NaNs in the middle or at the end of target. Segments with NaNs: {reprlib.repr(nan_segments)}." ) @@ -261,7 +261,7 @@ def forecast( nan_segment_mask = exog_df.isna().any() if nan_segment_mask.any(): - nan_segments = nan_segment_mask.loc[:, nan_segment_mask].index.get_level_values(0).tolist() + nan_segments = nan_segment_mask.loc[:, nan_segment_mask].index.get_level_values(0).unique().tolist() raise ValueError( f"There are NaNs in the middle or at the end of exogenous features. Segments with NaNs: {reprlib.repr(nan_segments)}." ) From 67fb748860132728f969430fcddc18a1d52b72db Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Thu, 26 Dec 2024 19:19:22 +0300 Subject: [PATCH 14/20] fix --- etna/models/nn/timesfm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/etna/models/nn/timesfm.py b/etna/models/nn/timesfm.py index 160e1f6c7..793665269 100644 --- a/etna/models/nn/timesfm.py +++ b/etna/models/nn/timesfm.py @@ -237,7 +237,7 @@ def forecast( end_idx = len(ts.index) - all_exog = self._exog_сolumns() + all_exog = self._exog_columns() df_slice = ts.df.loc[:, pd.IndexSlice[:, all_exog + ["target"]]] first_valid_index = ( df_slice.isna().any(axis=1).idxmin() From 397086fee82f985498aa247454129de1408f55aa Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Fri, 27 Dec 2024 11:05:38 +0300 Subject: [PATCH 15/20] skip timesfm tests in inference tests --- .../test_inference/test_forecast.py | 77 ++++++++++++++++--- .../test_inference/test_predict.py | 63 ++++++++++++--- 2 files changed, 120 insertions(+), 20 deletions(-) diff --git a/tests/test_models/test_inference/test_forecast.py b/tests/test_models/test_inference/test_forecast.py index 573ea4ce7..eb3a25ea9 100644 --- a/tests/test_models/test_inference/test_forecast.py +++ b/tests/test_models/test_inference/test_forecast.py @@ -249,7 +249,12 @@ def test_forecast_in_sample_full_no_target_failed_chronos(self, model, transform @pytest.mark.parametrize( "model, transforms, dataset_name", [ - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_forecast_in_sample_full_no_target_failed_timesfm(self, model, transforms, dataset_name, request): @@ -422,7 +427,12 @@ def test_forecast_in_sample_full_failed_chronos(self, model, transforms, dataset @pytest.mark.parametrize( "model, transforms, dataset_name", [ - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_forecast_in_sample_full_failed_timesfm(self, model, transforms, dataset_name, request): @@ -516,7 +526,12 @@ def _test_forecast_in_sample_suffix_no_target(ts, model, transforms, num_skip_po (NBeatsGenericModel(input_size=7, output_size=50, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_forecast_in_sample_suffix_no_target(self, model, transforms, dataset_name, request): @@ -641,7 +656,12 @@ class TestForecastInSampleSuffix: (NBeatsGenericModel(input_size=7, output_size=50, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_forecast_in_sample_suffix(self, model, transforms, dataset_name, request): @@ -817,7 +837,12 @@ def _test_forecast_out_sample(ts, model, transforms, prediction_size=5): (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_forecast_out_sample_datetime_timestamp(self, model, transforms, dataset_name, request): @@ -920,7 +945,12 @@ def test_forecast_out_sample_int_timestamp(self, model, transforms, dataset_name @pytest.mark.parametrize( "model, transforms, dataset_name", [ - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_forecast_out_sample_int_timestamp_failed_timesfm(self, model, transforms, dataset_name, request): @@ -1085,7 +1115,12 @@ def _test_forecast_out_sample_prefix(ts, model, transforms, full_prediction_size (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_forecast_out_sample_prefix(self, model, transforms, dataset_name, request): @@ -1297,7 +1332,12 @@ def test_forecast_out_sample_suffix_failed_chronos(self, model, transforms, data @pytest.mark.parametrize( "model, transforms, dataset_name", [ - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_forecast_out_sample_suffix_failed_timesfm(self, model, transforms, dataset_name, request): @@ -1446,7 +1486,12 @@ def _test_forecast_mixed_in_out_sample(ts, model, transforms, num_skip_points=50 (NBeatsGenericModel(input_size=7, output_size=55, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_forecast_mixed_in_out_sample(self, model, transforms, dataset_name, request): @@ -1609,7 +1654,12 @@ def _test_forecast_subset_segments(self, ts, model, transforms, segments, predic ), (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_forecast_subset_segments(self, model, transforms, dataset_name, request): @@ -1787,7 +1837,12 @@ def _test_forecast_new_segments(self, ts, model, transforms, train_segments, pre (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_forecast_new_segments(self, model, transforms, dataset_name, request): diff --git a/tests/test_models/test_inference/test_predict.py b/tests/test_models/test_inference/test_predict.py index 601703b36..a6ef08bd8 100644 --- a/tests/test_models/test_inference/test_predict.py +++ b/tests/test_models/test_inference/test_predict.py @@ -196,7 +196,12 @@ def test_predict_in_sample_full_failed_not_enough_context(self, model, transform (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_predict_in_sample_full_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -328,7 +333,12 @@ def test_predict_in_sample_suffix_datetime_timestamp(self, model, transforms, da (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_predict_in_sample_suffix_datetime_timestamp_failed_not_implemented_predict( @@ -475,7 +485,12 @@ def test_predict_in_sample_suffix_int_timestamp_failed(self, model, transforms, (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_predict_in_sample_suffix_int_timestamp_failed_not_implemented_predict( @@ -618,7 +633,12 @@ def test_predict_out_sample(self, model, transforms, dataset_name, request): (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_predict_out_sample_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -778,7 +798,12 @@ def test_predict_out_sample_prefix(self, model, transforms, dataset_name, reques (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_predict_out_sample_prefix_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -952,7 +977,12 @@ def test_predict_out_sample_suffix(self, model, transforms, dataset_name, reques (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_predict_out_sample_suffix_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -1131,7 +1161,12 @@ def test_predict_mixed_in_out_sample(self, model, transforms, dataset_name, requ (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_predict_mixed_in_out_sample_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -1298,7 +1333,12 @@ def test_predict_subset_segments(self, model, transforms, dataset_name, request) (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_predict_subset_segments_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -1430,7 +1470,12 @@ def test_predict_new_segments(self, model, transforms, dataset_name, request): (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - (TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), [], "example_tsds"), + pytest.param( + TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + [], + "example_tsds", + marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + ), ], ) def test_predict_new_segments_failed_not_implemented_predict(self, model, transforms, dataset_name, request): From 409ac14d38fb91af15ce8eaca877e8b136225fc9 Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Fri, 27 Dec 2024 12:58:54 +0300 Subject: [PATCH 16/20] skip test for timesfm --- tests/test_models/test_nn/test_timesfm.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_models/test_nn/test_timesfm.py b/tests/test_models/test_nn/test_timesfm.py index 3778c432b..da032ea86 100644 --- a/tests/test_models/test_nn/test_timesfm.py +++ b/tests/test_models/test_nn/test_timesfm.py @@ -79,6 +79,7 @@ def ts_exog_all_nan(): return ts +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_url(tmp_path): model_name = "timesfm-1.0-200m-pytorch.ckpt" @@ -87,6 +88,7 @@ def test_url(tmp_path): assert os.path.exists(tmp_path / model_name) +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_cache_dir(tmp_path): path_or_url = "google/timesfm-1.0-200m-pytorch" @@ -95,24 +97,28 @@ def test_cache_dir(tmp_path): assert os.path.exists(tmp_path / f"models--google--{model_name}") +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_context_size(): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=10) assert model.context_size == 10 +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_get_model(example_tsds): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") assert isinstance(model.get_model(), TimesFmTorch) +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_fit(example_tsds): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") model.fit(example_tsds) +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_predict(example_tsds): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") @@ -120,6 +126,7 @@ def test_predict(example_tsds): model.predict(ts=example_tsds, prediction_size=1) +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") def test_forecast_warns_big_context_size(ts_increasing_integers): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=512) pipeline = Pipeline(model=model, horizon=1) @@ -139,6 +146,7 @@ def test_forecast(ts, expected_ts_increasing_integers, encoder_length, request): assert_frame_equal(forecast.df, expected_ts_increasing_integers.df, atol=1) +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") def test_forecast_failed_nan_middle_target(ts_nan_middle): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=128) pipeline = Pipeline(model=model, horizon=2) @@ -171,6 +179,7 @@ def test_forecast_exogenous_features(ts, expected_ts_increasing_integers, encode assert_frame_equal(forecast.df.loc[:, pd.IndexSlice[:, "target"]], expected_ts_increasing_integers.df, atol=1) +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") def test_forecast_exog_features_failed_nan_middle_target(ts_nan_middle): horizon = 2 transforms = [ @@ -191,6 +200,7 @@ def test_forecast_exog_features_failed_nan_middle_target(ts_nan_middle): _ = pipeline.forecast() +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.parametrize("ts", ["ts_exog_middle_nan", "ts_exog_all_nan"]) def test_forecast_exog_features_failed_exog_nan(ts, request): ts = request.getfixturevalue(ts) @@ -209,6 +219,7 @@ def test_forecast_exog_features_failed_exog_nan(ts, request): _ = pipeline.forecast() +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_forecast_only_target_failed_int_timestamps(example_tsds_int_timestamp): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32) @@ -221,6 +232,7 @@ def test_forecast_only_target_failed_int_timestamps(example_tsds_int_timestamp): _ = pipeline.forecast() +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_forecast_exog_int_timestamps(example_tsds_int_timestamp): horizon = 2 @@ -240,6 +252,7 @@ def test_forecast_exog_int_timestamps(example_tsds_int_timestamp): _ = pipeline.forecast() +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.parametrize("encoder_length", [16, 33]) def test_forecast_wrong_context_len(ts_increasing_integers, encoder_length): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=encoder_length) @@ -249,6 +262,7 @@ def test_forecast_wrong_context_len(ts_increasing_integers, encoder_length): _ = pipeline.forecast() +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_forecast_without_fit(example_tsds): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32) @@ -256,6 +270,7 @@ def test_forecast_without_fit(example_tsds): _ = pipeline.forecast(example_tsds) +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_forecast_fails_components(example_tsds): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") @@ -264,11 +279,13 @@ def test_forecast_fails_components(example_tsds): pipeline.forecast(ts=example_tsds, return_components=True) +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_list_models(): assert TimesFMModel.list_models() == ["google/timesfm-1.0-200m-pytorch"] +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_save_load(tmp_path, ts_increasing_integers): path = Path(tmp_path) / "tmp.zip" @@ -282,6 +299,7 @@ def test_save_load(tmp_path, ts_increasing_integers): assert isinstance(loaded_model, TimesFMModel) +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_params_to_tune(): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") From 0d27c55f36d6b92719662d2fa3e0427337be6c5a Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Fri, 27 Dec 2024 13:49:29 +0300 Subject: [PATCH 17/20] skip last 2 tests --- tests/test_models/test_nn/test_timesfm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_models/test_nn/test_timesfm.py b/tests/test_models/test_nn/test_timesfm.py index da032ea86..491846946 100644 --- a/tests/test_models/test_nn/test_timesfm.py +++ b/tests/test_models/test_nn/test_timesfm.py @@ -135,6 +135,7 @@ def test_forecast_warns_big_context_size(ts_increasing_integers): _ = pipeline.forecast() +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.parametrize("encoder_length", [32, 64, 128]) @pytest.mark.parametrize("ts", ["ts_increasing_integers", "ts_nan_start"]) def test_forecast(ts, expected_ts_increasing_integers, encoder_length, request): @@ -155,6 +156,7 @@ def test_forecast_failed_nan_middle_target(ts_nan_middle): _ = pipeline.forecast() +@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.parametrize("encoder_length", [32, 64, 128]) @pytest.mark.parametrize("ts", ["ts_increasing_integers", "ts_nan_start"]) def test_forecast_exogenous_features(ts, expected_ts_increasing_integers, encoder_length, request): From bdd0ec61e60f76b9285da7789d5672095a1d476a Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Fri, 27 Dec 2024 16:24:36 +0300 Subject: [PATCH 18/20] comment inference timesfm tests --- .../test_inference/test_forecast.py | 133 +++++++++--------- .../test_inference/test_predict.py | 109 +++++++------- 2 files changed, 120 insertions(+), 122 deletions(-) diff --git a/tests/test_models/test_inference/test_forecast.py b/tests/test_models/test_inference/test_forecast.py index eb3a25ea9..7b7557c9b 100644 --- a/tests/test_models/test_inference/test_forecast.py +++ b/tests/test_models/test_inference/test_forecast.py @@ -46,7 +46,6 @@ from etna.models.nn import RNNModel from etna.models.nn import TFTModel from etna.models.nn import TFTNativeModel -from etna.models.nn import TimesFMModel from etna.models.nn.deepstate import CompositeSSM from etna.models.nn.deepstate import WeeklySeasonalitySSM from etna.transforms import LagTransform @@ -249,12 +248,12 @@ def test_forecast_in_sample_full_no_target_failed_chronos(self, model, transform @pytest.mark.parametrize( "model, transforms, dataset_name", [ - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_in_sample_full_no_target_failed_timesfm(self, model, transforms, dataset_name, request): @@ -427,12 +426,12 @@ def test_forecast_in_sample_full_failed_chronos(self, model, transforms, dataset @pytest.mark.parametrize( "model, transforms, dataset_name", [ - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_in_sample_full_failed_timesfm(self, model, transforms, dataset_name, request): @@ -526,12 +525,12 @@ def _test_forecast_in_sample_suffix_no_target(ts, model, transforms, num_skip_po (NBeatsGenericModel(input_size=7, output_size=50, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_in_sample_suffix_no_target(self, model, transforms, dataset_name, request): @@ -656,12 +655,12 @@ class TestForecastInSampleSuffix: (NBeatsGenericModel(input_size=7, output_size=50, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_in_sample_suffix(self, model, transforms, dataset_name, request): @@ -837,12 +836,12 @@ def _test_forecast_out_sample(ts, model, transforms, prediction_size=5): (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_out_sample_datetime_timestamp(self, model, transforms, dataset_name, request): @@ -945,12 +944,12 @@ def test_forecast_out_sample_int_timestamp(self, model, transforms, dataset_name @pytest.mark.parametrize( "model, transforms, dataset_name", [ - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_out_sample_int_timestamp_failed_timesfm(self, model, transforms, dataset_name, request): @@ -1115,12 +1114,12 @@ def _test_forecast_out_sample_prefix(ts, model, transforms, full_prediction_size (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_out_sample_prefix(self, model, transforms, dataset_name, request): @@ -1332,12 +1331,12 @@ def test_forecast_out_sample_suffix_failed_chronos(self, model, transforms, data @pytest.mark.parametrize( "model, transforms, dataset_name", [ - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_out_sample_suffix_failed_timesfm(self, model, transforms, dataset_name, request): @@ -1486,12 +1485,12 @@ def _test_forecast_mixed_in_out_sample(ts, model, transforms, num_skip_points=50 (NBeatsGenericModel(input_size=7, output_size=55, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_mixed_in_out_sample(self, model, transforms, dataset_name, request): @@ -1654,12 +1653,12 @@ def _test_forecast_subset_segments(self, ts, model, transforms, segments, predic ), (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_subset_segments(self, model, transforms, dataset_name, request): @@ -1837,12 +1836,12 @@ def _test_forecast_new_segments(self, ts, model, transforms, train_segments, pre (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_new_segments(self, model, transforms, dataset_name, request): diff --git a/tests/test_models/test_inference/test_predict.py b/tests/test_models/test_inference/test_predict.py index a6ef08bd8..2c8d433dc 100644 --- a/tests/test_models/test_inference/test_predict.py +++ b/tests/test_models/test_inference/test_predict.py @@ -44,7 +44,6 @@ from etna.models.nn import RNNModel from etna.models.nn import TFTModel from etna.models.nn import TFTNativeModel -from etna.models.nn import TimesFMModel from etna.models.nn.deepstate import CompositeSSM from etna.models.nn.deepstate import WeeklySeasonalitySSM from etna.transforms import LagTransform @@ -196,12 +195,12 @@ def test_predict_in_sample_full_failed_not_enough_context(self, model, transform (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_in_sample_full_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -333,12 +332,12 @@ def test_predict_in_sample_suffix_datetime_timestamp(self, model, transforms, da (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_in_sample_suffix_datetime_timestamp_failed_not_implemented_predict( @@ -485,12 +484,12 @@ def test_predict_in_sample_suffix_int_timestamp_failed(self, model, transforms, (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_in_sample_suffix_int_timestamp_failed_not_implemented_predict( @@ -633,12 +632,12 @@ def test_predict_out_sample(self, model, transforms, dataset_name, request): (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_out_sample_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -798,12 +797,12 @@ def test_predict_out_sample_prefix(self, model, transforms, dataset_name, reques (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_out_sample_prefix_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -977,12 +976,12 @@ def test_predict_out_sample_suffix(self, model, transforms, dataset_name, reques (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_out_sample_suffix_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -1161,12 +1160,12 @@ def test_predict_mixed_in_out_sample(self, model, transforms, dataset_name, requ (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_mixed_in_out_sample_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -1333,12 +1332,12 @@ def test_predict_subset_segments(self, model, transforms, dataset_name, request) (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_subset_segments_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -1470,12 +1469,12 @@ def test_predict_new_segments(self, model, transforms, dataset_name, request): (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_new_segments_failed_not_implemented_predict(self, model, transforms, dataset_name, request): From c911b2f23a1a59247e578f592cf3a5346062ffaa Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Fri, 27 Dec 2024 17:08:34 +0300 Subject: [PATCH 19/20] unlock one test --- tests/test_models/test_nn/test_timesfm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_models/test_nn/test_timesfm.py b/tests/test_models/test_nn/test_timesfm.py index 491846946..4c17fd71f 100644 --- a/tests/test_models/test_nn/test_timesfm.py +++ b/tests/test_models/test_nn/test_timesfm.py @@ -79,7 +79,6 @@ def ts_exog_all_nan(): return ts -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_url(tmp_path): model_name = "timesfm-1.0-200m-pytorch.ckpt" From 2b3e3fc9a1a0f44cb82478398d5191e2853aeec0 Mon Sep 17 00:00:00 2001 From: Egor Baturin Date: Fri, 27 Dec 2024 17:53:19 +0300 Subject: [PATCH 20/20] unlock all tests --- tests/test_models/test_nn/test_timesfm.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/tests/test_models/test_nn/test_timesfm.py b/tests/test_models/test_nn/test_timesfm.py index 4c17fd71f..3778c432b 100644 --- a/tests/test_models/test_nn/test_timesfm.py +++ b/tests/test_models/test_nn/test_timesfm.py @@ -87,7 +87,6 @@ def test_url(tmp_path): assert os.path.exists(tmp_path / model_name) -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_cache_dir(tmp_path): path_or_url = "google/timesfm-1.0-200m-pytorch" @@ -96,28 +95,24 @@ def test_cache_dir(tmp_path): assert os.path.exists(tmp_path / f"models--google--{model_name}") -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_context_size(): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=10) assert model.context_size == 10 -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_get_model(example_tsds): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") assert isinstance(model.get_model(), TimesFmTorch) -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_fit(example_tsds): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") model.fit(example_tsds) -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_predict(example_tsds): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") @@ -125,7 +120,6 @@ def test_predict(example_tsds): model.predict(ts=example_tsds, prediction_size=1) -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") def test_forecast_warns_big_context_size(ts_increasing_integers): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=512) pipeline = Pipeline(model=model, horizon=1) @@ -134,7 +128,6 @@ def test_forecast_warns_big_context_size(ts_increasing_integers): _ = pipeline.forecast() -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.parametrize("encoder_length", [32, 64, 128]) @pytest.mark.parametrize("ts", ["ts_increasing_integers", "ts_nan_start"]) def test_forecast(ts, expected_ts_increasing_integers, encoder_length, request): @@ -146,7 +139,6 @@ def test_forecast(ts, expected_ts_increasing_integers, encoder_length, request): assert_frame_equal(forecast.df, expected_ts_increasing_integers.df, atol=1) -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") def test_forecast_failed_nan_middle_target(ts_nan_middle): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=128) pipeline = Pipeline(model=model, horizon=2) @@ -155,7 +147,6 @@ def test_forecast_failed_nan_middle_target(ts_nan_middle): _ = pipeline.forecast() -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.parametrize("encoder_length", [32, 64, 128]) @pytest.mark.parametrize("ts", ["ts_increasing_integers", "ts_nan_start"]) def test_forecast_exogenous_features(ts, expected_ts_increasing_integers, encoder_length, request): @@ -180,7 +171,6 @@ def test_forecast_exogenous_features(ts, expected_ts_increasing_integers, encode assert_frame_equal(forecast.df.loc[:, pd.IndexSlice[:, "target"]], expected_ts_increasing_integers.df, atol=1) -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") def test_forecast_exog_features_failed_nan_middle_target(ts_nan_middle): horizon = 2 transforms = [ @@ -201,7 +191,6 @@ def test_forecast_exog_features_failed_nan_middle_target(ts_nan_middle): _ = pipeline.forecast() -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.parametrize("ts", ["ts_exog_middle_nan", "ts_exog_all_nan"]) def test_forecast_exog_features_failed_exog_nan(ts, request): ts = request.getfixturevalue(ts) @@ -220,7 +209,6 @@ def test_forecast_exog_features_failed_exog_nan(ts, request): _ = pipeline.forecast() -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_forecast_only_target_failed_int_timestamps(example_tsds_int_timestamp): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32) @@ -233,7 +221,6 @@ def test_forecast_only_target_failed_int_timestamps(example_tsds_int_timestamp): _ = pipeline.forecast() -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_forecast_exog_int_timestamps(example_tsds_int_timestamp): horizon = 2 @@ -253,7 +240,6 @@ def test_forecast_exog_int_timestamps(example_tsds_int_timestamp): _ = pipeline.forecast() -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.parametrize("encoder_length", [16, 33]) def test_forecast_wrong_context_len(ts_increasing_integers, encoder_length): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=encoder_length) @@ -263,7 +249,6 @@ def test_forecast_wrong_context_len(ts_increasing_integers, encoder_length): _ = pipeline.forecast() -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_forecast_without_fit(example_tsds): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32) @@ -271,7 +256,6 @@ def test_forecast_without_fit(example_tsds): _ = pipeline.forecast(example_tsds) -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_forecast_fails_components(example_tsds): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") @@ -280,13 +264,11 @@ def test_forecast_fails_components(example_tsds): pipeline.forecast(ts=example_tsds, return_components=True) -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_list_models(): assert TimesFMModel.list_models() == ["google/timesfm-1.0-200m-pytorch"] -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_save_load(tmp_path, ts_increasing_integers): path = Path(tmp_path) / "tmp.zip" @@ -300,7 +282,6 @@ def test_save_load(tmp_path, ts_increasing_integers): assert isinstance(loaded_model, TimesFMModel) -@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.") @pytest.mark.smoke def test_params_to_tune(): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch")