From 24d1657b4fbf80b3556ec69cd3ee495eecabb650 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sun, 15 Sep 2024 20:42:54 -0400 Subject: [PATCH] pass strict=True to zip() added in python 3.10 safer, since it errors if iterables have different lengths --- .pre-commit-config.yaml | 2 +- chgnet/graph/converter.py | 4 +++- chgnet/graph/graph.py | 2 +- chgnet/model/composition_model.py | 2 +- chgnet/model/functions.py | 3 ++- chgnet/model/model.py | 14 +++++++++++--- chgnet/trainer/trainer.py | 2 +- chgnet/utils/vasp_utils.py | 8 ++++---- examples/crystaltoolkit_relax_viewer.ipynb | 4 ++-- 9 files changed, 26 insertions(+), 15 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ed544705..a0de1493 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.4 + rev: v0.6.5 hooks: - id: ruff args: [--fix] diff --git a/chgnet/graph/converter.py b/chgnet/graph/converter.py index 3fb98f15..89276a98 100644 --- a/chgnet/graph/converter.py +++ b/chgnet/graph/converter.py @@ -215,7 +215,9 @@ def _create_graph_legacy( Graph data structure used to create Crystal_Graph object """ graph = Graph([Node(index=idx) for idx in range(n_atoms)]) - for ii, jj, img, dist in zip(center_index, neighbor_index, image, distance): + for ii, jj, img, dist in zip( + center_index, neighbor_index, image, distance, strict=True + ): graph.add_edge(center_index=ii, neighbor_index=jj, image=img, distance=dist) return graph diff --git a/chgnet/graph/graph.py b/chgnet/graph/graph.py index 7e373f2c..084dfbb6 100644 --- a/chgnet/graph/graph.py +++ b/chgnet/graph/graph.py @@ -309,7 +309,7 @@ def line_graph_adjacency_list(self, cutoff) -> tuple[list[list[int]], list[int]] # We will need to find directed edges with center = center1 # and create angles with DE1, then do the same for center2 and DE2 for center, dir_edge in zip( - u_edge.nodes, u_edge.info["directed_edge_index"] + u_edge.nodes, u_edge.info["directed_edge_index"], strict=True ): for directed_edges in self.nodes[center].neighbors.values(): for directed_edge in directed_edges: diff --git a/chgnet/model/composition_model.py b/chgnet/model/composition_model.py index 71f02764..b87ce668 100644 --- a/chgnet/model/composition_model.py +++ b/chgnet/model/composition_model.py @@ -141,7 +141,7 @@ def fit( composition_feas = torch.zeros([num_data, self.max_num_elements]) e = torch.zeros([num_data]) for index, (structure, energy) in enumerate( - zip(structures_or_graphs, energies) + zip(structures_or_graphs, energies, strict=True) ): if isinstance(structure, Structure): atomic_number = torch.tensor( diff --git a/chgnet/model/functions.py b/chgnet/model/functions.py index 2eddd1ce..a52da2ad 100644 --- a/chgnet/model/functions.py +++ b/chgnet/model/functions.py @@ -1,5 +1,6 @@ from __future__ import annotations +import itertools from collections.abc import Sequence import torch @@ -83,7 +84,7 @@ def __init__( find_activation(activation), ] if len(hidden_dim) != 1: - for h_in, h_out in zip(hidden_dim[0:-1], hidden_dim[1:]): + for h_in, h_out in itertools.pairwise(hidden_dim): layers.append(nn.Linear(h_in, h_out, bias=bias)) layers.append(find_activation(activation)) layers.append(nn.Dropout(dropout)) diff --git a/chgnet/model/model.py b/chgnet/model/model.py index abca5b21..d98a2144 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -379,7 +379,10 @@ def forward( if return_site_energies and self.composition_model is not None: site_energy_shifts = self.composition_model.get_site_energies(graphs) prediction["site_energies"] = [ - i + j for i, j in zip(prediction["site_energies"], site_energy_shifts) + i + j + for i, j in zip( + prediction["site_energies"], site_energy_shifts, strict=True + ) ] return prediction @@ -437,7 +440,12 @@ def _compute( # Message Passing for idx, (atom_layer, bond_layer, angle_layer) in enumerate( - zip(self.atom_conv_layers[:-1], self.bond_conv_layers, self.angle_layers) + zip( + self.atom_conv_layers[:-1], + self.bond_conv_layers, + self.angle_layers, + strict=False, + ) ): # Atom Conv atom_feas = atom_layer( @@ -522,7 +530,7 @@ def _compute( ) # Convert Stress unit from eV/A^3 to GPa scale = 1 / g.volumes * 160.21766208 - stress = [i * j for i, j in zip(stress, scale)] + stress = [i * j for i, j in zip(stress, scale, strict=False)] prediction["s"] = stress # Normalize energy if model is intensive diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index 60543ab6..67db99e8 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -830,7 +830,7 @@ def forward( if "m" in self.target_str: mag_preds, mag_targets = [], [] m_mae_size = 0 - for mag_pred, mag_target in zip(prediction["m"], targets["m"]): + for mag_pred, mag_target in zip(prediction["m"], targets["m"], strict=True): # exclude structures without magmom labels if mag_target is not None: mag_preds.append(mag_pred) diff --git a/chgnet/utils/vasp_utils.py b/chgnet/utils/vasp_utils.py index 82ee659c..d17d8a75 100644 --- a/chgnet/utils/vasp_utils.py +++ b/chgnet/utils/vasp_utils.py @@ -81,13 +81,13 @@ def parse_vasp_dir( tokens = [float(token) for token in re.findall(r"[\d\.\-]+", clean)] tokens.pop(0) if read_charge: - charge.append(dict(zip(header, tokens))) + charge.append(dict(zip(header, tokens, strict=True))) elif read_mag_x: - mag_x.append(dict(zip(header, tokens))) + mag_x.append(dict(zip(header, tokens, strict=True))) elif read_mag_y: - mag_y.append(dict(zip(header, tokens))) + mag_y.append(dict(zip(header, tokens, strict=True))) elif read_mag_z: - mag_z.append(dict(zip(header, tokens))) + mag_z.append(dict(zip(header, tokens, strict=True))) elif clean.startswith("tot"): if ion_step_count == (len(mag_x_all) + 1): mag_x_all.append(mag_x) diff --git a/examples/crystaltoolkit_relax_viewer.ipynb b/examples/crystaltoolkit_relax_viewer.ipynb index 60b218ab..363930f5 100644 --- a/examples/crystaltoolkit_relax_viewer.ipynb +++ b/examples/crystaltoolkit_relax_viewer.ipynb @@ -377,7 +377,7 @@ " coords = trajectory.atom_positions[step]\n", " structure.lattice = lattice # update structure in place for efficiency\n", " assert len(structure) == len(coords)\n", - " for site, coord in zip(structure, coords):\n", + " for site, coord in zip(structure, coords, strict=True):\n", " site.coords = coord\n", "\n", " title = make_title(*structure.get_space_group_info())\n", @@ -406,7 +406,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.11.9" } }, "nbformat": 4,