Skip to content
This repository has been archived by the owner on Apr 24, 2024. It is now read-only.

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
agoscinski committed Jun 22, 2023
1 parent 8ae5c09 commit 109d40e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 28 deletions.
11 changes: 8 additions & 3 deletions src/equisolve/numpy/models/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,12 @@ def score(self, X: TensorMap, y: TensorMap, parameter_key: str) -> float:
return rmse(y, y_pred, parameter_key)

def export_torchscript(self):
if not(HAS_TORCH):
raise ImportError("To export your model to TorchScript torch needs to be installed. Please reimport the equisolve after installing torch.")
if not (HAS_TORCH):
raise ImportError(
"To export your model to TorchScript torch needs to be installed. Please reimport the equisolve after installing torch."
)
from ..utils import tensor_map_to_torch_tensor_map

return TsRidge(tensor_map_to_torch_tensor_map(self._weights))


Expand All @@ -365,15 +368,17 @@ def __init__(self) -> None:


if HAS_TORCH:
import torch
import equistore.torch
import torch

# Issue #XX temporary import for dot hack
from ..utils import dot as torchscriptdot

class TsRidge(torch.nn.Module):
"""
TorchScript Ridge
"""

def __init__(self, weights: equistore.torch.TensorMap) -> None:
torch.nn.Module.__init__(self)
self._weights = weights
Expand Down
59 changes: 35 additions & 24 deletions src/equisolve/numpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

import equistore
import numpy as np
from equistore import TensorBlock, TensorMap, Labels
from equistore import Labels, TensorBlock, TensorMap

from .. import HAS_TORCH


def array_from_block(block: TensorBlock) -> np.ndarray:
"""Extract parts of a :class:`equistore.TensorBlock` into a array.
Expand Down Expand Up @@ -78,39 +79,46 @@ def dict_to_tensor_map(tensor_map_dict: dict):
np.savez(tmp_filename, **tensor_map_dict)
return equistore.load(tmp_filename)


if HAS_TORCH:
import equistore.torch
import torch

def labels_to_torch_labels(labels: Labels) -> equistore.torch.Labels:
return equistore.torch.Labels(
names = list(labels.dtype.names),
values = torch.tensor(labels.tolist(), dtype=torch.int32)
names=list(labels.dtype.names),
values=torch.tensor(labels.tolist(), dtype=torch.int32),
)

def tensor_map_to_torch_tensor_map(tm: TensorMap) -> equistore.torch.TensorMap:
blocks = []
for _, block in tm: # TODO -> items()
blocks.append(equistore.torch.TensorBlock(
values = torch.tensor(block.values),
samples = labels_to_torch_labels(block.samples),
components = [labels_to_torch_labels(component) for component in block.components],
properties = labels_to_torch_labels(block.properties),
for _, block in tm: # TODO -> items()
blocks.append(
equistore.torch.TensorBlock(
values=torch.tensor(block.values),
samples=labels_to_torch_labels(block.samples),
components=[
labels_to_torch_labels(component)
for component in block.components
],
properties=labels_to_torch_labels(block.properties),
)
)

return equistore.torch.TensorMap(
keys = labels_to_torch_labels(tm.keys),
blocks = blocks,
)
keys=labels_to_torch_labels(tm.keys),
blocks=blocks,
)

#########################################################################################
### all functions below are temporary until equistore.operations supports TorchScript ###
#########################################################################################
from typing import List
def dot(tensor_1: equistore.torch.TensorMap, tensor_2: equistore.torch.TensorMap) -> equistore.torch.TensorMap:
"""Compute the dot product of two :py:class:`TensorMap`.
"""

def dot(
tensor_1: equistore.torch.TensorMap, tensor_2: equistore.torch.TensorMap
) -> equistore.torch.TensorMap:
"""Compute the dot product of two :py:class:`TensorMap`."""
_check_same_keys(tensor_1, tensor_2, "dot")

blocks: List[equistore.torch.TensorBlock] = []
Expand All @@ -120,8 +128,9 @@ def dot(tensor_1: equistore.torch.TensorMap, tensor_2: equistore.torch.TensorMap

return equistore.torch.TensorMap(tensor_1.keys, blocks)


def _dot_block(block_1: equistore.torch.TensorBlock, block_2: equistore.torch.TensorBlock) -> equistore.torch.TensorBlock:
def _dot_block(
block_1: equistore.torch.TensorBlock, block_2: equistore.torch.TensorBlock
) -> equistore.torch.TensorBlock:
if not torch.all(torch.tensor(block_1.properties == block_2.properties)):
raise ValueError("TensorBlocks in `dot` should have the same properties")

Expand Down Expand Up @@ -161,7 +170,9 @@ def _dot_block(block_1: equistore.torch.TensorBlock, block_2: equistore.torch.Te

return result_block

def _check_same_keys(a: equistore.torch.TensorMap, b: equistore.torch.TensorMap, fname: str):
def _check_same_keys(
a: equistore.torch.TensorMap, b: equistore.torch.TensorMap, fname: str
):
"""Check if metadata between two TensorMaps is consistent for an operation.
The functions verifies that
Expand Down Expand Up @@ -189,18 +200,19 @@ def _check_same_keys(a: equistore.torch.TensorMap, b: equistore.torch.TensorMap,
f"got {len(keys_a)} and {len(keys_b)}"
)

#list_keys: List[bool] = []
#for i in range(len(keys_b)):
# list_keys: List[bool] = []
# for i in range(len(keys_b)):
# is_in_key_a = keys_b[i] in keys_a
#for key in keys_b:
# for key in keys_b:
# is_in_key_a = key in keys_a
# list_keys.append(is_in_key_a)

##if not torch.all(torch.tensor(list_keys)):
if not torch.all(torch.tensor([keys_b[i] in keys_a for i in range(len(keys_b))])):
if not torch.all(
torch.tensor([keys_b[i] in keys_a for i in range(len(keys_b))])
):
raise ValueError(f"inputs to {fname} should have the same keys")


def _check_blocks(a: TensorBlock, b: TensorBlock, props: List[str], fname: str):
"""Check if metadata between two TensorBlocks is consistent for an operation.
Expand Down Expand Up @@ -255,7 +267,6 @@ def _check_blocks(a: TensorBlock, b: TensorBlock, props: List[str], fname: str):
"choose from ['samples', 'properties', 'components']"
)


def _dispatch_dot(A, B):
"""Compute dot product of two arrays.
Expand Down
2 changes: 1 addition & 1 deletion tests/equisolve_tests/numpy/models/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
# SPDX-License-Identifier: BSD-3-Clause
import equistore
import numpy as np
import torch
import pytest
import torch
from equistore import Labels, TensorBlock, TensorMap
from numpy.testing import assert_allclose, assert_equal

Expand Down

0 comments on commit 109d40e

Please sign in to comment.