diff --git a/CHANGELOG.md b/CHANGELOG.md
index 729813490..53a902d45 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -13,8 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `ChronosModel` ([#511](https://github.com/etna-team/etna/pull/511))
- Add `ChronosBoltModel` ([#511](https://github.com/etna-team/etna/pull/511))
- Add usage example of `ChronosModel` and `ChronosBoltModel` in `202-NN_examples` notebook ([#511](https://github.com/etna-team/etna/pull/511))
--
--
+- Add `TimesFMModel` ([#544](https://github.com/etna-team/etna/pull/544))
+- Add usage example of `TimesFMModel` in `202-NN_examples` notebook ([#544](https://github.com/etna-team/etna/pull/544))
-
-
- Add `MissingCounter` metric ([#520](https://github.com/etna-team/etna/pull/520))
diff --git a/README.md b/README.md
index baae3b461..64016e1ad 100644
--- a/README.md
+++ b/README.md
@@ -149,7 +149,8 @@ Available user extensions are the following:
* `auto`: adds AutoML functionality,
* `statsforecast`: adds models from [statsforecast](https://nixtla.github.io/statsforecast/),
* `classiciation`: adds time series classification functionality,
-* `chronos`: adds Chronos-like pretrained models.
+* `chronos`: adds Chronos-like pretrained models,
+* `timesfm`: adds TimesFM pretrained models.
Install extension:
```bash
diff --git a/docs/source/api_reference/models.rst b/docs/source/api_reference/models.rst
index daea65539..ef06daee4 100644
--- a/docs/source/api_reference/models.rst
+++ b/docs/source/api_reference/models.rst
@@ -122,4 +122,5 @@ Pretrained neural network models:
:template: class.rst
nn.ChronosModel
- nn.ChronosBoltModel
\ No newline at end of file
+ nn.ChronosBoltModel
+ nn.TimesFMModel
\ No newline at end of file
diff --git a/docs/source/installation.rst b/docs/source/installation.rst
index 86cd18123..90ac3d214 100644
--- a/docs/source/installation.rst
+++ b/docs/source/installation.rst
@@ -24,7 +24,8 @@ Available user extensions are the following:
- ``auto``: adds AutoML functionality,
- ``statsforecast``: adds models from `statsforecast `_,
- ``classiciation``: adds time series classification functionality,
-- ``chronos``: adds Chronos-like pretrained models.
+- ``chronos``: adds Chronos-like pretrained models,
+- ``timesfm``: adds TimesFM pretrained models.
Install extension:
diff --git a/etna/libs/timesfm/__init__.py b/etna/libs/timesfm/__init__.py
new file mode 100644
index 000000000..2346aa8a2
--- /dev/null
+++ b/etna/libs/timesfm/__init__.py
@@ -0,0 +1,155 @@
+"""
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+ 1. Definitions.
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+"""
+
+
+from etna.libs.timesfm.timesfm import TimesFmTorch
+from etna.libs.timesfm.timesfm_base import TimesFmHparams, TimesFmCheckpoint
diff --git a/etna/libs/timesfm/patched_decoder.py b/etna/libs/timesfm/patched_decoder.py
new file mode 100644
index 000000000..53cd5c6c1
--- /dev/null
+++ b/etna/libs/timesfm/patched_decoder.py
@@ -0,0 +1,948 @@
+"""
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+ 1. Definitions.
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+"""
+
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Note: Copied from timesfm repository (https://github.com/google-research/timesfm/blob/154248137ccce29b01f4c3a765e85c3d9e4d92ba/src/timesfm/patched_decoder.py)
+
+"""Pytorch version of patched decoder."""
+
+import dataclasses
+import math
+from typing import List, Tuple, Optional
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+def _create_quantiles() -> List[float]:
+ return [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
+
+
+@dataclasses.dataclass
+class TimesFMConfig:
+ """Config for initializing timesfm patched_decoder class."""
+
+ # The number of blocks in the model.
+ num_layers: int = 20
+ # The number of attention heads used in the attention layers of the model.
+ num_heads: int = 16
+ # The number of key-value heads for implementing attention.
+ num_kv_heads: int = 16
+ # The hidden size of the model.
+ hidden_size: int = 1280
+ # The dimension of the MLP representations.
+ intermediate_size: int = 1280
+ # The number of head dimensions.
+ head_dim: int = 80
+ # The epsilon used by the rms normalization layers.
+ rms_norm_eps: float = 1e-6
+ # Patch length
+ patch_len: int = 32
+ # Horizon length
+ horizon_len: int = 128
+ # quantiles
+ quantiles: List[float] = dataclasses.field(default_factory=_create_quantiles)
+ # Padding value
+ pad_val: float = 1123581321.0
+ # Tolerance
+ tolerance: float = 1e-6
+ # The dtype of the weights.
+ dtype: str = "bfloat32"
+ # use positional embedding
+ use_positional_embedding: bool = True
+
+
+def _masked_mean_std(
+ inputs: torch.Tensor,
+ padding: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Calculates mean and standard deviation of `inputs` across axis 1.
+
+ It excludes values where `padding` is 1.
+
+ Args:
+ inputs: A PyTorch tensor of shape [b, n, p].
+ padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.
+
+ Returns:
+ A tuple containing the mean and standard deviation.
+ We return the statistics of the first patch with more than three non-padded
+ values.
+ """
+ # Selecting the first patch with more than 3 unpadded values.
+ pad_sum = torch.sum(1 - padding, dim=2)
+
+ def _get_patch_index(arr: torch.Tensor):
+ indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)
+ row_sum = (arr >= 3).to(torch.int32).sum(dim=1)
+ return torch.where(row_sum == 0, arr.shape[1] - 1, indices)
+
+ patch_indices = _get_patch_index(pad_sum)
+ bidxs = torch.arange(inputs.shape[0])
+
+ arr = inputs[bidxs, patch_indices, :]
+ pad = padding[bidxs, patch_indices, :]
+
+ # Create a mask where padding is 0
+ mask = 1 - pad
+
+ # Calculate the number of valid elements
+ num_valid_elements = torch.sum(mask, dim=1)
+ num_valid_elements = torch.where(
+ num_valid_elements == 0,
+ torch.tensor(1,
+ dtype=num_valid_elements.dtype,
+ device=num_valid_elements.device),
+ num_valid_elements,
+ )
+
+ # Calculate the masked sum and squared sum
+ masked_sum = torch.sum(arr * mask, dim=1)
+ masked_squared_sum = torch.sum((arr * mask)**2, dim=1)
+
+ # Calculate the masked mean and standard deviation
+ masked_mean = masked_sum / num_valid_elements
+ masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
+ masked_var = torch.where(
+ masked_var < 0.0,
+ torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device),
+ masked_var,
+ )
+ masked_std = torch.sqrt(masked_var)
+
+ return masked_mean, masked_std
+
+
+def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:
+ """Shifts rows of seq based on the first 0 in each row of the mask.
+
+ Args:
+ mask: mask tensor of shape [B, N]
+ seq: seq tensor of shape [B, N, P]
+
+ Returns:
+ Returns the shifted sequence.
+ """
+ batch_size, num_seq, feature_dim = seq.shape
+
+ new_mask: torch.BoolTensor = mask == 0
+
+ # Use argmax to find the first True value in each row
+ indices = new_mask.to(torch.int32).argmax(dim=1)
+
+ # Handle rows with all zeros
+ indices[~new_mask.any(dim=1)] = -1
+
+ # Create index ranges for each sequence in the batch
+ idx_range = (torch.arange(num_seq).to(
+ seq.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1,
+ feature_dim))
+
+ # Calculate shifted indices for each element in each sequence
+ shifted_idx = (idx_range - indices[:, None, None]) % num_seq
+
+ # Gather values from seq using shifted indices
+ shifted_seq = seq.gather(1, shifted_idx)
+
+ return shifted_seq
+
+
+def get_large_negative_number(dtype: torch.dtype) -> torch.Tensor:
+ """Returns a large negative value for the given dtype."""
+ if dtype.is_floating_point:
+ dtype_max = torch.finfo(dtype).max
+ else:
+ dtype_max = torch.iinfo(dtype).max
+ return torch.tensor(-0.7 * dtype_max, dtype=dtype)
+
+
+def apply_mask_to_logits(logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ """Applies a floating-point mask to a set of logits.
+
+ Args:
+ logits: A torch.Tensor of logit values.
+ mask: A torch.Tensor (float32) of mask values with the encoding described
+ in the function documentation.
+
+ Returns:
+ Masked logits.
+ """
+
+ min_value = get_large_negative_number(logits.dtype)
+
+ return torch.where((mask >= min_value * 0.5), logits, min_value)
+
+
+def convert_paddings_to_mask(paddings: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
+ """Converts binary paddings to a logit mask ready to add to attention matrix.
+
+ Args:
+ paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding
+ token.
+ dtype: data type of the input.
+
+ Returns:
+ A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits.
+ """
+ attention_mask = paddings.detach().clone()
+ attention_mask = attention_mask[:, None, None, :] # Equivalent to jnp.newaxis
+ attention_mask *= get_large_negative_number(dtype)
+ return attention_mask
+
+
+def causal_mask(input_t: torch.Tensor) -> torch.Tensor:
+ """Computes and returns causal mask.
+
+ Args:
+ input_t: A torch.Tensor of shape [B, T, D].
+
+ Returns:
+ An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has
+ already been converted to large negative values.
+ """
+ assert input_t.dtype.is_floating_point, input_t.dtype
+ large_negative_number = get_large_negative_number(input_t.dtype)
+ t = input_t.shape[1]
+ col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1)
+ row_idx = torch.arange(t).unsqueeze(1).repeat(1, t)
+ mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number
+ return mask.unsqueeze(0).unsqueeze(0).to(input_t.device) # Equivalent to jnp.newaxis
+
+
+def merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ """Merges 2 masks.
+
+ logscale mask is expected but 0/1 mask is also fine.
+
+ Args:
+ a: torch.Tensor of shape [1|B, 1, 1|T, S].
+ b: torch.Tensor of shape [1|B, 1, 1|T, S].
+
+ Returns:
+ torch.Tensor of shape [1|B, 1, 1|T, S].
+ """
+
+ def expand_t(key_mask):
+ query_mask = key_mask.transpose(-1, -2) # Equivalent of jnp.transpose
+ return torch.minimum(query_mask, key_mask)
+
+ if a.shape[2] != b.shape[2]:
+ if a.shape[2] == 1:
+ a = expand_t(a)
+ else:
+ assert b.shape[2] == 1
+ b = expand_t(b)
+
+ assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}."
+ return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum
+
+
+class ResidualBlock(nn.Module):
+ """TimesFM residual block."""
+
+ def __init__(
+ self,
+ input_dims,
+ hidden_dims,
+ output_dims,
+ ):
+ super(ResidualBlock, self).__init__()
+ self.input_dims = input_dims
+ self.hidden_dims = hidden_dims
+ self.output_dims = output_dims
+
+ # Hidden Layer
+ self.hidden_layer = nn.Sequential(
+ nn.Linear(input_dims, hidden_dims),
+ nn.SiLU(),
+ )
+
+ # Output Layer
+ self.output_layer = nn.Linear(hidden_dims, output_dims)
+ # Residual Layer
+ self.residual_layer = nn.Linear(input_dims, output_dims)
+
+ def forward(self, x):
+ hidden = self.hidden_layer(x)
+ output = self.output_layer(hidden)
+ residual = self.residual_layer(x)
+ return output + residual
+
+
+class RMSNorm(torch.nn.Module):
+ """Pax rms norm in pytorch."""
+
+ def __init__(
+ self,
+ dim: int,
+ eps: float = 1e-6,
+ add_unit_offset: bool = False,
+ ):
+ super().__init__()
+ self.eps = eps
+ self.add_unit_offset = add_unit_offset
+ self.weight = nn.Parameter(torch.zeros(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ output = self._norm(x.float())
+ if self.add_unit_offset:
+ output = output * (1 + self.weight.float())
+ else:
+ output = output * self.weight.float()
+ return output.type_as(x)
+
+
+class TransformerMLP(nn.Module):
+ """Pax transformer MLP in pytorch."""
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ ):
+ super().__init__()
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size)
+ self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
+
+ def forward(self, x, paddings=None):
+ gate_inp = self.layer_norm(x)
+ gate = self.gate_proj(gate_inp)
+ gate = F.relu(gate)
+ outputs = self.down_proj(gate)
+ if paddings is not None:
+ outputs = outputs * (1.0 - paddings[:, :, None])
+ return outputs + x
+
+
+class TimesFMAttention(nn.Module):
+ """Implements the attention used in TimesFM."""
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ head_dim: int,
+ ):
+ super().__init__()
+
+ self.num_heads = num_heads
+ self.num_kv_heads = num_kv_heads
+
+ assert self.num_heads % self.num_kv_heads == 0
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+
+ self.hidden_size = hidden_size
+ self.head_dim = head_dim
+
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = nn.Parameter(
+ torch.empty((self.head_dim,), dtype=torch.float32),)
+
+ self.qkv_proj = nn.Linear(
+ self.hidden_size,
+ (self.num_heads + 2 * self.num_kv_heads) * self.head_dim,
+ )
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
+
+ def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor:
+ # [batch_size, n_local_heads, input_len, head_dim]
+ r_softplus_0 = 1.442695041
+ softplus_func = torch.nn.Softplus()
+ scale = r_softplus_0 / math.sqrt(self.head_dim)
+ scale = scale * softplus_func(self.scaling)
+ return query * scale[None, None, None, :]
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ mask: torch.Tensor,
+ kv_write_indices: Optional[torch.Tensor] = None,
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ hidden_states_shape = hidden_states.shape
+ assert len(hidden_states_shape) == 3
+
+ batch_size, input_len, _ = hidden_states_shape
+
+ qkv = self.qkv_proj(hidden_states)
+ xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+
+ xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
+ xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
+ xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)
+ xq = self._per_dim_scaling(xq)
+
+ # Write new kv cache.
+ # [batch_size, input_len, n_local_kv_heads, head_dim]
+ if kv_cache is not None and kv_write_indices is not None:
+ k_cache, v_cache = kv_cache
+ k_cache.index_copy_(1, kv_write_indices, xk)
+ v_cache.index_copy_(1, kv_write_indices, xv)
+
+ key = k_cache
+ value = v_cache
+ else:
+ key = xk
+ value = xv
+ if self.num_kv_heads != self.num_heads:
+ # [batch_size, max_seq_len, n_local_heads, head_dim]
+ key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)
+ value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2)
+
+ # [batch_size, n_local_heads, input_len, head_dim]
+ q = xq.transpose(1, 2)
+ # [batch_size, n_local_heads, max_seq_len, head_dim]
+ k = key.transpose(1, 2)
+ v = value.transpose(1, 2)
+
+ # [batch_size, n_local_heads, input_len, max_seq_len]
+ scores = torch.matmul(q, k.transpose(2, 3))
+ scores = scores + mask
+ scores = F.softmax(scores.float(), dim=-1).type_as(q)
+
+ # [batch_size, n_local_heads, input_len, head_dim]
+ output = torch.matmul(scores, v)
+ # return scores, output.transpose(1, 2).contiguous()
+
+ # [batch_size, input_len, hidden_dim]
+ output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)
+ output = self.o_proj(output)
+ return scores, output
+
+
+class TimesFMDecoderLayer(nn.Module):
+ """Transformer layer."""
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ head_dim: int,
+ rms_norm_eps: float = 1e-6,
+ ):
+ super().__init__()
+ self.self_attn = TimesFMAttention(
+ hidden_size=hidden_size,
+ num_heads=num_heads,
+ num_kv_heads=num_kv_heads,
+ head_dim=head_dim,
+ )
+ self.mlp = TransformerMLP(
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ )
+ self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ mask: torch.Tensor,
+ paddings: torch.Tensor,
+ kv_write_indices: Optional[torch.Tensor] = None,
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ # Self Attention
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ scores, hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ mask=mask,
+ kv_write_indices=kv_write_indices,
+ kv_cache=kv_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # MLP
+ hidden_states = self.mlp(hidden_states, paddings=paddings)
+
+ return scores, hidden_states
+
+
+class StackedDecoder(nn.Module):
+ """Stacked transformer layer."""
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ head_dim: int,
+ num_layers: int,
+ rms_norm_eps: float = 1e-6,
+ ):
+ super().__init__()
+
+ self.layers = nn.ModuleList()
+ for _ in range(num_layers):
+ self.layers.append(
+ TimesFMDecoderLayer(
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_heads=num_heads,
+ num_kv_heads=num_kv_heads,
+ head_dim=head_dim,
+ rms_norm_eps=rms_norm_eps,
+ ))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ paddings: torch.Tensor,
+ kv_write_indices: Optional[torch.Tensor] = None,
+ kv_caches: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
+ ) -> torch.Tensor:
+ padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype)
+ atten_mask = causal_mask(hidden_states)
+ mask = merge_masks(padding_mask, atten_mask)
+ for i in range(len(self.layers)):
+ layer = self.layers[i]
+ kv_cache = kv_caches[i] if kv_caches is not None else None
+ _, hidden_states = layer(
+ hidden_states=hidden_states,
+ mask=mask,
+ paddings=paddings,
+ kv_write_indices=kv_write_indices,
+ kv_cache=kv_cache,
+ )
+ return hidden_states
+
+
+class PositionalEmbedding(torch.nn.Module):
+ """Generates position embedding for a given 1-d sequence.
+
+ Attributes:
+ min_timescale: Start of the geometric index. Determines the periodicity of
+ the added signal.
+ max_timescale: End of the geometric index. Determines the frequency of the
+ added signal.
+ embedding_dims: Dimension of the embedding to be generated.
+ """
+
+ def __init__(
+ self,
+ embedding_dims: int,
+ min_timescale: int = 1,
+ max_timescale: int = 10_000,
+ ) -> None:
+ super().__init__()
+ self.min_timescale = min_timescale
+ self.max_timescale = max_timescale
+ self.embedding_dims = embedding_dims
+
+ def forward(self, seq_length=None, position=None):
+ """Generates a Tensor of sinusoids with different frequencies.
+
+ Args:
+ seq_length: an optional Python int defining the output sequence length.
+ if the `position` argument is specified.
+ position: [B, seq_length], optional position for each token in the
+ sequence, only required when the sequence is packed.
+
+ Returns:
+ [B, seqlen, D] if `position` is specified, else [1, seqlen, D]
+ """
+ if position is None:
+ assert seq_length is not None
+ # [1, seqlen]
+ position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0)
+ else:
+ assert position.ndim == 2, position.shape
+
+ num_timescales = self.embedding_dims // 2
+ log_timescale_increment = math.log(
+ float(self.max_timescale) / float(self.min_timescale)) / max(
+ num_timescales - 1, 1)
+ inv_timescales = self.min_timescale * torch.exp(
+ torch.arange(num_timescales, dtype=torch.float32) *
+ -log_timescale_increment)
+ scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(
+ 0)
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
+ # Padding to ensure correct embedding dimension
+ signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))
+ return signal
+
+
+class PatchedTimeSeriesDecoder(nn.Module):
+ """Patched time-series decoder."""
+
+ def __init__(self, config: TimesFMConfig):
+ super().__init__()
+ self.config = config
+ self.input_ff_layer = ResidualBlock(
+ input_dims=2 * config.patch_len,
+ output_dims=config.hidden_size,
+ hidden_dims=config.intermediate_size,
+ )
+ self.freq_emb = nn.Embedding(num_embeddings=3,
+ embedding_dim=config.hidden_size)
+ self.horizon_ff_layer = ResidualBlock(
+ input_dims=config.hidden_size,
+ output_dims=config.horizon_len * (1 + len(config.quantiles)),
+ hidden_dims=config.intermediate_size,
+ )
+ self.stacked_transformer = StackedDecoder(
+ hidden_size=self.config.hidden_size,
+ intermediate_size=self.config.intermediate_size,
+ num_heads=self.config.num_heads,
+ num_kv_heads=self.config.num_kv_heads,
+ head_dim=self.config.head_dim,
+ num_layers=self.config.num_layers,
+ rms_norm_eps=self.config.rms_norm_eps,
+ )
+ if self.config.use_positional_embedding:
+ self.position_emb = PositionalEmbedding(self.config.hidden_size)
+
+ def _forward_transform(
+ self, inputs: torch.Tensor, patched_pads: torch.Tensor
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Input is of shape [B, N, P]."""
+ mu, sigma = _masked_mean_std(inputs, patched_pads)
+ sigma = torch.where(
+ sigma < self.config.tolerance,
+ torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device),
+ sigma,
+ )
+
+ # Normalize each patch
+ outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
+ outputs = torch.where(
+ torch.abs(inputs - self.config.pad_val) < self.config.tolerance,
+ torch.tensor(self.config.pad_val,
+ dtype=outputs.dtype,
+ device=outputs.device),
+ outputs,
+ )
+ return outputs, (mu, sigma)
+
+ def _reverse_transform(self, outputs: torch.Tensor, stats: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
+ """Output is of shape [B, N, P, Q]."""
+ mu, sigma = stats
+ return outputs * sigma[:, None, None, None] + mu[:, None, None, None]
+
+ def _preprocess_input(
+ self,
+ input_ts: torch.Tensor,
+ input_padding: torch.Tensor,
+ ) -> Tuple[
+ torch.Tensor,
+ torch.Tensor,
+ Optional[Tuple[torch.Tensor, torch.Tensor]],
+ torch.Tensor,
+ ]:
+ """Preprocess input for stacked transformer."""
+
+ # Reshape into patches (using view for efficiency)
+ bsize = input_ts.shape[0]
+ patched_inputs = input_ts.view(bsize, -1, self.config.patch_len)
+ patched_pads = input_padding.view(bsize, -1, self.config.patch_len)
+
+ patched_inputs = torch.where(
+ torch.abs(patched_pads - 1.0) < self.config.tolerance,
+ torch.tensor(0.0,
+ dtype=patched_inputs.dtype,
+ device=patched_inputs.device),
+ patched_inputs,
+ )
+ patched_pads = torch.where(
+ torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,
+ torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),
+ patched_pads,
+ )
+ patched_inputs, stats = self._forward_transform(patched_inputs,
+ patched_pads)
+
+ # B x N x D
+ patched_inputs = patched_inputs * (1.0 - patched_pads)
+ concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)
+ model_input = self.input_ff_layer(concat_inputs)
+
+ # A patch should not be padded even if there is at least one zero.
+ patched_padding = torch.min(patched_pads,
+ dim=-1)[0] # Get the values from the min result
+ if self.config.use_positional_embedding:
+ pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device)
+ pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)
+ pos_emb = _shift_padded_seq(patched_padding, pos_emb)
+ model_input += pos_emb
+
+ return model_input, patched_padding, stats, patched_inputs
+
+ def _postprocess_output(
+ self,
+ model_output: torch.Tensor,
+ num_outputs: int,
+ stats: Tuple[torch.Tensor, torch.Tensor],
+ ) -> torch.Tensor:
+ """Postprocess output of stacked transformer."""
+
+ # B x N x (H.Q)
+ output_ts = self.horizon_ff_layer(model_output)
+
+ # Reshape using view
+ b, n, _ = output_ts.shape
+ output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs)
+
+ return self._reverse_transform(output_ts, stats)
+
+ def forward(
+ self,
+ input_ts: torch.Tensor,
+ input_padding: torch.LongTensor,
+ freq: torch.Tensor,
+ ) -> torch.Tensor:
+ num_outputs = len(self.config.quantiles) + 1
+ model_input, patched_padding, stats, _ = self._preprocess_input(
+ input_ts=input_ts,
+ input_padding=input_padding,
+ )
+ f_emb = self.freq_emb(freq) # B x 1 x D
+ model_input += f_emb
+ model_output = self.stacked_transformer(model_input, patched_padding)
+
+ output_ts = self._postprocess_output(model_output, num_outputs, stats)
+ return output_ts
+
+ def decode(
+ self,
+ input_ts: torch.Tensor,
+ paddings: torch.Tensor,
+ freq: torch.LongTensor,
+ horizon_len: int,
+ output_patch_len: Optional[int] = None,
+ max_len: int = 512,
+ return_forecast_on_context: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Auto-regressive decoding without caching.
+
+ Args:
+ input_ts: input time-series and paddings. Time-series shape B x C.
+ paddings: padding shape B x (C + H) where H is the prediction length.
+ freq: frequency shape B x 1
+ horizon_len: prediction length.
+ output_patch_len: output length to be fetched from one step of
+ auto-regressive decoding.
+ max_len: maximum training context length.
+ return_forecast_on_context: whether to return the model forecast on the
+ context except the first input patch.
+
+ Returns:
+ Tuple of two forecasting results:
+ - Point (mean) output predictions as a tensor with shape B x H'.
+ - Full predictions (mean and quantiles) as a tensor with shape
+ B x H' x (1 + # quantiles).
+ In particular, if return_forecast_on_context is True, H' is H plus
+ the forecastable context length, i.e. context_len - (first) patch_len.
+ """
+ final_out = input_ts
+ context_len = final_out.shape[1]
+ full_outputs = []
+ if paddings.shape[1] != final_out.shape[1] + horizon_len:
+ raise ValueError(
+ "Length of paddings must match length of input + horizon_len:"
+ f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}")
+ if output_patch_len is None:
+ output_patch_len = self.config.horizon_len
+ num_decode_patches = (horizon_len + output_patch_len - 1) // output_patch_len
+ for step_index in range(num_decode_patches):
+ current_padding = paddings[:, 0:final_out.shape[1]]
+ input_ts = final_out[:, -max_len:]
+ input_padding = current_padding[:, -max_len:]
+ fprop_outputs = self(input_ts, input_padding, freq)
+ if return_forecast_on_context and step_index == 0:
+ # For the first decodings step, collect the model forecast on the
+ # context except the unavailable first input batch forecast.
+ new_full_ts = fprop_outputs[:, :-1, :self.config.patch_len, :]
+ new_full_ts = fprop_outputs.view(new_full_ts.size(0), -1,
+ new_full_ts.size(3))
+
+ full_outputs.append(new_full_ts)
+
+ # (full batch, last patch, output_patch_len, index of mean forecast = 0)
+ new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
+ new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
+ # (full batch, last patch, output_patch_len, all output indices)
+ full_outputs.append(new_full_ts)
+ final_out = torch.concat([final_out, new_ts], dim=-1) # TODO torch.concatenate(axis) => torch.concat(dim)
+
+ if return_forecast_on_context:
+ # `full_outputs` indexing starts at after the first input patch.
+ full_outputs = torch.concat( # TODO torch.concatenate(axis) => torch.concat(dim)
+ full_outputs,
+ dim=1)[:, :(context_len - self.config.patch_len + horizon_len), :]
+ else:
+ # `full_outputs` indexing starts at the forecast horizon.
+ full_outputs = torch.concat(full_outputs, dim=1)[:, 0:horizon_len, :] # TODO torch.concatenate(axis) => torch.concat(dim)
+
+ return full_outputs[:, :, 0], full_outputs
diff --git a/etna/libs/timesfm/timesfm.py b/etna/libs/timesfm/timesfm.py
new file mode 100644
index 000000000..d46782e3c
--- /dev/null
+++ b/etna/libs/timesfm/timesfm.py
@@ -0,0 +1,325 @@
+"""
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+ 1. Definitions.
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+"""
+
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Note: Copied from timesfm repository (https://github.com/google-research/timesfm/blob/154248137ccce29b01f4c3a765e85c3d9e4d92ba/src/timesfm/timesfm_torch.py)
+# Add method to change horizon after initialization.
+# Minor logic change of loading model.
+
+"""TimesFM pytorch forecast API for inference."""
+
+import logging
+from os import path
+from typing import Any, Sequence, Optional, Tuple
+import os
+import numpy as np
+import torch
+from huggingface_hub import snapshot_download
+from etna.libs.timesfm import timesfm_base
+
+from etna.libs.timesfm import patched_decoder as ppd
+
+_TOL = 1e-6
+
+
+class TimesFmTorch(timesfm_base.TimesFmBase):
+ """TimesFM forecast API for inference."""
+
+ def __post_init__(self):
+ self._model_config = ppd.TimesFMConfig(
+ num_layers=self.num_layers,
+ num_heads=self.num_heads,
+ hidden_size=self.model_dims,
+ intermediate_size=self.model_dims,
+ patch_len=self.input_patch_len,
+ horizon_len=self.output_patch_len,
+ head_dim=self.model_dims // self.num_heads,
+ quantiles=self.quantiles,
+ use_positional_embedding=self.use_pos_emb,
+ )
+ self._model = None
+ self.num_cores = 1
+ self.global_batch_size = self.per_core_batch_size
+ self._device = torch.device("cuda:0" if (
+ torch.cuda.is_available() and self.backend == "gpu") else "cpu")
+ self._median_index = -1
+
+ def _set_horizon(self, horizon): # changed: added to change horizon after initialization
+ self.horizon_len = horizon
+
+ def load_from_checkpoint(
+ self,
+ checkpoint: timesfm_base.TimesFmCheckpoint,
+ ) -> None:
+ """Loads a checkpoint and compiles the decoder."""
+ checkpoint_path = checkpoint.path
+ repo_id = checkpoint.huggingface_repo_id
+ if not os.path.exists(checkpoint_path): # changed: make loading similar to chronos
+ checkpoint_path = path.join(snapshot_download(checkpoint_path, cache_dir=checkpoint.local_dir), "torch_model.ckpt")
+ self._model = ppd.PatchedTimeSeriesDecoder(self._model_config)
+ loaded_checkpoint = torch.load(checkpoint_path) # changed: remove weights_only=True due to attribute absence in low torch versions
+ logging.info("Loading checkpoint from %s", checkpoint_path)
+ self._model.load_state_dict(loaded_checkpoint)
+ logging.info("Sending checkpoint to device %s", f"{self._device}")
+ self._model.to(self._device)
+ self._model.eval()
+ # TODO: add compilation.
+
+ def _forecast(
+ self,
+ inputs: Sequence[Any],
+ freq: Optional[Sequence[int]] = None,
+ window_size: Optional[int] = None,
+ forecast_context_len: Optional[int] = None,
+ return_forecast_on_context: bool = False,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Forecasts on a list of time series.
+
+ Args:
+ inputs: list of time series forecast contexts. Each context time series
+ should be in a format convertible to JTensor by `jnp.array`.
+ freq: frequency of each context time series. 0 for high frequency
+ (default), 1 for medium, and 2 for low. Notice this is different from
+ the `freq` required by `forecast_on_df`.
+ window_size: window size of trend + residual decomposition. If None then
+ we do not do decomposition.
+ forecast_context_len: optional max context length.
+ return_forecast_on_context: True to return the forecast on the context
+ when available, i.e. after the first input patch.
+
+ Returns:
+ A tuple for JTensors:
+ - the mean forecast of size (# inputs, # forecast horizon),
+ - the full forecast (mean + quantiles) of size
+ (# inputs, # forecast horizon, 1 + # quantiles).
+
+ Raises:
+ ValueError: If the checkpoint is not properly loaded.
+ """
+ if not self._model:
+ raise ValueError(
+ "Checkpoint not loaded. Call `load_from_checkpoint` before"
+ " `forecast`.")
+ if forecast_context_len is None:
+ fcontext_len = self.context_len
+ else:
+ fcontext_len = forecast_context_len
+ inputs = [np.array(ts)[-fcontext_len:] for ts in inputs]
+
+ if window_size is not None:
+ new_inputs = []
+ for ts in inputs:
+ new_inputs.extend(timesfm_base.moving_average(ts, window_size))
+ inputs = new_inputs
+
+ if freq is None:
+ logging.info("No frequency provided via `freq`. Default to high (0).")
+ freq = [0] * len(inputs)
+
+ input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq)
+ with torch.no_grad():
+ mean_outputs = []
+ full_outputs = []
+ assert input_ts.shape[0] % self.global_batch_size == 0
+ for i in range(input_ts.shape[0] // self.global_batch_size):
+ input_ts_in = torch.from_numpy(
+ np.array(input_ts[i * self.global_batch_size:(i + 1) *
+ self.global_batch_size],
+ dtype=np.float32)).to(self._device)
+ input_padding_in = torch.from_numpy(
+ np.array(input_padding[i * self.global_batch_size:(i + 1) *
+ self.global_batch_size],
+ dtype=np.float32)).to(self._device)
+ inp_freq_in = torch.from_numpy(
+ np.array(inp_freq[
+ i * self.global_batch_size:(i + 1) * self.global_batch_size,
+ :,
+ ],
+ dtype=np.int32)).long().to(self._device)
+ mean_output, full_output = self._model.decode(
+ input_ts=input_ts_in,
+ paddings=input_padding_in,
+ freq=inp_freq_in,
+ horizon_len=self.horizon_len,
+ return_forecast_on_context=return_forecast_on_context,
+ )
+ mean_output = mean_output.detach().cpu().numpy()
+ full_output = full_output.detach().cpu().numpy()
+ mean_output = np.array(mean_output)
+ full_output = np.array(full_output)
+ mean_outputs.append(mean_output)
+ full_outputs.append(full_output)
+
+ mean_outputs = np.concatenate(mean_outputs, axis=0)
+ full_outputs = np.concatenate(full_outputs, axis=0)
+
+ if pmap_pad > 0:
+ mean_outputs = mean_outputs[:-pmap_pad, ...]
+ full_outputs = full_outputs[:-pmap_pad, ...]
+
+ if window_size is not None:
+ mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
+ full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
+ return mean_outputs, full_outputs
\ No newline at end of file
diff --git a/etna/libs/timesfm/timesfm_base.py b/etna/libs/timesfm/timesfm_base.py
new file mode 100644
index 000000000..14755cbb3
--- /dev/null
+++ b/etna/libs/timesfm/timesfm_base.py
@@ -0,0 +1,812 @@
+"""
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+ 1. Definitions.
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+"""
+
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Note: Copied from timesfm repository (https://github.com/google-research/timesfm/blob/154248137ccce29b01f4c3a765e85c3d9e4d92ba/src/timesfm/timesfm_base.py)
+# replace print with logging
+
+import warnings
+
+"""Base class for TimesFM inference. This will be common to PAX and Pytorch."""
+
+import collections
+import dataclasses
+import logging
+import multiprocessing
+from typing import Any, Literal, Sequence, Optional, Tuple, List, Dict, Union
+from pathlib import Path
+import numpy as np
+import pandas as pd
+
+from utilsforecast.processing import make_future_dataframe
+
+from etna.libs.timesfm import xreg_lib
+
+Category = xreg_lib.Category
+XRegMode = xreg_lib.XRegMode
+
+_TOL = 1e-6
+DEFAULT_QUANTILES = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)
+
+
+def process_group(key, group, value_name, forecast_context_len):
+ group = group.tail(forecast_context_len)
+ return np.array(group[value_name], dtype=np.float32), key
+
+
+def moving_average(arr, window_size):
+ """Calculates the moving average using NumPy's convolution function."""
+ # Pad with zeros to handle initial window positions
+ arr_padded = np.pad(arr, (window_size - 1, 0), "constant")
+ smoothed_arr = (np.convolve(arr_padded, np.ones(window_size), "valid") /
+ window_size)
+ return [smoothed_arr, arr - smoothed_arr]
+
+
+def freq_map(freq: Optional[str]):
+ """Returns the frequency map for the given frequency string."""
+ if freq is None: # changed: added this case to handle int timestamps during forecasting with exogenous features
+ warnings.warn("Frequency is None. Mapping it to 0, that can be not optimal. Better to set it to known frequency")
+ return 0
+ freq = str.upper(freq)
+ if (freq.endswith("H") or freq.endswith("T") or freq.endswith("MIN") or
+ freq.endswith("D") or freq.endswith("B") or freq.endswith("U")):
+ return 0
+ elif freq.endswith(("W", "M", "MS")):
+ return 1
+ elif freq.endswith("Y") or freq.endswith("Q"):
+ return 2
+ else:
+ raise ValueError(f"Invalid frequency: {freq}")
+
+
+# Per time series normalization: forward.
+def _normalize(batch):
+ stats = [
+ (np.mean(x), np.where((w := np.std(x)) > _TOL, w, 1.0)) for x in batch
+ ]
+ new_batch = [(x - stat[0]) / stat[1] for x, stat in zip(batch, stats)]
+ return new_batch, stats
+
+
+# Per time series normalization: inverse.
+def _renormalize(batch, stats):
+ return [x * stat[1] + stat[0] for x, stat in zip(batch, stats)]
+
+
+@dataclasses.dataclass()
+class TimesFmHparams:
+ """Hparams used to initialize a TimesFM model for inference.
+
+ These are the sufficient subset of hparams to configure TimesFM inference
+ agnostic to the checkpoint version, and are not necessarily the same as the
+ hparams used to train the checkpoint.
+
+ Attributes:
+ context_len: Largest context length the model allows for each decode call.
+ This technically can be any large, but practically should set to the
+ context length the checkpoint was trained with.
+ horizon_len: Forecast horizon.
+ input_patch_len: Input patch len.
+ output_patch_len: Output patch len. How many timepoints is taken from a
+ single step of autoregressive decoding. Can be set as the training horizon
+ of the checkpoint.
+ num_layers: Number of transformer layers in the model.
+ model_dims: Model dimension.
+ per_core_batch_size: Batch size on each core for data parallelism.
+ backend: One of "cpu", "gpu" or "tpu".
+ quantiles: Which quantiles are output by the model.
+ """
+
+ context_len: int = 512
+ horizon_len: int = 128
+ input_patch_len: int = 32
+ output_patch_len: int = 128
+ num_layers: int = 20
+ num_heads: int = 16
+ model_dims: int = 1280
+ per_core_batch_size: int = 32
+ backend: Literal["cpu", "gpu", "tpu"] = "cpu"
+ quantiles: Optional[Sequence[float]] = DEFAULT_QUANTILES
+ use_positional_embedding: bool = True
+ # Hparams beyond the model.
+ point_forecast_mode: Literal["mean", "median"] = "median"
+
+
+@dataclasses.dataclass()
+class TimesFmCheckpoint:
+ """Checkpoint used to initialize a TimesFM model for inference.
+
+ Attributes:
+ version: Version of the checkpoint, e.g. "jax", "torch", "tensorflow", etc.
+ The factory will create the corresponding TimesFm inference class based on
+ this version.
+ path: Path to the checkpoint.
+ type: If provided, type of the checkpoint used by the specific checkpoint
+ loader per version.
+ step: If provided, step of the checkpoint.
+ """
+
+ version: str = "jax"
+ path: Optional[Union[str, Path]] = None
+ huggingface_repo_id: Optional[str] = None
+ type: Any = None
+ step: Optional[int] = None
+ local_dir: Optional[Union[str, Path]] = None
+
+
+class TimesFmBase:
+ """Base TimesFM forecast API for inference.
+
+ This class is the scaffolding for calling TimesFM forecast. To properly use:
+ 1. Create an instance with the correct hyperparameters of a TimesFM model.
+ 2. Call `load_from_checkpoint` to load a compatible checkpoint.
+ 3. Call `forecast` for inference.
+ """
+
+ def _logging(self, s):
+ print(s)
+
+ def __post_init__(self) -> None:
+ """Additional initialization for subclasses before checkpoint loading."""
+ pass
+
+ def __init__(self, hparams: TimesFmHparams,
+ checkpoint: TimesFmCheckpoint) -> None:
+ """Initializes the TimesFM forecast API.
+
+ Args:
+ hparams: Hyperparameters of the model.
+ checkpoint: Checkpoint to load. Notice `checkpoint.version` will decide
+ which TimesFM version to use.
+ """
+ self.hparams = hparams
+
+ # Expand hparams for conciseness within the model code.
+ self.context_len = hparams.context_len
+ self.horizon_len = hparams.horizon_len
+ self.input_patch_len = hparams.input_patch_len
+ self.output_patch_len = hparams.output_patch_len
+ self.num_layers = hparams.num_layers
+ self.model_dims = hparams.model_dims
+ self.backend = hparams.backend
+ self.quantiles = hparams.quantiles
+ self.num_heads = hparams.num_heads
+ self.use_pos_emb = hparams.use_positional_embedding
+
+ # Rewrite these values in __post_init__ for SPMD.
+ self.num_cores = 1
+ self.per_core_batch_size = hparams.per_core_batch_size
+ self.global_batch_size = hparams.per_core_batch_size
+
+ self._horizon_start = self.context_len - self.input_patch_len
+ self.__post_init__()
+ self.load_from_checkpoint(checkpoint)
+
+ def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None:
+ """Loads a checkpoint and compiles the decoder."""
+ raise NotImplementedError("`load_from_checkpoint` is not implemented.")
+
+ def _preprocess(
+ self, inputs: Sequence[np.ndarray],
+ freq: Sequence[int]) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]:
+ """Formats and pads raw inputs to feed into the model.
+
+ This function both pads each time series to match the context length, and
+ pads the inputs to meet the SPMD shape requirement.
+
+ Args:
+ inputs: A list of 1d JTensors. Each JTensor is the context time series of
+ a single forecast task.
+ freq: list of frequencies
+
+ Returns:
+ A tuple of:
+ - the padded input time series to meet the model required context.
+ - the padding indicator.
+ - the frequency of each input time series.
+ - the number of padded examples for SPMD so that each core has the same
+ number (a multiple of `batch_size`) of examples.
+ """
+
+ input_ts, input_padding, inp_freq = [], [], []
+
+ pmap_pad = ((len(inputs) - 1) // self.global_batch_size +
+ 1) * self.global_batch_size - len(inputs)
+
+ for i, ts in enumerate(inputs):
+ input_len = ts.shape[0]
+ padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float)
+ if input_len < self.context_len:
+ num_front_pad = self.context_len - input_len
+ ts = np.concatenate([np.zeros(shape=(num_front_pad,), dtype=float), ts],
+ axis=0)
+ padding = np.concatenate(
+ [np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0)
+ elif input_len > self.context_len:
+ ts = ts[-self.context_len:]
+ padding = padding[-(self.context_len + self.horizon_len):]
+
+ input_ts.append(ts)
+ input_padding.append(padding)
+ inp_freq.append(freq[i])
+
+ # Padding the remainder batch.
+ for _ in range(pmap_pad):
+ input_ts.append(input_ts[-1])
+ input_padding.append(input_padding[-1])
+ inp_freq.append(inp_freq[-1])
+
+ return (
+ np.stack(input_ts, axis=0),
+ np.stack(input_padding, axis=0),
+ np.array(inp_freq).astype(np.int32).reshape(-1, 1),
+ pmap_pad,
+ )
+
+ def _forecast(
+ self,
+ inputs: Sequence[Any],
+ freq: Optional[Sequence[int]] = None,
+ window_size: Optional[int] = None,
+ forecast_context_len: Optional[int] = None,
+ return_forecast_on_context: bool = False,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Forecasts on a list of time series.
+
+ Args:
+ inputs: list of time series forecast contexts. Each context time series
+ should be in a format convertible to JTensor by `jnp.array`.
+ freq: frequency of each context time series. 0 for high frequency
+ (default), 1 for medium, and 2 for low. Notice this is different from
+ the `freq` required by `forecast_on_df`.
+ window_size: window size of trend + residual decomposition. If None then
+ we do not do decomposition.
+ forecast_context_len: optional max context length.
+ return_forecast_on_context: True to return the forecast on the context
+ when available, i.e. after the first input patch.
+
+ Returns:
+ A tuple for np.array:
+ - the mean forecast of size (# inputs, # forecast horizon),
+ - the full forecast (mean + quantiles) of size
+ (# inputs, # forecast horizon, 1 + # quantiles).
+
+ Raises:
+ ValueError: If the checkpoint is not properly loaded.
+ """
+ raise NotImplementedError("`_forecast` is not implemented.")
+
+ def forecast(
+ self,
+ inputs: Sequence[Any],
+ freq: Optional[Sequence[int]] = None,
+ window_size: Optional[int] = None,
+ forecast_context_len: Optional[int] = None,
+ return_forecast_on_context: bool = False,
+ normalize: bool = False,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Forecasts on a list of time series.
+
+ Args:
+ inputs: list of time series forecast contexts. Each context time series
+ should be in a format convertible to JTensor by `jnp.array`.
+ freq: frequency of each context time series. 0 for high frequency
+ (default), 1 for medium, and 2 for low. Notice this is different from
+ the `freq` required by `forecast_on_df`.
+ window_size: window size of trend + residual decomposition. If None then
+ we do not do decomposition.
+ forecast_context_len: optional max context length.
+ return_forecast_on_context: True to return the forecast on the context
+ when available, i.e. after the first input patch.
+ normalize: If True, then we normalize the inputs before forecasting and
+ the outputs are then renormalized to the original scale.
+
+ Returns:
+ A tuple for np.array:
+ - the mean forecast of size (# inputs, # forecast horizon),
+ - the full forecast (mean + quantiles) of size
+ (# inputs, # forecast horizon, 1 + # quantiles).
+
+ Raises:
+ ValueError: If the checkpoint is not properly loaded.
+ """
+ stats = None
+ if normalize:
+ inputs, stats = _normalize(inputs)
+ mean_forecast, quantile_forecast = self._forecast(
+ inputs,
+ freq,
+ window_size,
+ forecast_context_len,
+ return_forecast_on_context,
+ )
+ if stats is not None:
+ stats = np.array(stats)
+ mu = stats[:, 0]
+ sigma = stats[:, 1]
+ mean_forecast = mean_forecast * sigma[:, None] + mu[:, None]
+ quantile_forecast = (quantile_forecast * sigma[:, None, None] +
+ mu[:, None, None])
+ if self.hparams.point_forecast_mode == "mean":
+ return mean_forecast, quantile_forecast
+ elif self.hparams.point_forecast_mode == "median":
+ if self._median_index == -1:
+ for i, quantile in enumerate(self.quantiles):
+ if quantile == 0.5:
+ self._median_index = i
+ break
+ if self._median_index == -1:
+ raise ValueError("Median (0.5) is not found in the model quantiles:"
+ f" {self.quantiles}. Please check the hparams.")
+ return (
+ quantile_forecast[:, :, 1 + self._median_index],
+ quantile_forecast,
+ )
+ else:
+ raise ValueError(
+ "Unsupported point forecast mode:"
+ f" {self.hparams.point_forecast_mode}. Use 'mean' or 'median'.")
+
+ def forecast_with_covariates(
+ self,
+ inputs: List[Sequence[float]],
+ dynamic_numerical_covariates: Optional[Dict[str, Sequence[Sequence[float]]]] = None,
+ dynamic_categorical_covariates: Optional[Dict[str, Sequence[Sequence[Category]]]] = None,
+ static_numerical_covariates: Optional[Dict[str, Sequence[float]]] = None,
+ static_categorical_covariates: Optional[Dict[str, Sequence[Category]]]= None,
+ freq: Optional[Sequence[int]] = None,
+ window_size: Optional[int] = None,
+ forecast_context_len: Optional[int]= None,
+ xreg_mode: XRegMode = "xreg + timesfm",
+ normalize_xreg_target_per_input: bool = True,
+ ridge: float = 0.0,
+ max_rows_per_col: int = 0,
+ force_on_cpu: bool = False,
+ ):
+ """Forecasts on a list of time series with covariates.
+
+ To optimize inference speed, avoid string valued categorical covariates.
+
+ Args:
+ inputs: A list of time series forecast contexts. Each context time series
+ should be in a format convertible to JTensor by `jnp.array`.
+ dynamic_numerical_covariates: A dict of dynamic numerical covariates.
+ dynamic_categorical_covariates: A dict of dynamic categorical covariates.
+ static_numerical_covariates: A dict of static numerical covariates.
+ static_categorical_covariates: A dict of static categorical covariates.
+ freq: frequency of each context time series. 0 for high frequency
+ (default), 1 for medium, and 2 for low. Notice this is different from
+ the `freq` required by `forecast_on_df`.
+ window_size: window size of trend + residual decomposition. If None then
+ we do not do decomposition.
+ forecast_context_len: optional max context length.
+ xreg_mode: one of "xreg + timesfm" or "timesfm + xreg". "xreg + timesfm"
+ fits a model on the residuals of the TimesFM forecast. "timesfm + xreg"
+ fits a model on the targets then forecasts on the residuals via TimesFM.
+ normalize_xreg_target_per_input: whether to normalize the xreg target per
+ input in the given batch.
+ ridge: ridge penalty for the linear model.
+ max_rows_per_col: max number of rows per column for the linear model.
+ force_on_cpu: whether to force running on cpu for the linear model.
+
+ Returns:
+ A tuple of two lists. The first is the outputs of the model. The second is
+ the outputs of the xreg.
+ """
+
+ # Verify and bookkeep covariates.
+ if not (dynamic_numerical_covariates or dynamic_categorical_covariates or
+ static_numerical_covariates or static_categorical_covariates):
+ raise ValueError(
+ "At least one of dynamic_numerical_covariates,"
+ " dynamic_categorical_covariates, static_numerical_covariates,"
+ " static_categorical_covariates must be set.")
+
+ # Track the lengths of (1) each input, (2) the part that can be used in the
+ # linear model, and (3) the horizon.
+ input_lens, train_lens, test_lens = [], [], []
+
+ for i, input_ts in enumerate(inputs):
+ input_len = len(input_ts)
+ input_lens.append(input_len)
+
+ if xreg_mode == "timesfm + xreg":
+ # For fitting residuals, no TimesFM forecast on the first patch.
+ train_lens.append(max(0, input_len - self.input_patch_len))
+ elif xreg_mode == "xreg + timesfm":
+ train_lens.append(input_len)
+ else:
+ raise ValueError(f"Unsupported mode: {xreg_mode}")
+
+ if dynamic_numerical_covariates:
+ test_lens.append(
+ len(list(dynamic_numerical_covariates.values())[0][i]) - input_len)
+ elif dynamic_categorical_covariates:
+ test_lens.append(
+ len(list(dynamic_categorical_covariates.values())[0][i]) -
+ input_len)
+ else:
+ test_lens.append(self.horizon_len)
+
+ if test_lens[-1] > self.horizon_len:
+ raise ValueError(
+ "Forecast requested longer horizon than the model definition "
+ f"supports: {test_lens[-1]} vs {self.horizon_len}.")
+
+ # Prepare the covariates into train and test.
+ train_dynamic_numerical_covariates = collections.defaultdict(list)
+ test_dynamic_numerical_covariates = collections.defaultdict(list)
+ train_dynamic_categorical_covariates = collections.defaultdict(list)
+ test_dynamic_categorical_covariates = collections.defaultdict(list)
+ for covariates, train_covariates, test_covariates in (
+ (
+ dynamic_numerical_covariates,
+ train_dynamic_numerical_covariates,
+ test_dynamic_numerical_covariates,
+ ),
+ (
+ dynamic_categorical_covariates,
+ train_dynamic_categorical_covariates,
+ test_dynamic_categorical_covariates,
+ ),
+ ):
+ if not covariates:
+ continue
+ for covariate_name, covariate_values in covariates.items():
+ for input_len, train_len, covariate_value in zip(
+ input_lens, train_lens, covariate_values):
+ train_covariates[covariate_name].append(
+ covariate_value[(input_len - train_len):input_len])
+ test_covariates[covariate_name].append(covariate_value[input_len:])
+
+ # Fit models.
+ if xreg_mode == "timesfm + xreg":
+ # Forecast via TimesFM then fit a model on the residuals.
+ mean_outputs, _ = self.forecast(
+ inputs,
+ freq,
+ window_size,
+ forecast_context_len,
+ return_forecast_on_context=True,
+ )
+ targets = [
+ (np.array(input_ts)[-train_len:] -
+ mean_output[(self._horizon_start - train_len):self._horizon_start])
+ for input_ts, mean_output, train_len in zip(inputs, mean_outputs,
+ train_lens)
+ ]
+ per_instance_stats = None
+ if normalize_xreg_target_per_input:
+ targets, per_instance_stats = _normalize(targets)
+ xregs = xreg_lib.BatchedInContextXRegLinear(
+ targets=targets,
+ train_lens=train_lens,
+ test_lens=test_lens,
+ train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,
+ test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,
+ train_dynamic_categorical_covariates=
+ train_dynamic_categorical_covariates,
+ test_dynamic_categorical_covariates=
+ test_dynamic_categorical_covariates,
+ static_numerical_covariates=static_numerical_covariates,
+ static_categorical_covariates=static_categorical_covariates,
+ ).fit(
+ ridge=ridge,
+ one_hot_encoder_drop=None if ridge > 0 else "first",
+ max_rows_per_col=max_rows_per_col,
+ force_on_cpu=force_on_cpu,
+ debug_info=False,
+ assert_covariates=True,
+ assert_covariate_shapes=True,
+ )
+ if normalize_xreg_target_per_input:
+ xregs = _renormalize(xregs, per_instance_stats)
+ outputs = [
+ (mean_output[self._horizon_start:(self._horizon_start + test_len)] +
+ xreg)
+ for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs)
+ ]
+
+ else:
+ # Fit a model on the targets then forecast on the residuals via TimesFM.
+ targets = [
+ np.array(input_ts)[-train_len:]
+ for input_ts, train_len in zip(inputs, train_lens)
+ ]
+ per_instance_stats = None
+ if normalize_xreg_target_per_input:
+ targets, per_instance_stats = _normalize(targets)
+ xregs, xregs_on_context, _, _, _ = xreg_lib.BatchedInContextXRegLinear(
+ targets=targets,
+ train_lens=train_lens,
+ test_lens=test_lens,
+ train_dynamic_numerical_covariates=train_dynamic_numerical_covariates,
+ test_dynamic_numerical_covariates=test_dynamic_numerical_covariates,
+ train_dynamic_categorical_covariates=
+ train_dynamic_categorical_covariates,
+ test_dynamic_categorical_covariates=
+ test_dynamic_categorical_covariates,
+ static_numerical_covariates=static_numerical_covariates,
+ static_categorical_covariates=static_categorical_covariates,
+ ).fit(
+ ridge=ridge,
+ one_hot_encoder_drop=None if ridge > 0 else "first",
+ max_rows_per_col=max_rows_per_col,
+ force_on_cpu=force_on_cpu,
+ debug_info=True,
+ assert_covariates=True,
+ assert_covariate_shapes=True,
+ )
+ mean_outputs, _ = self.forecast(
+ [
+ target - xreg_on_context
+ for target, xreg_on_context in zip(targets, xregs_on_context)
+ ],
+ freq,
+ window_size,
+ forecast_context_len,
+ return_forecast_on_context=True,
+ )
+ outputs = [
+ (mean_output[self._horizon_start:(self._horizon_start + test_len)] +
+ xreg)
+ for mean_output, test_len, xreg in zip(mean_outputs, test_lens, xregs)
+ ]
+ if normalize_xreg_target_per_input:
+ outputs = _renormalize(outputs, per_instance_stats)
+
+ return outputs, xregs
+
+ def forecast_on_df(
+ self,
+ inputs: pd.DataFrame,
+ freq: str,
+ forecast_context_len: int = 0,
+ value_name: str = "values",
+ model_name: str = "timesfm",
+ window_size: Optional[int] = None,
+ num_jobs: int = 1,
+ verbose: bool = True,
+ ) -> pd.DataFrame:
+ """Forecasts on a list of time series.
+
+ Args:
+ inputs: A pd.DataFrame of all time series. The dataframe should have a
+ `unique_id` column for identifying the time series, a `ds` column for
+ timestamps and a value column for the time series values.
+ freq: string valued `freq` of data. Notice this is different from the
+ `freq` required by `forecast`. See `freq_map` for allowed values.
+ forecast_context_len: If provided none zero, we take the last
+ `forecast_context_len` time-points from each series as the forecast
+ context instead of the `context_len` set by the model.
+ value_name: The name of the value column.
+ model_name: name of the model to be written into future df.
+ window_size: window size of trend + residual decomposition. If None then
+ we do not do decomposition.
+ num_jobs: number of parallel processes to use for dataframe processing.
+ verbose: output model states in terminal.
+
+ Returns:
+ Future forecasts dataframe.
+ """
+ if not ("unique_id" in inputs.columns and "ds" in inputs.columns and
+ value_name in inputs.columns):
+ raise ValueError(
+ f"DataFrame must have unique_id, ds and {value_name} columns.")
+ if not forecast_context_len:
+ forecast_context_len = self.context_len
+ logging.info("Preprocessing dataframe.")
+ df_sorted = inputs.sort_values(by=["unique_id", "ds"])
+ new_inputs = []
+ uids = []
+ if num_jobs == 1:
+ if verbose:
+ logging.info("Processing dataframe with single process.") # changed: replace print
+ for key, group in df_sorted.groupby("unique_id"):
+ inp, uid = process_group(
+ key,
+ group,
+ value_name,
+ forecast_context_len,
+ )
+ new_inputs.append(inp)
+ uids.append(uid)
+ else:
+ if num_jobs == -1:
+ num_jobs = multiprocessing.cpu_count()
+ if verbose:
+ logging.info("Processing dataframe with multiple processes.") # changed: replace print
+ with multiprocessing.Pool(processes=num_jobs) as pool:
+ results = pool.starmap(
+ process_group,
+ [(key, group, value_name, forecast_context_len)
+ for key, group in df_sorted.groupby("unique_id")],
+ )
+ new_inputs, uids = zip(*results)
+ if verbose:
+ logging.info("Finished preprocessing dataframe.") # changed: replace print
+ freq_inps = [freq_map(freq)] * len(new_inputs)
+ _, full_forecast = self.forecast(new_inputs,
+ freq=freq_inps,
+ window_size=window_size)
+ if verbose:
+ logging.info("Finished forecasting.")
+ fcst_df = make_future_dataframe(
+ uids=uids,
+ last_times=df_sorted.groupby("unique_id")["ds"].tail(1),
+ h=self.horizon_len,
+ freq=freq,
+ )
+ fcst_df[model_name] = full_forecast[:, 0:self.horizon_len, 0].reshape(-1, 1)
+
+ for i, q in enumerate(self.quantiles):
+ q_col = f"{model_name}-q-{q}"
+ fcst_df[q_col] = full_forecast[:, 0:self.horizon_len,
+ 1 + i].reshape(-1, 1)
+ if q == 0.5:
+ fcst_df[model_name] = fcst_df[q_col]
+ logging.info("Finished creating output dataframe.")
+ return fcst_df
\ No newline at end of file
diff --git a/etna/libs/timesfm/xreg_lib.py b/etna/libs/timesfm/xreg_lib.py
new file mode 100644
index 000000000..4521009bf
--- /dev/null
+++ b/etna/libs/timesfm/xreg_lib.py
@@ -0,0 +1,643 @@
+"""
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+ 1. Definitions.
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+"""
+
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Note: Copied from timesfm repository (https://github.com/google-research/timesfm/blob/154248137ccce29b01f4c3a765e85c3d9e4d92ba/src/timesfm/xreg_lib.py)
+# add check of sklearn version for OHE
+"""Helper functions for in-context covariates and regression."""
+
+import itertools
+import math
+from typing import Any, Iterable, Literal, Mapping, Sequence, Union, Optional, Tuple, List
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from sklearn import preprocessing
+from sklearn import __version__ as sklearn_version
+
+Category = Union[int, str]
+
+_TOL = 1e-6
+XRegMode = Literal["timesfm + xreg", "xreg + timesfm"]
+
+
+def _unnest(nested: Sequence[Sequence[Any]]) -> np.ndarray:
+ return np.array(list(itertools.chain.from_iterable(nested)))
+
+
+def _repeat(elements: Iterable[Any], counts: Iterable[int]) -> np.ndarray:
+ return np.array(
+ list(
+ itertools.chain.from_iterable(map(itertools.repeat, elements,
+ counts))))
+
+
+def _to_padded_jax_array(x: np.ndarray) -> jax.Array:
+ if x.ndim == 1:
+ (i,) = x.shape
+ di = 2**math.ceil(math.log2(i)) - i
+ return jnp.pad(x, ((0, di),), mode="constant", constant_values=0.0)
+ elif x.ndim == 2:
+ i, j = x.shape
+ di = 2**math.ceil(math.log2(i)) - i
+ dj = 2**math.ceil(math.log2(j)) - j
+ return jnp.pad(x, ((0, di), (0, dj)), mode="constant", constant_values=0.0)
+ else:
+ raise ValueError(f"Unsupported array shape: {x.shape}")
+
+
+class BatchedInContextXRegBase:
+ """Helper class for in-context regression covariate formatting.
+
+ Attributes:
+ targets: List of targets (responses) of the in-context regression.
+ train_lens: List of lengths of each target vector from the context.
+ test_lens: List of lengths of each forecast horizon.
+ train_dynamic_numerical_covariates: Dict of covariate names mapping to the
+ dynamic numerical covariates of each forecast task on the context. Their
+ lengths should match the corresponding lengths in `train_lens`.
+ train_dynamic_categorical_covariates: Dict of covariate names mapping to the
+ dynamic categorical covariates of each forecast task on the context. Their
+ lengths should match the corresponding lengths in `train_lens`.
+ test_dynamic_numerical_covariates: Dict of covariate names mapping to the
+ dynamic numerical covariates of each forecast task on the horizon. Their
+ lengths should match the corresponding lengths in `test_lens`.
+ test_dynamic_categorical_covariates: Dict of covariate names mapping to the
+ dynamic categorical covariates of each forecast task on the horizon. Their
+ lengths should match the corresponding lengths in `test_lens`.
+ static_numerical_covariates: Dict of covariate names mapping to the static
+ numerical covariates of each forecast task.
+ static_categorical_covariates: Dict of covariate names mapping to the static
+ categorical covariates of each forecast task.
+ """
+
+ def __init__(
+ self,
+ targets: Sequence[Sequence[float]],
+ train_lens: Sequence[int],
+ test_lens: Sequence[int],
+ train_dynamic_numerical_covariates: Optional[Mapping[str, Sequence[Sequence[float]]]] = None,
+ train_dynamic_categorical_covariates: Optional[Mapping[str, Sequence[Sequence[Category]]]] = None,
+ test_dynamic_numerical_covariates: Optional[Mapping[str, Sequence[Sequence[float]]]] = None,
+ test_dynamic_categorical_covariates: Optional[Mapping[str, Sequence[Sequence[Category]]]] = None,
+ static_numerical_covariates: Optional[Mapping[str, Sequence[float]]] = None,
+ static_categorical_covariates: Optional[Mapping[str, Sequence[Category]]] = None,
+ ) -> None:
+ """Initializes with the exogenous covariate inputs.
+
+ Here we use model fitting language to refer to the context as 'train' and
+ the horizon as 'test'. We assume batched inputs. To properly format the
+ request:
+
+ - `train_lens` represents the contexts in the batch. Targets and all train
+ dynamic covariates should have the same lengths as the corresponding
+ elements
+ in `train_lens`. Notice each `train_len` can be different from the exact
+ length of the corresponding context depending on how much of the context is
+ used for fitting the in-context model.
+ - `test_lens` represents the horizon lengths in the batch. All tesdt
+ dynamic
+ covariates should have the same lengths as the corresponding elements in
+ `test_lens`.
+ - Static covariates should be one for each input.
+ - For train and test dynamic covariates, they should have the same
+ covariate
+ names.
+
+ Pass an empty dict {} for a covariate type if it is not present.
+
+ Example:
+ Here is a set of valid inputs whose schema can be used for reference.
+ ```
+ targets = [
+ [0.0, 0.1, 0.2],
+ [0.0, 0.1, 0.2, 0.3],
+ ] # Two inputs in this batch.
+ train_lens = [3, 4]
+ test_lens = [2, 5] # Forecast horizons 2 and 5 respectively.
+ train_dynamic_numerical_covariates = {
+ "cov_1_dn": [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0, 1.5]],
+ "cov_2_dn": [[0.0, 1.5, 1.0], [0.0, 1.5, 1.0, 2.5]],
+ } # Each train dynamic covariate has 3 and 4 elements respectively.
+ test_dynamic_numerical_covariates = {
+ "cov_1_dn": [[0.1, 0.6], [0.1, 0.6, 1.1, 1.6, 2.4]],
+ "cov_2_dn": [[0.1, 1.1], [0.1, 1.6, 1.1, 2.6, 10.0]],
+ } # Each test dynamic covariate has 2 and 5 elements respectively.
+ train_dynamic_categorical_covariates = {
+ "cov_1_dc": [[0, 1, 0], [0, 1, 2, 3]],
+ "cov_2_dc": [["good", "bad", "good"], ["good", "good", "bad",
+ "bad"]],
+ }
+ test_dynamic_categorical_covariates = {
+ "cov_1_dc": [[1, 0], [1, 0, 2, 3, 1]],
+ "cov_2_dc": [["bad", "good"], ["bad", "bad", "bad", "bad", "bad"]],
+ }
+ static_numerical_covariates = {
+ "cov_1_sn": [0.0, 3.0],
+ "cov_2_sn": [2.0, 1.0],
+ "cov_3_sn": [1.0, 2.0],
+ } # Each static covariate has 1 element for each input.
+ static_categorical_covariates = {
+ "cov_1_sc": ["apple", "orange"],
+ "cov_2_sc": [2, 3],
+ }
+ ```
+
+ Args:
+ targets: List of targets (responses) of the in-context regression.
+ train_lens: List of lengths of each target vector from the context.
+ test_lens: List of lengths of each forecast horizon.
+ train_dynamic_numerical_covariates: Dict of covariate names mapping to the
+ dynamic numerical covariates of each forecast task on the context. Their
+ lengths should match the corresponding lengths in `train_lens`.
+ train_dynamic_categorical_covariates: Dict of covariate names mapping to
+ the dynamic categorical covariates of each forecast task on the context.
+ Their lengths should match the corresponding lengths in `train_lens`.
+ test_dynamic_numerical_covariates: Dict of covariate names mapping to the
+ dynamic numerical covariates of each forecast task on the horizon. Their
+ lengths should match the corresponding lengths in `test_lens`.
+ test_dynamic_categorical_covariates: Dict of covariate names mapping to
+ the dynamic categorical covariates of each forecast task on the horizon.
+ Their lengths should match the corresponding lengths in `test_lens`.
+ static_numerical_covariates: Dict of covariate names mapping to the static
+ numerical covariates of each forecast task.
+ static_categorical_covariates: Dict of covariate names mapping to the
+ static categorical covariates of each forecast task.
+ """
+ self.targets = targets
+ self.train_lens = train_lens
+ self.test_lens = test_lens
+ self.train_dynamic_numerical_covariates = (
+ train_dynamic_numerical_covariates or {})
+ self.train_dynamic_categorical_covariates = (
+ train_dynamic_categorical_covariates or {})
+ self.test_dynamic_numerical_covariates = (test_dynamic_numerical_covariates
+ or {})
+ self.test_dynamic_categorical_covariates = (
+ test_dynamic_categorical_covariates or {})
+ self.static_numerical_covariates = static_numerical_covariates or {}
+ self.static_categorical_covariates = static_categorical_covariates or {}
+
+ def _assert_covariates(self, assert_covariate_shapes: bool = False) -> None:
+ """Verifies the validity of the covariate inputs."""
+
+ # Check presence.
+ if (self.train_dynamic_numerical_covariates and
+ not self.test_dynamic_numerical_covariates) or (
+ not self.train_dynamic_numerical_covariates and
+ self.test_dynamic_numerical_covariates):
+ raise ValueError(
+ "train_dynamic_numerical_covariates and"
+ " test_dynamic_numerical_covariates must be both present or both"
+ " absent.")
+
+ if (self.train_dynamic_categorical_covariates and
+ not self.test_dynamic_categorical_covariates) or (
+ not self.train_dynamic_categorical_covariates and
+ self.test_dynamic_categorical_covariates):
+ raise ValueError(
+ "train_dynamic_categorical_covariates and"
+ " test_dynamic_categorical_covariates must be both present or both"
+ " absent.")
+
+ # Check keys.
+ for dict_a, dict_b, dict_a_name, dict_b_name in (
+ (
+ self.train_dynamic_numerical_covariates,
+ self.test_dynamic_numerical_covariates,
+ "train_dynamic_numerical_covariates",
+ "test_dynamic_numerical_covariates",
+ ),
+ (
+ self.train_dynamic_categorical_covariates,
+ self.test_dynamic_categorical_covariates,
+ "train_dynamic_categorical_covariates",
+ "test_dynamic_categorical_covariates",
+ ),
+ ):
+ if w := set(dict_a.keys()) - set(dict_b.keys()):
+ raise ValueError(
+ f"{dict_a_name} has keys not present in {dict_b_name}: {w}")
+ if w := set(dict_b.keys()) - set(dict_a.keys()):
+ raise ValueError(
+ f"{dict_b_name} has keys not present in {dict_a_name}: {w}")
+
+ # Check shapes.
+ if assert_covariate_shapes:
+ if len(self.targets) != len(self.train_lens):
+ raise ValueError(
+ "targets and train_lens must have the same number of elements.")
+
+ if len(self.train_lens) != len(self.test_lens):
+ raise ValueError(
+ "train_lens and test_lens must have the same number of elements.")
+
+ for i, (target, train_len) in enumerate(zip(self.targets,
+ self.train_lens)):
+ if len(target) != train_len:
+ raise ValueError(
+ f"targets[{i}] has length {len(target)} != expected {train_len}.")
+
+ for key, values in self.static_numerical_covariates.items():
+ if len(values) != len(self.train_lens):
+ raise ValueError(
+ f"static_numerical_covariates has key {key} with number of"
+ f" examples {len(values)} != expected {len(self.train_lens)}.")
+
+ for key, values in self.static_categorical_covariates.items():
+ if len(values) != len(self.train_lens):
+ raise ValueError(
+ f"static_categorical_covariates has key {key} with number of"
+ f" examples {len(values)} != expected {len(self.train_lens)}.")
+
+ for lens, dict_cov, dict_cov_name in (
+ (
+ self.train_lens,
+ self.train_dynamic_numerical_covariates,
+ "train_dynamic_numerical_covariates",
+ ),
+ (
+ self.train_lens,
+ self.train_dynamic_categorical_covariates,
+ "train_dynamic_categorical_covariates",
+ ),
+ (
+ self.test_lens,
+ self.test_dynamic_numerical_covariates,
+ "test_dynamic_numerical_covariates",
+ ),
+ (
+ self.test_lens,
+ self.test_dynamic_categorical_covariates,
+ "test_dynamic_categorical_covariates",
+ ),
+ ):
+ for key, cov_values in dict_cov.items():
+ if len(cov_values) != len(lens):
+ raise ValueError(
+ f"{dict_cov_name} has key {key} with number of examples"
+ f" {len(cov_values)} != expected {len(lens)}.")
+ for i, cov_value in enumerate(cov_values):
+ if len(cov_value) != lens[i]:
+ raise ValueError(
+ f"{dict_cov_name} has key {key} with its {i}-th example"
+ f" length {len(cov_value)} != expected {lens[i]}.")
+
+ def create_covariate_matrix(
+ self,
+ one_hot_encoder_drop: Optional[str]= "first",
+ use_intercept: bool = True,
+ assert_covariates: bool = False,
+ assert_covariate_shapes: bool = False,
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Creates target vector and covariate matrices for in context regression.
+
+ Here we use model fitting language to refer to the context as 'train' and
+ the horizon as 'test'.
+
+ Args:
+ one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.
+ use_intercept: Whether to prepare an intercept (all 1) column in the
+ matrices.
+ assert_covariates: Whether to assert the validity of the covariate inputs.
+ assert_covariate_shapes: Whether to assert the shapes of the covariate
+ inputs when `assert_covariates` is True.
+
+ Returns:
+ A tuple of the target vector, the covariate matrix for the context, and
+ the covariate matrix for the horizon.
+ """
+ if assert_covariates:
+ self._assert_covariates(assert_covariate_shapes)
+
+ x_train, x_test = [], []
+
+ # Numerical features.
+ for name in sorted(self.train_dynamic_numerical_covariates):
+ x_train.append(
+ _unnest(self.train_dynamic_numerical_covariates[name])[:, np.newaxis])
+ x_test.append(
+ _unnest(self.test_dynamic_numerical_covariates[name])[:, np.newaxis])
+
+ for covs in self.static_numerical_covariates.values():
+ x_train.append(_repeat(covs, self.train_lens)[:, np.newaxis])
+ x_test.append(_repeat(covs, self.test_lens)[:, np.newaxis])
+
+ if x_train:
+ x_train = np.concatenate(x_train, axis=1)
+ x_test = np.concatenate(x_test, axis=1)
+
+ # Normalize for robustness.
+ x_mean = np.mean(x_train, axis=0, keepdims=True)
+ x_std = np.where((w := np.std(x_train, axis=0, keepdims=True)) > _TOL, w,
+ 1.0)
+ x_train = [(x_train - x_mean) / x_std]
+ x_test = [(x_test - x_mean) / x_std]
+
+ sklearn_version_tuple = tuple(map(int, sklearn_version.split(".")))
+ encoder_params = {}
+ if sklearn_version_tuple < (1, 2):
+ encoder_params["sparse"] = False
+ else:
+ encoder_params["sparse_output"] = False
+
+ # Categorical features. Encode one by one.
+ one_hot_encoder = preprocessing.OneHotEncoder(
+ drop=one_hot_encoder_drop,
+ handle_unknown="ignore",
+ **encoder_params
+ )
+ for name in sorted(self.train_dynamic_categorical_covariates.keys()):
+ ohe_train = _unnest(
+ self.train_dynamic_categorical_covariates[name])[:, np.newaxis]
+ ohe_test = _unnest(
+ self.test_dynamic_categorical_covariates[name])[:, np.newaxis]
+ x_train.append(np.array(one_hot_encoder.fit_transform(ohe_train)))
+ x_test.append(np.array(one_hot_encoder.transform(ohe_test)))
+
+ for covs in self.static_categorical_covariates.values():
+ ohe = one_hot_encoder.fit_transform(np.array(covs)[:, np.newaxis])
+ x_train.append(_repeat(ohe, self.train_lens))
+ x_test.append(_repeat(ohe, self.test_lens))
+
+ x_train = np.concatenate(x_train, axis=1)
+ x_test = np.concatenate(x_test, axis=1)
+
+ if use_intercept:
+ x_train = np.pad(x_train, ((0, 0), (1, 0)), constant_values=1.0)
+ x_test = np.pad(x_test, ((0, 0), (1, 0)), constant_values=1.0)
+
+ return _unnest(self.targets), x_train, x_test
+
+ def fit(self) -> Any:
+ raise NotImplementedError("Fit is not implemented.")
+
+
+class BatchedInContextXRegLinear(BatchedInContextXRegBase):
+ """Linear in-context regression model."""
+
+ def fit(
+ self,
+ ridge: float = 0.0,
+ one_hot_encoder_drop: Optional[str] = "first",
+ use_intercept: bool = True,
+ force_on_cpu: bool = False,
+ max_rows_per_col: int = 0,
+ max_rows_per_col_sample_seed: int = 42,
+ debug_info: bool = False,
+ assert_covariates: bool = False,
+ assert_covariate_shapes: bool = False,
+ ) -> Union[List[np.ndarray], Tuple[List[np.ndarray], List[np.ndarray], jax.Array, jax.Array, jax.Array]]:
+ """Fits a linear model for in-context regression.
+
+ Args:
+ ridge: A non-negative value for specifying the ridge regression penalty.
+ If 0 is provided, fallback to ordinary least squares. Note this penalty
+ is added to the normalized covariate matrix.
+ one_hot_encoder_drop: Which drop strategy to use for the one hot encoder.
+ use_intercept: Whether to prepare an intercept (all 1) column in the
+ matrices.
+ force_on_cpu: Whether to force execution on cpu for accelerator machines.
+ max_rows_per_col: How many rows to subsample per column. 0 for no
+ subsampling. This is for speeding up model fitting.
+ max_rows_per_col_sample_seed: The seed for the subsampling if needed by
+ `max_rows_per_col`.
+ debug_info: Whether to return debug info.
+ assert_covariates: Whether to assert the validity of the covariate inputs.
+ assert_covariate_shapes: Whether to assert the shapes of the covariate
+ inputs when `assert_covariates` is True.
+
+ Returns:
+ If `debug_info` is False:
+ The linear fits on the horizon.
+ If `debug_info` is True:
+ A tuple of:
+ - the linear fits on the horizon,
+ - the linear fits on the context,
+ - the flattened target vector,
+ - the covariate matrix for the context, and
+ - the covariate matrix for the horizon.
+ """
+ flat_targets, x_train_raw, x_test = self.create_covariate_matrix(
+ one_hot_encoder_drop=one_hot_encoder_drop,
+ use_intercept=use_intercept,
+ assert_covariates=assert_covariates,
+ assert_covariate_shapes=assert_covariate_shapes,
+ )
+
+ x_train = x_train_raw.copy()
+ if max_rows_per_col:
+ nrows, ncols = x_train.shape
+ if nrows > (w := ncols * max_rows_per_col):
+ subsample = jax.random.choice(
+ jax.random.PRNGKey(max_rows_per_col_sample_seed),
+ nrows,
+ (w,),
+ replace=False,
+ )
+ x_train = x_train[subsample]
+ flat_targets = flat_targets[subsample]
+
+ device = jax.devices("cpu")[0] if force_on_cpu else None
+ # Runs jitted version of the solvers which are quicker at the cost of
+ # running jitting during the first time calling. Re-jitting happens whenever
+ # new (padded) shapes are encountered.
+ # Occasionally it helps with the speed and the accuracy if we force single
+ # thread execution on cpu for accelerator machines:
+ # 1. Avoid moving data to accelerator memory.
+ # 2. Avoid precision loss if any.
+ with jax.default_device(device):
+ x_train_raw = _to_padded_jax_array(x_train_raw)
+ x_train = _to_padded_jax_array(x_train)
+ flat_targets = _to_padded_jax_array(flat_targets)
+ x_test = _to_padded_jax_array(x_test)
+ beta_hat = (jnp.linalg.pinv(
+ x_train.T @ x_train + ridge * jnp.eye(x_train.shape[1]),
+ hermitian=True,
+ ) @ x_train.T @ flat_targets)
+ y_hat = x_test @ beta_hat
+ y_hat_context = x_train_raw @ beta_hat if debug_info else None
+
+ outputs = []
+ outputs_context = []
+
+ # Reconstruct the ragged 2-dim batched forecasts from flattened linear fits.
+ train_index, test_index = 0, 0
+ for train_index_delta, test_index_delta in zip(self.train_lens,
+ self.test_lens):
+ outputs.append(np.array(y_hat[test_index:(test_index +
+ test_index_delta)]))
+ if debug_info:
+ outputs_context.append(
+ np.array(y_hat_context[train_index:(train_index +
+ train_index_delta)]))
+ train_index += train_index_delta
+ test_index += test_index_delta
+
+ if debug_info:
+ return outputs, outputs_context, flat_targets, x_train, x_test
+ else:
+ return outputs
\ No newline at end of file
diff --git a/etna/models/nn/__init__.py b/etna/models/nn/__init__.py
index b972e2aab..4512ac420 100644
--- a/etna/models/nn/__init__.py
+++ b/etna/models/nn/__init__.py
@@ -16,3 +16,6 @@
if SETTINGS.chronos_required:
from etna.models.nn.chronos import ChronosBoltModel
from etna.models.nn.chronos import ChronosModel
+
+if SETTINGS.timesfm_required:
+ from etna.models.nn.timesfm import TimesFMModel
diff --git a/etna/models/nn/chronos/base.py b/etna/models/nn/chronos/base.py
index a20514a1a..7505b150c 100644
--- a/etna/models/nn/chronos/base.py
+++ b/etna/models/nn/chronos/base.py
@@ -202,10 +202,9 @@ def _forecast(
if max_context_size < self.context_size:
warnings.warn("Actual length of a dataset is less that context size. All history will be used as context.")
- available_context_size = min(max_context_size, self.context_size)
- target = ts.df.loc[:, pd.IndexSlice[:, "target"]]
- context = torch.tensor(target.values.T[:, :available_context_size])
+ target = ts.df.loc[:, pd.IndexSlice[:, "target"]].dropna()
+ context = torch.tensor(target.values.T)
if prediction_interval:
quantiles_forecast, target_forecast = self.pipeline.predict_quantiles(
diff --git a/etna/models/nn/timesfm.py b/etna/models/nn/timesfm.py
new file mode 100644
index 000000000..793665269
--- /dev/null
+++ b/etna/models/nn/timesfm.py
@@ -0,0 +1,376 @@
+import os
+import reprlib
+import warnings
+from pathlib import Path
+from typing import Dict
+from typing import List
+from typing import Literal
+from typing import Optional
+from urllib import request
+
+import numpy as np
+import pandas as pd
+
+from etna import SETTINGS
+from etna.datasets import TSDataset
+from etna.distributions import BaseDistribution
+from etna.models.base import NonPredictionIntervalContextRequiredAbstractModel
+
+if SETTINGS.timesfm_required:
+ from etna.libs.timesfm import TimesFmCheckpoint
+ from etna.libs.timesfm import TimesFmHparams
+ from etna.libs.timesfm import TimesFmTorch
+ from etna.libs.timesfm.timesfm_base import freq_map
+
+_DOWNLOAD_PATH = Path.home() / ".etna" / "timesfm"
+
+
+class TimesFMModel(NonPredictionIntervalContextRequiredAbstractModel):
+ """
+ Class for pretrained timesfm models.
+
+ This model is only for zero-shot forecasting: it doesn't support training on data during ``fit``.
+
+ This model doesn't support forecasting on misaligned data with `freq=None` without exogenous features.
+
+ This model doesn't support NaN in the middle or at the end of target and exogenous features.
+ Use :py:class:`~etna.transforms.TimeSeriesImputerTransform` to fill them.
+
+ Official implementation: https://github.com/google-research/timesfm
+
+ Note
+ ----
+ This model requires ``timesfm`` extension to be installed.
+ Read more about this at :ref:`installation page `.
+ """
+
+ def __init__(
+ self,
+ path_or_url: str,
+ encoder_length: int = 512,
+ device: Literal["cpu", "gpu"] = "cpu",
+ batch_size: int = 128,
+ static_reals: Optional[List[str]] = None,
+ static_categoricals: Optional[List[str]] = None,
+ time_varying_reals: Optional[List[str]] = None,
+ time_varying_categoricals: Optional[List[str]] = None,
+ cache_dir: Path = _DOWNLOAD_PATH,
+ ):
+ """
+ Init TimesFM model.
+
+ Parameters
+ ----------
+ path_or_url:
+ Path to the model. It can be huggingface repository, local path or external url.
+
+ - If huggingface repository, the available models are:
+
+ - 'google/timesfm-1.0-200m-pytorch'.
+ During the first initialization model is downloaded from huggingface and saved to local ``cache_dir``.
+ All following initializations model will be loaded from ``cache_dir``.
+ - If local path, it should be a file with model weights, that can be loaded by :py:func:`torch.load`.
+ - If external url, it must be a file with model weights, that can be loaded by :py:func:`torch.load`. Model will be downloaded to ``cache_dir``.
+ device:
+ Device type. Can be "cpu" or "gpu".
+ encoder_length:
+ Number of last timestamps to use as a context. It needs to be a multiplier of 32.
+ batch_size:
+ Batch size. It can be useful when inference is done on gpu.
+ static_reals:
+ Continuous features that have one unique feature value for the whole series. The first value in the series will be used for each feature.
+ static_categoricals:
+ Categorical features that have one unique feature value for the whole series. The first value in the series will be used for each feature.
+ time_varying_reals:
+ Time varying continuous features known for future.
+ time_varying_categoricals:
+ Time varying categorical features known for future.
+ cache_dir:
+ Local path to save model from huggingface during first model initialization. All following class initializations appropriate model version will be downloaded from this path.
+ """
+ self.path_or_url = path_or_url
+ self.encoder_length = encoder_length
+ self.device = device
+ self.batch_size = batch_size
+ self.static_reals = static_reals
+ self.static_categoricals = static_categoricals
+ self.time_varying_reals = time_varying_reals
+ self.time_varying_categoricals = time_varying_categoricals
+ self.cache_dir = cache_dir
+
+ self._set_pipeline()
+
+ def _set_pipeline(self):
+ """Set ``tfm`` attribute."""
+ if self._is_url():
+ full_model_path = self._download_model_from_url()
+ self.tfm = TimesFmTorch(
+ hparams=TimesFmHparams(
+ context_len=self.encoder_length, per_core_batch_size=self.batch_size, backend=self.device
+ ),
+ checkpoint=TimesFmCheckpoint(path=full_model_path),
+ )
+ else:
+ self.tfm = TimesFmTorch(
+ hparams=TimesFmHparams(
+ context_len=self.encoder_length, per_core_batch_size=self.batch_size, backend=self.device
+ ),
+ checkpoint=TimesFmCheckpoint(path=self.path_or_url, local_dir=self.cache_dir),
+ )
+
+ def _is_url(self):
+ """Check whether ``path_or_url`` is url."""
+ return self.path_or_url.startswith("https://") or self.path_or_url.startswith("http://")
+
+ def _download_model_from_url(self) -> str:
+ """Download model from url to local cache_dir."""
+ model_file = self.path_or_url.split("/")[-1]
+ full_model_path = f"{self.cache_dir}/{model_file}"
+ if not os.path.exists(full_model_path):
+ request.urlretrieve(url=self.path_or_url, filename=full_model_path)
+ return full_model_path
+
+ @property
+ def context_size(self) -> int:
+ """Context size for model."""
+ return self.encoder_length
+
+ def get_model(self) -> TimesFmTorch:
+ """Get model."""
+ return self.tfm
+
+ def fit(self, ts: TSDataset):
+ """Fit model.
+
+ For this model, fit does nothing.
+
+ Parameters
+ ----------
+ ts:
+ Dataset with features.
+
+ Returns
+ -------
+ :
+ Model after fit
+ """
+ return self
+
+ def predict(
+ self,
+ ts: TSDataset,
+ prediction_size: int,
+ return_components: bool = False,
+ ) -> TSDataset:
+ """Make predictions using true values as autoregression context (teacher forcing).
+
+ Parameters
+ ----------
+ ts:
+ Dataset with features.
+ prediction_size:
+ Number of last timestamps to leave after making prediction.
+ Previous timestamps will be used as a context.
+ return_components:
+ If True additionally returns forecast components.
+
+ Returns
+ -------
+ :
+ Dataset with predictions.
+ """
+ raise NotImplementedError("Method predict isn't currently implemented!")
+
+ def _exog_columns(self) -> List[str]:
+ static_reals = [] if self.static_reals is None else self.static_reals
+ static_categoricals = [] if self.static_categoricals is None else self.static_categoricals
+ time_varying_reals = [] if self.time_varying_reals is None else self.time_varying_reals
+ time_varying_categoricals = [] if self.time_varying_categoricals is None else self.time_varying_categoricals
+
+ return static_reals + static_categoricals + time_varying_reals + time_varying_categoricals
+
+ def forecast(
+ self,
+ ts: TSDataset,
+ prediction_size: int,
+ return_components: bool = False,
+ ) -> TSDataset:
+ """Make autoregressive forecasts.
+
+ Parameters
+ ----------
+ ts:
+ Dataset with features.
+ prediction_size:
+ Number of last timestamps to leave after making prediction.
+ Previous timestamps will be used as a context.
+ return_components:
+ If True additionally returns forecast components.
+
+ Returns
+ -------
+ :
+ Dataset with predictions.
+
+ Raises
+ ------
+ NotImplementedError:
+ if return_components mode is used.
+ ValueError:
+ if dataset doesn't have any context timestamps.
+ ValueError:
+ if there are NaNs in the middle or end of the time series.
+ NotImplementedError:
+ if forecasting is done without exogenous features and dataset has None frequency.
+ """
+ if return_components:
+ raise NotImplementedError("This mode isn't currently implemented!")
+
+ max_context_size = len(ts.index) - prediction_size
+ if max_context_size <= 0:
+ raise ValueError("Dataset doesn't have any context timestamps.")
+
+ if max_context_size < self.context_size:
+ warnings.warn("Actual length of a dataset is less that context size. All history will be used as context.")
+
+ self.tfm._set_horizon(prediction_size)
+
+ end_idx = len(ts.index)
+
+ all_exog = self._exog_columns()
+ df_slice = ts.df.loc[:, pd.IndexSlice[:, all_exog + ["target"]]]
+ first_valid_index = (
+ df_slice.isna().any(axis=1).idxmin()
+ ) # If all timestamps contains NaNs, idxmin() returns the first timestamp
+
+ target_df = df_slice.loc[first_valid_index : ts.index[-prediction_size - 1], pd.IndexSlice[:, "target"]]
+
+ nan_segment_mask = target_df.isna().any()
+ if nan_segment_mask.any():
+ nan_segments = nan_segment_mask.loc[:, nan_segment_mask].index.get_level_values(0).unique().tolist()
+ raise ValueError(
+ f"There are NaNs in the middle or at the end of target. Segments with NaNs: {reprlib.repr(nan_segments)}."
+ )
+
+ future_ts = ts.tsdataset_idx_slice(start_idx=end_idx - prediction_size, end_idx=end_idx)
+
+ if len(all_exog) > 0:
+ target = target_df.values.swapaxes(1, 0).tolist()
+
+ exog_df = df_slice.loc[first_valid_index:, pd.IndexSlice[:, all_exog]]
+
+ nan_segment_mask = exog_df.isna().any()
+ if nan_segment_mask.any():
+ nan_segments = nan_segment_mask.loc[:, nan_segment_mask].index.get_level_values(0).unique().tolist()
+ raise ValueError(
+ f"There are NaNs in the middle or at the end of exogenous features. Segments with NaNs: {reprlib.repr(nan_segments)}."
+ )
+
+ static_reals_dict = (
+ {
+ column: exog_df.loc[exog_df.index[0], pd.IndexSlice[:, column]].values.tolist()
+ for column in self.static_reals
+ }
+ if self.static_reals is not None
+ else None
+ )
+ static_categoricals_dict = (
+ {
+ column: exog_df.loc[exog_df.index[0], pd.IndexSlice[:, column]].values.tolist()
+ for column in self.static_categoricals
+ }
+ if self.static_categoricals is not None
+ else None
+ )
+ time_varying_reals_dict = (
+ {
+ column: exog_df.loc[:, pd.IndexSlice[:, column]].values.swapaxes(1, 0).tolist()
+ for column in self.time_varying_reals
+ }
+ if self.time_varying_reals is not None
+ else None
+ )
+ time_varying_categoricals_dict = (
+ {
+ column: exog_df.loc[:, pd.IndexSlice[:, column]].values.swapaxes(1, 0).tolist()
+ for column in self.time_varying_categoricals
+ }
+ if self.time_varying_categoricals is not None
+ else None
+ )
+
+ complex_forecast, _ = self.tfm.forecast_with_covariates(
+ inputs=target,
+ dynamic_numerical_covariates=time_varying_reals_dict,
+ dynamic_categorical_covariates=time_varying_categoricals_dict,
+ static_numerical_covariates=static_reals_dict,
+ static_categorical_covariates=static_categoricals_dict,
+ freq=[freq_map(ts.freq)] * len(ts.segments),
+ )
+ future_ts.df.loc[:, pd.IndexSlice[:, "target"]] = np.vstack(complex_forecast).swapaxes(1, 0)
+ else:
+ if ts.freq is None:
+ raise NotImplementedError(
+ "Forecasting misaligned data with freq=None without exogenous features isn't currently implemented."
+ )
+
+ target = TSDataset.to_flatten(df=target_df)
+ target = target.rename(columns={"segment": "unique_id", "timestamp": "ds"})
+
+ predictions = self.tfm.forecast_on_df(target, freq=ts.freq, value_name="target")
+
+ predictions = predictions.rename(columns={"unique_id": "segment", "ds": "timestamp", "timesfm": "target"})
+ predictions = TSDataset.to_dataset(predictions)
+ future_ts.df.loc[:, pd.IndexSlice[:, "target"]] = predictions.loc[
+ :, pd.IndexSlice[:, "target"]
+ ].values # .values is needed to cast predictions type of initial target type in ts
+ return future_ts
+
+ @staticmethod
+ def list_models() -> List[str]:
+ """
+ Return a list of available pretrained timesfm models.
+
+ Returns
+ -------
+ :
+ List of available pretrained chronos models.
+ """
+ return ["google/timesfm-1.0-200m-pytorch"]
+
+ def save(self, path: Path):
+ """Save the model. This method doesn't save model's weights.
+
+ During ``load`` weights are loaded from the path where they were saved during ``init``
+
+ Parameters
+ ----------
+ path:
+ Path to save object to.
+ """
+ self._save(path=path, skip_attributes=["tfm"])
+
+ @classmethod
+ def load(cls, path: Path):
+ """Load the model.
+
+ Parameters
+ ----------
+ path:
+ Path to load object from.
+ """
+ obj: TimesFMModel = super().load(path=path)
+ obj._set_pipeline()
+ return obj
+
+ def params_to_tune(self) -> Dict[str, BaseDistribution]:
+ """Get default grid for tuning hyperparameters.
+
+ This grid is empty.
+
+ Returns
+ -------
+ :
+ Grid to tune.
+ """
+ return {}
diff --git a/etna/settings.py b/etna/settings.py
index 987525548..afbcc5d8a 100644
--- a/etna/settings.py
+++ b/etna/settings.py
@@ -52,6 +52,21 @@ def _is_chronos_available():
return False
+def _is_timesfm_available():
+ true_case = (
+ _module_available("torch")
+ & _module_available("jax")
+ & _module_available("jaxlib")
+ & _module_available("huggingface_hub")
+ & _module_available("utilsforecast")
+ )
+ if true_case:
+ return True
+ else:
+ warnings.warn("etna[timesfm] is not available, to install it, run `pip install etna[timesfm]`")
+ return False
+
+
def _is_wandb_available():
if _module_available("wandb"):
return True
@@ -112,6 +127,7 @@ def __init__( # noqa: D107
self,
torch_required: Optional[bool] = None,
chronos_required: Optional[bool] = None,
+ timesfm_required: Optional[bool] = None,
prophet_required: Optional[bool] = None,
wandb_required: Optional[bool] = None,
classification_required: Optional[bool] = None,
@@ -131,6 +147,11 @@ def __init__( # noqa: D107
_is_chronos_available,
"etna[chronos] is not available, to install it, run `pip install etna[chronos]`.",
)
+ self.timesfm_required: bool = _get_optional_value(
+ timesfm_required,
+ _is_timesfm_available,
+ "etna[timesfm] is not available, to install it, run `pip install etna[timesfm]`.",
+ )
self.wandb_required: bool = _get_optional_value(
wandb_required, _is_wandb_available, "wandb is not available, to install it, " "run `pip install wandb`."
)
diff --git a/examples/202-NN_examples.ipynb b/examples/202-NN_examples.ipynb
index 29bdeee07..d6a53d551 100644
--- a/examples/202-NN_examples.ipynb
+++ b/examples/202-NN_examples.ipynb
@@ -33,7 +33,8 @@
" * [N-BEATS Model](#section_3_9)\n",
" * [PatchTS Model](#section_3_10)\n",
" * [Chronos Model](#section_3_11)\n",
- " * [Chronos Bolt Model](#section_3_12)"
+ " * [Chronos Bolt Model](#section_3_12)\n",
+ " * [TimesFM Model](#section_3_13)"
]
},
{
@@ -43,7 +44,7 @@
"metadata": {},
"outputs": [],
"source": [
- "!pip install \"etna[torch,chronos]\" -q"
+ "!pip install \"etna[torch,chronos,timesfm]\" -q"
]
},
{
@@ -4717,15 +4718,15 @@
"[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n",
"[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n",
"[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n",
- "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.2s remaining: 0.0s\n",
- "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.3s remaining: 0.0s\n",
- "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.5s remaining: 0.0s\n",
- "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.5s finished\n",
+ "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n",
+ "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.2s remaining: 0.0s\n",
+ "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.3s remaining: 0.0s\n",
+ "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.3s finished\n",
"[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n",
"[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n",
"[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s remaining: 0.0s\n",
- "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n",
- "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n"
+ "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s remaining: 0.0s\n",
+ "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.0s finished\n"
]
}
],
@@ -4760,27 +4761,6 @@
"print(f\"Average SMAPE for Chronos tiny: {score:.3f}\")"
]
},
- {
- "cell_type": "code",
- "execution_count": 88,
- "id": "8334cd6a",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/png": "",
- "text/plain": [
- "