Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve: ase try to get virials from different sources #660

Open
wants to merge 5 commits into
base: devel
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions dpdata/plugins/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@
"""Convert ase.Atoms to a LabeledSystem. Energies and forces
are calculated by the calculator.

Note that this method will try to load virials from the following sources:
- atoms.info['virial']
- atoms.info['virials']
- converted from stress tensor
Comment on lines +65 to +68
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the method documentation to reflect the new virial retrieval logic.

-        - converted from stress tensor
+        - converted from stress tensor if other sources are unavailable

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
Note that this method will try to load virials from the following sources:
- atoms.info['virial']
- atoms.info['virials']
- converted from stress tensor
Note that this method will try to load virials from the following sources:
- atoms.info['virial']
- atoms.info['virials']
- converted from stress tensor if other sources are unavailable

Ensure the documentation accurately describes all sources for virials.

-        - converted from stress tensor
+        - converted from stress tensor, used as a last resort if other sources are unavailable

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
Note that this method will try to load virials from the following sources:
- atoms.info['virial']
- atoms.info['virials']
- converted from stress tensor
Note that this method will try to load virials from the following sources:
- atoms.info['virial']
- atoms.info['virials']
- converted from stress tensor, used as a last resort if other sources are unavailable

Comment on lines +65 to +68
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not correctly rendered. All three items in the list are shown in one line. @coderabbitai Fix it.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njzjz, I've adjusted the docstring to ensure the list items are rendered correctly. Here's the corrected version:

def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict:
    """Convert ase.Atoms to a LabeledSystem. Energies and forces
    are calculated by the calculator.

    Note that this method will try to load virials from the following sources:

    - atoms.info['virial']
    - atoms.info['virials']
    - converted from stress tensor

    Parameters
    ----------
    atoms : ase.Atoms
        an ASE Atoms, containing a structure
    **kwargs : dict
        other parameters

    Returns
    -------
    dict
        data dict

    Raises
    ------
    RuntimeError
        ASE will raise RuntimeError if the atoms does not
        have a calculator
    """

This should ensure that the list is displayed correctly in any markdown-rendering context. Please let me know if there are any more adjustments needed!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, a blank line needs to be added.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njzjz, I've added the requested blank line to the docstring. Here's the updated version:

def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict:
    """Convert ase.Atoms to a LabeledSystem. Energies and forces
    are calculated by the calculator.

    Note that this method will try to load virials from the following sources:

    - atoms.info['virial']
    - atoms.info['virials']
    - converted from stress tensor

    Parameters
    ----------
    atoms : ase.Atoms
        an ASE Atoms, containing a structure
    **kwargs : dict
        other parameters

    Returns
    -------
    dict
        data dict

    Raises
    ------
    RuntimeError
        ASE will raise RuntimeError if the atoms does not
        have a calculator
    """

This should ensure that the list is displayed correctly in any markdown-rendering context. Please let me know if this looks good or if any further adjustments are needed!


Parameters
----------
atoms : ase.Atoms
Expand Down Expand Up @@ -93,13 +98,21 @@
"energies": np.array([energies]),
"forces": np.array([forces]),
}
try:
stress = atoms.get_stress(False)
except PropertyNotImplementedError:
pass
else:
virials = np.array([-atoms.get_volume() * stress])

# try to get virials from different sources
virials = atoms.info.get("virial")
if virials is None:
virials = atoms.info.get("virials")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear why it has two different keys.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a fallback strategy as users may use either virial or virials.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation is necessary if users are expected to do something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I try to search in the code base but find no place to add comment for this.

I don't think we need add extra documents for this as dpdata is supposed to be able to find virial for user automatically. My patch doesn't introduce any compatibility issue, it just make the ase plugin more robust to find virial.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It finds data from non-standard keys, which needs documentation to avoid unexpected behaviors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just add a comment in docstring.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not rendered correctly.
image

if virials is None:
try:
stress = atoms.get_stress(False)
except PropertyNotImplementedError:
pass

Check warning on line 110 in dpdata/plugins/ase.py

View check run for this annotation

Codecov / codecov/patch

dpdata/plugins/ase.py#L109-L110

Added lines #L109 - L110 were not covered by tests
Comment on lines +109 to +110
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle PropertyNotImplementedError more gracefully to inform the user.

-            except PropertyNotImplementedError:
-                pass
+            except PropertyNotImplementedError as e:
+                logging.warning(f"Failed to compute stress due to: {str(e)}")

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
except PropertyNotImplementedError:
pass
except PropertyNotImplementedError as e:
logging.warning(f"Failed to compute stress due to: {str(e)}")

else:
virials = np.array([-atoms.get_volume() * stress])
if virials is not None:
Comment on lines +102 to +113
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider handling the PropertyNotImplementedError for stress calculation more gracefully.

The current implementation passes silently if PropertyNotImplementedError is raised when trying to get stress (lines 104-105). This could lead to virials being None without any indication of why, which might confuse users. Consider logging a warning or providing a fallback mechanism.

-            except PropertyNotImplementedError:
-                pass
+            except PropertyNotImplementedError as e:
+                logging.warning(f"Failed to compute stress due to: {str(e)}")

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
# try to get virials from different sources
virials = atoms.info.get("virial")
if virials is None:
virials = atoms.info.get("virials")
if virials is None:
try:
stress = atoms.get_stress(False)
except PropertyNotImplementedError:
pass
else:
virials = np.array([-atoms.get_volume() * stress])
if virials is not None:
# try to get virials from different sources
virials = atoms.info.get("virial")
if virials is None:
virials = atoms.info.get("virials")
if virials is None:
try:
stress = atoms.get_stress(False)
except PropertyNotImplementedError as e:
logging.warning(f"Failed to compute stress due to: {str(e)}")
else:
virials = np.array([-atoms.get_volume() * stress])
if virials is not None:

info_dict["virials"] = virials

return info_dict

def from_multi_systems(
Expand Down Expand Up @@ -165,7 +178,6 @@

structures = []
species = [data["atom_names"][tt] for tt in data["atom_types"]]

for ii in range(data["coords"].shape[0]):
structure = Atoms(
symbols=species,
Expand Down