Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timezone flag to the container runs #244

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@ RUN apt update && apt install -y software-properties-common && \
git \
# Java to build our fork of Hydra
default-jre \
# Audio libraries
ffmpeg \
sox \
libavdevice-dev \
# Clean up
&& rm -rf /var/lib/apt/lists/*
# Install torchaudio
COPY install_torchaudio_latest.sh /install_torchaudio_latest.sh
RUN /bin/bash /install_torchaudio_latest.sh
# To not have to specify `-u origin <BRANCH_NAME>` when pushing
RUN git config --global push.autoSetupRemote true
# To push the current branch to the existing same name branch
Expand Down
5 changes: 1 addition & 4 deletions cneuromax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,11 @@
login_wandb()
TaskRunner.store_configs_and_run_task()
"""

import os
import warnings

from beartype import BeartypeConf
from beartype.claw import beartype_this_package

os.environ["OPENBLAS_NUM_THREADS"] = "1"
beartype_this_package(conf=BeartypeConf(is_pep484_tower=True))
warnings.filterwarnings(action="ignore", module="beartype")
warnings.filterwarnings(action="ignore", module="lightning")
warnings.filterwarnings(action="ignore", module="gymnasium")
Expand Down
24 changes: 16 additions & 8 deletions cneuromax/fitting/deeplearning/datamodule/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Annotated as An
from typing import final

from datasets import Dataset as HFDataset
from lightning.pytorch import LightningDataModule
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
Expand All @@ -13,7 +14,10 @@

@dataclass
class Datasets:
"""Holds stage-specific :class:`~torch.utils.data.Dataset` objects.
"""Holds phase-specific :class:`~torch.utils.data.Dataset` objects.

Using the word ``phase`` to not overload :mod:`lightning` ``stage``
terminology used for ``fit``, ``validate`` and ``test``.

Args:
train: Training dataset.
Expand All @@ -22,10 +26,10 @@ class Datasets:
predict: Prediction dataset.
"""

train: Dataset[Tensor] | None = None
val: Dataset[Tensor] | None = None
test: Dataset[Tensor] | None = None
predict: Dataset[Tensor] | None = None
train: Dataset[Tensor] | HFDataset | None = None
val: Dataset[Tensor] | HFDataset | None = None
test: Dataset[Tensor] | HFDataset | None = None
predict: Dataset[Tensor] | HFDataset | None = None


@dataclass
Expand All @@ -44,16 +48,18 @@ class BaseDataModuleConfig:
class BaseDataModule(LightningDataModule, metaclass=ABCMeta):
"""Base :mod:`lightning` ``DataModule``.

With ``<stage>`` being any of ``train``, ``val``, ``test`` or
With ``<phase>`` being any of ``train``, ``val``, ``test`` or
``predict``, subclasses need to properly define the
``datasets.<stage>`` attribute(s) for each desired stage.
``datasets.<phase>`` attribute(s) for each desired phase.

Args:
config: See :class:`BaseDataModuleConfig`.

Attributes:
config (:class:`BaseDataModuleConfig`)
datasets (:class:`Datasets`)
collate_fn (``callable``): See \
:paramref:`torch.utils.data.DataLoader.collate_fn`.
pin_memory (``bool``): Whether to copy tensors into device\
pinned memory before returning them (is set to ``True`` by\
default if :paramref:`~BaseDataModuleConfig.device` is\
Expand All @@ -72,6 +78,7 @@ def __init__(self: "BaseDataModule", config: BaseDataModuleConfig) -> None:
super().__init__()
self.config = config
self.datasets = Datasets()
self.collate_fn = None
self.pin_memory = self.config.device == "gpu"
self.per_device_batch_size = 1
self.per_device_num_workers = 0
Expand Down Expand Up @@ -108,7 +115,7 @@ def state_dict(self: "BaseDataModule") -> dict[str, int]:
@final
def x_dataloader(
self: "BaseDataModule",
dataset: Dataset[Tensor] | None,
dataset: Dataset[Tensor] | HFDataset | None,
*,
shuffle: bool = True,
) -> DataLoader[Tensor]:
Expand All @@ -134,6 +141,7 @@ def x_dataloader(
batch_size=self.per_device_batch_size,
shuffle=shuffle,
num_workers=self.per_device_num_workers,
collate_fn=self.collate_fn,
pin_memory=self.pin_memory,
)

Expand Down
7 changes: 6 additions & 1 deletion cneuromax/fitting/deeplearning/litmodule/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
r""":class:`lightning.pytorch.LightningModule`\s."""

from cneuromax.fitting.deeplearning.litmodule.base import (
BaseLitModule,
BaseLitModuleConfig,
)

__all__ = ["BaseLitModule"]
__all__ = [
"BaseLitModule",
"BaseLitModuleConfig",
]
37 changes: 23 additions & 14 deletions cneuromax/fitting/deeplearning/litmodule/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
""":class:`BaseLitModule`."""

from abc import ABCMeta
from dataclasses import dataclass
from functools import partial
from typing import Annotated as An
from typing import Any, final
Expand All @@ -10,9 +12,22 @@
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from cneuromax.fitting.deeplearning.utils.type import Batch_type
from cneuromax.utils.beartype import one_of


@dataclass
class BaseLitModuleConfig:
"""Holds :class:`BaseDataModule` config values.

Args:
data_dir: See :paramref:`~.BaseSubtaskConfig.data_dir`.
device: See :paramref:`~.FittingSubtaskConfig.device`.
"""

device: An[str, one_of("cpu", "gpu")] = "${config.device}"


class BaseLitModule(LightningModule, metaclass=ABCMeta):
"""Base :mod:`lightning` ``LitModule``.

Expand Down Expand Up @@ -69,11 +84,13 @@ def step(

def __init__(
self: "BaseLitModule",
config: BaseLitModuleConfig,
nnmodule: nn.Module,
optimizer: partial[Optimizer],
scheduler: partial[LRScheduler],
) -> None:
super().__init__()
self.config = config
self.nnmodule = nnmodule
self.optimizer = optimizer(params=self.parameters())
self.scheduler = scheduler(optimizer=self.optimizer)
Expand All @@ -86,9 +103,7 @@ def __init__(
@final
def stage_step(
self: "BaseLitModule",
batch: Num[Tensor, " ..."]
| tuple[Num[Tensor, " ..."], ...]
| list[Num[Tensor, " ..."]],
batch: Batch_type,
stage: An[str, one_of("train", "val", "test", "predict")],
) -> Num[Tensor, " ..."]:
"""Generic stage wrapper around the :meth:`step` method.
Expand All @@ -105,17 +120,15 @@ def stage_step(
The loss value(s).
"""
if isinstance(batch, list):
tupled_batch: tuple[Num[Tensor, " ..."], ...] = tuple(batch)
loss: Num[Tensor, " ..."] = self.step(tupled_batch, stage)
batch = tuple(batch)
loss: Num[Tensor, " ..."] = self.step(batch, stage)
self.log(name=f"{stage}/loss", value=loss)
return loss

@final
def training_step(
self: "BaseLitModule",
batch: Num[Tensor, " ..."]
| tuple[Num[Tensor, " ..."], ...]
| list[Num[Tensor, " ..."]],
batch: Batch_type,
) -> Num[Tensor, " ..."]:
"""Calls :meth:`stage_step` with argument ``stage="train"``.

Expand All @@ -130,9 +143,7 @@ def training_step(
@final
def validation_step(
self: "BaseLitModule",
batch: Num[Tensor, " ..."]
| tuple[Num[Tensor, " ..."], ...]
| list[Num[Tensor, " ..."]],
batch: Batch_type,
# :paramref:`*args` & :paramref:`**kwargs` type annotations
# cannot be more specific because of
# :meth:`LightningModule.validation_step`\'s signature.
Expand All @@ -154,9 +165,7 @@ def validation_step(
@final
def test_step(
self: "BaseLitModule",
batch: Num[Tensor, " ..."]
| tuple[Num[Tensor, " ..."], ...]
| list[Num[Tensor, " ..."]],
batch: Batch_type,
) -> Num[Tensor, " ..."]:
"""Calls :meth:`stage_step` with argument ``stage="test"``.

Expand Down
11 changes: 11 additions & 0 deletions cneuromax/fitting/deeplearning/utils/type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Typing utilities."""

from jaxtyping import Num
from torch import Tensor

Batch_type = (
Num[Tensor, " ..."]
| tuple[Num[Tensor, " ..."], ...]
| list[Num[Tensor, " ..."]]
| dict[str, Num[Tensor, " ..."]]
)
62 changes: 62 additions & 0 deletions cneuromax/projects/friends_language_encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Friends language finetuning ``project``."""

from hydra_zen import ZenStore
from transformers import AutoModelForMaskedLM

from cneuromax.fitting.deeplearning.runner import DeepLearningTaskRunner
from cneuromax.utils.hydra_zen import fs_builds

from .datamodule import (
FriendsDataModule,
FriendsDataModuleConfig,
)
from .litmodule import (
FriendsFinetuningModel,
FriendsLitModuleConfig,
)

__all__ = [
"TaskRunner",
"FriendsDataModule",
"FriendsDataModuleConfig",
"FriendsFinetuningModel",
"FriendsLitModuleConfig",
]


class TaskRunner(DeepLearningTaskRunner):
"""``project`` ``task`` runner."""

@classmethod
def store_configs(cls: type["TaskRunner"], store: ZenStore) -> None:
"""Stores :mod:`hydra-core` ``project`` configs.

Args:
store: See :paramref:`~.BaseTaskRunner.store_configs.store`.
"""
super().store_configs(store=store)
store(name="model_name")
store(
fs_builds(
FriendsDataModule,
config=FriendsDataModuleConfig(),
),
name="friends_language_encoder",
group="datamodule",
)
store(
fs_builds(
FriendsFinetuningModel,
config=FriendsLitModuleConfig(),
),
name="friends_language_encoder",
group="litmodule",
)
store(
fs_builds(
AutoModelForMaskedLM.from_pretrained,
pretrained_model_name_or_path="${model_name}",
),
name="friends_language_encoder",
group="litmodule/nnmodule",
)
Loading
Loading