Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lightning train #27

Merged
merged 4 commits into from
May 14, 2024
Merged

Lightning train #27

merged 4 commits into from
May 14, 2024

Conversation

otavioon
Copy link
Contributor

This PR adds a LightningTrainer operator, used to train Pytorch Lightning models.

Howver, GPU still not working inside jupyter environment.

Below a small train script, supposing original.npy and envelope.npy is a 4-D tensor with shape (N, C, H, W), where N is the number of samples, C is the number of channels (C=1), HxW. The model used is the U-Net from Minerva repository. It is trained to perform a regression (from original, as input and envelope as label).

from minerva.models.nets.unet import UNet
from dasf.datasets import Dataset, DatasetArray
from dasf.pipeline import Pipeline
from dasf.pipeline.executors import DaskPipelineExecutor
import lightning as L
from dasf.ml.dl import LightningTrainer


class LabeledDataset(Dataset):
    def __init__(self, original_path, label_path, chunks=(1, -1, -1)):
        self.original = DatasetArray(
            name="input", root=original_path, chunks=chunks
        )
        self.label = DatasetArray(name="label", root=label_path, chunks=chunks)

    def load(self):
        self.original.load()
        self.label.load()
        return self

    def _lazy_load_cpu(self):
        return self.load()

    def _load_cpu(self):
        return self.load()

    def _lazy_load_gpu(self):
        return self.load()

    def _load_gpu(self):
        return self.load()

    def __len__(self):
        return len(self.original)

    def __getitem__(self, idx):
        return self.original[idx], self.label[idx]


def main():
    original_path = "/workspaces/dasf/data/original.npy"
    label_path = "/workspaces/dasf/data/envelope.npy"

    model = UNet()
    dataset = LabeledDataset(original_path, label_path)
    trainer = LightningTrainer(model=model, use_gpu=True, unsqueeze_dim=0)

    executor = DaskPipelineExecutor(
        local=False, use_gpu=False, address="172.17.0.5", port=8786
    )
    pipeline = Pipeline(
        name="pipeline",
        executor=executor,
        verbose=True,
    )

    pipeline.add(trainer.fit, train_data=dataset)

    pipeline.run()


if __name__ == "__main__":
    main()
``

@SerodioJ SerodioJ merged commit 441e9d3 into main May 14, 2024
2 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants