Skip to content

Commit

Permalink
update pytest for cp2kcell
Browse files Browse the repository at this point in the history
  • Loading branch information
robinzyb committed Oct 26, 2023
1 parent 1555e16 commit a6aee13
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 12 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ pip install .
- [Manipulate CP2K Cube Files](./docs/cube/README.md)
- [Manipulate CP2K Pdos Files](./docs/pdos/README.md)


# Feature Request
Any advice is welcome. If you would like to request new feature, please open a issue in github and upload example input and output files.



Expand Down
33 changes: 23 additions & 10 deletions cp2kdata/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
"""


if len(cell_param) == 1 and isinstance(cell_param, float):
if isinstance(cell_param, float):
self.cell_matrix = np.array(
[
[cell_param, 0, 0],
Expand All @@ -37,27 +37,33 @@ def __init__(
]
)
print("input cell_param is a float, the cell is assumed to be cubic")
elif cell_param.shape == 3:
elif cell_param.shape == (3,):
self.cell_matrix = np.array(
[
[cell_param[0], 0, 0],
[0, cell_param[1], 0],
[0, 0, cell_param[2]]
]
)
print("input cell_param is a list or array, the cell is assumed to be orthorhombic")
elif cell_param.shape == 6:
print("the length of input cell_param is 3, "
"the cell is assumed to be orthorhombic")
elif cell_param.shape == (6,):
self.cell_matrix = cellpar_to_cell(cell_param)
print("input cell_param is in [a, b, c, alpha, beta, gamma] form, it is converted to cell matrix")
print("the length of input cell_param is 6, "
"the Cp2kCell assumes it is [a, b, c, alpha, beta, gamma], "
"which will be converted to cell matrix")
elif cell_param.shape == (3, 3):
self.cell_matrix = cell_param
print("input cell_param is a matrix, the cell is read as is")
print("input cell_param is a matrix with shape of (3,3), "
"the cell is read as is")
else:
raise ValueError("The input cell_param is not supported")


if (grid_point is None) and (grid_spacing_matrix is None):
print("No grid point generated")
self.grid_point = None
self.grid_spacing_matrix = None
print("No grid point information")
elif (grid_point is None) and (grid_spacing_matrix is not None):
self.grid_spacing_matrix = grid_spacing_matrix
self.grid_point = np.round(self.cell_matrix/self.grid_spacing_matrix)
Expand All @@ -68,9 +74,13 @@ def __init__(
self.grid_point = np.array(grid_point)
self.grid_spacing_matrix = np.array(grid_spacing_matrix)

self.grid_point = self.grid_point.astype(int)
if grid_point is not None:
self.grid_point = self.grid_point.astype(int)

self.volume = np.linalg.det(self.cell_matrix)
self.dv = np.linalg.det(self.grid_spacing_matrix)

if grid_point is not None:
self.dv = np.linalg.det(self.grid_spacing_matrix)

self.cell_param = cell_to_cellpar(self.cell_matrix)

Expand All @@ -81,7 +91,10 @@ def get_volume(self):
return self.volume

def get_dv(self):
return self.dv
try:
return self.dv
except AttributeError as ae:
print("No grid point information is available")

def get_cell_param(self):
return self.cell_param
Expand Down
2 changes: 1 addition & 1 deletion docs/cube/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mycube = Cp2kCube(cube_file_path)
## Retrieving Cell Information
Users can easily obtain cell information from CP2K cube files by the following method
```python
cell = mycube.get()
cell = mycube.get_cell()
type(cell)
```
As a result, you will get new object `Cp2kCell`
Expand Down
104 changes: 104 additions & 0 deletions tests/test_cell/test_cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pytest
import numpy as np
from ase.geometry.cell import cellpar_to_cell
from ase.geometry.cell import cell_to_cellpar
from cp2kdata.cell import Cp2kCell # Replace 'your_module' with the actual module containing the Cp2kCell class.

class TestCp2kCell:
# Define expected cell matrices for different cases
def _create_expected_cell_matrix(self, cell_param):
if isinstance(cell_param, float):
return np.array([[cell_param, 0, 0], [0, cell_param, 0], [0, 0, cell_param]])
elif cell_param.shape == (3,):
return np.array([[cell_param[0], 0, 0], [0, cell_param[1], 0], [0, 0, cell_param[2]]])
elif cell_param.shape == (6,):
return cellpar_to_cell(cell_param)
elif cell_param.shape == (3, 3):
return cell_param
def _create_expected_grid_spacing_matrix(self, cell_matrix, grid_point, grid_spacing_matrix):
if (grid_point is None) and (grid_spacing_matrix is None):
grid_point = None
grid_spacing_matrix = None
print("No grid point information")
elif (grid_point is None) and (grid_spacing_matrix is not None):
grid_spacing_matrix = grid_spacing_matrix
grid_point = np.round(cell_matrix/grid_spacing_matrix)
elif (grid_point is not None) and (grid_spacing_matrix is None):
grid_point = np.array(grid_point)
grid_spacing_matrix = cell_matrix/grid_point[:, np.newaxis]
elif (grid_point is not None) and (grid_spacing_matrix is not None):
grid_point = np.array(grid_point)
grid_spacing_matrix = np.array(grid_spacing_matrix)
return grid_spacing_matrix

def _create_expected_cell_param(self, cell_param):
if isinstance(cell_param, float):
return np.array([cell_param, cell_param, cell_param, 90.0, 90.0, 90.0])
elif cell_param.shape == (3,):
return np.array([cell_param[0], cell_param[1], cell_param[2], 90.0, 90.0, 90.0])
elif cell_param.shape == (6,):
return cell_param
elif cell_param.shape == (3, 3):
return cell_to_cellpar(cell_param)

# Define sample data using a fixture
@pytest.fixture(params=[
(np.array([10.0, 12.0, 15.0]), None, None),
#(np.array([10.0, 12.0, 15.0]), None, np.array([[1.0, 0.0, 0.0],[0,1.0,0],[0,0,1.0]])),
(np.array([10.0, 12.0, 15.0, 90.0, 90.0, 90.0]), None, None),
(np.array([[10.0, 0, 0], [0, 12.0, 0], [0, 0, 15.0]]), [2, 2, 2], None)
])
def sample_data(self, request):
return request.param

def test_constructor(self, sample_data):
cell_param, grid_point, grid_spacing_matrix = sample_data
cell = Cp2kCell(cell_param, grid_point, grid_spacing_matrix)
assert np.array_equal(cell.cell_matrix, self._create_expected_cell_matrix(cell_param))

def test_copy(self, sample_data):
cell_param, grid_point, grid_spacing_matrix = sample_data
cell = Cp2kCell(cell_param, grid_point, grid_spacing_matrix)
copied_cell = cell.copy()
assert np.array_equal(cell.cell_matrix, copied_cell.cell_matrix)
assert np.array_equal(cell.grid_point, copied_cell.grid_point)
assert np.array_equal(cell.grid_spacing_matrix, copied_cell.grid_spacing_matrix)

def test_get_volume(self, sample_data):
cell_param, grid_point, grid_spacing_matrix = sample_data
cell = Cp2kCell(cell_param, grid_point, grid_spacing_matrix)
expected_volume = np.linalg.det(self._create_expected_cell_matrix(cell_param))
assert cell.get_volume() == expected_volume

def test_get_dv(self, sample_data, capsys):
cell_param, grid_point, grid_spacing_matrix = sample_data
cell = Cp2kCell(cell_param, grid_point, grid_spacing_matrix)
if (grid_point is None) and (grid_spacing_matrix is None):
cell.get_dv()
captured = capsys.readouterr()
assert captured.out.splitlines()[-1] == "No grid point information is available"
else:
print(cell.cell_matrix, grid_point, grid_spacing_matrix)
expected_grid_spacing_matrix = self._create_expected_grid_spacing_matrix(cell.cell_matrix, grid_point, grid_spacing_matrix)
expected_dv = np.linalg.det(expected_grid_spacing_matrix)
assert cell.get_dv() == expected_dv

def test_get_cell_param(self, sample_data):
cell_param, grid_point, grid_spacing_matrix = sample_data
cell = Cp2kCell(cell_param, grid_point, grid_spacing_matrix)
expected_cell_param = self._create_expected_cell_param(cell_param)
assert np.array_equal(cell.get_cell_param(), expected_cell_param)

def test_get_cell_angles(self, sample_data):
cell_param, grid_point, grid_spacing_matrix = sample_data
cell = Cp2kCell(cell_param, grid_point, grid_spacing_matrix)
expected_cell_param = self._create_expected_cell_param(cell_param)
expected_cell_angles = expected_cell_param[3:]
assert np.array_equal(cell.get_cell_angles(), expected_cell_angles)

def test_get_cell_lengths(self, sample_data):
cell_param, grid_point, grid_spacing_matrix = sample_data
cell = Cp2kCell(cell_param, grid_point, grid_spacing_matrix)
expected_cell_param = self._create_expected_cell_param(cell_param)
expected_cell_lengths = expected_cell_param[:3]
assert np.array_equal(cell.get_cell_lengths(), expected_cell_lengths)

0 comments on commit a6aee13

Please sign in to comment.