Skip to content

Commit

Permalink
pytorch apis added, prototype in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
mpatrou committed Jan 18, 2024
1 parent f6fcf20 commit ce90a04
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 27 deletions.
33 changes: 20 additions & 13 deletions example/sphere_pytorch_prototype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,27 @@
from matplotlib import pyplot as plt
from sasmodels.core import load_model
from sasmodels.direct_model import call_kernel,get_mesh
from sasmodels.details import make_kernel_args, dispersion_mesh
from sasmodels.details import make_kernel_args

import sasmodels.kerneltorch as kt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("device",device)

def make_kernel(model, q_vectors,device):
"""Instantiate the python kernel with input *q_vectors*"""
q_input = kt.PyInput(q_vectors, dtype=torch.float64)
return kt.PyKernel(model.info, q_input, device = device)

#cuda_src = sas_3j1x_x + sphere_c

#Step 1. Define r and q vectors

model = load_model('_spherepy')
q = logspace(-3, -1, 20)
q = logspace(-3, -1, 200)
print("q",q[6])
kernel = model.make_kernel([q])

pars = {'radius': 20, 'radius_pd': 0.1, 'radius_pd_n':100, 'scale': 2}
pars = {'radius': 200, 'radius_pd': 0.1, 'radius_pd_n':100, 'scale': 2}

t_before = time.time()
Iq = call_kernel(kernel, pars)
Expand All @@ -28,19 +34,20 @@
print("Tota Numpy: ",total_np)

t_before = time.time()
q_t = torch.logspace(start=-3, end=-1, steps=200)
kernel = model.make_kernel([q_t])
q_t = torch.logspace(start=-3, end=-1, steps=200).to(device)
kernel = make_kernel(model, [q_t],device)
Iq_t = call_kernel(kernel, pars)

# call_kernel unwrap
calculator = kernel
cutoff=0.
mono=False
#calculator = kernel
#cutoff=0.
#mono=False

mesh = get_mesh(calculator.info, pars, dim=calculator.dim, mono=mono)
#mesh = get_mesh(calculator.info, pars, dim=calculator.dim, mono=mono)
#print("in call_kernel: pars:", list(zip(*mesh))[0])
call_details, values, is_magnetic = make_kernel_args(calculator, mesh)
#call_details, values, is_magnetic = make_kernel_args(calculator, mesh)
#print("in call_kernel: values:", values)
Iq_t = calculator(call_details, values, cutoff, is_magnetic)
#Iq_t = calculator(call_details, values, cutoff, is_magnetic)

