Skip to content

Commit

Permalink
deps: Vendor more_itertools.consecutive_groups
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhollas committed Jun 26, 2024
1 parent 7984c8e commit 6dc123a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
39 changes: 28 additions & 11 deletions aiidalab_widgets_base/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
"""Some utility functions used acrross the repository."""

import itertools
import operator
import threading
from enum import Enum
from typing import Any

import ase
import ase.io
import ipywidgets as ipw
import more_itertools as mit
import numpy as np
import traitlets
import traitlets as tl
from aiida.plugins import DataFactory
from ase import Atoms
from ase.io import read

CifData = DataFactory("core.cif") # pylint: disable=invalid-name
StructureData = DataFactory("core.structure") # pylint: disable=invalid-name
Expand Down Expand Up @@ -43,24 +44,40 @@ def get_ase_from_file(fname, file_format=None): # pylint: disable=redefined-bui
# store_tags parameter is useful for CIF files
# https://wiki.fysik.dtu.dk/ase/ase/io/formatoptions.html#cif
if file_format == "cif":
traj = read(fname, format=file_format, index=":", store_tags=True)
traj = ase.io.read(fname, format=file_format, index=":", store_tags=True)
else:
traj = read(fname, format=file_format, index=":")
traj = ase.io.read(fname, format=file_format, index=":")
if not traj:
raise ValueError(f"Could not read any information from the file {fname}")
return traj


def find_ranges(iterable):
"""Yield range of consecutive numbers."""
for grp in mit.consecutive_groups(iterable):
for grp in _consecutive_groups(iterable):
group = list(grp)
if len(group) == 1:
yield group[0]
else:
yield group[0], group[-1]


def _consecutive_groups(iterable, ordering=lambda x: x):
"""Yield groups of consecutive items using :func:`itertools.groupby`.
The *ordering* function determines whether two items are adjacent by
returning their position.
This is a vendored version of more_itertools.consecutive_groups
https://more-itertools.readthedocs.io/en/v10.3.0/_modules/more_itertools/more.html#consecutive_groups
Distributed under MIT license: https://more-itertools.readthedocs.io/en/v10.3.0/license.html
Thank you Bo Bayles for the original implementation. <3
"""
for _, g in itertools.groupby(
enumerate(iterable), key=lambda x: x[0] - ordering(x[1])
):
yield map(operator.itemgetter(1), g)


def list_to_string_range(lst, shift=1):
"""Converts a list like [0, 2, 3, 4] into a string like '1 3..5'.
Expand Down Expand Up @@ -124,15 +141,15 @@ def inverse_matrix(self):
return np.linalg.inv(self.matrix)


class _StatusWidgetMixin(traitlets.HasTraits):
class _StatusWidgetMixin(tl.HasTraits):
"""Show temporary messages for example for status updates.
This is a mixin class that is meant to be part of an inheritance
tree of an actual widget with a 'value' traitlet that is used
to convey a status message. See the non-private classes below
for examples.
"""

message = traitlets.Unicode(default_value="", allow_none=True)
message = tl.Unicode(default_value="", allow_none=True)
new_line = "\n"

def __init__(self, clear_after=3, *args, **kwargs):
Expand Down Expand Up @@ -169,7 +186,7 @@ class StatusHTML(_StatusWidgetMixin, ipw.HTML):

# This method should be part of _StatusWidgetMixin, but that does not work
# for an unknown reason.
@traitlets.observe("message")
@tl.observe("message")
def _observe_message(self, change):
self.show_temporary_message(change["new"])

Expand Down Expand Up @@ -201,7 +218,7 @@ def wrap_message(message, level=MessageLevel.INFO):
"""


def ase2spglib(ase_structure: Atoms) -> tuple[Any, Any, Any]:
def ase2spglib(ase_structure: ase.Atoms) -> tuple[Any, Any, Any]:
"""
Convert ase Atoms instance to spglib cell in the format defined at
https://spglib.github.io/spglib/python-spglib.html#crystal-structure-cell
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ install_requires =
traitlets~=5.9.0
ipywidgets~=7.7
widgetsnbextension<3.6.3
more-itertools~=8.0
pymysql~=0.9
nglview~=3.0
spglib>=1.14,<3
Expand Down

0 comments on commit 6dc123a

Please sign in to comment.