Skip to content

Commit

Permalink
update setup
Browse files Browse the repository at this point in the history
  • Loading branch information
fishjojo committed Oct 20, 2023
1 parent 09f501d commit 7da668f
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 37 deletions.
10 changes: 6 additions & 4 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
include MANIFEST.in
include README.md setup.py LICENSE

prune pyscfad/lib/build
recursive-include pyscfadlib/thirdparty *.so *.dylib
include pyscfadlib/*.so pyscfadlib/*.dylib pyscfadlib/config.h.in

include pyscfad/lib/*.so
include pyscfad/lib/*.dylib
# source code
recursive-include pyscfadlib *.c *.h CMakeLists.txt

recursive-include pyscfad/lib *.c *.h
global-exclude *.py[cod]
prune pyscfadlib/build
3 changes: 3 additions & 0 deletions pyscfad/backend/_jax/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import jax

def is_tensor(x):
return isinstance(x, jax.Array)

def stop_gradient(x):
return jax.lax.stop_gradient(x)

Expand Down
3 changes: 3 additions & 0 deletions pyscfad/backend/_numpy/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np

def is_tensor(x):
return isinstance(x, (np.ndarray, np.generic))

def stop_gradient(x):
return x

Expand Down
4 changes: 2 additions & 2 deletions pyscfad/backend/_torch/core.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from keras_core import ops

is_tensor = ops.is_tensor
stop_gradient = ops.stop_gradient
convert_to_tensor = ops.convert_to_tensor

def convert_to_numpy(x):
x = stop_gradient(x)
return ops.convert_to_numpy(x)

convert_to_tensor = ops.convert_to_tensor
92 changes: 92 additions & 0 deletions pyscfad/lib/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import sys
from functools import wraps
from pyscf.lib import logger
from pyscf.lib.logger import (
DEBUG1 as DEBUG1,
DEBUG as DEBUG,
INFO as INFO,
NOTE as NOTE,
NOTICE as NOTICE,
WARN as WARN,
WARNING as WARNING,
ERR as ERR,
ERROR as ERROR,
QUIET as QUIET,
TIMER_LEVEL as TIMER_LEVEL,
process_clock as process_clock,
perf_counter as perf_counter,
)
from pyscfad import ops

def flush(rec, msg, *args):
arg_list = []
for arg in args:
if ops.is_tensor(arg):
arg = ops.convert_to_numpy(arg)
arg_list.append(arg)
logger.flush(rec, msg, *arg_list)

def log(rec, msg, *args):
if rec.verbose > QUIET:
flush(rec, msg, *args)

def error(rec, msg, *args):
if rec.verbose >= ERROR:
flush(rec, '\nERROR: '+msg+'\n', *args)
#sys.stderr.write('ERROR: ' + (msg%args) + '\n')

def warn(rec, msg, *args):
if rec.verbose >= WARN:
flush(rec, '\nWARN: '+msg+'\n', *args)
#if rec.stdout is not sys.stdout:
# sys.stderr.write('WARN: ' + (msg%args) + '\n')

def info(rec, msg, *args):
if rec.verbose >= INFO:
flush(rec, msg, *args)

def note(rec, msg, *args):
if rec.verbose >= NOTICE:
flush(rec, msg, *args)

def debug(rec, msg, *args):
if rec.verbose >= DEBUG:
flush(rec, msg, *args)

def debug1(rec, msg, *args):
if rec.verbose >= DEBUG1:
flush(rec, msg, *args)

def timer(rec, msg, cpu0=None, wall0=None):
if cpu0 is None:
cpu0 = rec._t0
if wall0 is None:
wall0 = rec._w0
rec._t0, rec._w0 = process_clock(), perf_counter()
if rec.verbose >= TIMER_LEVEL:
flush(rec, ' CPU time for %s %9.2f sec, wall time %9.2f sec'
% (msg, rec._t0-cpu0, rec._w0-wall0))
return rec._t0, rec._w0

class Logger(logger.Logger):
log = log
error = error
warn = warn
note = note
info = info
debug = debug
debug1 = debug1
timer = timer

@wraps(logger.new_logger)
def new_logger(rec=None, verbose=None):
if isinstance(verbose, Logger):
log = verbose
elif isinstance(verbose, int):
if getattr(rec, 'stdout', None):
log = Logger(rec.stdout, verbose)
else:
log = Logger(sys.stdout, verbose)
else:
log = Logger(rec.stdout, rec.verbose)
return log
4 changes: 4 additions & 0 deletions pyscfad/ops/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from pyscfad import backend

__all__ = [
'is_tensor',
'stop_gradient',
'stop_grad',
'convert_to_numpy',
'convert_to_tensor',
]

def is_tensor(x):
return backend.core.is_tensor(x)

def stop_gradient(x):
return backend.core.stop_gradient(x)

Expand Down
6 changes: 3 additions & 3 deletions pyscfad/scf/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy
import jax
from pyscf.data import nist
from pyscf.lib import logger, module_method
from pyscf.lib import module_method
from pyscf.scf import hf as pyscf_hf
from pyscf.scf import chkfile
from pyscf.scf.hf import TIGHT_GRAD_CONV_TOL
Expand All @@ -11,7 +11,7 @@
from pyscfad import ops
from pyscfad import numpy as np
from pyscfad import lib
from pyscfad.lib import jit, stop_trace
from pyscfad.lib import logger, jit, stop_trace
from pyscfad.ops import stop_grad
from pyscfad import util
from pyscfad.implicit_diff import make_implicit_diff
Expand Down Expand Up @@ -47,7 +47,7 @@ def _scf(dm, mf, s1e, h1e, *,
vhf = mf.get_veff(mol, dm)
e_tot = mf.energy_tot(dm, h1e, vhf)
log.info('init E= %.15g', e_tot)
cput1 = log.timer('initialize scf', log._t0, log._w0)
cput1 = log.timer('initialize scf')

for cycle in range(mf.max_cycle):
dm_last = dm
Expand Down
6 changes: 3 additions & 3 deletions pyscfadlib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ set(C_LINK_TEMPLATE "<CMAKE_C_COMPILER> <CMAKE_SHARED_LIBRARY_C_FLAGS> <LANGUAGE
include(ExternalProject)

option(BUILD_LIBCINT "Using libcint for analytical gaussian integral" ON)
option(WITH_F12 "Compling F12 integrals" ON)
option(WITH_F12 "Compling F12 integrals" OFF)

#FIXME here I have to rename libcint otherwise it conflicts with the one used by pyscf
set(patch_libcint ${PROJECT_SOURCE_DIR}/apply_patch.sh libcint.patch)

if(BUILD_LIBCINT)
ExternalProject_Add(libcint
ExternalProject_Add(libcintad
GIT_REPOSITORY https://github.com/fishjojo/libcint.git
GIT_TAG ad1
PATCH_COMMAND ${patch_libcint}
Expand All @@ -104,7 +104,7 @@ if(BUILD_LIBCINT)
-DCMAKE_C_CREATE_SHARED_LIBRARY:STRING=${C_LINK_TEMPLATE}
-DBUILD_MARCH_NATIVE:STRING=${BUILD_MARCH_NATIVE}
)
add_dependencies(cgto_vjp libcint)
add_dependencies(cgto_vjp libcintad)
endif()

option(BUILD_PYSCF_LIB "Building PySCF library" ON)
Expand Down
61 changes: 36 additions & 25 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,37 @@
exec(f.read(), _dct)
__version__ = _dct['__version__']

class CMakeBuildExt(build_ext):
def get_platform():
from distutils.util import get_platform
platform = get_platform()
if sys.platform == 'darwin':
arch = os.getenv('CMAKE_OSX_ARCHITECTURES')
if arch:
osname = platform.rsplit('-', 1)[0]
if ';' in arch:
platform = f'{osname}-universal2'
else:
platform = f'{osname}-{arch}'
elif os.getenv('_PYTHON_HOST_PLATFORM'):
# the cibuildwheel environment
platform = os.getenv('_PYTHON_HOST_PLATFORM')
if platform.endswith('arm64'):
os.putenv('CMAKE_OSX_ARCHITECTURES', 'arm64')
elif platform.endswith('x86_64'):
os.putenv('CMAKE_OSX_ARCHITECTURES', 'x86_64')
else:
os.putenv('CMAKE_OSX_ARCHITECTURES', 'arm64;x86_64')
return platform

class CMakeBuildPy(build_py):
def run(self):
extension = self.extensions[0]
assert extension.name == 'pyscfad_lib_placeholder'
self.build_cmake(extension)
self.plat_name = get_platform()
self.build_base = 'build'
self.build_lib = os.path.join(self.build_base, 'lib')
self.build_temp = os.path.join(self.build_base, f'temp.{self.plat_name}')

def build_cmake(self, extension):
self.announce('Configuring extensions', level=3)
src_dir = os.path.abspath(os.path.join(__file__, '..', 'pyscfad', 'lib'))
src_dir = os.path.abspath(os.path.join(__file__, '..', 'pyscfadlib'))
cmd = ['cmake', f'-S{src_dir}', f'-B{self.build_temp}']
configure_args = os.getenv('CMAKE_CONFIGURE_ARGS')
if configure_args:
Expand All @@ -33,23 +55,14 @@ def build_cmake(self, extension):
else:
self.spawn(cmd)

# To remove the infix string like cpython-37m-x86_64-linux-gnu.so
# Python ABI updates since 3.5
# https://www.python.org/dev/peps/pep-3149/
def get_ext_filename(self, ext_name):
ext_path = ext_name.split('.')
filename = build_ext.get_ext_filename(self, ext_name)
name, ext_suffix = os.path.splitext(filename)
return os.path.join(*ext_path) + ext_suffix

#from distutils.command.build import build
#build.sub_commands = ([c for c in build.sub_commands if c[0] == 'build_ext'] +
# [c for c in build.sub_commands if c[0] != 'build_ext'])
super().run()

class BuildExtFirst(build_py):
def run(self):
self.run_command("build_ext")
return super().run()
from wheel.bdist_wheel import bdist_wheel
initialize_options = bdist_wheel.initialize_options
def initialize_with_default_plat_name(self):
initialize_options(self)
self.plat_name = get_platform()
bdist_wheel.initialize_options = initialize_with_default_plat_name

setup(
name='pyscfad',
Expand All @@ -60,9 +73,7 @@ def run(self):
include_package_data=True,
packages=find_packages(exclude=["examples","*test*"]),
python_requires='>=3.8',
#ext_modules=[Extension('pyscfad_lib_placeholder', [])],
#cmdclass={'build_py': BuildExtFirst,
# 'build_ext': CMakeBuildExt},
cmdclass={'build_py': CMakeBuildPy,},
install_requires=[
'numpy>=1.17',
'scipy',
Expand Down

0 comments on commit 7da668f

Please sign in to comment.