From 035abf226d231249f69dee92cb3c9559e8b74df4 Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel) YANG" Date: Fri, 13 Sep 2024 07:49:03 +0800 Subject: [PATCH] Support NumPy 2 (#202) Fixed data type compatibility for support of Numpy 2 --------- Co-authored-by: BowenD-UCB <84425382+BowenD-UCB@users.noreply.github.com> --- .github/workflows/test.yml | 3 + chgnet/data/dataset.py | 40 ++++---- chgnet/graph/converter.py | 10 +- chgnet/graph/crystalgraph.py | 2 +- chgnet/graph/cygraph.pyx | 67 +++++++------- .../fast_converter_libraries/create_graph.c | 92 +++++++++---------- chgnet/graph/graph.py | 2 + chgnet/model/model.py | 6 +- chgnet/utils/vasp_utils.py | 8 +- examples/make_graphs.py | 1 - pyproject.toml | 16 ++-- setup.py | 4 +- tests/test_graph.py | 10 +- 13 files changed, 140 insertions(+), 121 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 95b4a9a7..70dc8cc7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -34,6 +34,9 @@ jobs: pip install uv uv pip install -e .[test,logging] --resolution=${{ matrix.version.resolution }} --system + # TODO: remove pin once reverse readline fixed + uv pip install monty==2024.7.12 --system + - name: Run Tests run: pytest --capture=no --cov --cov-report=xml env: diff --git a/chgnet/data/dataset.py b/chgnet/data/dataset.py index 2bf74ee2..e02a164e 100644 --- a/chgnet/data/dataset.py +++ b/chgnet/data/dataset.py @@ -24,7 +24,7 @@ from chgnet import TrainTask warnings.filterwarnings("ignore") -datatype = torch.float32 +TORCH_DTYPE = torch.float32 class StructureData(Dataset): @@ -163,13 +163,13 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict]: struct, graph_id=graph_id, mp_id=mp_id ) targets = { - "e": torch.tensor(self.energies[graph_id], dtype=datatype), - "f": torch.tensor(self.forces[graph_id], dtype=datatype), + "e": torch.tensor(self.energies[graph_id], dtype=TORCH_DTYPE), + "f": torch.tensor(self.forces[graph_id], dtype=TORCH_DTYPE), } if self.stresses is not None: # Convert VASP stress targets["s"] = torch.tensor( - self.stresses[graph_id], dtype=datatype + self.stresses[graph_id], dtype=TORCH_DTYPE ) * (-0.1) if self.magmoms is not None: mag = self.magmoms[graph_id] @@ -177,7 +177,7 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict]: if mag is None: targets["m"] = None else: - targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype)) + targets["m"] = torch.abs(torch.tensor(mag, dtype=TORCH_DTYPE)) return crystal_graph, targets @@ -275,18 +275,18 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict[str, Tensor]]: for key in self.targets: if key == "e": energy = self.data[graph_id][self.energy_key] - targets["e"] = torch.tensor(energy, dtype=datatype) + targets["e"] = torch.tensor(energy, dtype=TORCH_DTYPE) elif key == "f": force = self.data[graph_id][self.force_key] - targets["f"] = torch.tensor(force, dtype=datatype) + targets["f"] = torch.tensor(force, dtype=TORCH_DTYPE) elif key == "s": stress = self.data[graph_id][self.stress_key] # Convert VASP stress - targets["s"] = torch.tensor(stress, dtype=datatype) * -0.1 + targets["s"] = torch.tensor(stress, dtype=TORCH_DTYPE) * -0.1 elif key == "m": mag = self.data[graph_id][self.magmom_key] # use absolute value for magnetic moments - targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype)) + targets["m"] = torch.abs(torch.tensor(mag, dtype=TORCH_DTYPE)) return crystal_graph, targets # Omit structures with isolated atoms. @@ -404,21 +404,23 @@ def __getitem__(self, idx) -> tuple[CrystalGraph, dict[str, Tensor]]: for key in self.targets: if key == "e": energy = self.labels[mp_id][graph_id][self.energy_key] - targets["e"] = torch.tensor(energy, dtype=datatype) + targets["e"] = torch.tensor(energy, dtype=TORCH_DTYPE) elif key == "f": force = self.labels[mp_id][graph_id][self.force_key] - targets["f"] = torch.tensor(force, dtype=datatype) + targets["f"] = torch.tensor(force, dtype=TORCH_DTYPE) elif key == "s": stress = self.labels[mp_id][graph_id][self.stress_key] # Convert VASP stress - targets["s"] = torch.tensor(stress, dtype=datatype) * (-0.1) + targets["s"] = torch.tensor(stress, dtype=TORCH_DTYPE) * (-0.1) elif key == "m": mag = self.labels[mp_id][graph_id][self.magmom_key] # use absolute value for magnetic moments if mag is None: targets["m"] = None else: - targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype)) + targets["m"] = torch.abs( + torch.tensor(mag, dtype=TORCH_DTYPE) + ) return crystal_graph, targets # Omit failed structures. Return another randomly selected structure @@ -629,21 +631,23 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict[str, Tensor]]: for key in self.targets: if key == "e": energy = self.data[mp_id][graph_id][self.energy_key] - targets["e"] = torch.tensor(energy, dtype=datatype) + targets["e"] = torch.tensor(energy, dtype=TORCH_DTYPE) elif key == "f": force = self.data[mp_id][graph_id][self.force_key] - targets["f"] = torch.tensor(force, dtype=datatype) + targets["f"] = torch.tensor(force, dtype=TORCH_DTYPE) elif key == "s": stress = self.data[mp_id][graph_id][self.stress_key] # Convert VASP stress - targets["s"] = torch.tensor(stress, dtype=datatype) * (-0.1) + targets["s"] = torch.tensor(stress, dtype=TORCH_DTYPE) * (-0.1) elif key == "m": mag = self.data[mp_id][graph_id][self.magmom_key] # use absolute value for magnetic moments if mag is None: targets["m"] = None else: - targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype)) + targets["m"] = torch.abs( + torch.tensor(mag, dtype=TORCH_DTYPE) + ) return crystal_graph, targets # Omit structures with isolated atoms. Return another randomly selected @@ -773,7 +777,7 @@ def collate_graphs(batch_data: list) -> tuple[list[CrystalGraph], dict[str, Tens graphs = [graph for graph, _ in batch_data] all_targets = {key: [] for key in batch_data[0][1]} all_targets["e"] = torch.tensor( - [targets["e"] for _, targets in batch_data], dtype=datatype + [targets["e"] for _, targets in batch_data], dtype=TORCH_DTYPE ) for _, targets in batch_data: diff --git a/chgnet/graph/converter.py b/chgnet/graph/converter.py index b83ea5af..3fb98f15 100644 --- a/chgnet/graph/converter.py +++ b/chgnet/graph/converter.py @@ -23,7 +23,7 @@ except (ImportError, AttributeError): make_graph = None -DATATYPE = torch.float32 +TORCH_DTYPE = torch.float32 class CrystalGraphConverter(nn.Module): @@ -124,10 +124,10 @@ def forward( requires_grad=False, ) atom_frac_coord = torch.tensor( - structure.frac_coords, dtype=DATATYPE, requires_grad=True + structure.frac_coords, dtype=TORCH_DTYPE, requires_grad=True ) lattice = torch.tensor( - structure.lattice.matrix, dtype=DATATYPE, requires_grad=True + structure.lattice.matrix, dtype=TORCH_DTYPE, requires_grad=True ) center_index, neighbor_index, image, distance = structure.get_neighbor_list( r=self.atom_graph_cutoff, sites=structure.sites, numerical_tol=1e-8 @@ -177,7 +177,7 @@ def forward( atomic_number=atomic_number, atom_frac_coord=atom_frac_coord, atom_graph=atom_graph, - neighbor_image=torch.tensor(image, dtype=DATATYPE), + neighbor_image=torch.tensor(image, dtype=TORCH_DTYPE), directed2undirected=directed2undirected, undirected2directed=undirected2directed, bond_graph=bond_graph, @@ -250,7 +250,7 @@ def _create_graph_fast( """ center_index = np.ascontiguousarray(center_index) neighbor_index = np.ascontiguousarray(neighbor_index) - image = np.ascontiguousarray(image, dtype=np.int_) + image = np.ascontiguousarray(image, dtype=np.int64) distance = np.ascontiguousarray(distance) gc_saved = gc.get_threshold() gc.set_threshold(0) diff --git a/chgnet/graph/crystalgraph.py b/chgnet/graph/crystalgraph.py index 637b359a..4f4572d1 100644 --- a/chgnet/graph/crystalgraph.py +++ b/chgnet/graph/crystalgraph.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from typing_extensions import Self -datatype = torch.float32 +TORCH_DTYPE = torch.float32 class CrystalGraph: diff --git a/chgnet/graph/cygraph.pyx b/chgnet/graph/cygraph.pyx index 7b82fda7..e7abe360 100644 --- a/chgnet/graph/cygraph.pyx +++ b/chgnet/graph/cygraph.pyx @@ -7,70 +7,75 @@ # cython: profile=False # distutils: language = c -import chgnet.graph.graph + import numpy as np +cimport numpy as np + +import chgnet.graph.graph + from libc.stdlib cimport free + cdef extern from 'fast_converter_libraries/create_graph.c': ctypedef struct Node: - long index + np.int64_t index LongToDirectedEdgeList* neighbors - long num_neighbors + np.int64_t num_neighbors ctypedef struct NodeIndexPair: - long center - long neighbor + np.int64_t center + np.int64_t neighbor ctypedef struct UndirectedEdge: NodeIndexPair nodes - long index - long* directed_edge_indices - long num_directed_edges - double distance + np.int64_t index + np.int64_t* directed_edge_indices + np.int64_t num_directed_edges + np.float64_t distance ctypedef struct DirectedEdge: NodeIndexPair nodes - long index - const long* image - long undirected_edge_index - double distance + np.int64_t index + const np.int64_t* image + np.int64_t undirected_edge_index + np.float64_t distance ctypedef struct LongToDirectedEdgeList: - long key + np.int64_t key DirectedEdge** directed_edges_list int num_directed_edges_in_group ctypedef struct ReturnElems2: - long num_nodes - long num_directed_edges - long num_undirected_edges + np.int64_t num_nodes + np.int64_t num_directed_edges + np.int64_t num_undirected_edges Node* nodes UndirectedEdge** undirected_edges_list DirectedEdge** directed_edges_list ReturnElems2* create_graph( - long* center_index, - long n_e, - long* neighbor_index, - long* image, - double* distance, - long num_atoms) + np.int64_t* center_index, + np.int64_t n_e, + np.int64_t* neighbor_index, + np.int64_t* image, + np.float64_t* distance, + np.int64_t num_atoms) - void free_LongToDirectedEdgeList_in_nodes(Node* nodes, long num_nodes) + void free_LongToDirectedEdgeList_in_nodes(Node* nodes, np.int64_t num_nodes) LongToDirectedEdgeList** get_neighbors(Node* node) def make_graph( - const long[::1] center_index, - const long n_e, - const long[::1] neighbor_index, - const long[:, ::1] image, - const double[::1] distance, - const long num_atoms + const np.int64_t[::1] center_index, + const np.int64_t n_e, + const np.int64_t[::1] neighbor_index, + const np.int64_t[:, ::1] image, + const np.float64_t[::1] distance, + const np.int64_t num_atoms ): cdef ReturnElems2* returned - returned = create_graph( ¢er_index[0], n_e, &neighbor_index[0], &image[0][0], &distance[0], num_atoms) + returned = create_graph( ¢er_index[0], n_e, &neighbor_index[0], &image[0][0], &distance[0], num_atoms) chg_DirectedEdge = chgnet.graph.graph.DirectedEdge chg_Node = chgnet.graph.graph.Node diff --git a/chgnet/graph/fast_converter_libraries/create_graph.c b/chgnet/graph/fast_converter_libraries/create_graph.c index 7a6ee0e3..a4ac9f37 100644 --- a/chgnet/graph/fast_converter_libraries/create_graph.c +++ b/chgnet/graph/fast_converter_libraries/create_graph.c @@ -12,32 +12,32 @@ typedef struct _ReturnElems2 ReturnElems2; // in the graph class in chgnet.graph.graph such that anyone familiar with that code should be able to pick up this // code pretty easily. -long MEM_ERR = 100; +int64_t MEM_ERR = 100; typedef struct _Node { - long index; + int64_t index; LongToDirectedEdgeList* neighbors; // Assuming neighbors can only be directed edge. Key is dest_node, value is DirectedEdge struct - long num_neighbors; + int64_t num_neighbors; } Node; typedef struct _NodeIndexPair { - long center; - long neighbor; + int64_t center; + int64_t neighbor; } NodeIndexPair; typedef struct _UndirectedEdge { NodeIndexPair nodes; - long index; - long* directed_edge_indices; - long num_directed_edges; + int64_t index; + int64_t* directed_edge_indices; + int64_t num_directed_edges; double distance; } UndirectedEdge; typedef struct _DirectedEdge { NodeIndexPair nodes; - long index; - const long* image; // Only access the first 3, never edit - long undirected_edge_index; + int64_t index; + const int64_t* image; // Only access the first 3, never edit + int64_t undirected_edge_index; double distance; } DirectedEdge; @@ -49,7 +49,7 @@ typedef struct _StructToUndirectedEdgeList { } StructToUndirectedEdgeList; typedef struct _LongToDirectedEdgeList { - long key; + int64_t key; DirectedEdge** directed_edges_list; int num_directed_edges_in_group; UT_hash_handle hh; @@ -57,36 +57,36 @@ typedef struct _LongToDirectedEdgeList { typedef struct _ReturnElems2 { - long num_nodes; - long num_directed_edges; - long num_undirected_edges; + int64_t num_nodes; + int64_t num_directed_edges; + int64_t num_undirected_edges; Node* nodes; UndirectedEdge** undirected_edges_list; DirectedEdge** directed_edges_list; } ReturnElems2; bool find_in_undirected(NodeIndexPair* tmp, StructToUndirectedEdgeList** undirected_edges, StructToUndirectedEdgeList** found_entry); -void directed_to_undirected(DirectedEdge* directed, UndirectedEdge* undirected, long undirected_index); +void directed_to_undirected(DirectedEdge* directed, UndirectedEdge* undirected, int64_t undirected_index); void create_new_undirected_edges_entry(StructToUndirectedEdgeList** undirected_edges, NodeIndexPair* tmp, UndirectedEdge* new_undirected_edge); void append_to_undirected_edges_tmp(UndirectedEdge* undirected, StructToUndirectedEdgeList** undirected_edges, NodeIndexPair* tmp); -void append_to_undirected_edges_list(UndirectedEdge** undirected_edges_list, UndirectedEdge* to_add, long* num_undirected_edges); -void append_to_directed_edges_list(DirectedEdge** directed_edges_list, DirectedEdge* to_add, long* num_directed_edges); -void add_neighbors_to_node(Node* node, long neighbor_index, DirectedEdge* directed_edge); +void append_to_undirected_edges_list(UndirectedEdge** undirected_edges_list, UndirectedEdge* to_add, int64_t* num_undirected_edges); +void append_to_directed_edges_list(DirectedEdge** directed_edges_list, DirectedEdge* to_add, int64_t* num_directed_edges); +void add_neighbors_to_node(Node* node, int64_t neighbor_index, DirectedEdge* directed_edge); void print_neighbors(Node* node); -void append_to_directed_edge_indices(UndirectedEdge* undirected_edge, long directed_edge_index); +void append_to_directed_edge_indices(UndirectedEdge* undirected_edge, int64_t directed_edge_index); bool is_reversed_directed_edge(DirectedEdge* directed_edge1, DirectedEdge* directed_edge2); void free_undirected_edges(StructToUndirectedEdgeList** undirected_edges); -void free_LongToDirectedEdgeList_in_nodes(Node* nodes, long num_nodes); +void free_LongToDirectedEdgeList_in_nodes(Node* nodes, int64_t num_nodes); -Node* create_nodes(long num_nodes) { +Node* create_nodes(int64_t num_nodes) { Node* Nodes = (Node*) malloc(sizeof(Node) * num_nodes); if (Nodes == NULL) { return NULL; } - for (long i = 0; i < num_nodes; i++) { + for (int64_t i = 0; i < num_nodes; i++) { Nodes[i].index = i; Nodes[i].num_neighbors = 0; @@ -98,22 +98,22 @@ Node* create_nodes(long num_nodes) { } ReturnElems2* create_graph( - long* center_indices, - long num_edges, - long* neighbor_indices, - long* images, // contiguous memory (row-major) of image elements (total of n_e * 3 integers) + int64_t* center_indices, + int64_t num_edges, + int64_t* neighbor_indices, + int64_t* images, // contiguous memory (row-major) of image elements (total of n_e * 3 integers) double* distances, - long num_atoms + int64_t num_atoms ) { // Initialize pertinent data structures --------------------- Node* nodes = create_nodes(num_atoms); DirectedEdge** directed_edges_list = calloc(num_edges, sizeof(DirectedEdge)); - long num_directed_edges = 0; + int64_t num_directed_edges = 0; // There will never be more undirected edges than directed edges UndirectedEdge** undirected_edges_list = calloc(num_edges, sizeof(UndirectedEdge)); - long num_undirected_edges = 0; + int64_t num_undirected_edges = 0; StructToUndirectedEdgeList* undirected_edges = NULL; // Pointer to beginning of list of UndirectedEdges corresponding to tmp of current iteration @@ -133,7 +133,7 @@ ReturnElems2* create_graph( DirectedEdge* this_directed_edge; // Add all edges to graph information - for (long i = 0; i < num_edges; i++) { + for (int64_t i = 0; i < num_edges; i++) { // Haven't processed the edge yet processed_edge = false; // Create the current directed edge ------------------- @@ -236,11 +236,11 @@ void free_undirected_edges(StructToUndirectedEdgeList** undirected_edges) { } } -void free_LongToDirectedEdgeList_in_nodes(Node* nodes, long num_nodes) { +void free_LongToDirectedEdgeList_in_nodes(Node* nodes, int64_t num_nodes) { LongToDirectedEdgeList* current; LongToDirectedEdgeList* tmp; - for (long node_i = 0; node_i < num_nodes; node_i++) { + for (int64_t node_i = 0; node_i < num_nodes; node_i++) { HASH_ITER(hh, nodes[node_i].neighbors, current, tmp) { HASH_DEL(nodes[node_i].neighbors, current); free(current->directed_edges_list); @@ -323,7 +323,7 @@ void append_to_undirected_edges_tmp(UndirectedEdge* undirected, StructToUndirect StructToUndirectedEdgeList* this_undirected_edges_item; find_in_undirected(tmp, undirected_edges, &this_undirected_edges_item); - long num_undirected_edges = this_undirected_edges_item->num_undirected_edges_in_group; + int64_t num_undirected_edges = this_undirected_edges_item->num_undirected_edges_in_group; // No need to worry about originally malloc'ing memory for this_undirected_edges_item->undirected_edges_list // this is because, we first call create_new_undirected_edges_entry for all entries. This function already mallocs for us. @@ -340,7 +340,7 @@ void append_to_undirected_edges_tmp(UndirectedEdge* undirected, StructToUndirect } -void directed_to_undirected(DirectedEdge* directed, UndirectedEdge* undirected, long undirected_index) { +void directed_to_undirected(DirectedEdge* directed, UndirectedEdge* undirected, int64_t undirected_index) { // Copy over image and distance undirected->distance = directed->distance; undirected->nodes = directed->nodes; @@ -348,12 +348,12 @@ void directed_to_undirected(DirectedEdge* directed, UndirectedEdge* undirected, // Add a new directed_edge_index to the directed_edge_indices pointer. This should be the first undirected->num_directed_edges = 1; - undirected->directed_edge_indices = malloc(sizeof(long)); + undirected->directed_edge_indices = malloc(sizeof(int64_t)); undirected->directed_edge_indices[0] = directed->index; } -void append_to_undirected_edges_list(UndirectedEdge** undirected_edges_list, UndirectedEdge* to_add, long* num_undirected_edges) { +void append_to_undirected_edges_list(UndirectedEdge** undirected_edges_list, UndirectedEdge* to_add, int64_t* num_undirected_edges) { // No need to realloc for space since our original alloc should cover everything // Assign value to next available position @@ -361,7 +361,7 @@ void append_to_undirected_edges_list(UndirectedEdge** undirected_edges_list, Und *num_undirected_edges += 1; } -void append_to_directed_edges_list(DirectedEdge** directed_edges_list, DirectedEdge* to_add, long* num_directed_edges) { +void append_to_directed_edges_list(DirectedEdge** directed_edges_list, DirectedEdge* to_add, int64_t* num_directed_edges) { // No need to realloc for space since our original alloc should cover everything // Assign value to next available position @@ -369,21 +369,21 @@ void append_to_directed_edges_list(DirectedEdge** directed_edges_list, DirectedE *num_directed_edges += 1; } -void append_to_directed_edge_indices(UndirectedEdge* undirected_edge, long directed_edge_index) { +void append_to_directed_edge_indices(UndirectedEdge* undirected_edge, int64_t directed_edge_index) { // TODO: don't need to realloc if we always know that there will be 2 directed edges per undirected edge. Update this later for performance boosts. - // TODO: other random performance boost: don't pass longs into function parameters, pass long* instead - undirected_edge->directed_edge_indices = realloc(undirected_edge->directed_edge_indices, sizeof(long) * (undirected_edge->num_directed_edges + 1)); + // TODO: other random performance boost: don't pass int64_ts into function parameters, pass int64_t* instead + undirected_edge->directed_edge_indices = realloc(undirected_edge->directed_edge_indices, sizeof(int64_t) * (undirected_edge->num_directed_edges + 1)); undirected_edge->directed_edge_indices[undirected_edge->num_directed_edges] = directed_edge_index; undirected_edge->num_directed_edges += 1; } // If there already exists neighbor_index within the Node node, then adds directed_edge to the list of directed edges. // If there doesn't already exist neighbor_index within the Node node, then create a new entry into the node's neighbors hashmap and add the entry -void add_neighbors_to_node(Node* node, long neighbor_index, DirectedEdge* directed_edge) { +void add_neighbors_to_node(Node* node, int64_t neighbor_index, DirectedEdge* directed_edge) { LongToDirectedEdgeList* entry = NULL; // Search for the neighbor_index in our hashmap - HASH_FIND(hh, node->neighbors, &neighbor_index, sizeof(long), entry); + HASH_FIND(hh, node->neighbors, &neighbor_index, sizeof(int64_t), entry); if (entry) { // We found something, update the list within this pointer @@ -401,7 +401,7 @@ void add_neighbors_to_node(Node* node, long neighbor_index, DirectedEdge* direct entry->key = neighbor_index; entry->num_directed_edges_in_group = 1; - HASH_ADD(hh, node->neighbors, key, sizeof(long), entry); + HASH_ADD(hh, node->neighbors, key, sizeof(int64_t), entry); node->num_neighbors += 1; } @@ -409,11 +409,11 @@ void add_neighbors_to_node(Node* node, long neighbor_index, DirectedEdge* direct // Returns a list of LongToDirectedEdgeList pointers which are entries for the neighbors of the inputted node LongToDirectedEdgeList** get_neighbors(Node* node) { - long num_neighbors = HASH_COUNT(node->neighbors); + int64_t num_neighbors = HASH_COUNT(node->neighbors); LongToDirectedEdgeList** entries = malloc(sizeof(LongToDirectedEdgeList*) * num_neighbors); LongToDirectedEdgeList* entry; - long counter = 0; + int64_t counter = 0; for (entry = node->neighbors; entry != NULL; entry = entry->hh.next) { entries[counter] = entry; counter += 1; diff --git a/chgnet/graph/graph.py b/chgnet/graph/graph.py index a32e9936..7e373f2c 100644 --- a/chgnet/graph/graph.py +++ b/chgnet/graph/graph.py @@ -92,6 +92,8 @@ def __eq__(self, other: object) -> bool: bool: True if other is the same directed edge, or if other is the directed edge with reverse direction of self, else False. """ + if not isinstance(other, DirectedEdge): + return False self_img = (self.info or {}).get("image") other_img = (other.info or {}).get("image") none_img = self_img is other_img is None diff --git a/chgnet/model/model.py b/chgnet/model/model.py index d2030337..abca5b21 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -11,7 +11,7 @@ from torch import Tensor, nn from chgnet.graph import CrystalGraph, CrystalGraphConverter -from chgnet.graph.crystalgraph import datatype +from chgnet.graph.crystalgraph import TORCH_DTYPE from chgnet.model.composition_model import AtomRef from chgnet.model.encoders import AngleEncoder, AtomEmbedding, BondEncoder from chgnet.model.functions import MLP, GatedMLP, find_normalization @@ -808,7 +808,7 @@ def from_graphs( if compute_stress: strain = graph.lattice.new_zeros([3, 3], requires_grad=True) lattice = graph.lattice @ ( - torch.eye(3, dtype=datatype).to(strain.device) + strain + torch.eye(3, dtype=TORCH_DTYPE).to(strain.device) + strain ) else: strain = None @@ -878,7 +878,7 @@ def from_graphs( torch.cat(atom_owners, dim=0).type(torch.int32).to(atomic_numbers.device) ) directed2undirected = torch.cat(directed2undirected, dim=0) - volumes = torch.tensor(volumes, dtype=datatype, device=atomic_numbers.device) + volumes = torch.tensor(volumes, dtype=TORCH_DTYPE, device=atomic_numbers.device) return cls( atomic_numbers=atomic_numbers, diff --git a/chgnet/utils/vasp_utils.py b/chgnet/utils/vasp_utils.py index 5b712262..96943d3c 100644 --- a/chgnet/utils/vasp_utils.py +++ b/chgnet/utils/vasp_utils.py @@ -1,7 +1,8 @@ from __future__ import annotations -import os.path +import os import re +import warnings from typing import TYPE_CHECKING from monty.io import reverse_readfile @@ -114,7 +115,7 @@ def parse_vasp_dir( read_charge = read_mag_x = read_mag_y = read_mag_z = False if len(oszicar.ionic_steps) == len(mag_x_all): # unfinished VASP job - print("Unfinished OUTCAR") + warnings.warn("Unfinished OUTCAR", stacklevel=2) elif len(oszicar.ionic_steps) == (len(mag_x_all) - 1): # finished job mag_x_all.pop(-1) @@ -213,5 +214,6 @@ def solve_charge_by_mag( print(f"Solved oxidation state, {total_charge=}") out_structure.add_oxidation_state_by_site(ox_list) return out_structure - print("Failed to solve oxidation state") + + warnings.warn("Failed to solve oxidation state", stacklevel=2) return None diff --git a/examples/make_graphs.py b/examples/make_graphs.py index 8aacc2a5..15a5fcfe 100644 --- a/examples/make_graphs.py +++ b/examples/make_graphs.py @@ -10,7 +10,6 @@ from chgnet.data.dataset import StructureData, StructureJsonData from chgnet.graph import CrystalGraphConverter -datatype = torch.float32 random.seed(100) diff --git a/pyproject.toml b/pyproject.toml index 0cbeeb10..ce44d983 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,10 +9,12 @@ license = { text = "Modified BSD" } dependencies = [ "ase>=3.23.0", "cython>=3", - "numpy>=1.26,<2", + # "monty==2024.7.12", # TODO: restore once readline fixed + # "numpy>=1.26", # TODO: remove after test + "numpy>=2.0.0", "nvidia-ml-py3>=7.352.0", - "pymatgen==2024.8.9", - "torch>=1.11.0", + "pymatgen>=2024.9.10", + "torch>=2.4.1", "typing-extensions>=4.12", ] classifiers = [ @@ -33,7 +35,7 @@ test = ["pytest-cov>=4", "pytest>=8", "wandb>=0.17"] # needed to run interactive example notebooks examples = ["crystal-toolkit>=2023.11.3", "pandas>=2.2"] docs = ["lazydocs>=0.4"] -logging = ["wandb>=0.17"] +logging = ["wandb>=0.17.2"] dispersion = ["dftd4>=3.6", "torch-dftd>=0.4"] [project.urls] @@ -49,11 +51,7 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] } "chgnet.pretrained" = ["*", "**/*"] [build-system] -requires = [ - "Cython", - "setuptools>=65", - "wheel", -] +requires = ["Cython", "numpy>=2.0.0", "setuptools>=65", "wheel"] build-backend = "setuptools.build_meta" [tool.ruff] diff --git a/setup.py b/setup.py index e9492092..5210e327 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,10 @@ from __future__ import annotations +import numpy as np from setuptools import Extension, setup ext_modules = [Extension("chgnet.graph.cygraph", ["chgnet/graph/cygraph.pyx"])] setup( - ext_modules=ext_modules, - setup_requires=["Cython"], + ext_modules=ext_modules, setup_requires=["Cython"], include_dirs=[np.get_include()] ) diff --git a/tests/test_graph.py b/tests/test_graph.py index c52f8836..f31e6634 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -96,12 +96,18 @@ def test_line_graph(bigraph: Graph) -> None: assert undirected2directed[2] == 3 -def test_directed_edge() -> None: +def test_directed_edge(capsys) -> None: info = {"image": np.zeros(3), "distance": 1} edge = DirectedEdge([0, 1], index=0, info=info) undirected = edge.make_undirected(index=0, info=info) assert edge == edge # noqa: PLR0124 - assert edge == undirected + captured = capsys.readouterr() + expected_message = ( + "!!!!!! the two directed edges are equal but this operation is " + "not supposed to happen\n" + ) + assert captured.err == expected_message + assert edge != undirected assert edge.nodes == [0, 1] assert edge.index == 0 assert repr(edge) == f"DirectedEdge(nodes=[0, 1], index=0, {info=})"