From ed82d3cf16e5886ceae20c9dff76efc4817d6cbb Mon Sep 17 00:00:00 2001 From: ourownstory Date: Wed, 3 Jul 2024 15:01:16 -0700 Subject: [PATCH] clarify ID drop --- neuralprophet/time_dataset.py | 4 +++- tests/utils/benchmark_time_dataset.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/neuralprophet/time_dataset.py b/neuralprophet/time_dataset.py index 3d7ed0ab0..855896e49 100644 --- a/neuralprophet/time_dataset.py +++ b/neuralprophet/time_dataset.py @@ -98,8 +98,10 @@ def __init__( # self.tensor_data = torch.tensor(df.values, dtype=torch.float32 self.df["ds"] = self.df["ds"].astype(int) // 10**9 # Convert to Unix timestamp in seconds + # skipping col "ID" is string type that is interpreted as object by torch (self.df[col].dtype == "O") + # "ID" is stored in self.meta["df_name"] self.tensor_dict = { - col: torch.tensor(self.df[col].values, dtype=torch.float32) for col in self.df if self.df[col].dtype != "O" + col: torch.tensor(self.df[col].values, dtype=torch.float32) for col in self.df if col != "ID" } # Construct index map diff --git a/tests/utils/benchmark_time_dataset.py b/tests/utils/benchmark_time_dataset.py index d80bd4f88..88d1a6f28 100644 --- a/tests/utils/benchmark_time_dataset.py +++ b/tests/utils/benchmark_time_dataset.py @@ -470,4 +470,4 @@ def measure_times(): compare.print() -# measure_times() +measure_times()