Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making the backend attribute case-insensitive #22

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/timesfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class TimesFm:

Attributes:
per_core_batch_size: Batch size on each core for data parallelism.
backend: One of "cpu", "gpu" or "tpu".
backend: One of "cpu", "gpu" or "tpu" (case-insensitive).
num_devices: Number of cores provided the backend.
global_batch_size: per_core_batch_size * num_devices. Each batch of
inference task will be padded with respect to global_batch_size to
Expand Down Expand Up @@ -126,7 +126,9 @@ def __init__(
num_layers: int,
model_dims: int,
per_core_batch_size: int = 32,
backend: Literal["cpu", "gpu", "tpu"] = "cpu",
backend: Literal["GPU", "gpu", "Gpu", "gPu", "gpU", "GPu", "gPU", "GpU",
"CPU", "Cpu", "cPU", "cpU", "CPu", "cPu", "CPU", "CpU",
"TPU", "Tpu", "tPU", "tpU", "TPu", "tPu", "TPU", "TpU"] = "cpu",
quantiles: Sequence[float] | None = None,
verbose: bool = True,
) -> None:
Expand All @@ -144,12 +146,12 @@ def __init__(
num_layers: Number of transformer layers.
model_dims: Model dimension.
per_core_batch_size: Batch size on each core for data parallelism.
backend: One of "cpu", "gpu" or "tpu".
backend: One of "cpu", "gpu" or "tpu" (case-insensitive).
quantiles: list of output quantiles supported by the model.
verbose: Whether to print logging messages.
"""
self.per_core_batch_size = per_core_batch_size
self.backend = backend
self.backend = backend.lower()
self.num_devices = jax.local_device_count(self.backend)
self.global_batch_size = self.per_core_batch_size * self.num_devices

Expand Down Expand Up @@ -600,3 +602,5 @@ def forecast_on_df(
fcst_df[model_name] = fcst_df[q_col]
logging.info("Finished creating output dataframe.")
return fcst_df