Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchモジュール: Model モジュールの実装 #74 #111

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ description = "The core package for creating autonomous machine intelligence"
license = {file = "LICENSE"}
readme = "README.md"
requires-python = ">=3.12"
dependencies = []
dependencies = [
"torch>=2.6.0",
"torchaudio>=2.6.0",
"torchvision>=0.21.0",
]

keywords = [
"AI",
Expand Down
7 changes: 7 additions & 0 deletions src/pamiq_core/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .model import TorchInferenceModel, TorchTrainingModel, default_infer_procedure

__all__ = [
"TorchInferenceModel",
"TorchTrainingModel",
"default_infer_procedure",
]
185 changes: 185 additions & 0 deletions src/pamiq_core/torch/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import copy
from threading import RLock
from typing import Any, Protocol, override

import torch
import torch.nn as nn

from pamiq_core.model import InferenceModel, TrainingModel

CPU_DEVICE = torch.device("cpu")


class InferenceProcedureCallable(Protocol):
"""Typing for `inference_procedure` argument of TorchTrainingModel because
`typing.Callable` can not typing `*args` and `**kwds`."""

def __call__(self, inference_model: nn.Module, *args: Any, **kwds: Any) -> Any: ...


def get_device[T](
module: nn.Module, default_device: T = CPU_DEVICE
) -> torch.device | T:
"""Retrieves the device where the module runs.

Args:
module: A module that you want to know which device it runs on.
default_device: A device to return if any device not found.
Returns:
A device that the module uses or default_device.
"""
for param in module.parameters():
return param.device
for buf in module.buffers():
return buf.device
return default_device


def default_infer_procedure(inference_model: nn.Module, *args: Any, **kwds: Any) -> Any:
"""Default inference forward flow.

Tensors in `args` and `kwds` are sent to the computing device. If
you override this method, be careful to send the input tensor to the
computing device.
"""
device = get_device(inference_model, CPU_DEVICE)
new_args: list[Any] = []
new_kwds: dict[Any, Any] = {}
for i in args:
if isinstance(i, torch.Tensor):
i = i.to(device)
new_args.append(i)

for k, v in kwds.items():
if isinstance(v, torch.Tensor):
v = v.to(device)
new_kwds[k] = v

return inference_model(*new_args, **new_kwds)


class TorchInferenceModel[T: nn.Module](InferenceModel):
"""Wrapper class for torch model to infer in InferenceThread."""

def __init__(
self, model: T, inference_procedure: InferenceProcedureCallable
) -> None:
"""Initialize.

Args:
model: A torch model for inference.
inference_procedure: An inference procedure as Callable.
"""
self._model = model
self._inference_procedure = inference_procedure
self._lock = RLock()

@property
def raw_model(self) -> T:
"""Returns the internal dnn model.

Do not access this property in the inference thread. This
property is used to switch the model between training and
inference model."
"""
return self._model

@raw_model.setter
def raw_model(self, m: T) -> None:
"""Sets the model in a thread-safe manner."""
with self._lock:
self._model = m

@torch.inference_mode()
@override
def infer(self, *args: Any, **kwds: Any) -> Any:
"""Performs the inference in a thread-safe manner."""
with self._lock:
return self._inference_procedure(self.raw_model, *args, **kwds)


class TorchTrainingModel[T: nn.Module](TrainingModel[TorchInferenceModel[T]]):
"""Wrapper class for training torch model in TrainingThread.

Needed for multi-thread training and inference in parallel.
"""

@override
def __init__(
self,
model: T,
has_inference_model: bool = True,
inference_thread_only: bool = False,
default_device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
inference_procedure: InferenceProcedureCallable = default_infer_procedure,
):
"""Initialize.

Args:
model: A torch model.
has_inference_model: Whether to have inference model.
inference_thread_only: Whether it is an inference thread only.
default_device: A device if any device not found.
dtype: Data type of the model.
inference_procedure: An inference procedure as Callable.
"""
super().__init__(has_inference_model, inference_thread_only)
if dtype is not None:
model = model.type(dtype)
self.model: T = model
if (
default_device is None
): # prevents from moving the model to cpu unintentionally.
default_device = get_device(model, CPU_DEVICE)
self._default_device = torch.device(default_device)
self._inference_procedure = inference_procedure

self.model.to(self._default_device)

@override
def _create_inference_model(self) -> TorchInferenceModel[T]:
"""Create inference model.

Returns:
TorchInferenceModel.
"""
model = self.model
if not self.inference_thread_only: # the model does not need to be copied to training thread If it is used only in the inference thread.
model = copy.deepcopy(model)
return TorchInferenceModel(model, self._inference_procedure)

@override
def sync_impl(self, inference_model: TorchInferenceModel[T]) -> None:
"""Copies params of training model to self._inference_model.

Args:
inference_model: InferenceModel to sync.
"""

eval_of_raw_model = getattr(self.model, "eval") # To pass python-no-eval check.
eval_of_raw_model()

# Hold the grads.
grads: list[torch.Tensor | None] = []
for p in self.model.parameters():
grads.append(p.grad)
p.grad = None

# Swap the training model and the inference model.
self.model, inference_model.raw_model = (
inference_model.raw_model,
self.model,
)
self.model.load_state_dict(self.inference_model.raw_model.state_dict())

# Assign the model grads.
for i, p in enumerate(self.model.parameters()):
p.grad = grads[i]

self.model.train()

@override
def forward(self, *args: Any, **kwds: Any) -> Any:
"""forward."""
return self.model(*args, **kwds)
Empty file.
58 changes: 58 additions & 0 deletions tests/pamiq_core/torch/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import copy
from threading import RLock
from typing import Any, Protocol, override

import pytest
import torch
import torch.nn as nn
from pytest_mock import MockerFixture

from pamiq_core.model import InferenceModel, TrainingModel
from pamiq_core.torch import (
TorchInferenceModel,
TorchTrainingModel,
default_infer_procedure,
)


def test_get_device() -> None:
pass


def test_default_infer_procedure() -> None:
pass


class TestTorchInferenceModel:
@pytest.fixture
def model(self) -> nn.Module:
return nn.Linear(3, 5)

@pytest.fixture
def torch_inference_model(self, model: nn.Module) -> TorchInferenceModel:
torch_inference_model = TorchInferenceModel(model, default_infer_procedure)
return torch_inference_model

def test_raw_model(
self, torch_inference_model: TorchInferenceModel, model: nn.Module
) -> None:
assert model is torch_inference_model.raw_model

def test_infer(
self, torch_inference_model: TorchInferenceModel, model: nn.Module
) -> None:
input_tensor = torch.randn([2, 3])
output_tensor = torch_inference_model.infer(input_tensor)
expected_tensor = model(input_tensor)
assert torch.equal(output_tensor, expected_tensor)


class TestTorchTrainingModel:
def test_create_inference(self) -> None:
pass

def test_sync_impl(self) -> None:
pass

def test_forward(self) -> None:
pass
Loading