Skip to content

Commit

Permalink
pass strict=True to zip() added in python 3.10
Browse files Browse the repository at this point in the history
safer, since it errors if iterables have different lengths
  • Loading branch information
janosh committed Sep 16, 2024
1 parent 877caaa commit 24d1657
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
rev: v0.6.5
hooks:
- id: ruff
args: [--fix]
Expand Down
4 changes: 3 additions & 1 deletion chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ def _create_graph_legacy(
Graph data structure used to create Crystal_Graph object
"""
graph = Graph([Node(index=idx) for idx in range(n_atoms)])
for ii, jj, img, dist in zip(center_index, neighbor_index, image, distance):
for ii, jj, img, dist in zip(
center_index, neighbor_index, image, distance, strict=True
):
graph.add_edge(center_index=ii, neighbor_index=jj, image=img, distance=dist)

return graph
Expand Down
2 changes: 1 addition & 1 deletion chgnet/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def line_graph_adjacency_list(self, cutoff) -> tuple[list[list[int]], list[int]]
# We will need to find directed edges with center = center1
# and create angles with DE1, then do the same for center2 and DE2
for center, dir_edge in zip(
u_edge.nodes, u_edge.info["directed_edge_index"]
u_edge.nodes, u_edge.info["directed_edge_index"], strict=True
):
for directed_edges in self.nodes[center].neighbors.values():
for directed_edge in directed_edges:
Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/composition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def fit(
composition_feas = torch.zeros([num_data, self.max_num_elements])
e = torch.zeros([num_data])
for index, (structure, energy) in enumerate(
zip(structures_or_graphs, energies)
zip(structures_or_graphs, energies, strict=True)
):
if isinstance(structure, Structure):
atomic_number = torch.tensor(
Expand Down
3 changes: 2 additions & 1 deletion chgnet/model/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
from collections.abc import Sequence

import torch
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(
find_activation(activation),
]
if len(hidden_dim) != 1:
for h_in, h_out in zip(hidden_dim[0:-1], hidden_dim[1:]):
for h_in, h_out in itertools.pairwise(hidden_dim):
layers.append(nn.Linear(h_in, h_out, bias=bias))
layers.append(find_activation(activation))
layers.append(nn.Dropout(dropout))
Expand Down
14 changes: 11 additions & 3 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,10 @@ def forward(
if return_site_energies and self.composition_model is not None:
site_energy_shifts = self.composition_model.get_site_energies(graphs)
prediction["site_energies"] = [
i + j for i, j in zip(prediction["site_energies"], site_energy_shifts)
i + j
for i, j in zip(
prediction["site_energies"], site_energy_shifts, strict=True
)
]
return prediction

Expand Down Expand Up @@ -437,7 +440,12 @@ def _compute(

# Message Passing
for idx, (atom_layer, bond_layer, angle_layer) in enumerate(
zip(self.atom_conv_layers[:-1], self.bond_conv_layers, self.angle_layers)
zip(
self.atom_conv_layers[:-1],
self.bond_conv_layers,
self.angle_layers,
strict=False,
)
):
# Atom Conv
atom_feas = atom_layer(
Expand Down Expand Up @@ -522,7 +530,7 @@ def _compute(
)
# Convert Stress unit from eV/A^3 to GPa
scale = 1 / g.volumes * 160.21766208
stress = [i * j for i, j in zip(stress, scale)]
stress = [i * j for i, j in zip(stress, scale, strict=False)]
prediction["s"] = stress

# Normalize energy if model is intensive
Expand Down
2 changes: 1 addition & 1 deletion chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ def forward(
if "m" in self.target_str:
mag_preds, mag_targets = [], []
m_mae_size = 0
for mag_pred, mag_target in zip(prediction["m"], targets["m"]):
for mag_pred, mag_target in zip(prediction["m"], targets["m"], strict=True):
# exclude structures without magmom labels
if mag_target is not None:
mag_preds.append(mag_pred)
Expand Down
8 changes: 4 additions & 4 deletions chgnet/utils/vasp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ def parse_vasp_dir(
tokens = [float(token) for token in re.findall(r"[\d\.\-]+", clean)]
tokens.pop(0)
if read_charge:
charge.append(dict(zip(header, tokens)))
charge.append(dict(zip(header, tokens, strict=True)))
elif read_mag_x:
mag_x.append(dict(zip(header, tokens)))
mag_x.append(dict(zip(header, tokens, strict=True)))
elif read_mag_y:
mag_y.append(dict(zip(header, tokens)))
mag_y.append(dict(zip(header, tokens, strict=True)))
elif read_mag_z:
mag_z.append(dict(zip(header, tokens)))
mag_z.append(dict(zip(header, tokens, strict=True)))
elif clean.startswith("tot"):
if ion_step_count == (len(mag_x_all) + 1):
mag_x_all.append(mag_x)
Expand Down
4 changes: 2 additions & 2 deletions examples/crystaltoolkit_relax_viewer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@
" coords = trajectory.atom_positions[step]\n",
" structure.lattice = lattice # update structure in place for efficiency\n",
" assert len(structure) == len(coords)\n",
" for site, coord in zip(structure, coords):\n",
" for site, coord in zip(structure, coords, strict=True):\n",
" site.coords = coord\n",
"\n",
" title = make_title(*structure.get_space_group_info())\n",
Expand Down Expand Up @@ -406,7 +406,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 24d1657

Please sign in to comment.