From 6b8966edf9294de68f388c728c41010d202004f6 Mon Sep 17 00:00:00 2001 From: Michelle Richer Date: Tue, 10 Dec 2024 11:18:25 -0500 Subject: [PATCH] Comment and clean up extrapolation code --- atomdb/species.py | 39 ++++++++++++++++----------------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/atomdb/species.py b/atomdb/species.py index 5daec10..deacf4d 100644 --- a/atomdb/species.py +++ b/atomdb/species.py @@ -15,35 +15,23 @@ r"""AtomDB, a database of atomic and ionic properties.""" -from dataclasses import dataclass, field, asdict - -from glob import glob - -from importlib import import_module - import json - +import re +from dataclasses import asdict, dataclass, field +from importlib import import_module from numbers import Integral - from os import makedirs, path -from msgpack import packb, unpackb - -from msgpack_numpy import encode, decode - import numpy as np - -from numpy import ndarray - import pooch -import re import requests - +from msgpack import packb, unpackb +from msgpack_numpy import decode, encode +from numpy import ndarray from scipy.interpolate import CubicSpline -from atomdb.utils import DEFAULT_DATASET, DEFAULT_DATAPATH, DEFAULT_REMOTE -from atomdb.periodic import element_symbol, Element - +from atomdb.periodic import Element, element_symbol +from atomdb.utils import DEFAULT_DATAPATH, DEFAULT_DATASET, DEFAULT_REMOTE __all__ = [ "Species", @@ -166,6 +154,8 @@ def __init__(self, x, y, log=False): self._log = log self._obj = CubicSpline( x, + # Clip y values to >=ε^2 if using log because they have to be above 0; + # having them be at least ε^2 seems to work based on my testing np.log(y.clip(min=np.finfo(float).eps ** 2)) if log else y, axis=0, bc_type="not-a-knot", @@ -192,6 +182,7 @@ def __call__(self, x, deriv=0): if not (0 <= deriv <= 2): raise ValueError(f"Invalid derivative order {deriv}; must be 0 <= `deriv` <= 2") elif self._log: + # Get y = exp(log y). We'll handle errors from small log y values later. with np.errstate(over="ignore"): y = np.exp(self._obj(x)) if deriv == 1: @@ -205,7 +196,9 @@ def __call__(self, x, deriv=0): y = d2logy.flatten() * y + dlogy.flatten() ** 2 / y else: y = self._obj(x, nu=deriv) + # Handle errors from the y = exp(log y) operation -- set NaN to zero np.nan_to_num(y, nan=0., copy=False) + # Cutoff value: assume y(x) is zero where x >= 2 times final given point x_n y[x > 2 * self._obj.x[-1]] = 0 return y @@ -221,7 +214,7 @@ def default(self, obj): return JSONEncoder.default(self, obj) -class _AtomicOrbitals(object): +class _AtomicOrbitals: """Atomic orbitals class.""" def __init__(self, data) -> None: @@ -886,13 +879,13 @@ def datafile( url=f"{remotepath}{dataset.lower()}/db/repodata.txt", known_hash=None, path=path.join(datapath, dataset.lower(), "db"), - fname=f"repo_data.txt", + fname="repo_data.txt", ) # if the file is not found or remote was not valid, use the local repodata file except (requests.exceptions.HTTPError, ValueError): repodata = path.join(datapath, dataset.lower(), "db", "repo_data.txt") - with open(repodata, "r") as f: + with open(repodata) as f: data = f.read() files = re.findall(rf"\b{elem}+_{charge}+_{mult}+_{nexc}\.msg\b", data) species_list = []