t_after = time.time()
total_torch = t_after -t_before
Expand Down
53 changes: 39 additions & 14 deletions sasmodels/kerneltorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class PyKernel(Kernel):
*q_input* is the DllInput q vectors at which the kernel should be
evaluated.
*device* : cpu or cuda for calculations with pytorch on CPUs or GPUs
The resulting call method takes the *pars*, a list of values for
the fixed parameters to the kernel, and *pd_pars*, a list of (value,weight)
vectors for the polydisperse parameters. *cutoff* determines the
Expand All @@ -116,8 +118,9 @@ class PyKernel(Kernel):
Call :meth:`release` when done with the kernel instance.
"""
def __init__(self, model_info, q_input):
# type: (ModelInfo, List[np.ndarray]) -> None
def __init__(self, model_info, q_input, device):
# type: (ModelInfo, List[np.ndarray], str) -> None
self.device = device
self.dtype = np.dtype('d')
self.info = model_info
self.q_input = q_input
Expand Down Expand Up @@ -184,13 +187,13 @@ def _call_kernel(self, call_details, values, cutoff, magnetic, radius_effective_
# type: (CallDetails, np.ndarray, np.ndarray, float, bool) -> None
if magnetic:
raise NotImplementedError("Magnetism not implemented for pure python models")
#print("Calling python kernel")

#call_details.show(values)
radius = ((lambda: 0.0) if radius_effective_mode == 0
else (lambda: self._radius(radius_effective_mode)))
self.result = _loops(
self._parameter_vector, self._form, self._volume, radius,
self.q_input.nq, call_details, values, cutoff)
self.q_input.nq, call_details, values, cutoff,self.device)

def release(self):
# type: () -> None
Expand All @@ -202,8 +205,8 @@ def release(self):


def _loops(parameters, form, form_volume, form_radius, nq, call_details,
values, cutoff):
# type: (np.ndarray, Callable[[], np.ndarray], Callable[[], float], Callable[[], float], int, CallDetails, np.ndarray, float) -> None
values, cutoff,device = 'cpu'):
# type: (np.ndarray, Callable[[], np.ndarray], Callable[[], float], Callable[[], float], int, CallDetails, np.ndarray, float, str) -> None
################################################################
# #
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! #
Expand All @@ -221,17 +224,25 @@ def _loops(parameters, form, form_volume, form_radius, nq, call_details,
# mesh, we update the components with the polydispersity values before
# calling the respective functions.
n_pars = len(parameters)
parameters[:] = values[2:n_pars+2]
parameters = torch.DoubleTensor(parameters).to(device)

#parameters[:] = values[2:n_pars+2]
parameters[:] = torch.DoubleTensor(values[2:n_pars+2])

print("parameters",parameters)
if call_details.num_active == 0:
total = form()
weight_norm = 1.0
weighted_shell, weighted_form = form_volume()
weighted_radius = form_radius()

else:
pd_value = values[2+n_pars:2+n_pars + call_details.num_weights]
pd_weight = values[2+n_pars + call_details.num_weights:]
#transform to tensor flow
pd_value = torch.DoubleTensor(values[2+n_pars:2+n_pars + call_details.num_weights])
pd_weight = torch.DoubleTensor(values[2+n_pars + call_details.num_weights:])

#print("pd_value",pd_value)
#print("pd_weight",pd_weight)

weight_norm = 0.0
weighted_form = 0.0
Expand All @@ -250,15 +261,26 @@ def _loops(parameters, form, form_volume, form_radius, nq, call_details,
pd_stride = call_details.pd_stride[:call_details.num_active]
pd_length = call_details.pd_length[:call_details.num_active]

total = np.zeros(nq, np.float64)
#total = np.zeros(nq, np.float64)
total = torch.zeros(nq, dtype= torch.float64).to(device)

#print("ll", range(call_details.num_eval))
#parallel for loop
# each look_index can be GPU/CPU thread: tid
# each thread has its own pd_index and p0_index

for loop_index in range(call_details.num_eval):
# Update polydispersity parameter values.
if p0_index == p0_length:
pd_index = (loop_index//pd_stride)%pd_length
parameters[pd_par] = pd_value[pd_offset+pd_index]
partial_weight = np.prod(pd_weight[pd_offset+pd_index][1:])
#partial_weight = np.prod(pd_weight[pd_offset+pd_index][1:])
partial_weight = torch.prod(pd_weight[pd_offset+pd_index][1:])

p0_index = loop_index%p0_length

# weight can become and array of weights calculated in parallel.
# weights[tid] = partial_weight * pd_weight[p0_offset + p0_index]
weight = partial_weight * pd_weight[p0_offset + p0_index]
parameters[p0_par] = pd_value[p0_offset + p0_index]
p0_index += 1
Expand All @@ -267,8 +289,9 @@ def _loops(parameters, form, form_volume, form_radius, nq, call_details,
# Assume that NaNs are only generated if the parameters are bad;
# exclude all q for that NaN. Even better would be to have an
# INVALID expression like the C models, but that is expensive.
Iq = np.asarray(form(), 'd')
if np.isnan(Iq).any():
#Iq = np.asarray(form(), 'd')
Iq = torch.asarray(form()).to(device)
if torch.isnan(Iq).any():
continue

# Update value and norm.
Expand All @@ -279,7 +302,9 @@ def _loops(parameters, form, form_volume, form_radius, nq, call_details,
weighted_form += weight * unweighted_form
weighted_radius += weight * form_radius()

result = np.hstack((total, weight_norm, weighted_form, weighted_shell, weighted_radius))
#result = np.hstack((total, weight_norm, weighted_form, weighted_shell, weighted_radius))
result = torch.hstack((total, weight_norm, weighted_form, weighted_shell, weighted_radius)).to(device)

return result


Expand Down

0 comments on commit ce90a04

Please sign in to comment.