Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] TypeError: model must be a LightningModule or torch._dynamo.OptimizedModule, got TemporalFusionTransformer #1723

Open
hagersalehahmed opened this issue Dec 7, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@hagersalehahmed
Copy link

hagersalehahmed commented Dec 7, 2024

what is main reason of this error? 1 frames
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/compile.py in _maybe_unwrap_optimized(model)
123 if isinstance(model, pl.LightningModule):
124 return model
--> 125 raise TypeError(
126 f"model must be a LightningModule or torch._dynamo.OptimizedModule, got {type(model).__qualname__}"
127 )

There is one subject_ID is 12345, I split data into 80% training and 20% testing

data['time_idx'] = range(1, len(data) + 1)
split_index = int(0.8 * len(data))
train_df = data[:split_index]
test_df = data[split_index:]
print(len(train_df))


scaler_rr = MinMaxScaler()
scaler_spo2 = MinMaxScaler()
train_df["RR"] = scaler_rr.fit_transform(train_df[["RR"]])
test_df["RR"] = scaler_rr.transform(test_df[["RR"]])

train_df["SpO2"] = scaler_spo2.fit_transform(train_df[["SpO2"]])
test_df["SpO2"] = scaler_spo2.transform(test_df[["SpO2"]])
test_df.head()


max_encoder_length = 5  # past
max_prediction_length = 3 #Future 


training_cutoff = train_df['time_idx'].max()-max_prediction_length 

train_df["time_idx"] = pd.factorize(train_df["time_idx"])[0]
test_df["time_idx"] = pd.factorize(test_df["time_idx"])[0]


max_encoder_length = 5
max_prediction_length = 3


training = TimeSeriesDataSet(
    train_df,
    time_idx="time_idx",
    target=["RR", "SpO2"],
    group_ids=["subject_id"],
    min_encoder_length=max_encoder_length // 2,
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=[],
    time_varying_known_reals=["time_idx"],
    time_varying_unknown_reals=["RR", "SpO2"],
    target_normalizer=MultiNormalizer(
        [
            GroupNormalizer(groups=["subject_id"], scale_by_group=True)
            for _ in ["RR", "SpO2"]
        ]
    ),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)


testing = TimeSeriesDataSet.from_dataset(
    training,
    test_df,
    predict=True,
    stop_randomization=True,
)


batch_size = 16

train_dataloader = training.to_dataloader(
    train=True, batch_size=batch_size, num_workers=0
)

test_dataloader = testing.to_dataloader(
    train=False,  batch_size=batch_size*10, num_workers=0
)



trainer = pl.Trainer(
    accelerator="cpu",gradient_clip_val=0.1,
)


tft = TemporalFusionTransformer.from_dataset(
    training,
    # not meaningful for finding the learning rate but otherwise very important
    learning_rate=0.03,
    hidden_size=8,  # most important hyperparameter apart from learning rate
    # number of attention heads. Set to up to 4 for large datasets
    attention_head_size=1,
    dropout=0.1,  # between 0.1 and 0.3 are good values
    hidden_continuous_size=8,  # set to <= hidden_size
    loss=QuantileLoss(),
    optimizer="adam",
    
)
print(f"Number of parameters in network: {tft.size() / 1e3:.1f}k")

trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=test_dataloader,
)
@hagersalehahmed hagersalehahmed added the bug Something isn't working label Dec 7, 2024
@github-project-automation github-project-automation bot moved this to Needs triage & validation in Bugfixing - pytorch-forecasting Dec 7, 2024
@fnhirwa
Copy link
Member

fnhirwa commented Dec 9, 2024

hey @hagersalehahmed would you please provide the full code example you were running for me to be able to reproduce the bug?
Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: Needs triage & validation
Development

No branches or pull requests

2 participants