diff --git a/CHANGELOG.md b/CHANGELOG.md index 729813490..53a902d45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,8 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add `ChronosModel` ([#511](https://github.com/etna-team/etna/pull/511)) - Add `ChronosBoltModel` ([#511](https://github.com/etna-team/etna/pull/511)) - Add usage example of `ChronosModel` and `ChronosBoltModel` in `202-NN_examples` notebook ([#511](https://github.com/etna-team/etna/pull/511)) -- -- +- Add `TimesFMModel` ([#544](https://github.com/etna-team/etna/pull/544)) +- Add usage example of `TimesFMModel` in `202-NN_examples` notebook ([#544](https://github.com/etna-team/etna/pull/544)) - - - Add `MissingCounter` metric ([#520](https://github.com/etna-team/etna/pull/520)) diff --git a/README.md b/README.md index baae3b461..64016e1ad 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,8 @@ Available user extensions are the following: * `auto`: adds AutoML functionality, * `statsforecast`: adds models from [statsforecast](https://nixtla.github.io/statsforecast/), * `classiciation`: adds time series classification functionality, -* `chronos`: adds Chronos-like pretrained models. +* `chronos`: adds Chronos-like pretrained models, +* `timesfm`: adds TimesFM pretrained models. Install extension: ```bash diff --git a/docs/source/api_reference/models.rst b/docs/source/api_reference/models.rst index daea65539..ef06daee4 100644 --- a/docs/source/api_reference/models.rst +++ b/docs/source/api_reference/models.rst @@ -122,4 +122,5 @@ Pretrained neural network models: :template: class.rst nn.ChronosModel - nn.ChronosBoltModel \ No newline at end of file + nn.ChronosBoltModel + nn.TimesFMModel \ No newline at end of file diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 86cd18123..90ac3d214 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -24,7 +24,8 @@ Available user extensions are the following: - ``auto``: adds AutoML functionality, - ``statsforecast``: adds models from `statsforecast `_, - ``classiciation``: adds time series classification functionality, -- ``chronos``: adds Chronos-like pretrained models. +- ``chronos``: adds Chronos-like pretrained models, +- ``timesfm``: adds TimesFM pretrained models. Install extension: diff --git a/etna/libs/timesfm/__init__.py b/etna/libs/timesfm/__init__.py new file mode 100644 index 000000000..2346aa8a2 --- /dev/null +++ b/etna/libs/timesfm/__init__.py @@ -0,0 +1,155 @@ +""" + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + 1. Definitions. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. +""" + + +from etna.libs.timesfm.timesfm import TimesFmTorch +from etna.libs.timesfm.timesfm_base import TimesFmHparams, TimesFmCheckpoint diff --git a/etna/libs/timesfm/patched_decoder.py b/etna/libs/timesfm/patched_decoder.py new file mode 100644 index 000000000..53cd5c6c1 --- /dev/null +++ b/etna/libs/timesfm/patched_decoder.py @@ -0,0 +1,948 @@ +""" + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + 1. Definitions. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. +""" + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: Copied from timesfm repository (https://github.com/google-research/timesfm/blob/154248137ccce29b01f4c3a765e85c3d9e4d92ba/src/timesfm/patched_decoder.py) + +"""Pytorch version of patched decoder.""" + +import dataclasses +import math +from typing import List, Tuple, Optional +import torch +from torch import nn +import torch.nn.functional as F + + +def _create_quantiles() -> List[float]: + return [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + + +@dataclasses.dataclass +class TimesFMConfig: + """Config for initializing timesfm patched_decoder class.""" + + # The number of blocks in the model. + num_layers: int = 20 + # The number of attention heads used in the attention layers of the model. + num_heads: int = 16 + # The number of key-value heads for implementing attention. + num_kv_heads: int = 16 + # The hidden size of the model. + hidden_size: int = 1280 + # The dimension of the MLP representations. + intermediate_size: int = 1280 + # The number of head dimensions. + head_dim: int = 80 + # The epsilon used by the rms normalization layers. + rms_norm_eps: float = 1e-6 + # Patch length + patch_len: int = 32 + # Horizon length + horizon_len: int = 128 + # quantiles + quantiles: List[float] = dataclasses.field(default_factory=_create_quantiles) + # Padding value + pad_val: float = 1123581321.0 + # Tolerance + tolerance: float = 1e-6 + # The dtype of the weights. + dtype: str = "bfloat32" + # use positional embedding + use_positional_embedding: bool = True + + +def _masked_mean_std( + inputs: torch.Tensor, + padding: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculates mean and standard deviation of `inputs` across axis 1. + + It excludes values where `padding` is 1. + + Args: + inputs: A PyTorch tensor of shape [b, n, p]. + padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. + + Returns: + A tuple containing the mean and standard deviation. + We return the statistics of the first patch with more than three non-padded + values. + """ + # Selecting the first patch with more than 3 unpadded values. + pad_sum = torch.sum(1 - padding, dim=2) + + def _get_patch_index(arr: torch.Tensor): + indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) + row_sum = (arr >= 3).to(torch.int32).sum(dim=1) + return torch.where(row_sum == 0, arr.shape[1] - 1, indices) + + patch_indices = _get_patch_index(pad_sum) + bidxs = torch.arange(inputs.shape[0]) + + arr = inputs[bidxs, patch_indices, :] + pad = padding[bidxs, patch_indices, :] + + # Create a mask where padding is 0 + mask = 1 - pad + + # Calculate the number of valid elements + num_valid_elements = torch.sum(mask, dim=1) + num_valid_elements = torch.where( + num_valid_elements == 0, + torch.tensor(1, + dtype=num_valid_elements.dtype, + device=num_valid_elements.device), + num_valid_elements, + ) + + # Calculate the masked sum and squared sum + masked_sum = torch.sum(arr * mask, dim=1) + masked_squared_sum = torch.sum((arr * mask)**2, dim=1) + + # Calculate the masked mean and standard deviation + masked_mean = masked_sum / num_valid_elements + masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 + masked_var = torch.where( + masked_var < 0.0, + torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device), + masked_var, + ) + masked_std = torch.sqrt(masked_var) + + return masked_mean, masked_std + + +def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: + """Shifts rows of seq based on the first 0 in each row of the mask. + + Args: + mask: mask tensor of shape [B, N] + seq: seq tensor of shape [B, N, P] + + Returns: + Returns the shifted sequence. + """ + batch_size, num_seq, feature_dim = seq.shape + + new_mask: torch.BoolTensor = mask == 0 + + # Use argmax to find the first True value in each row + indices = new_mask.to(torch.int32).argmax(dim=1) + + # Handle rows with all zeros + indices[~new_mask.any(dim=1)] = -1 + + # Create index ranges for each sequence in the batch + idx_range = (torch.arange(num_seq).to( + seq.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1, + feature_dim)) + + # Calculate shifted indices for each element in each sequence + shifted_idx = (idx_range - indices[:, None, None]) % num_seq + + # Gather values from seq using shifted indices + shifted_seq = seq.gather(1, shifted_idx) + + return shifted_seq + + +def get_large_negative_number(dtype: torch.dtype) -> torch.Tensor: + """Returns a large negative value for the given dtype.""" + if dtype.is_floating_point: + dtype_max = torch.finfo(dtype).max + else: + dtype_max = torch.iinfo(dtype).max + return torch.tensor(-0.7 * dtype_max, dtype=dtype) + + +def apply_mask_to_logits(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Applies a floating-point mask to a set of logits. + + Args: + logits: A torch.Tensor of logit values. + mask: A torch.Tensor (float32) of mask values with the encoding described + in the function documentation. + + Returns: + Masked logits. + """ + + min_value = get_large_negative_number(logits.dtype) + + return torch.where((mask >= min_value * 0.5), logits, min_value) + + +def convert_paddings_to_mask(paddings: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Converts binary paddings to a logit mask ready to add to attention matrix. + + Args: + paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding + token. + dtype: data type of the input. + + Returns: + A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits. + """ + attention_mask = paddings.detach().clone() + attention_mask = attention_mask[:, None, None, :] # Equivalent to jnp.newaxis + attention_mask *= get_large_negative_number(dtype) + return attention_mask + + +def causal_mask(input_t: torch.Tensor) -> torch.Tensor: + """Computes and returns causal mask. + + Args: + input_t: A torch.Tensor of shape [B, T, D]. + + Returns: + An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has + already been converted to large negative values. + """ + assert input_t.dtype.is_floating_point, input_t.dtype + large_negative_number = get_large_negative_number(input_t.dtype) + t = input_t.shape[1] + col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1) + row_idx = torch.arange(t).unsqueeze(1).repeat(1, t) + mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number + return mask.unsqueeze(0).unsqueeze(0).to(input_t.device) # Equivalent to jnp.newaxis + + +def merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Merges 2 masks. + + logscale mask is expected but 0/1 mask is also fine. + + Args: + a: torch.Tensor of shape [1|B, 1, 1|T, S]. + b: torch.Tensor of shape [1|B, 1, 1|T, S]. + + Returns: + torch.Tensor of shape [1|B, 1, 1|T, S]. + """ + + def expand_t(key_mask): + query_mask = key_mask.transpose(-1, -2) # Equivalent of jnp.transpose + return torch.minimum(query_mask, key_mask) + + if a.shape[2] != b.shape[2]: + if a.shape[2] == 1: + a = expand_t(a) + else: + assert b.shape[2] == 1 + b = expand_t(b) + + assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}." + return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum + + +class ResidualBlock(nn.Module): + """TimesFM residual block.""" + + def __init__( + self, + input_dims, + hidden_dims, + output_dims, + ): + super(ResidualBlock, self).__init__() + self.input_dims = input_dims + self.hidden_dims = hidden_dims + self.output_dims = output_dims + + # Hidden Layer + self.hidden_layer = nn.Sequential( + nn.Linear(input_dims, hidden_dims), + nn.SiLU(), + ) + + # Output Layer + self.output_layer = nn.Linear(hidden_dims, output_dims) + # Residual Layer + self.residual_layer = nn.Linear(input_dims, output_dims) + + def forward(self, x): + hidden = self.hidden_layer(x) + output = self.output_layer(hidden) + residual = self.residual_layer(x) + return output + residual + + +class RMSNorm(torch.nn.Module): + """Pax rms norm in pytorch.""" + + def __init__( + self, + dim: int, + eps: float = 1e-6, + add_unit_offset: bool = False, + ): + super().__init__() + self.eps = eps + self.add_unit_offset = add_unit_offset + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + if self.add_unit_offset: + output = output * (1 + self.weight.float()) + else: + output = output * self.weight.float() + return output.type_as(x) + + +class TransformerMLP(nn.Module): + """Pax transformer MLP in pytorch.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size) + self.down_proj = nn.Linear(intermediate_size, hidden_size) + self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) + + def forward(self, x, paddings=None): + gate_inp = self.layer_norm(x) + gate = self.gate_proj(gate_inp) + gate = F.relu(gate) + outputs = self.down_proj(gate) + if paddings is not None: + outputs = outputs * (1.0 - paddings[:, :, None]) + return outputs + x + + +class TimesFMAttention(nn.Module): + """Implements the attention used in TimesFM.""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + ): + super().__init__() + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.hidden_size = hidden_size + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = nn.Parameter( + torch.empty((self.head_dim,), dtype=torch.float32),) + + self.qkv_proj = nn.Linear( + self.hidden_size, + (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) + + def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor: + # [batch_size, n_local_heads, input_len, head_dim] + r_softplus_0 = 1.442695041 + softplus_func = torch.nn.Softplus() + scale = r_softplus_0 / math.sqrt(self.head_dim) + scale = scale * softplus_func(self.scaling) + return query * scale[None, None, None, :] + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor, + kv_write_indices: Optional[torch.Tensor] = None, + kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + hidden_states_shape = hidden_states.shape + assert len(hidden_states_shape) == 3 + + batch_size, input_len, _ = hidden_states_shape + + qkv = self.qkv_proj(hidden_states) + xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) + xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) + xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) + xq = self._per_dim_scaling(xq) + + # Write new kv cache. + # [batch_size, input_len, n_local_kv_heads, head_dim] + if kv_cache is not None and kv_write_indices is not None: + k_cache, v_cache = kv_cache + k_cache.index_copy_(1, kv_write_indices, xk) + v_cache.index_copy_(1, kv_write_indices, xv) + + key = k_cache + value = v_cache + else: + key = xk + value = xv + if self.num_kv_heads != self.num_heads: + # [batch_size, max_seq_len, n_local_heads, head_dim] + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) + value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) + + # [batch_size, n_local_heads, input_len, head_dim] + q = xq.transpose(1, 2) + # [batch_size, n_local_heads, max_seq_len, head_dim] + k = key.transpose(1, 2) + v = value.transpose(1, 2) + + # [batch_size, n_local_heads, input_len, max_seq_len] + scores = torch.matmul(q, k.transpose(2, 3)) + scores = scores + mask + scores = F.softmax(scores.float(), dim=-1).type_as(q) + + # [batch_size, n_local_heads, input_len, head_dim] + output = torch.matmul(scores, v) + # return scores, output.transpose(1, 2).contiguous() + + # [batch_size, input_len, hidden_dim] + output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) + output = self.o_proj(output) + return scores, output + + +class TimesFMDecoderLayer(nn.Module): + """Transformer layer.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + rms_norm_eps: float = 1e-6, + ): + super().__init__() + self.self_attn = TimesFMAttention( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + ) + self.mlp = TransformerMLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + ) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor, + paddings: torch.Tensor, + kv_write_indices: Optional[torch.Tensor] = None, + kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + scores, hidden_states = self.self_attn( + hidden_states=hidden_states, + mask=mask, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + ) + hidden_states = residual + hidden_states + + # MLP + hidden_states = self.mlp(hidden_states, paddings=paddings) + + return scores, hidden_states + + +class StackedDecoder(nn.Module): + """Stacked transformer layer.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + num_layers: int, + rms_norm_eps: float = 1e-6, + ): + super().__init__() + + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.layers.append( + TimesFMDecoderLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + )) + + def forward( + self, + hidden_states: torch.Tensor, + paddings: torch.Tensor, + kv_write_indices: Optional[torch.Tensor] = None, + kv_caches: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, + ) -> torch.Tensor: + padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype) + atten_mask = causal_mask(hidden_states) + mask = merge_masks(padding_mask, atten_mask) + for i in range(len(self.layers)): + layer = self.layers[i] + kv_cache = kv_caches[i] if kv_caches is not None else None + _, hidden_states = layer( + hidden_states=hidden_states, + mask=mask, + paddings=paddings, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + ) + return hidden_states + + +class PositionalEmbedding(torch.nn.Module): + """Generates position embedding for a given 1-d sequence. + + Attributes: + min_timescale: Start of the geometric index. Determines the periodicity of + the added signal. + max_timescale: End of the geometric index. Determines the frequency of the + added signal. + embedding_dims: Dimension of the embedding to be generated. + """ + + def __init__( + self, + embedding_dims: int, + min_timescale: int = 1, + max_timescale: int = 10_000, + ) -> None: + super().__init__() + self.min_timescale = min_timescale + self.max_timescale = max_timescale + self.embedding_dims = embedding_dims + + def forward(self, seq_length=None, position=None): + """Generates a Tensor of sinusoids with different frequencies. + + Args: + seq_length: an optional Python int defining the output sequence length. + if the `position` argument is specified. + position: [B, seq_length], optional position for each token in the + sequence, only required when the sequence is packed. + + Returns: + [B, seqlen, D] if `position` is specified, else [1, seqlen, D] + """ + if position is None: + assert seq_length is not None + # [1, seqlen] + position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) + else: + assert position.ndim == 2, position.shape + + num_timescales = self.embedding_dims // 2 + log_timescale_increment = math.log( + float(self.max_timescale) / float(self.min_timescale)) / max( + num_timescales - 1, 1) + inv_timescales = self.min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float32) * + -log_timescale_increment) + scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze( + 0) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + # Padding to ensure correct embedding dimension + signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) + return signal + + +class PatchedTimeSeriesDecoder(nn.Module): + """Patched time-series decoder.""" + + def __init__(self, config: TimesFMConfig): + super().__init__() + self.config = config + self.input_ff_layer = ResidualBlock( + input_dims=2 * config.patch_len, + output_dims=config.hidden_size, + hidden_dims=config.intermediate_size, + ) + self.freq_emb = nn.Embedding(num_embeddings=3, + embedding_dim=config.hidden_size) + self.horizon_ff_layer = ResidualBlock( + input_dims=config.hidden_size, + output_dims=config.horizon_len * (1 + len(config.quantiles)), + hidden_dims=config.intermediate_size, + ) + self.stacked_transformer = StackedDecoder( + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size, + num_heads=self.config.num_heads, + num_kv_heads=self.config.num_kv_heads, + head_dim=self.config.head_dim, + num_layers=self.config.num_layers, + rms_norm_eps=self.config.rms_norm_eps, + ) + if self.config.use_positional_embedding: + self.position_emb = PositionalEmbedding(self.config.hidden_size) + + def _forward_transform( + self, inputs: torch.Tensor, patched_pads: torch.Tensor + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Input is of shape [B, N, P].""" + mu, sigma = _masked_mean_std(inputs, patched_pads) + sigma = torch.where( + sigma < self.config.tolerance, + torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), + sigma, + ) + + # Normalize each patch + outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] + outputs = torch.where( + torch.abs(inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(self.config.pad_val, + dtype=outputs.dtype, + device=outputs.device), + outputs, + ) + return outputs, (mu, sigma) + + def _reverse_transform(self, outputs: torch.Tensor, stats: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + """Output is of shape [B, N, P, Q].""" + mu, sigma = stats + return outputs * sigma[:, None, None, None] + mu[:, None, None, None] + + def _preprocess_input( + self, + input_ts: torch.Tensor, + input_padding: torch.Tensor, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + Optional[Tuple[torch.Tensor, torch.Tensor]], + torch.Tensor, + ]: + """Preprocess input for stacked transformer.""" + + # Reshape into patches (using view for efficiency) + bsize = input_ts.shape[0] + patched_inputs = input_ts.view(bsize, -1, self.config.patch_len) + patched_pads = input_padding.view(bsize, -1, self.config.patch_len) + + patched_inputs = torch.where( + torch.abs(patched_pads - 1.0) < self.config.tolerance, + torch.tensor(0.0, + dtype=patched_inputs.dtype, + device=patched_inputs.device), + patched_inputs, + ) + patched_pads = torch.where( + torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), + patched_pads, + ) + patched_inputs, stats = self._forward_transform(patched_inputs, + patched_pads) + + # B x N x D + patched_inputs = patched_inputs * (1.0 - patched_pads) + concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) + model_input = self.input_ff_layer(concat_inputs) + + # A patch should not be padded even if there is at least one zero. + patched_padding = torch.min(patched_pads, + dim=-1)[0] # Get the values from the min result + if self.config.use_positional_embedding: + pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device) + pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0) + pos_emb = _shift_padded_seq(patched_padding, pos_emb) + model_input += pos_emb + + return model_input, patched_padding, stats, patched_inputs + + def _postprocess_output( + self, + model_output: torch.Tensor, + num_outputs: int, + stats: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + """Postprocess output of stacked transformer.""" + + # B x N x (H.Q) + output_ts = self.horizon_ff_layer(model_output) + + # Reshape using view + b, n, _ = output_ts.shape + output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs) + + return self._reverse_transform(output_ts, stats) + + def forward( + self, + input_ts: torch.Tensor, + input_padding: torch.LongTensor, + freq: torch.Tensor, + ) -> torch.Tensor: + num_outputs = len(self.config.quantiles) + 1 + model_input, patched_padding, stats, _ = self._preprocess_input( + input_ts=input_ts, + input_padding=input_padding, + ) + f_emb = self.freq_emb(freq) # B x 1 x D + model_input += f_emb + model_output = self.stacked_transformer(model_input, patched_padding) + + output_ts = self._postprocess_output(model_output, num_outputs, stats) + return output_ts + + def decode( + self, + input_ts: torch.Tensor, + paddings: torch.Tensor, + freq: torch.LongTensor, + horizon_len: int, + output_patch_len: Optional[int] = None, + max_len: int = 512, + return_forecast_on_context: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Auto-regressive decoding without caching. + + Args: + input_ts: input time-series and paddings. Time-series shape B x C. + paddings: padding shape B x (C + H) where H is the prediction length. + freq: frequency shape B x 1 + horizon_len: prediction length. + output_patch_len: output length to be fetched from one step of + auto-regressive decoding. + max_len: maximum training context length. + return_forecast_on_context: whether to return the model forecast on the + context except the first input patch. + + Returns: + Tuple of two forecasting results: + - Point (mean) output predictions as a tensor with shape B x H'. + - Full predictions (mean and quantiles) as a tensor with shape + B x H' x (1 + # quantiles). + In particular, if return_forecast_on_context is True, H' is H plus + the forecastable context length, i.e. context_len - (first) patch_len. + """ + final_out = input_ts + context_len = final_out.shape[1] + full_outputs = [] + if paddings.shape[1] != final_out.shape[1] + horizon_len: + raise ValueError( + "Length of paddings must match length of input + horizon_len:" + f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}") + if output_patch_len is None: + output_patch_len = self.config.horizon_len + num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len + for step_index in range(num_decode_patches): + current_padding = paddings[:, 0:final_out.shape[1]] + input_ts = final_out[:, -max_len:] + input_padding = current_padding[:, -max_len:] + fprop_outputs = self(input_ts, input_padding, freq) + if return_forecast_on_context and step_index == 0: + # For the first decodings step, collect the model forecast on the + # context except the unavailable first input batch forecast. + new_full_ts = fprop_outputs[:, :-1, :self.config.patch_len, :] + new_full_ts = fprop_outputs.view(new_full_ts.size(0), -1, + new_full_ts.size(3)) + + full_outputs.append(new_full_ts) + + # (full batch, last patch, output_patch_len, index of mean forecast = 0) + new_ts = fprop_outputs[:, -1, :output_patch_len, 0] + new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] + # (full batch, last patch, output_patch_len, all output indices) + full_outputs.append(new_full_ts) + final_out = torch.concat([final_out, new_ts], dim=-1) # TODO torch.concatenate(axis) => torch.concat(dim) + + if return_forecast_on_context: + # `full_outputs` indexing starts at after the first input patch. + full_outputs = torch.concat( # TODO torch.concatenate(axis) => torch.concat(dim) + full_outputs, + dim=1)[:, :(context_len - self.config.patch_len + horizon_len), :] + else: + # `full_outputs` indexing starts at the forecast horizon. + full_outputs = torch.concat(full_outputs, dim=1)[:, 0:horizon_len, :] # TODO torch.concatenate(axis) => torch.concat(dim) + + return full_outputs[:, :, 0], full_outputs diff --git a/etna/libs/timesfm/timesfm.py b/etna/libs/timesfm/timesfm.py new file mode 100644 index 000000000..d46782e3c --- /dev/null +++ b/etna/libs/timesfm/timesfm.py @@ -0,0 +1,325 @@ +""" + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + 1. Definitions. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. +""" + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: Copied from timesfm repository (https://github.com/google-research/timesfm/blob/154248137ccce29b01f4c3a765e85c3d9e4d92ba/src/timesfm/timesfm_torch.py) +# Add method to change horizon after initialization. +# Minor logic change of loading model. + +"""TimesFM pytorch forecast API for inference.""" + +import logging +from os import path +from typing import Any, Sequence, Optional, Tuple +import os +import numpy as np +import torch +from huggingface_hub import snapshot_download +from etna.libs.timesfm import timesfm_base + +from etna.libs.timesfm import patched_decoder as ppd + +_TOL = 1e-6 + + +class TimesFmTorch(timesfm_base.TimesFmBase): + """TimesFM forecast API for inference.""" + + def __post_init__(self): + self._model_config = ppd.TimesFMConfig( + num_layers=self.num_layers, + num_heads=self.num_heads, + hidden_size=self.model_dims, + intermediate_size=self.model_dims, + patch_len=self.input_patch_len, + horizon_len=self.output_patch_len, + head_dim=self.model_dims // self.num_heads, + quantiles=self.quantiles, + use_positional_embedding=self.use_pos_emb, + ) + self._model = None + self.num_cores = 1 + self.global_batch_size = self.per_core_batch_size + self._device = torch.device("cuda:0" if ( + torch.cuda.is_available() and self.backend == "gpu") else "cpu") + self._median_index = -1 + + def _set_horizon(self, horizon): # changed: added to change horizon after initialization + self.horizon_len = horizon + + def load_from_checkpoint( + self, + checkpoint: timesfm_base.TimesFmCheckpoint, + ) -> None: + """Loads a checkpoint and compiles the decoder.""" + checkpoint_path = checkpoint.path + repo_id = checkpoint.huggingface_repo_id + if not os.path.exists(checkpoint_path): # changed: make loading similar to chronos + checkpoint_path = path.join(snapshot_download(checkpoint_path, cache_dir=checkpoint.local_dir), "torch_model.ckpt") + self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) + loaded_checkpoint = torch.load(checkpoint_path) # changed: remove weights_only=True due to attribute absence in low torch versions + logging.info("Loading checkpoint from %s", checkpoint_path) + self._model.load_state_dict(loaded_checkpoint) + logging.info("Sending checkpoint to device %s", f"{self._device}") + self._model.to(self._device) + self._model.eval() + # TODO: add compilation. + + def _forecast( + self, + inputs: Sequence[Any], + freq: Optional[Sequence[int]] = None, + window_size: Optional[int] = None, + forecast_context_len: Optional[int] = None, + return_forecast_on_context: bool = False, + ) -> Tuple[np.ndarray, np.ndarray]: + """Forecasts on a list of time series. + + Args: + inputs: list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + return_forecast_on_context: True to return the forecast on the context + when available, i.e. after the first input patch. + + Returns: + A tuple for JTensors: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). + + Raises: + ValueError: If the checkpoint is not properly loaded. + """ + if not self._model: + raise ValueError( + "Checkpoint not loaded. Call `load_from_checkpoint` before" + " `forecast`.") + if forecast_context_len is None: + fcontext_len = self.context_len + else: + fcontext_len = forecast_context_len + inputs = [np.array(ts)[-fcontext_len:] for ts in inputs] + + if window_size is not None: + new_inputs = [] + for ts in inputs: + new_inputs.extend(timesfm_base.moving_average(ts, window_size)) + inputs = new_inputs + + if freq is None: + logging.info("No frequency provided via `freq`. Default to high (0).") + freq = [0] * len(inputs) + + input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq) + with torch.no_grad(): + mean_outputs = [] + full_outputs = [] + assert input_ts.shape[0] % self.global_batch_size == 0 + for i in range(input_ts.shape[0] // self.global_batch_size): + input_ts_in = torch.from_numpy( + np.array(input_ts[i * self.global_batch_size:(i + 1) * + self.global_batch_size], + dtype=np.float32)).to(self._device) + input_padding_in = torch.from_numpy( + np.array(input_padding[i * self.global_batch_size:(i + 1) * + self.global_batch_size], + dtype=np.float32)).to(self._device) + inp_freq_in = torch.from_numpy( + np.array(inp_freq[ + i * self.global_batch_size:(i + 1) * self.global_batch_size, + :, + ], + dtype=np.int32)).long().to(self._device) + mean_output, full_output = self._model.decode( + input_ts=input_ts_in, + paddings=input_padding_in, + freq=inp_freq_in, + horizon_len=self.horizon_len, + return_forecast_on_context=return_forecast_on_context, + ) + mean_output = mean_output.detach().cpu().numpy() + full_output = full_output.detach().cpu().numpy() + mean_output = np.array(mean_output) + full_output = np.array(full_output) + mean_outputs.append(mean_output) + full_outputs.append(full_output) + + mean_outputs = np.concatenate(mean_outputs, axis=0) + full_outputs = np.concatenate(full_outputs, axis=0) + + if pmap_pad > 0: + mean_outputs = mean_outputs[:-pmap_pad, ...] + full_outputs = full_outputs[:-pmap_pad, ...] + + if window_size is not None: + mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] + full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] + return mean_outputs, full_outputs \ No newline at end of file diff --git a/etna/libs/timesfm/timesfm_base.py b/etna/libs/timesfm/timesfm_base.py new file mode 100644 index 000000000..14755cbb3 --- /dev/null +++ b/etna/libs/timesfm/timesfm_base.py @@ -0,0 +1,812 @@ +""" + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + 1. Definitions. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. +""" + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: Copied from timesfm repository (https://github.com/google-research/timesfm/blob/154248137ccce29b01f4c3a765e85c3d9e4d92ba/src/timesfm/timesfm_base.py) +# replace print with logging + +import warnings + +"""Base class for TimesFM inference. This will be common to PAX and Pytorch.""" + +import collections +import dataclasses +import logging +import multiprocessing +from typing import Any, Literal, Sequence, Optional, Tuple, List, Dict, Union +from pathlib import Path +import numpy as np +import pandas as pd + +from utilsforecast.processing import make_future_dataframe + +from etna.libs.timesfm import xreg_lib + +Category = xreg_lib.Category +XRegMode = xreg_lib.XRegMode + +_TOL = 1e-6 +DEFAULT_QUANTILES = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9) + + +def process_group(key, group, value_name, forecast_context_len): + group = group.tail(forecast_context_len) + return np.array(group[value_name], dtype=np.float32), key + + +def moving_average(arr, window_size): + """Calculates the moving average using NumPy's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = np.pad(arr, (window_size - 1, 0), "constant") + smoothed_arr = (np.convolve(arr_padded, np.ones(window_size), "valid") / + window_size) + return [smoothed_arr, arr - smoothed_arr] + + +def freq_map(freq: Optional[str]): + """Returns the frequency map for the given frequency string.""" + if freq is None: # changed: added this case to handle int timestamps during forecasting with exogenous features + warnings.warn("Frequency is None. Mapping it to 0, that can be not optimal. Better to set it to known frequency") + return 0 + freq = str.upper(freq) + if (freq.endswith("H") or freq.endswith("T") or freq.endswith("MIN") or + freq.endswith("D") or freq.endswith("B") or freq.endswith("U")): + return 0 + elif freq.endswith(("W", "M", "MS")): + return 1 + elif freq.endswith("Y") or freq.endswith("Q"): + return 2 + else: + raise ValueError(f"Invalid frequency: {freq}") + + +# Per time series normalization: forward. +def _normalize(batch): + stats = [ + (np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch + ] + new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)] + return new_batch, stats + + +# Per time series normalization: inverse. +def _renormalize(batch, stats): + return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)] + + +@dataclasses.dataclass() +class TimesFmHparams: + """Hparams used to initialize a TimesFM model for inference. + + These are the sufficient subset of hparams to configure TimesFM inference + agnostic to the checkpoint version, and are not necessarily the same as the + hparams used to train the checkpoint. + + Attributes: + context_len: Largest context length the model allows for each decode call. + This technically can be any large, but practically should set to the + context length the checkpoint was trained with. + horizon_len: Forecast horizon. + input_patch_len: Input patch len. + output_patch_len: Output patch len. How many timepoints is taken from a + single step of autoregressive decoding. Can be set as the training horizon + of the checkpoint. + num_layers: Number of transformer layers in the model. + model_dims: Model dimension. + per_core_batch_size: Batch size on each core for data parallelism. + backend: One of "cpu", "gpu" or "tpu". + quantiles: Which quantiles are output by the model. + """ + + context_len: int = 512 + horizon_len: int = 128 + input_patch_len: int = 32 + output_patch_len: int = 128 + num_layers: int = 20 + num_heads: int = 16 + model_dims: int = 1280 + per_core_batch_size: int = 32 + backend: Literal["cpu", "gpu", "tpu"] = "cpu" + quantiles: Optional[Sequence[float]] = DEFAULT_QUANTILES + use_positional_embedding: bool = True + # Hparams beyond the model. + point_forecast_mode: Literal["mean", "median"] = "median" + + +@dataclasses.dataclass() +class TimesFmCheckpoint: + """Checkpoint used to initialize a TimesFM model for inference. + + Attributes: + version: Version of the checkpoint, e.g. "jax", "torch", "tensorflow", etc. + The factory will create the corresponding TimesFm inference class based on + this version. + path: Path to the checkpoint. + type: If provided, type of the checkpoint used by the specific checkpoint + loader per version. + step: If provided, step of the checkpoint. + """ + + version: str = "jax" + path: Optional[Union[str, Path]] = None + huggingface_repo_id: Optional[str] = None + type: Any = None + step: Optional[int] = None + local_dir: Optional[Union[str, Path]] = None + + +class TimesFmBase: + """Base TimesFM forecast API for inference. + + This class is the scaffolding for calling TimesFM forecast. To properly use: + 1. Create an instance with the correct hyperparameters of a TimesFM model. + 2. Call `load_from_checkpoint` to load a compatible checkpoint. + 3. Call `forecast` for inference. + """ + + def _logging(self, s): + print(s) + + def __post_init__(self) -> None: + """Additional initialization for subclasses before checkpoint loading.""" + pass + + def __init__(self, hparams: TimesFmHparams, + checkpoint: TimesFmCheckpoint) -> None: + """Initializes the TimesFM forecast API. + + Args: + hparams: Hyperparameters of the model. + checkpoint: Checkpoint to load. Notice `checkpoint.version` will decide + which TimesFM version to use. + """ + self.hparams = hparams + + # Expand hparams for conciseness within the model code. + self.context_len = hparams.context_len + self.horizon_len = hparams.horizon_len + self.input_patch_len = hparams.input_patch_len + self.output_patch_len = hparams.output_patch_len + self.num_layers = hparams.num_layers + self.model_dims = hparams.model_dims + self.backend = hparams.backend + self.quantiles = hparams.quantiles + self.num_heads = hparams.num_heads + self.use_pos_emb = hparams.use_positional_embedding + + # Rewrite these values in __post_init__ for SPMD. + self.num_cores = 1 + self.per_core_batch_size = hparams.per_core_batch_size + self.global_batch_size = hparams.per_core_batch_size + + self._horizon_start = self.context_len - self.input_patch_len + self.__post_init__() + self.load_from_checkpoint(checkpoint) + + def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None: + """Loads a checkpoint and compiles the decoder.""" + raise NotImplementedError("`load_from_checkpoint` is not implemented.") + + def _preprocess( + self, inputs: Sequence[np.ndarray], + freq: Sequence[int]) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]: + """Formats and pads raw inputs to feed into the model. + + This function both pads each time series to match the context length, and + pads the inputs to meet the SPMD shape requirement. + + Args: + inputs: A list of 1d JTensors. Each JTensor is the context time series of + a single forecast task. + freq: list of frequencies + + Returns: + A tuple of: + - the padded input time series to meet the model required context. + - the padding indicator. + - the frequency of each input time series. + - the number of padded examples for SPMD so that each core has the same + number (a multiple of `batch_size`) of examples. + """ + + input_ts, input_padding, inp_freq = [], [], [] + + pmap_pad = ((len(inputs) - 1) // self.global_batch_size + + 1) * self.global_batch_size - len(inputs) + + for i, ts in enumerate(inputs): + input_len = ts.shape[0] + padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float) + if input_len < self.context_len: + num_front_pad = self.context_len - input_len + ts = np.concatenate([np.zeros(shape=(num_front_pad,), dtype=float), ts], + axis=0) + padding = np.concatenate( + [np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0) + elif input_len > self.context_len: + ts = ts[-self.context_len:] + padding = padding[-(self.context_len + self.horizon_len):] + + input_ts.append(ts) + input_padding.append(padding) + inp_freq.append(freq[i]) + + # Padding the remainder batch. + for _ in range(pmap_pad): + input_ts.append(input_ts[-1]) + input_padding.append(input_padding[-1]) + inp_freq.append(inp_freq[-1]) + + return ( + np.stack(input_ts, axis=0), + np.stack(input_padding, axis=0), + np.array(inp_freq).astype(np.int32).reshape(-1, 1), + pmap_pad, + ) + + def _forecast( + self, + inputs: Sequence[Any], + freq: Optional[Sequence[int]] = None, + window_size: Optional[int] = None, + forecast_context_len: Optional[int] = None, + return_forecast_on_context: bool = False, + ) -> Tuple[np.ndarray, np.ndarray]: + """Forecasts on a list of time series. + + Args: + inputs: list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + return_forecast_on_context: True to return the forecast on the context + when available, i.e. after the first input patch. + + Returns: + A tuple for np.array: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). + + Raises: + ValueError: If the checkpoint is not properly loaded. + """ + raise NotImplementedError("`_forecast` is not implemented.") + + def forecast( + self, + inputs: Sequence[Any], + freq: Optional[Sequence[int]] = None, + window_size: Optional[int] = None, + forecast_context_len: Optional[int] = None, + return_forecast_on_context: bool = False, + normalize: bool = False, + ) -> Tuple[np.ndarray, np.ndarray]: + """Forecasts on a list of time series. + + Args: + inputs: list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + return_forecast_on_context: True to return the forecast on the context + when available, i.e. after the first input patch. + normalize: If True, then we normalize the inputs before forecasting and + the outputs are then renormalized to the original scale. + + Returns: + A tuple for np.array: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# inputs, # forecast horizon, 1 + # quantiles). + + Raises: + ValueError: If the checkpoint is not properly loaded. + """ + stats = None + if normalize: + inputs, stats = _normalize(inputs) + mean_forecast, quantile_forecast = self._forecast( + inputs, + freq, + window_size, + forecast_context_len, + return_forecast_on_context, + ) + if stats is not None: + stats = np.array(stats) + mu = stats[:, 0] + sigma = stats[:, 1] + mean_forecast = mean_forecast * sigma[:, None] + mu[:, None] + quantile_forecast = (quantile_forecast * sigma[:, None, None] + + mu[:, None, None]) + if self.hparams.point_forecast_mode == "mean": + return mean_forecast, quantile_forecast + elif self.hparams.point_forecast_mode == "median": + if self._median_index == -1: + for i, quantile in enumerate(self.quantiles): + if quantile == 0.5: + self._median_index = i + break + if self._median_index == -1: + raise ValueError("Median (0.5) is not found in the model quantiles:" + f" {self.quantiles}. Please check the hparams.") + return ( + quantile_forecast[:, :, 1 + self._median_index], + quantile_forecast, + ) + else: + raise ValueError( + "Unsupported point forecast mode:" + f" {self.hparams.point_forecast_mode}. Use 'mean' or 'median'.") + + def forecast_with_covariates( + self, + inputs: List[Sequence[float]], + dynamic_numerical_covariates: Optional[Dict[str, Sequence[Sequence[float]]]] = None, + dynamic_categorical_covariates: Optional[Dict[str, Sequence[Sequence[Category]]]] = None, + static_numerical_covariates: Optional[Dict[str, Sequence[float]]] = None, + static_categorical_covariates: Optional[Dict[str, Sequence[Category]]]= None, + freq: Optional[Sequence[int]] = None, + window_size: Optional[int] = None, + forecast_context_len: Optional[int]= None, + xreg_mode: XRegMode = "xreg + timesfm", + normalize_xreg_target_per_input: bool = True, + ridge: float = 0.0, + max_rows_per_col: int = 0, + force_on_cpu: bool = False, + ): + """Forecasts on a list of time series with covariates. + + To optimize inference speed, avoid string valued categorical covariates. + + Args: + inputs: A list of time series forecast contexts. Each context time series + should be in a format convertible to JTensor by `jnp.array`. + dynamic_numerical_covariates: A dict of dynamic numerical covariates. + dynamic_categorical_covariates: A dict of dynamic categorical covariates. + static_numerical_covariates: A dict of static numerical covariates. + static_categorical_covariates: A dict of static categorical covariates. + freq: frequency of each context time series. 0 for high frequency + (default), 1 for medium, and 2 for low. Notice this is different from + the `freq` required by `forecast_on_df`. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + forecast_context_len: optional max context length. + xreg_mode: one of "xreg + timesfm" or "timesfm + xreg". "xreg + timesfm" + fits a model on the residuals of the TimesFM forecast. "timesfm + xreg" + fits a model on the targets then forecasts on the residuals via TimesFM. + normalize_xreg_target_per_input: whether to normalize the xreg target per + input in the given batch. + ridge: ridge penalty for the linear model. + max_rows_per_col: max number of rows per column for the linear model. + force_on_cpu: whether to force running on cpu for the linear model. + + Returns: + A tuple of two lists. The first is the outputs of the model. The second is + the outputs of the xreg. + """ + + # Verify and bookkeep covariates. + if not (dynamic_numerical_covariates or dynamic_categorical_covariates or + static_numerical_covariates or static_categorical_covariates): + raise ValueError( + "At least one of dynamic_numerical_covariates," + " dynamic_categorical_covariates, static_numerical_covariates," + " static_categorical_covariates must be set.") + + # Track the lengths of (1) each input, (2) the part that can be used in the + # linear model, and (3) the horizon. + input_lens, train_lens, test_lens = [], [], [] + + for i, input_ts in enumerate(inputs): + input_len = len(input_ts) + input_lens.append(input_len) + + if xreg_mode == "timesfm + xreg": + # For fitting residuals, no TimesFM forecast on the first patch. + train_lens.append(max(0, input_len - self.input_patch_len)) + elif xreg_mode == "xreg + timesfm": + train_lens.append(input_len) + else: + raise ValueError(f"Unsupported mode: {xreg_mode}") + + if dynamic_numerical_covariates: + test_lens.append( + len(list(dynamic_numerical_covariates.values())[0][i]) - input_len) + elif dynamic_categorical_covariates: + test_lens.append( + len(list(dynamic_categorical_covariates.values())[0][i]) - + input_len) + else: + test_lens.append(self.horizon_len) + + if test_lens[-1] > self.horizon_len: + raise ValueError( + "Forecast requested longer horizon than the model definition " + f"supports: {test_lens[-1]} vs {self.horizon_len}.") + + # Prepare the covariates into train and test. + train_dynamic_numerical_covariates = collections.defaultdict(list) + test_dynamic_numerical_covariates = collections.defaultdict(list) + train_dynamic_categorical_covariates = collections.defaultdict(list) + test_dynamic_categorical_covariates = collections.defaultdict(list) + for covariates, train_covariates, test_covariates in ( + ( + dynamic_numerical_covariates, + train_dynamic_numerical_covariates, + test_dynamic_numerical_covariates, + ), + ( + dynamic_categorical_covariates, + train_dynamic_categorical_covariates, + test_dynamic_categorical_covariates, + ), + ): + if not covariates: + continue + for covariate_name, covariate_values in covariates.items(): + for input_len, train_len, covariate_value in zip( + input_lens, train_lens, covariate_values): + train_covariates[covariate_name].append( + covariate_value[(input_len - train_len):input_len]) + test_covariates[covariate_name].append(covariate_value[input_len:]) + + # Fit models. + if xreg_mode == "timesfm + xreg": + # Forecast via TimesFM then fit a model on the residuals. + mean_outputs, _ = self.forecast( + inputs, + freq, + window_size, + forecast_context_len, + return_forecast_on_context=True, + ) + targets = [ + (np.array(input_ts)[-train_len:] - + mean_output[(self._horizon_start - train_len):self._horizon_start]) + for input_ts, mean_output, train_len in zip(inputs, mean_outputs, + train_lens) + ] + per_instance_stats = None + if normalize_xreg_target_per_input: + targets, per_instance_stats = _normalize(targets) + xregs = xreg_lib.BatchedInContextXRegLinear( + targets=targets, + train_lens=train_lens, + test_lens=test_lens, + train_dynamic_numerical_covariates=train_dynamic_numerical_covariates, + test_dynamic_numerical_covariates=test_dynamic_numerical_covariates, + train_dynamic_categorical_covariates= + train_dynamic_categorical_covariates, + test_dynamic_categorical_covariates= + test_dynamic_categorical_covariates, + static_numerical_covariates=static_numerical_covariates, + static_categorical_covariates=static_categorical_covariates, + ).fit( + ridge=ridge, + one_hot_encoder_drop=None if ridge > 0 else "first", + max_rows_per_col=max_rows_per_col, + force_on_cpu=force_on_cpu, + debug_info=False, + assert_covariates=True, + assert_covariate_shapes=True, + ) + if normalize_xreg_target_per_input: + xregs = _renormalize(xregs, per_instance_stats) + outputs = [ + (mean_output[self._horizon_start:(self._horizon_start + test_len)] + + xreg) + for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs) + ] + + else: + # Fit a model on the targets then forecast on the residuals via TimesFM. + targets = [ + np.array(input_ts)[-train_len:] + for input_ts, train_len in zip(inputs, train_lens) + ] + per_instance_stats = None + if normalize_xreg_target_per_input: + targets, per_instance_stats = _normalize(targets) + xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear( + targets=targets, + train_lens=train_lens, + test_lens=test_lens, + train_dynamic_numerical_covariates=train_dynamic_numerical_covariates, + test_dynamic_numerical_covariates=test_dynamic_numerical_covariates, + train_dynamic_categorical_covariates= + train_dynamic_categorical_covariates, + test_dynamic_categorical_covariates= + test_dynamic_categorical_covariates, + static_numerical_covariates=static_numerical_covariates, + static_categorical_covariates=static_categorical_covariates, + ).fit( + ridge=ridge, + one_hot_encoder_drop=None if ridge > 0 else "first", + max_rows_per_col=max_rows_per_col, + force_on_cpu=force_on_cpu, + debug_info=True, + assert_covariates=True, + assert_covariate_shapes=True, + ) + mean_outputs, _ = self.forecast( + [ + target - xreg_on_context + for target, xreg_on_context in zip(targets, xregs_on_context) + ], + freq, + window_size, + forecast_context_len, + return_forecast_on_context=True, + ) + outputs = [ + (mean_output[self._horizon_start:(self._horizon_start + test_len)] + + xreg) + for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs) + ] + if normalize_xreg_target_per_input: + outputs = _renormalize(outputs, per_instance_stats) + + return outputs, xregs + + def forecast_on_df( + self, + inputs: pd.DataFrame, + freq: str, + forecast_context_len: int = 0, + value_name: str = "values", + model_name: str = "timesfm", + window_size: Optional[int] = None, + num_jobs: int = 1, + verbose: bool = True, + ) -> pd.DataFrame: + """Forecasts on a list of time series. + + Args: + inputs: A pd.DataFrame of all time series. The dataframe should have a + `unique_id` column for identifying the time series, a `ds` column for + timestamps and a value column for the time series values. + freq: string valued `freq` of data. Notice this is different from the + `freq` required by `forecast`. See `freq_map` for allowed values. + forecast_context_len: If provided none zero, we take the last + `forecast_context_len` time-points from each series as the forecast + context instead of the `context_len` set by the model. + value_name: The name of the value column. + model_name: name of the model to be written into future df. + window_size: window size of trend + residual decomposition. If None then + we do not do decomposition. + num_jobs: number of parallel processes to use for dataframe processing. + verbose: output model states in terminal. + + Returns: + Future forecasts dataframe. + """ + if not ("unique_id" in inputs.columns and "ds" in inputs.columns and + value_name in inputs.columns): + raise ValueError( + f"DataFrame must have unique_id, ds and {value_name} columns.") + if not forecast_context_len: + forecast_context_len = self.context_len + logging.info("Preprocessing dataframe.") + df_sorted = inputs.sort_values(by=["unique_id", "ds"]) + new_inputs = [] + uids = [] + if num_jobs == 1: + if verbose: + logging.info("Processing dataframe with single process.") # changed: replace print + for key, group in df_sorted.groupby("unique_id"): + inp, uid = process_group( + key, + group, + value_name, + forecast_context_len, + ) + new_inputs.append(inp) + uids.append(uid) + else: + if num_jobs == -1: + num_jobs = multiprocessing.cpu_count() + if verbose: + logging.info("Processing dataframe with multiple processes.") # changed: replace print + with multiprocessing.Pool(processes=num_jobs) as pool: + results = pool.starmap( + process_group, + [(key, group, value_name, forecast_context_len) + for key, group in df_sorted.groupby("unique_id")], + ) + new_inputs, uids = zip(*results) + if verbose: + logging.info("Finished preprocessing dataframe.") # changed: replace print + freq_inps = [freq_map(freq)] * len(new_inputs) + _, full_forecast = self.forecast(new_inputs, + freq=freq_inps, + window_size=window_size) + if verbose: + logging.info("Finished forecasting.") + fcst_df = make_future_dataframe( + uids=uids, + last_times=df_sorted.groupby("unique_id")["ds"].tail(1), + h=self.horizon_len, + freq=freq, + ) + fcst_df[model_name] = full_forecast[:, 0:self.horizon_len, 0].reshape(-1, 1) + + for i, q in enumerate(self.quantiles): + q_col = f"{model_name}-q-{q}" + fcst_df[q_col] = full_forecast[:, 0:self.horizon_len, + 1 + i].reshape(-1, 1) + if q == 0.5: + fcst_df[model_name] = fcst_df[q_col] + logging.info("Finished creating output dataframe.") + return fcst_df \ No newline at end of file diff --git a/etna/libs/timesfm/xreg_lib.py b/etna/libs/timesfm/xreg_lib.py new file mode 100644 index 000000000..4521009bf --- /dev/null +++ b/etna/libs/timesfm/xreg_lib.py @@ -0,0 +1,643 @@ +""" + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + 1. Definitions. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. +""" + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: Copied from timesfm repository (https://github.com/google-research/timesfm/blob/154248137ccce29b01f4c3a765e85c3d9e4d92ba/src/timesfm/xreg_lib.py) +# add check of sklearn version for OHE +"""Helper functions for in-context covariates and regression.""" + +import itertools +import math +from typing import Any, Iterable, Literal, Mapping, Sequence, Union, Optional, Tuple, List + +import jax +import jax.numpy as jnp +import numpy as np +from sklearn import preprocessing +from sklearn import __version__ as sklearn_version + +Category = Union[int, str] + +_TOL = 1e-6 +XRegMode = Literal["timesfm + xreg", "xreg + timesfm"] + + +def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray: + return np.array(list(itertools.chain.from_iterable(nested))) + + +def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray: + return np.array( + list( + itertools.chain.from_iterable(map(itertools.repeat, elements, + counts)))) + + +def _to_padded_jax_array(x: np.ndarray) -> jax.Array: + if x.ndim == 1: + (i,) = x.shape + di = 2**math.ceil(math.log2(i)) - i + return jnp.pad(x, ((0, di),), mode="constant", constant_values=0.0) + elif x.ndim == 2: + i, j = x.shape + di = 2**math.ceil(math.log2(i)) - i + dj = 2**math.ceil(math.log2(j)) - j + return jnp.pad(x, ((0, di), (0, dj)), mode="constant", constant_values=0.0) + else: + raise ValueError(f"Unsupported array shape: {x.shape}") + + +class BatchedInContextXRegBase: + """Helper class for in-context regression covariate formatting. + + Attributes: + targets: List of targets (responses) of the in-context regression. + train_lens: List of lengths of each target vector from the context. + test_lens: List of lengths of each forecast horizon. + train_dynamic_numerical_covariates: Dict of covariate names mapping to the + dynamic numerical covariates of each forecast task on the context. Their + lengths should match the corresponding lengths in `train_lens`. + train_dynamic_categorical_covariates: Dict of covariate names mapping to the + dynamic categorical covariates of each forecast task on the context. Their + lengths should match the corresponding lengths in `train_lens`. + test_dynamic_numerical_covariates: Dict of covariate names mapping to the + dynamic numerical covariates of each forecast task on the horizon. Their + lengths should match the corresponding lengths in `test_lens`. + test_dynamic_categorical_covariates: Dict of covariate names mapping to the + dynamic categorical covariates of each forecast task on the horizon. Their + lengths should match the corresponding lengths in `test_lens`. + static_numerical_covariates: Dict of covariate names mapping to the static + numerical covariates of each forecast task. + static_categorical_covariates: Dict of covariate names mapping to the static + categorical covariates of each forecast task. + """ + + def __init__( + self, + targets: Sequence[Sequence[float]], + train_lens: Sequence[int], + test_lens: Sequence[int], + train_dynamic_numerical_covariates: Optional[Mapping[str, Sequence[Sequence[float]]]] = None, + train_dynamic_categorical_covariates: Optional[Mapping[str, Sequence[Sequence[Category]]]] = None, + test_dynamic_numerical_covariates: Optional[Mapping[str, Sequence[Sequence[float]]]] = None, + test_dynamic_categorical_covariates: Optional[Mapping[str, Sequence[Sequence[Category]]]] = None, + static_numerical_covariates: Optional[Mapping[str, Sequence[float]]] = None, + static_categorical_covariates: Optional[Mapping[str, Sequence[Category]]] = None, + ) -> None: + """Initializes with the exogenous covariate inputs. + + Here we use model fitting language to refer to the context as 'train' and + the horizon as 'test'. We assume batched inputs. To properly format the + request: + + - `train_lens` represents the contexts in the batch. Targets and all train + dynamic covariates should have the same lengths as the corresponding + elements + in `train_lens`. Notice each `train_len` can be different from the exact + length of the corresponding context depending on how much of the context is + used for fitting the in-context model. + - `test_lens` represents the horizon lengths in the batch. All tesdt + dynamic + covariates should have the same lengths as the corresponding elements in + `test_lens`. + - Static covariates should be one for each input. + - For train and test dynamic covariates, they should have the same + covariate + names. + + Pass an empty dict {} for a covariate type if it is not present. + + Example: + Here is a set of valid inputs whose schema can be used for reference. + ``` + targets = [ + [0.0, 0.1, 0.2], + [0.0, 0.1, 0.2, 0.3], + ] # Two inputs in this batch. + train_lens = [3, 4] + test_lens = [2, 5] # Forecast horizons 2 and 5 respectively. + train_dynamic_numerical_covariates = { + "cov_1_dn": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]], + "cov_2_dn": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]], + } # Each train dynamic covariate has 3 and 4 elements respectively. + test_dynamic_numerical_covariates = { + "cov_1_dn": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]], + "cov_2_dn": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]], + } # Each test dynamic covariate has 2 and 5 elements respectively. + train_dynamic_categorical_covariates = { + "cov_1_dc": [[0, 1, 0], [0, 1, 2, 3]], + "cov_2_dc": [["good", "bad", "good"], ["good", "good", "bad", + "bad"]], + } + test_dynamic_categorical_covariates = { + "cov_1_dc": [[1, 0], [1, 0, 2, 3, 1]], + "cov_2_dc": [["bad", "good"], ["bad", "bad", "bad", "bad", "bad"]], + } + static_numerical_covariates = { + "cov_1_sn": [0.0, 3.0], + "cov_2_sn": [2.0, 1.0], + "cov_3_sn": [1.0, 2.0], + } # Each static covariate has 1 element for each input. + static_categorical_covariates = { + "cov_1_sc": ["apple", "orange"], + "cov_2_sc": [2, 3], + } + ``` + + Args: + targets: List of targets (responses) of the in-context regression. + train_lens: List of lengths of each target vector from the context. + test_lens: List of lengths of each forecast horizon. + train_dynamic_numerical_covariates: Dict of covariate names mapping to the + dynamic numerical covariates of each forecast task on the context. Their + lengths should match the corresponding lengths in `train_lens`. + train_dynamic_categorical_covariates: Dict of covariate names mapping to + the dynamic categorical covariates of each forecast task on the context. + Their lengths should match the corresponding lengths in `train_lens`. + test_dynamic_numerical_covariates: Dict of covariate names mapping to the + dynamic numerical covariates of each forecast task on the horizon. Their + lengths should match the corresponding lengths in `test_lens`. + test_dynamic_categorical_covariates: Dict of covariate names mapping to + the dynamic categorical covariates of each forecast task on the horizon. + Their lengths should match the corresponding lengths in `test_lens`. + static_numerical_covariates: Dict of covariate names mapping to the static + numerical covariates of each forecast task. + static_categorical_covariates: Dict of covariate names mapping to the + static categorical covariates of each forecast task. + """ + self.targets = targets + self.train_lens = train_lens + self.test_lens = test_lens + self.train_dynamic_numerical_covariates = ( + train_dynamic_numerical_covariates or {}) + self.train_dynamic_categorical_covariates = ( + train_dynamic_categorical_covariates or {}) + self.test_dynamic_numerical_covariates = (test_dynamic_numerical_covariates + or {}) + self.test_dynamic_categorical_covariates = ( + test_dynamic_categorical_covariates or {}) + self.static_numerical_covariates = static_numerical_covariates or {} + self.static_categorical_covariates = static_categorical_covariates or {} + + def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None: + """Verifies the validity of the covariate inputs.""" + + # Check presence. + if (self.train_dynamic_numerical_covariates and + not self.test_dynamic_numerical_covariates) or ( + not self.train_dynamic_numerical_covariates and + self.test_dynamic_numerical_covariates): + raise ValueError( + "train_dynamic_numerical_covariates and" + " test_dynamic_numerical_covariates must be both present or both" + " absent.") + + if (self.train_dynamic_categorical_covariates and + not self.test_dynamic_categorical_covariates) or ( + not self.train_dynamic_categorical_covariates and + self.test_dynamic_categorical_covariates): + raise ValueError( + "train_dynamic_categorical_covariates and" + " test_dynamic_categorical_covariates must be both present or both" + " absent.") + + # Check keys. + for dict_a, dict_b, dict_a_name, dict_b_name in ( + ( + self.train_dynamic_numerical_covariates, + self.test_dynamic_numerical_covariates, + "train_dynamic_numerical_covariates", + "test_dynamic_numerical_covariates", + ), + ( + self.train_dynamic_categorical_covariates, + self.test_dynamic_categorical_covariates, + "train_dynamic_categorical_covariates", + "test_dynamic_categorical_covariates", + ), + ): + if w := set(dict_a.keys()) - set(dict_b.keys()): + raise ValueError( + f"{dict_a_name} has keys not present in {dict_b_name}: {w}") + if w := set(dict_b.keys()) - set(dict_a.keys()): + raise ValueError( + f"{dict_b_name} has keys not present in {dict_a_name}: {w}") + + # Check shapes. + if assert_covariate_shapes: + if len(self.targets) != len(self.train_lens): + raise ValueError( + "targets and train_lens must have the same number of elements.") + + if len(self.train_lens) != len(self.test_lens): + raise ValueError( + "train_lens and test_lens must have the same number of elements.") + + for i, (target, train_len) in enumerate(zip(self.targets, + self.train_lens)): + if len(target) != train_len: + raise ValueError( + f"targets[{i}] has length {len(target)} != expected {train_len}.") + + for key, values in self.static_numerical_covariates.items(): + if len(values) != len(self.train_lens): + raise ValueError( + f"static_numerical_covariates has key {key} with number of" + f" examples {len(values)} != expected {len(self.train_lens)}.") + + for key, values in self.static_categorical_covariates.items(): + if len(values) != len(self.train_lens): + raise ValueError( + f"static_categorical_covariates has key {key} with number of" + f" examples {len(values)} != expected {len(self.train_lens)}.") + + for lens, dict_cov, dict_cov_name in ( + ( + self.train_lens, + self.train_dynamic_numerical_covariates, + "train_dynamic_numerical_covariates", + ), + ( + self.train_lens, + self.train_dynamic_categorical_covariates, + "train_dynamic_categorical_covariates", + ), + ( + self.test_lens, + self.test_dynamic_numerical_covariates, + "test_dynamic_numerical_covariates", + ), + ( + self.test_lens, + self.test_dynamic_categorical_covariates, + "test_dynamic_categorical_covariates", + ), + ): + for key, cov_values in dict_cov.items(): + if len(cov_values) != len(lens): + raise ValueError( + f"{dict_cov_name} has key {key} with number of examples" + f" {len(cov_values)} != expected {len(lens)}.") + for i, cov_value in enumerate(cov_values): + if len(cov_value) != lens[i]: + raise ValueError( + f"{dict_cov_name} has key {key} with its {i}-th example" + f" length {len(cov_value)} != expected {lens[i]}.") + + def create_covariate_matrix( + self, + one_hot_encoder_drop: Optional[str]= "first", + use_intercept: bool = True, + assert_covariates: bool = False, + assert_covariate_shapes: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Creates target vector and covariate matrices for in context regression. + + Here we use model fitting language to refer to the context as 'train' and + the horizon as 'test'. + + Args: + one_hot_encoder_drop: Which drop strategy to use for the one hot encoder. + use_intercept: Whether to prepare an intercept (all 1) column in the + matrices. + assert_covariates: Whether to assert the validity of the covariate inputs. + assert_covariate_shapes: Whether to assert the shapes of the covariate + inputs when `assert_covariates` is True. + + Returns: + A tuple of the target vector, the covariate matrix for the context, and + the covariate matrix for the horizon. + """ + if assert_covariates: + self._assert_covariates(assert_covariate_shapes) + + x_train, x_test = [], [] + + # Numerical features. + for name in sorted(self.train_dynamic_numerical_covariates): + x_train.append( + _unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis]) + x_test.append( + _unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis]) + + for covs in self.static_numerical_covariates.values(): + x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis]) + x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis]) + + if x_train: + x_train = np.concatenate(x_train, axis=1) + x_test = np.concatenate(x_test, axis=1) + + # Normalize for robustness. + x_mean = np.mean(x_train, axis=0, keepdims=True) + x_std = np.where((w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w, + 1.0) + x_train = [(x_train - x_mean) / x_std] + x_test = [(x_test - x_mean) / x_std] + + sklearn_version_tuple = tuple(map(int, sklearn_version.split("."))) + encoder_params = {} + if sklearn_version_tuple < (1, 2): + encoder_params["sparse"] = False + else: + encoder_params["sparse_output"] = False + + # Categorical features. Encode one by one. + one_hot_encoder = preprocessing.OneHotEncoder( + drop=one_hot_encoder_drop, + handle_unknown="ignore", + **encoder_params + ) + for name in sorted(self.train_dynamic_categorical_covariates.keys()): + ohe_train = _unnest( + self.train_dynamic_categorical_covariates[name])[:, np.newaxis] + ohe_test = _unnest( + self.test_dynamic_categorical_covariates[name])[:, np.newaxis] + x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train))) + x_test.append(np.array(one_hot_encoder.transform(ohe_test))) + + for covs in self.static_categorical_covariates.values(): + ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis]) + x_train.append(_repeat(ohe, self.train_lens)) + x_test.append(_repeat(ohe, self.test_lens)) + + x_train = np.concatenate(x_train, axis=1) + x_test = np.concatenate(x_test, axis=1) + + if use_intercept: + x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0) + x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0) + + return _unnest(self.targets), x_train, x_test + + def fit(self) -> Any: + raise NotImplementedError("Fit is not implemented.") + + +class BatchedInContextXRegLinear(BatchedInContextXRegBase): + """Linear in-context regression model.""" + + def fit( + self, + ridge: float = 0.0, + one_hot_encoder_drop: Optional[str] = "first", + use_intercept: bool = True, + force_on_cpu: bool = False, + max_rows_per_col: int = 0, + max_rows_per_col_sample_seed: int = 42, + debug_info: bool = False, + assert_covariates: bool = False, + assert_covariate_shapes: bool = False, + ) -> Union[List[np.ndarray], Tuple[List[np.ndarray], List[np.ndarray], jax.Array, jax.Array, jax.Array]]: + """Fits a linear model for in-context regression. + + Args: + ridge: A non-negative value for specifying the ridge regression penalty. + If 0 is provided, fallback to ordinary least squares. Note this penalty + is added to the normalized covariate matrix. + one_hot_encoder_drop: Which drop strategy to use for the one hot encoder. + use_intercept: Whether to prepare an intercept (all 1) column in the + matrices. + force_on_cpu: Whether to force execution on cpu for accelerator machines. + max_rows_per_col: How many rows to subsample per column. 0 for no + subsampling. This is for speeding up model fitting. + max_rows_per_col_sample_seed: The seed for the subsampling if needed by + `max_rows_per_col`. + debug_info: Whether to return debug info. + assert_covariates: Whether to assert the validity of the covariate inputs. + assert_covariate_shapes: Whether to assert the shapes of the covariate + inputs when `assert_covariates` is True. + + Returns: + If `debug_info` is False: + The linear fits on the horizon. + If `debug_info` is True: + A tuple of: + - the linear fits on the horizon, + - the linear fits on the context, + - the flattened target vector, + - the covariate matrix for the context, and + - the covariate matrix for the horizon. + """ + flat_targets, x_train_raw, x_test = self.create_covariate_matrix( + one_hot_encoder_drop=one_hot_encoder_drop, + use_intercept=use_intercept, + assert_covariates=assert_covariates, + assert_covariate_shapes=assert_covariate_shapes, + ) + + x_train = x_train_raw.copy() + if max_rows_per_col: + nrows, ncols = x_train.shape + if nrows > (w := ncols * max_rows_per_col): + subsample = jax.random.choice( + jax.random.PRNGKey(max_rows_per_col_sample_seed), + nrows, + (w,), + replace=False, + ) + x_train = x_train[subsample] + flat_targets = flat_targets[subsample] + + device = jax.devices("cpu")[0] if force_on_cpu else None + # Runs jitted version of the solvers which are quicker at the cost of + # running jitting during the first time calling. Re-jitting happens whenever + # new (padded) shapes are encountered. + # Occasionally it helps with the speed and the accuracy if we force single + # thread execution on cpu for accelerator machines: + # 1. Avoid moving data to accelerator memory. + # 2. Avoid precision loss if any. + with jax.default_device(device): + x_train_raw = _to_padded_jax_array(x_train_raw) + x_train = _to_padded_jax_array(x_train) + flat_targets = _to_padded_jax_array(flat_targets) + x_test = _to_padded_jax_array(x_test) + beta_hat = (jnp.linalg.pinv( + x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]), + hermitian=True, + ) @ x_train.T @ flat_targets) + y_hat = x_test @ beta_hat + y_hat_context = x_train_raw @ beta_hat if debug_info else None + + outputs = [] + outputs_context = [] + + # Reconstruct the ragged 2-dim batched forecasts from flattened linear fits. + train_index, test_index = 0, 0 + for train_index_delta, test_index_delta in zip(self.train_lens, + self.test_lens): + outputs.append(np.array(y_hat[test_index:(test_index + + test_index_delta)])) + if debug_info: + outputs_context.append( + np.array(y_hat_context[train_index:(train_index + + train_index_delta)])) + train_index += train_index_delta + test_index += test_index_delta + + if debug_info: + return outputs, outputs_context, flat_targets, x_train, x_test + else: + return outputs \ No newline at end of file diff --git a/etna/models/nn/__init__.py b/etna/models/nn/__init__.py index b972e2aab..4512ac420 100644 --- a/etna/models/nn/__init__.py +++ b/etna/models/nn/__init__.py @@ -16,3 +16,6 @@ if SETTINGS.chronos_required: from etna.models.nn.chronos import ChronosBoltModel from etna.models.nn.chronos import ChronosModel + +if SETTINGS.timesfm_required: + from etna.models.nn.timesfm import TimesFMModel diff --git a/etna/models/nn/chronos/base.py b/etna/models/nn/chronos/base.py index a20514a1a..7505b150c 100644 --- a/etna/models/nn/chronos/base.py +++ b/etna/models/nn/chronos/base.py @@ -202,10 +202,9 @@ def _forecast( if max_context_size < self.context_size: warnings.warn("Actual length of a dataset is less that context size. All history will be used as context.") - available_context_size = min(max_context_size, self.context_size) - target = ts.df.loc[:, pd.IndexSlice[:, "target"]] - context = torch.tensor(target.values.T[:, :available_context_size]) + target = ts.df.loc[:, pd.IndexSlice[:, "target"]].dropna() + context = torch.tensor(target.values.T) if prediction_interval: quantiles_forecast, target_forecast = self.pipeline.predict_quantiles( diff --git a/etna/models/nn/timesfm.py b/etna/models/nn/timesfm.py new file mode 100644 index 000000000..793665269 --- /dev/null +++ b/etna/models/nn/timesfm.py @@ -0,0 +1,376 @@ +import os +import reprlib +import warnings +from pathlib import Path +from typing import Dict +from typing import List +from typing import Literal +from typing import Optional +from urllib import request + +import numpy as np +import pandas as pd + +from etna import SETTINGS +from etna.datasets import TSDataset +from etna.distributions import BaseDistribution +from etna.models.base import NonPredictionIntervalContextRequiredAbstractModel + +if SETTINGS.timesfm_required: + from etna.libs.timesfm import TimesFmCheckpoint + from etna.libs.timesfm import TimesFmHparams + from etna.libs.timesfm import TimesFmTorch + from etna.libs.timesfm.timesfm_base import freq_map + +_DOWNLOAD_PATH = Path.home() / ".etna" / "timesfm" + + +class TimesFMModel(NonPredictionIntervalContextRequiredAbstractModel): + """ + Class for pretrained timesfm models. + + This model is only for zero-shot forecasting: it doesn't support training on data during ``fit``. + + This model doesn't support forecasting on misaligned data with `freq=None` without exogenous features. + + This model doesn't support NaN in the middle or at the end of target and exogenous features. + Use :py:class:`~etna.transforms.TimeSeriesImputerTransform` to fill them. + + Official implementation: https://github.com/google-research/timesfm + + Note + ---- + This model requires ``timesfm`` extension to be installed. + Read more about this at :ref:`installation page `. + """ + + def __init__( + self, + path_or_url: str, + encoder_length: int = 512, + device: Literal["cpu", "gpu"] = "cpu", + batch_size: int = 128, + static_reals: Optional[List[str]] = None, + static_categoricals: Optional[List[str]] = None, + time_varying_reals: Optional[List[str]] = None, + time_varying_categoricals: Optional[List[str]] = None, + cache_dir: Path = _DOWNLOAD_PATH, + ): + """ + Init TimesFM model. + + Parameters + ---------- + path_or_url: + Path to the model. It can be huggingface repository, local path or external url. + + - If huggingface repository, the available models are: + + - 'google/timesfm-1.0-200m-pytorch'. + During the first initialization model is downloaded from huggingface and saved to local ``cache_dir``. + All following initializations model will be loaded from ``cache_dir``. + - If local path, it should be a file with model weights, that can be loaded by :py:func:`torch.load`. + - If external url, it must be a file with model weights, that can be loaded by :py:func:`torch.load`. Model will be downloaded to ``cache_dir``. + device: + Device type. Can be "cpu" or "gpu". + encoder_length: + Number of last timestamps to use as a context. It needs to be a multiplier of 32. + batch_size: + Batch size. It can be useful when inference is done on gpu. + static_reals: + Continuous features that have one unique feature value for the whole series. The first value in the series will be used for each feature. + static_categoricals: + Categorical features that have one unique feature value for the whole series. The first value in the series will be used for each feature. + time_varying_reals: + Time varying continuous features known for future. + time_varying_categoricals: + Time varying categorical features known for future. + cache_dir: + Local path to save model from huggingface during first model initialization. All following class initializations appropriate model version will be downloaded from this path. + """ + self.path_or_url = path_or_url + self.encoder_length = encoder_length + self.device = device + self.batch_size = batch_size + self.static_reals = static_reals + self.static_categoricals = static_categoricals + self.time_varying_reals = time_varying_reals + self.time_varying_categoricals = time_varying_categoricals + self.cache_dir = cache_dir + + self._set_pipeline() + + def _set_pipeline(self): + """Set ``tfm`` attribute.""" + if self._is_url(): + full_model_path = self._download_model_from_url() + self.tfm = TimesFmTorch( + hparams=TimesFmHparams( + context_len=self.encoder_length, per_core_batch_size=self.batch_size, backend=self.device + ), + checkpoint=TimesFmCheckpoint(path=full_model_path), + ) + else: + self.tfm = TimesFmTorch( + hparams=TimesFmHparams( + context_len=self.encoder_length, per_core_batch_size=self.batch_size, backend=self.device + ), + checkpoint=TimesFmCheckpoint(path=self.path_or_url, local_dir=self.cache_dir), + ) + + def _is_url(self): + """Check whether ``path_or_url`` is url.""" + return self.path_or_url.startswith("https://") or self.path_or_url.startswith("http://") + + def _download_model_from_url(self) -> str: + """Download model from url to local cache_dir.""" + model_file = self.path_or_url.split("/")[-1] + full_model_path = f"{self.cache_dir}/{model_file}" + if not os.path.exists(full_model_path): + request.urlretrieve(url=self.path_or_url, filename=full_model_path) + return full_model_path + + @property + def context_size(self) -> int: + """Context size for model.""" + return self.encoder_length + + def get_model(self) -> TimesFmTorch: + """Get model.""" + return self.tfm + + def fit(self, ts: TSDataset): + """Fit model. + + For this model, fit does nothing. + + Parameters + ---------- + ts: + Dataset with features. + + Returns + ------- + : + Model after fit + """ + return self + + def predict( + self, + ts: TSDataset, + prediction_size: int, + return_components: bool = False, + ) -> TSDataset: + """Make predictions using true values as autoregression context (teacher forcing). + + Parameters + ---------- + ts: + Dataset with features. + prediction_size: + Number of last timestamps to leave after making prediction. + Previous timestamps will be used as a context. + return_components: + If True additionally returns forecast components. + + Returns + ------- + : + Dataset with predictions. + """ + raise NotImplementedError("Method predict isn't currently implemented!") + + def _exog_columns(self) -> List[str]: + static_reals = [] if self.static_reals is None else self.static_reals + static_categoricals = [] if self.static_categoricals is None else self.static_categoricals + time_varying_reals = [] if self.time_varying_reals is None else self.time_varying_reals + time_varying_categoricals = [] if self.time_varying_categoricals is None else self.time_varying_categoricals + + return static_reals + static_categoricals + time_varying_reals + time_varying_categoricals + + def forecast( + self, + ts: TSDataset, + prediction_size: int, + return_components: bool = False, + ) -> TSDataset: + """Make autoregressive forecasts. + + Parameters + ---------- + ts: + Dataset with features. + prediction_size: + Number of last timestamps to leave after making prediction. + Previous timestamps will be used as a context. + return_components: + If True additionally returns forecast components. + + Returns + ------- + : + Dataset with predictions. + + Raises + ------ + NotImplementedError: + if return_components mode is used. + ValueError: + if dataset doesn't have any context timestamps. + ValueError: + if there are NaNs in the middle or end of the time series. + NotImplementedError: + if forecasting is done without exogenous features and dataset has None frequency. + """ + if return_components: + raise NotImplementedError("This mode isn't currently implemented!") + + max_context_size = len(ts.index) - prediction_size + if max_context_size <= 0: + raise ValueError("Dataset doesn't have any context timestamps.") + + if max_context_size < self.context_size: + warnings.warn("Actual length of a dataset is less that context size. All history will be used as context.") + + self.tfm._set_horizon(prediction_size) + + end_idx = len(ts.index) + + all_exog = self._exog_columns() + df_slice = ts.df.loc[:, pd.IndexSlice[:, all_exog + ["target"]]] + first_valid_index = ( + df_slice.isna().any(axis=1).idxmin() + ) # If all timestamps contains NaNs, idxmin() returns the first timestamp + + target_df = df_slice.loc[first_valid_index : ts.index[-prediction_size - 1], pd.IndexSlice[:, "target"]] + + nan_segment_mask = target_df.isna().any() + if nan_segment_mask.any(): + nan_segments = nan_segment_mask.loc[:, nan_segment_mask].index.get_level_values(0).unique().tolist() + raise ValueError( + f"There are NaNs in the middle or at the end of target. Segments with NaNs: {reprlib.repr(nan_segments)}." + ) + + future_ts = ts.tsdataset_idx_slice(start_idx=end_idx - prediction_size, end_idx=end_idx) + + if len(all_exog) > 0: + target = target_df.values.swapaxes(1, 0).tolist() + + exog_df = df_slice.loc[first_valid_index:, pd.IndexSlice[:, all_exog]] + + nan_segment_mask = exog_df.isna().any() + if nan_segment_mask.any(): + nan_segments = nan_segment_mask.loc[:, nan_segment_mask].index.get_level_values(0).unique().tolist() + raise ValueError( + f"There are NaNs in the middle or at the end of exogenous features. Segments with NaNs: {reprlib.repr(nan_segments)}." + ) + + static_reals_dict = ( + { + column: exog_df.loc[exog_df.index[0], pd.IndexSlice[:, column]].values.tolist() + for column in self.static_reals + } + if self.static_reals is not None + else None + ) + static_categoricals_dict = ( + { + column: exog_df.loc[exog_df.index[0], pd.IndexSlice[:, column]].values.tolist() + for column in self.static_categoricals + } + if self.static_categoricals is not None + else None + ) + time_varying_reals_dict = ( + { + column: exog_df.loc[:, pd.IndexSlice[:, column]].values.swapaxes(1, 0).tolist() + for column in self.time_varying_reals + } + if self.time_varying_reals is not None + else None + ) + time_varying_categoricals_dict = ( + { + column: exog_df.loc[:, pd.IndexSlice[:, column]].values.swapaxes(1, 0).tolist() + for column in self.time_varying_categoricals + } + if self.time_varying_categoricals is not None + else None + ) + + complex_forecast, _ = self.tfm.forecast_with_covariates( + inputs=target, + dynamic_numerical_covariates=time_varying_reals_dict, + dynamic_categorical_covariates=time_varying_categoricals_dict, + static_numerical_covariates=static_reals_dict, + static_categorical_covariates=static_categoricals_dict, + freq=[freq_map(ts.freq)] * len(ts.segments), + ) + future_ts.df.loc[:, pd.IndexSlice[:, "target"]] = np.vstack(complex_forecast).swapaxes(1, 0) + else: + if ts.freq is None: + raise NotImplementedError( + "Forecasting misaligned data with freq=None without exogenous features isn't currently implemented." + ) + + target = TSDataset.to_flatten(df=target_df) + target = target.rename(columns={"segment": "unique_id", "timestamp": "ds"}) + + predictions = self.tfm.forecast_on_df(target, freq=ts.freq, value_name="target") + + predictions = predictions.rename(columns={"unique_id": "segment", "ds": "timestamp", "timesfm": "target"}) + predictions = TSDataset.to_dataset(predictions) + future_ts.df.loc[:, pd.IndexSlice[:, "target"]] = predictions.loc[ + :, pd.IndexSlice[:, "target"] + ].values # .values is needed to cast predictions type of initial target type in ts + return future_ts + + @staticmethod + def list_models() -> List[str]: + """ + Return a list of available pretrained timesfm models. + + Returns + ------- + : + List of available pretrained chronos models. + """ + return ["google/timesfm-1.0-200m-pytorch"] + + def save(self, path: Path): + """Save the model. This method doesn't save model's weights. + + During ``load`` weights are loaded from the path where they were saved during ``init`` + + Parameters + ---------- + path: + Path to save object to. + """ + self._save(path=path, skip_attributes=["tfm"]) + + @classmethod + def load(cls, path: Path): + """Load the model. + + Parameters + ---------- + path: + Path to load object from. + """ + obj: TimesFMModel = super().load(path=path) + obj._set_pipeline() + return obj + + def params_to_tune(self) -> Dict[str, BaseDistribution]: + """Get default grid for tuning hyperparameters. + + This grid is empty. + + Returns + ------- + : + Grid to tune. + """ + return {} diff --git a/etna/settings.py b/etna/settings.py index 987525548..afbcc5d8a 100644 --- a/etna/settings.py +++ b/etna/settings.py @@ -52,6 +52,21 @@ def _is_chronos_available(): return False +def _is_timesfm_available(): + true_case = ( + _module_available("torch") + & _module_available("jax") + & _module_available("jaxlib") + & _module_available("huggingface_hub") + & _module_available("utilsforecast") + ) + if true_case: + return True + else: + warnings.warn("etna[timesfm] is not available, to install it, run `pip install etna[timesfm]`") + return False + + def _is_wandb_available(): if _module_available("wandb"): return True @@ -112,6 +127,7 @@ def __init__( # noqa: D107 self, torch_required: Optional[bool] = None, chronos_required: Optional[bool] = None, + timesfm_required: Optional[bool] = None, prophet_required: Optional[bool] = None, wandb_required: Optional[bool] = None, classification_required: Optional[bool] = None, @@ -131,6 +147,11 @@ def __init__( # noqa: D107 _is_chronos_available, "etna[chronos] is not available, to install it, run `pip install etna[chronos]`.", ) + self.timesfm_required: bool = _get_optional_value( + timesfm_required, + _is_timesfm_available, + "etna[timesfm] is not available, to install it, run `pip install etna[timesfm]`.", + ) self.wandb_required: bool = _get_optional_value( wandb_required, _is_wandb_available, "wandb is not available, to install it, " "run `pip install wandb`." ) diff --git a/examples/202-NN_examples.ipynb b/examples/202-NN_examples.ipynb index 29bdeee07..d6a53d551 100644 --- a/examples/202-NN_examples.ipynb +++ b/examples/202-NN_examples.ipynb @@ -33,7 +33,8 @@ " * [N-BEATS Model](#section_3_9)\n", " * [PatchTS Model](#section_3_10)\n", " * [Chronos Model](#section_3_11)\n", - " * [Chronos Bolt Model](#section_3_12)" + " * [Chronos Bolt Model](#section_3_12)\n", + " * [TimesFM Model](#section_3_13)" ] }, { @@ -43,7 +44,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install \"etna[torch,chronos]\" -q" + "!pip install \"etna[torch,chronos,timesfm]\" -q" ] }, { @@ -4717,15 +4718,15 @@ "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.2s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.3s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.5s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.5s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.3s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.3s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s finished\n" ] } ], @@ -4760,27 +4761,6 @@ "print(f\"Average SMAPE for Chronos tiny: {score:.3f}\")" ] }, - { - "cell_type": "code", - "execution_count": 88, - "id": "8334cd6a", - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "plot_backtest(forecast_chronos, ts, history_len=20)" - ] - }, { "cell_type": "markdown", "id": "e78ef2ad", @@ -4791,7 +4771,7 @@ }, { "cell_type": "code", - "execution_count": 89, + "execution_count": 88, "id": "cfff335c", "metadata": {}, "outputs": [], @@ -4801,7 +4781,7 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 89, "id": "5b07ab78", "metadata": {}, "outputs": [ @@ -4815,15 +4795,15 @@ "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.7s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.5s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.2s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.2s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.4s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.6s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s finished\n" ] } ], @@ -4841,7 +4821,7 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": 90, "id": "cd7238b2", "metadata": {}, "outputs": [ @@ -4868,7 +4848,7 @@ }, { "cell_type": "code", - "execution_count": 92, + "execution_count": 91, "id": "bc2fde04", "metadata": {}, "outputs": [ @@ -4882,15 +4862,15 @@ "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.4s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.4s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 3.7s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 7.5s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 11.4s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 11.4s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.6s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.3s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.9s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.9s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s finished\n" ] } ], @@ -4908,7 +4888,7 @@ }, { "cell_type": "code", - "execution_count": 93, + "execution_count": 92, "id": "7b5d7d58", "metadata": {}, "outputs": [ @@ -4925,6 +4905,27 @@ "print(f\"Average SMAPE for Chronos small with long context: {score:.3f}\")" ] }, + { + "cell_type": "code", + "execution_count": 93, + "id": "041c767b", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_backtest(forecast_chronos, ts, history_len=20)" + ] + }, { "cell_type": "markdown", "id": "6bbea6f5", @@ -4989,19 +4990,19 @@ "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.2s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.3s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.3s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.3s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.3s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.4s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.4s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s finished\n" ] } ], @@ -5035,6 +5036,249 @@ "score = metrics_chronos_bolt[\"SMAPE\"].mean()\n", "print(f\"Average SMAPE for Chronos Bolt small with long context: {score:.3f}\")" ] + }, + { + "cell_type": "markdown", + "id": "d438f154", + "metadata": {}, + "source": [ + "### 3.13 TimesFm Model \n", + "\n", + "`TimesFMModel` is one more pretrained model for zero-shot forecasting. It has similar interface to `ChronosBoltModel` and `ChronosModel`." + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "id": "d2d30ce1", + "metadata": {}, + "outputs": [], + "source": [ + "from etna.models.nn import TimesFMModel" + ] + }, + { + "cell_type": "markdown", + "id": "7978edc4", + "metadata": {}, + "source": [ + "Now only one model is available." + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "id": "215231c6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['google/timesfm-1.0-200m-pytorch']" + ] + }, + "execution_count": 99, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "TimesFMModel.list_models()" + ] + }, + { + "cell_type": "markdown", + "id": "32806e28", + "metadata": {}, + "source": [ + "Be careful. `encoder_length` needs to be a multiplier of 32." + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "id": "cf12e1ca", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "31c00b3df9ce47b9bf5a43dfdb29b79d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 3 files: 0%| | 0/3 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_backtest(forecast_timesfm, ts, history_len=20)" + ] + }, + { + "cell_type": "markdown", + "id": "6baa3066", + "metadata": {}, + "source": [ + "Model can work with exogenous features." + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "id": "fcc1d4c4", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "915baa91dd90481b93504302699e6d67", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 3 files: 0%| | 0/3 [00:00=0.5.2)", "pipreqs", "requirementslib" plugins = ["setuptools"] requirements-deprecated-finder = ["pip-api", "pipreqs"] +[[package]] +name = "jax" +version = "0.4.13" +description = "Differentiate, compile, and transform Numpy code." +optional = true +python-versions = ">=3.8" +files = [ + {file = "jax-0.4.13.tar.gz", hash = "sha256:03bfe6749dfe647f16f15f6616638adae6c4a7ca7167c75c21961ecfd3a3baaa"}, +] + +[package.dependencies] +importlib_metadata = {version = ">=4.6", markers = "python_version < \"3.10\""} +ml_dtypes = ">=0.1.0" +numpy = ">=1.21" +opt_einsum = "*" +scipy = ">=1.7" + +[package.extras] +australis = ["protobuf (>=3.13,<4)"] +ci = ["jaxlib (==0.4.12)"] +cpu = ["jaxlib (==0.4.13)"] +cuda = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-cudnn86 = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-local = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-pip = ["jaxlib (==0.4.13+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-local = ["jaxlib (==0.4.13+cuda12.cudnn89)"] +cuda12-pip = ["jaxlib (==0.4.13+cuda12.cudnn89)", "nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] +minimum-jaxlib = ["jaxlib (==0.4.11)"] +tpu = ["jaxlib (==0.4.13)", "libtpu-nightly (==0.1.dev20230622)"] + +[[package]] +name = "jaxlib" +version = "0.4.13" +description = "XLA library for JAX" +optional = true +python-versions = ">=3.8" +files = [ + {file = "jaxlib-0.4.13-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:532ebc4fb11386282ad63b83941d4557f4038c1144acf026f1f8565f64c7e9c0"}, + {file = "jaxlib-0.4.13-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a259bb35429bfbd3b76e43019dfc8f7d6ea94bb217400b78f7d0824ce07a58ac"}, + {file = "jaxlib-0.4.13-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:ea1bc9811ef7d73a15e3213115e88fe7f5d14b59d95027bea9fccc98e5a14af8"}, + {file = "jaxlib-0.4.13-cp310-cp310-win_amd64.whl", hash = "sha256:fde66a93e9be89d99e5792f677ed8e319667d6b2396865b1c52c1312844c47f9"}, + {file = "jaxlib-0.4.13-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:49690fcdd26560515fd15399fc3a44777e0bfc5db5c48fe76ff7bc7228e8b2fb"}, + {file = "jaxlib-0.4.13-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f4e9e34e5d8a6556f62fead14aee0b1614c2c6296f0078d8e6139d6aff109649"}, + {file = "jaxlib-0.4.13-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:8000c0d15c107328e8f7b7b3ac91dd822f5c287a80231882b620503ed141fa89"}, + {file = "jaxlib-0.4.13-cp311-cp311-win_amd64.whl", hash = "sha256:19ae4c316b17a49342432c69f7f89f190b975333f3f9e9e175f686a651bc7347"}, + {file = "jaxlib-0.4.13-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:522635d5e159401a386c79f1236c218c1f68fbb4ca6648115c3ad3c2c3f518ab"}, + {file = "jaxlib-0.4.13-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:411334d903df07dc1ace8d52fc53c17f6bc1d55aff7f6e0e5cf61ec149f758a0"}, + {file = "jaxlib-0.4.13-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:839173b2e9593f5e9a6d3c42852cd15070fe80a939246efbb5cf40eec815de89"}, + {file = "jaxlib-0.4.13-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:c230ef85712e608d0f048869766a5a63afeb2e72309943db0df9f959ab17307f"}, + {file = "jaxlib-0.4.13-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d19c05c15f962e098d49b45e2758aacf19330d192ec5395f9ef136f62db90edc"}, + {file = "jaxlib-0.4.13-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:b5c0a9737efd95fe18fd7715ce30dfce476546705ea8934aad6731777a9631a5"}, + {file = "jaxlib-0.4.13-cp39-cp39-win_amd64.whl", hash = "sha256:bebb4cf001f180dc431f9604daf930c2d9cc778e4dda26f401ac939b7bac912e"}, +] + +[package.dependencies] +ml-dtypes = ">=0.1.0" +numpy = ">=1.21" +scipy = ">=1.7" + +[package.extras] +cuda11-pip = ["nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-pip = ["nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] + [[package]] name = "jedi" version = "0.18.2" @@ -2673,6 +2736,41 @@ files = [ {file = "mistune-2.0.5.tar.gz", hash = "sha256:0246113cb2492db875c6be56974a7c893333bf26cd92891c85f63151cee09d34"}, ] +[[package]] +name = "ml-dtypes" +version = "0.2.0" +description = "" +optional = true +python-versions = ">=3.7" +files = [ + {file = "ml_dtypes-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:df6a76e1c8adf484feb138ed323f9f40a7b6c21788f120f7c78bec20ac37ee81"}, + {file = "ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc29a0524ef5e23a7fbb8d881bdecabeb3fc1d19d9db61785d077a86cb94fab2"}, + {file = "ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f08c391c2794f2aad358e6f4c70785a9a7b1df980ef4c232b3ccd4f6fe39f719"}, + {file = "ml_dtypes-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:75015818a7fccf99a5e8ed18720cb430f3e71a8838388840f4cdf225c036c983"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e70047ec2c83eaee01afdfdabee2c5b0c133804d90d0f7db4dd903360fcc537c"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36d28b8861a8931695e5a31176cad5ae85f6504906650dea5598fbec06c94606"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e85ba8e24cf48d456e564688e981cf379d4c8e644db0a2f719b78de281bac2ca"}, + {file = "ml_dtypes-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:832a019a1b6db5c4422032ca9940a990fa104eee420f643713241b3a518977fa"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8faaf0897942c8253dd126662776ba45f0a5861968cf0f06d6d465f8a7bc298a"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35b984cddbe8173b545a0e3334fe56ea1a5c3eb67c507f60d0cfde1d3fa8f8c2"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:022d5a4ee6be14569c2a9d1549e16f1ec87ca949681d0dca59995445d5fcdd5b"}, + {file = "ml_dtypes-0.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:50845af3e9a601810751b55091dee6c2562403fa1cb4e0123675cf3a4fc2c17a"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f00c71c8c63e03aff313bc6a7aeaac9a4f1483a921a6ffefa6d4404efd1af3d0"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80d304c836d73f10605c58ccf7789c171cc229bfb678748adfb7cea2510dfd0e"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32107e7fa9f62db9a5281de923861325211dfff87bd23faefb27b303314635ab"}, + {file = "ml_dtypes-0.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:1749b60348da71fd3c2ab303fdbc1965958dc50775ead41f5669c932a341cafd"}, + {file = "ml_dtypes-0.2.0.tar.gz", hash = "sha256:6488eb642acaaf08d8020f6de0a38acee7ac324c1e6e92ee0c0fea42422cb797"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.21.2", markers = "python_version > \"3.9\""}, + {version = ">1.20", markers = "python_version <= \"3.9\""}, +] + +[package.extras] +dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] + [[package]] name = "msgpack" version = "1.0.5" @@ -3215,6 +3313,17 @@ files = [ antlr4-python3-runtime = "==4.9.*" PyYAML = ">=5.1.0" +[[package]] +name = "opt-einsum" +version = "3.4.0" +description = "Path optimization of einsum functions." +optional = true +python-versions = ">=3.8" +files = [ + {file = "opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd"}, + {file = "opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac"}, +] + [[package]] name = "optuna" version = "2.10.1" @@ -6031,6 +6140,27 @@ brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +[[package]] +name = "utilsforecast" +version = "0.2.10" +description = "Forecasting utilities" +optional = true +python-versions = ">=3.8" +files = [ + {file = "utilsforecast-0.2.10-py3-none-any.whl", hash = "sha256:ee7860f18a6df5dd695b51e603f3866a00dd0da0eb6c12d07f052e54390cf1a7"}, + {file = "utilsforecast-0.2.10.tar.gz", hash = "sha256:6058ca1a00b7e9dc02346a071a8e3f4dabe2a01f6f6b5a563c6c849754a86d3e"}, +] + +[package.dependencies] +numpy = "*" +packaging = "*" +pandas = ">=1.1.1" + +[package.extras] +dev = ["black", "datasetsforecast (==0.0.8)", "nbdev (<2.3.26)", "numba (>=0.58.0)", "pandas[plot]", "plotly", "plotly-resampler", "polars[numpy]", "pyarrow", "scipy"] +plotting = ["pandas[plot]", "plotly", "plotly-resampler"] +polars = ["polars[numpy]"] + [[package]] name = "wandb" version = "0.12.21" @@ -6348,8 +6478,8 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] [extras] -all = ["accelerate", "einops", "huggingface-hub", "optuna", "prophet", "pytorch-forecasting", "pytorch-lightning", "pyts", "sqlalchemy", "statsforecast", "torch", "transformers", "tsfresh", "wandb"] -all-dev = ["GitPython", "Sphinx", "accelerate", "black", "click", "click", "codespell", "einops", "flake8", "flake8-bugbear", "flake8-comprehensions", "flake8-docstrings", "huggingface-hub", "ipywidgets", "isort", "jupyter", "mypy", "myst-parser", "nbconvert", "nbqa", "nbsphinx", "optuna", "pep8-naming", "prophet", "pydata-sphinx-theme", "pytest", "pytest-cov", "pytest-shard", "pytorch-forecasting", "pytorch-lightning", "pyts", "semver", "semver", "sphinx-design", "sphinx-mathjax-offline", "sqlalchemy", "statsforecast", "torch", "transformers", "tsfresh", "types-PyYAML", "types-setuptools", "wandb"] +all = ["accelerate", "einops", "huggingface-hub", "jax", "jaxlib", "optuna", "prophet", "pytorch-forecasting", "pytorch-lightning", "pyts", "sqlalchemy", "statsforecast", "torch", "transformers", "tsfresh", "utilsforecast", "wandb"] +all-dev = ["GitPython", "Sphinx", "accelerate", "black", "click", "click", "codespell", "einops", "flake8", "flake8-bugbear", "flake8-comprehensions", "flake8-docstrings", "huggingface-hub", "ipywidgets", "isort", "jax", "jaxlib", "jupyter", "mypy", "myst-parser", "nbconvert", "nbqa", "nbsphinx", "optuna", "pep8-naming", "prophet", "pydata-sphinx-theme", "pytest", "pytest-cov", "pytest-shard", "pytorch-forecasting", "pytorch-lightning", "pyts", "semver", "semver", "sphinx-design", "sphinx-mathjax-offline", "sqlalchemy", "statsforecast", "torch", "transformers", "tsfresh", "types-PyYAML", "types-setuptools", "utilsforecast", "wandb"] auto = ["optuna", "sqlalchemy"] chronos = ["accelerate", "huggingface-hub", "torch", "transformers"] classification = ["pyts", "tsfresh"] @@ -6360,10 +6490,11 @@ release = ["click", "semver"] statsforecast = ["statsforecast"] style = ["black", "codespell", "flake8", "flake8-bugbear", "flake8-comprehensions", "flake8-docstrings", "isort", "mypy", "nbqa", "pep8-naming", "types-PyYAML", "types-setuptools"] tests = ["pytest", "pytest-cov", "pytest-shard"] +timesfm = ["huggingface-hub", "jax", "jaxlib", "torch", "utilsforecast"] torch = ["einops", "pytorch-forecasting", "pytorch-lightning", "torch"] wandb = ["wandb"] [metadata] lock-version = "2.0" python-versions = ">=3.8.0, <3.11.0" -content-hash = "e77d4814a86e46ab4009fd8d0ba2900a3099263a8aad209befbdd764447e4691" +content-hash = "560122b4e3d4cbbe17fa501f1eb961c3a9605cf52a40b1f9df7129d3ab58f5cf" diff --git a/pyproject.toml b/pyproject.toml index c00f004ac..2e4cd2239 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,10 @@ transformers = {version = "<5", optional = true} accelerate = {version = "<1", optional = true} huggingface-hub = {version = ">=0.21,<0.23", optional = true} +jax = {version = "<1", optional = true} +jaxlib = {version = "<1", optional = true} +utilsforecast = {version = ">=0.1.10,<1", optional = true} + sphinx-mathjax-offline = {version = "^0.0.2", optional = true} nbsphinx = {version = "^0.9.0", optional = true} Sphinx = {version = "^6.2", optional = true} @@ -132,6 +136,7 @@ auto = ["optuna", "sqlalchemy"] classification = ["pyts", "tsfresh"] statsforecast = ["statsforecast"] chronos = ["torch", "transformers", "accelerate", "huggingface-hub"] +timesfm = ["torch", "jax", "jaxlib", "huggingface-hub", "utilsforecast"] # dev deps release = ["click", "semver"] @@ -147,7 +152,8 @@ all = [ "optuna", "sqlalchemy", "pyts", "tsfresh", "statsforecast", - "transformers", "accelerate", "huggingface-hub" + "transformers", "accelerate", "huggingface-hub", + "jax", "jaxlib", "utilsforecast" ] all-dev = [ @@ -163,7 +169,8 @@ all-dev = [ "jupyter", "nbconvert", "ipywidgets", "pyts", "tsfresh", "statsforecast", - "transformers", "accelerate", "huggingface-hub" + "transformers", "accelerate", "huggingface-hub", + "jax", "jaxlib", "utilsforecast" ] [tool.poetry.scripts] diff --git a/tests/test_models/test_inference/test_forecast.py b/tests/test_models/test_inference/test_forecast.py index bc3b9da02..7b7557c9b 100644 --- a/tests/test_models/test_inference/test_forecast.py +++ b/tests/test_models/test_inference/test_forecast.py @@ -245,6 +245,22 @@ def test_forecast_in_sample_full_no_target_failed_chronos(self, model, transform with pytest.raises(ValueError, match="Dataset doesn't have any context timestamps."): self._test_forecast_in_sample_full_no_target(ts, model, transforms) + @pytest.mark.parametrize( + "model, transforms, dataset_name", + [ + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), + ], + ) + def test_forecast_in_sample_full_no_target_failed_timesfm(self, model, transforms, dataset_name, request): + ts = request.getfixturevalue(dataset_name) + with pytest.raises(ValueError, match="Dataset doesn't have any context timestamps."): + self._test_forecast_in_sample_full_no_target(ts, model, transforms) + class TestForecastInSampleFull: """Test forecast on full train dataset. @@ -402,7 +418,23 @@ def test_forecast_in_sample_full_not_implemented(self, model, transforms, datase (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), ], ) - def test_forecast_in_sample_full_no_target_failed_chronos(self, model, transforms, dataset_name, request): + def test_forecast_in_sample_full_failed_chronos(self, model, transforms, dataset_name, request): + ts = request.getfixturevalue(dataset_name) + with pytest.raises(ValueError, match="Dataset doesn't have any context timestamps."): + _test_prediction_in_sample_full(ts, model, transforms, method_name="forecast") + + @pytest.mark.parametrize( + "model, transforms, dataset_name", + [ + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), + ], + ) + def test_forecast_in_sample_full_failed_timesfm(self, model, transforms, dataset_name, request): ts = request.getfixturevalue(dataset_name) with pytest.raises(ValueError, match="Dataset doesn't have any context timestamps."): _test_prediction_in_sample_full(ts, model, transforms, method_name="forecast") @@ -493,6 +525,12 @@ def _test_forecast_in_sample_suffix_no_target(ts, model, transforms, num_skip_po (NBeatsGenericModel(input_size=7, output_size=50, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_in_sample_suffix_no_target(self, model, transforms, dataset_name, request): @@ -617,6 +655,12 @@ class TestForecastInSampleSuffix: (NBeatsGenericModel(input_size=7, output_size=50, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_in_sample_suffix(self, model, transforms, dataset_name, request): @@ -792,6 +836,12 @@ def _test_forecast_out_sample(ts, model, transforms, prediction_size=5): (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_out_sample_datetime_timestamp(self, model, transforms, dataset_name, request): @@ -891,6 +941,23 @@ def test_forecast_out_sample_int_timestamp(self, model, transforms, dataset_name ts_int_timestamp = convert_ts_to_int_timestamp(ts, shift=10) self._test_forecast_out_sample(ts_int_timestamp, model, transforms) + @pytest.mark.parametrize( + "model, transforms, dataset_name", + [ + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), + ], + ) + def test_forecast_out_sample_int_timestamp_failed_timesfm(self, model, transforms, dataset_name, request): + ts = request.getfixturevalue(dataset_name) + ts_int_timestamp = convert_ts_to_int_timestamp(ts, shift=10) + with pytest.raises(NotImplementedError, match="Data with None frequency isn't currently implemented!"): + self._test_forecast_out_sample(ts_int_timestamp, model, transforms) + @pytest.mark.parametrize( "model, transforms, dataset_name", [ @@ -1047,6 +1114,12 @@ def _test_forecast_out_sample_prefix(ts, model, transforms, full_prediction_size (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_out_sample_prefix(self, model, transforms, dataset_name, request): @@ -1255,6 +1328,23 @@ def test_forecast_out_sample_suffix_failed_chronos(self, model, transforms, data with pytest.raises(AssertionError): self._test_forecast_out_sample_suffix(ts, model, transforms) + @pytest.mark.parametrize( + "model, transforms, dataset_name", + [ + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), + ], + ) + def test_forecast_out_sample_suffix_failed_timesfm(self, model, transforms, dataset_name, request): + """This test is expected to fail due to patch strategy in TimesFM""" + ts = request.getfixturevalue(dataset_name) + with pytest.raises(AssertionError): + self._test_forecast_out_sample_suffix(ts, model, transforms) + @to_be_fixed( raises=NotImplementedError, match="This model can't make forecast on out-of-sample data that goes after training data with a gap", @@ -1395,6 +1485,12 @@ def _test_forecast_mixed_in_out_sample(ts, model, transforms, num_skip_points=50 (NBeatsGenericModel(input_size=7, output_size=55, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_mixed_in_out_sample(self, model, transforms, dataset_name, request): @@ -1557,6 +1653,12 @@ def _test_forecast_subset_segments(self, ts, model, transforms, segments, predic ), (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_subset_segments(self, model, transforms, dataset_name, request): @@ -1734,6 +1836,12 @@ def _test_forecast_new_segments(self, ts, model, transforms, train_segments, pre (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_new_segments(self, model, transforms, dataset_name, request): diff --git a/tests/test_models/test_inference/test_predict.py b/tests/test_models/test_inference/test_predict.py index f53019df5..2c8d433dc 100644 --- a/tests/test_models/test_inference/test_predict.py +++ b/tests/test_models/test_inference/test_predict.py @@ -195,6 +195,12 @@ def test_predict_in_sample_full_failed_not_enough_context(self, model, transform (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_in_sample_full_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -326,6 +332,12 @@ def test_predict_in_sample_suffix_datetime_timestamp(self, model, transforms, da (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_in_sample_suffix_datetime_timestamp_failed_not_implemented_predict( @@ -472,6 +484,12 @@ def test_predict_in_sample_suffix_int_timestamp_failed(self, model, transforms, (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_in_sample_suffix_int_timestamp_failed_not_implemented_predict( @@ -614,6 +632,12 @@ def test_predict_out_sample(self, model, transforms, dataset_name, request): (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_out_sample_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -773,6 +797,12 @@ def test_predict_out_sample_prefix(self, model, transforms, dataset_name, reques (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_out_sample_prefix_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -946,6 +976,12 @@ def test_predict_out_sample_suffix(self, model, transforms, dataset_name, reques (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_out_sample_suffix_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -1124,6 +1160,12 @@ def test_predict_mixed_in_out_sample(self, model, transforms, dataset_name, requ (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_mixed_in_out_sample_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -1290,6 +1332,12 @@ def test_predict_subset_segments(self, model, transforms, dataset_name, request) (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_subset_segments_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -1421,6 +1469,12 @@ def test_predict_new_segments(self, model, transforms, dataset_name, request): (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_new_segments_failed_not_implemented_predict(self, model, transforms, dataset_name, request): diff --git a/tests/test_models/test_nn/test_timesfm.py b/tests/test_models/test_nn/test_timesfm.py new file mode 100644 index 000000000..3778c432b --- /dev/null +++ b/tests/test_models/test_nn/test_timesfm.py @@ -0,0 +1,288 @@ +import os +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +from pandas.testing import assert_frame_equal + +from etna.datasets import TSDataset +from etna.datasets import generate_ar_df +from etna.libs.timesfm import TimesFmTorch +from etna.models.nn import TimesFMModel +from etna.pipeline import Pipeline +from etna.transforms import DateFlagsTransform +from etna.transforms import LagTransform +from etna.transforms import SegmentEncoderTransform + + +def generate_increasing_df(): + n = 128 + df = generate_ar_df(start_time="2001-01-01", periods=n, n_segments=2) + df["target"] = list(range(n)) + list(range(100, 100 + n)) + return df + + +def generate_exog(): + n = 128 + df_exog = generate_ar_df(start_time="2001-01-01", periods=n + 2, n_segments=2) + df_exog.rename(columns={"target": "exog"}, inplace=True) + return df_exog + + +@pytest.fixture +def ts_increasing_integers(): + df = generate_increasing_df() + ts = TSDataset(df, freq="D") + return ts + + +@pytest.fixture +def ts_nan_start(): + df = generate_increasing_df() + df.loc[0, "target"] = np.NaN + ts = TSDataset(df, freq="D") + return ts + + +@pytest.fixture +def ts_nan_middle(): + df = generate_increasing_df() + df.loc[120, "target"] = np.NaN + ts = TSDataset(df, freq="D") + return ts + + +@pytest.fixture +def expected_ts_increasing_integers(): + df = generate_ar_df(start_time="2001-05-09", periods=2, n_segments=2) + df["target"] = [128.0, 129.0] + [228.0, 229.0] + ts = TSDataset(df, freq="D") + return ts + + +@pytest.fixture +def ts_exog_middle_nan(): + df = generate_increasing_df() + df_exog = generate_exog() + df_exog.loc[120, "exog"] = np.NaN + ts = TSDataset(df, df_exog=df_exog, freq="D", known_future="all") + return ts + + +@pytest.fixture +def ts_exog_all_nan(): + df = generate_increasing_df() + df_exog = generate_exog() + df_exog["exog"] = np.NaN + ts = TSDataset(df, df_exog=df_exog, freq="D", known_future="all") + return ts + + +@pytest.mark.smoke +def test_url(tmp_path): + model_name = "timesfm-1.0-200m-pytorch.ckpt" + url = f"http://etna-github-prod.cdn-tinkoff.ru/timesfm/{model_name}" + _ = TimesFMModel(path_or_url=url, cache_dir=tmp_path) + assert os.path.exists(tmp_path / model_name) + + +@pytest.mark.smoke +def test_cache_dir(tmp_path): + path_or_url = "google/timesfm-1.0-200m-pytorch" + model_name = path_or_url.split("/")[-1] + _ = TimesFMModel(path_or_url=path_or_url, cache_dir=tmp_path) + assert os.path.exists(tmp_path / f"models--google--{model_name}") + + +@pytest.mark.smoke +def test_context_size(): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=10) + assert model.context_size == 10 + + +@pytest.mark.smoke +def test_get_model(example_tsds): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") + assert isinstance(model.get_model(), TimesFmTorch) + + +@pytest.mark.smoke +def test_fit(example_tsds): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") + model.fit(example_tsds) + + +@pytest.mark.smoke +def test_predict(example_tsds): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") + with pytest.raises(NotImplementedError, match="Method predict isn't currently implemented!"): + model.predict(ts=example_tsds, prediction_size=1) + + +def test_forecast_warns_big_context_size(ts_increasing_integers): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=512) + pipeline = Pipeline(model=model, horizon=1) + pipeline.fit(ts_increasing_integers) + with pytest.warns(UserWarning, match="Actual length of a dataset is less that context size."): + _ = pipeline.forecast() + + +@pytest.mark.parametrize("encoder_length", [32, 64, 128]) +@pytest.mark.parametrize("ts", ["ts_increasing_integers", "ts_nan_start"]) +def test_forecast(ts, expected_ts_increasing_integers, encoder_length, request): + ts = request.getfixturevalue(ts) + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=encoder_length) + pipeline = Pipeline(model=model, horizon=2) + pipeline.fit(ts) + forecast = pipeline.forecast() + assert_frame_equal(forecast.df, expected_ts_increasing_integers.df, atol=1) + + +def test_forecast_failed_nan_middle_target(ts_nan_middle): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=128) + pipeline = Pipeline(model=model, horizon=2) + pipeline.fit(ts_nan_middle) + with pytest.raises(ValueError, match=r"There are NaNs in the middle or at the end of target. Segments with NaNs:"): + _ = pipeline.forecast() + + +@pytest.mark.parametrize("encoder_length", [32, 64, 128]) +@pytest.mark.parametrize("ts", ["ts_increasing_integers", "ts_nan_start"]) +def test_forecast_exogenous_features(ts, expected_ts_increasing_integers, encoder_length, request): + ts = request.getfixturevalue(ts) + + horizon = 2 + transforms = [ + SegmentEncoderTransform(), + LagTransform(in_column="target", lags=[horizon, horizon + 1], out_column="lag"), + DateFlagsTransform(day_number_in_week=True, day_number_in_month=False, is_weekend=False, out_column="flag"), + ] + model = TimesFMModel( + path_or_url="google/timesfm-1.0-200m-pytorch", + encoder_length=encoder_length, + static_categoricals=["segment_code"], + time_varying_reals=[f"lag_{horizon}", f"lag_{horizon+1}"], + time_varying_categoricals=["flag_day_number_in_week"], + ) + pipeline = Pipeline(model=model, transforms=transforms, horizon=horizon) + pipeline.fit(ts) + forecast = pipeline.forecast() + assert_frame_equal(forecast.df.loc[:, pd.IndexSlice[:, "target"]], expected_ts_increasing_integers.df, atol=1) + + +def test_forecast_exog_features_failed_nan_middle_target(ts_nan_middle): + horizon = 2 + transforms = [ + SegmentEncoderTransform(), + LagTransform(in_column="target", lags=[horizon, horizon + 1], out_column="lag"), + DateFlagsTransform(day_number_in_week=True, day_number_in_month=False, is_weekend=False, out_column="flag"), + ] + model = TimesFMModel( + path_or_url="google/timesfm-1.0-200m-pytorch", + encoder_length=128, + static_categoricals=["segment_code"], + time_varying_reals=[f"lag_{horizon}", f"lag_{horizon+1}"], + time_varying_categoricals=["flag_day_number_in_week"], + ) + pipeline = Pipeline(model=model, transforms=transforms, horizon=horizon) + pipeline.fit(ts_nan_middle) + with pytest.raises(ValueError, match="There are NaNs in the middle or at the end of target. Segments with NaNs:"): + _ = pipeline.forecast() + + +@pytest.mark.parametrize("ts", ["ts_exog_middle_nan", "ts_exog_all_nan"]) +def test_forecast_exog_features_failed_exog_nan(ts, request): + ts = request.getfixturevalue(ts) + + horizon = 2 + model = TimesFMModel( + path_or_url="google/timesfm-1.0-200m-pytorch", + encoder_length=128, + time_varying_reals=["exog"], + ) + pipeline = Pipeline(model=model, transforms=[], horizon=horizon) + pipeline.fit(ts) + with pytest.raises( + ValueError, match="There are NaNs in the middle or at the end of exogenous features. Segments with NaNs:" + ): + _ = pipeline.forecast() + + +@pytest.mark.smoke +def test_forecast_only_target_failed_int_timestamps(example_tsds_int_timestamp): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32) + pipeline = Pipeline(model=model, horizon=1) + pipeline.fit(example_tsds_int_timestamp) + with pytest.raises( + NotImplementedError, + match="Forecasting misaligned data with freq=None without exogenous features isn't currently implemented.", + ): + _ = pipeline.forecast() + + +@pytest.mark.smoke +def test_forecast_exog_int_timestamps(example_tsds_int_timestamp): + horizon = 2 + transforms = [SegmentEncoderTransform(), LagTransform(in_column="target", lags=[horizon], out_column="lag")] + model = TimesFMModel( + path_or_url="google/timesfm-1.0-200m-pytorch", + encoder_length=32, + static_categoricals=["segment_code"], + time_varying_reals=[f"lag_{horizon}"], + ) + pipeline = Pipeline(model=model, transforms=transforms, horizon=horizon) + pipeline.fit(example_tsds_int_timestamp) + with pytest.warns( + UserWarning, + match="Frequency is None. Mapping it to 0, that can be not optimal. Better to set it to known frequency", + ): + _ = pipeline.forecast() + + +@pytest.mark.parametrize("encoder_length", [16, 33]) +def test_forecast_wrong_context_len(ts_increasing_integers, encoder_length): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=encoder_length) + pipeline = Pipeline(model=model, horizon=1) + pipeline.fit(ts_increasing_integers) + with pytest.raises(RuntimeError, match=r"shape .+ is invalid for input of size \d+"): + _ = pipeline.forecast() + + +@pytest.mark.smoke +def test_forecast_without_fit(example_tsds): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32) + pipeline = Pipeline(model=model, horizon=1) + _ = pipeline.forecast(example_tsds) + + +@pytest.mark.smoke +def test_forecast_fails_components(example_tsds): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") + pipeline = Pipeline(model=model, horizon=1) + with pytest.raises(NotImplementedError, match="This mode isn't currently implemented!"): + pipeline.forecast(ts=example_tsds, return_components=True) + + +@pytest.mark.smoke +def test_list_models(): + assert TimesFMModel.list_models() == ["google/timesfm-1.0-200m-pytorch"] + + +@pytest.mark.smoke +def test_save_load(tmp_path, ts_increasing_integers): + path = Path(tmp_path) / "tmp.zip" + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32) + model.save(path) + loaded_model = TimesFMModel.load(path) + + pipeline = Pipeline(model=loaded_model, horizon=1) + pipeline.fit(ts_increasing_integers) + _ = pipeline.forecast() + assert isinstance(loaded_model, TimesFMModel) + + +@pytest.mark.smoke +def test_params_to_tune(): + model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch") + assert len(model.params_to_tune()) == 0