Skip to content

Commit

Permalink
Comment and clean up extrapolation code
Browse files Browse the repository at this point in the history
  • Loading branch information
msricher committed Dec 16, 2024
1 parent 8af6b2d commit 6b8966e
Showing 1 changed file with 16 additions and 23 deletions.
39 changes: 16 additions & 23 deletions atomdb/species.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,23 @@

r"""AtomDB, a database of atomic and ionic properties."""

from dataclasses import dataclass, field, asdict

from glob import glob

from importlib import import_module

import json

import re
from dataclasses import asdict, dataclass, field
from importlib import import_module
from numbers import Integral

from os import makedirs, path

from msgpack import packb, unpackb

from msgpack_numpy import encode, decode

import numpy as np

from numpy import ndarray

import pooch
import re
import requests

from msgpack import packb, unpackb
from msgpack_numpy import decode, encode
from numpy import ndarray
from scipy.interpolate import CubicSpline

from atomdb.utils import DEFAULT_DATASET, DEFAULT_DATAPATH, DEFAULT_REMOTE
from atomdb.periodic import element_symbol, Element

from atomdb.periodic import Element, element_symbol
from atomdb.utils import DEFAULT_DATAPATH, DEFAULT_DATASET, DEFAULT_REMOTE

__all__ = [
"Species",
Expand Down Expand Up @@ -166,6 +154,8 @@ def __init__(self, x, y, log=False):
self._log = log
self._obj = CubicSpline(
x,
# Clip y values to >=ε^2 if using log because they have to be above 0;
# having them be at least ε^2 seems to work based on my testing
np.log(y.clip(min=np.finfo(float).eps ** 2)) if log else y,
axis=0,
bc_type="not-a-knot",
Expand All @@ -192,6 +182,7 @@ def __call__(self, x, deriv=0):
if not (0 <= deriv <= 2):
raise ValueError(f"Invalid derivative order {deriv}; must be 0 <= `deriv` <= 2")
elif self._log:
# Get y = exp(log y). We'll handle errors from small log y values later.
with np.errstate(over="ignore"):
y = np.exp(self._obj(x))
if deriv == 1:
Expand All @@ -205,7 +196,9 @@ def __call__(self, x, deriv=0):
y = d2logy.flatten() * y + dlogy.flatten() ** 2 / y
else:
y = self._obj(x, nu=deriv)
# Handle errors from the y = exp(log y) operation -- set NaN to zero
np.nan_to_num(y, nan=0., copy=False)
# Cutoff value: assume y(x) is zero where x >= 2 times final given point x_n
y[x > 2 * self._obj.x[-1]] = 0
return y

Expand All @@ -221,7 +214,7 @@ def default(self, obj):
return JSONEncoder.default(self, obj)


class _AtomicOrbitals(object):
class _AtomicOrbitals:
"""Atomic orbitals class."""

def __init__(self, data) -> None:
Expand Down Expand Up @@ -886,13 +879,13 @@ def datafile(
url=f"{remotepath}{dataset.lower()}/db/repodata.txt",
known_hash=None,
path=path.join(datapath, dataset.lower(), "db"),
fname=f"repo_data.txt",
fname="repo_data.txt",
)
# if the file is not found or remote was not valid, use the local repodata file
except (requests.exceptions.HTTPError, ValueError):
repodata = path.join(datapath, dataset.lower(), "db", "repo_data.txt")

with open(repodata, "r") as f:
with open(repodata) as f:
data = f.read()
files = re.findall(rf"\b{elem}+_{charge}+_{mult}+_{nexc}\.msg\b", data)
species_list = []
Expand Down

0 comments on commit 6b8966e

Please sign in to comment.