Skip to content

Commit

Permalink
Codegen: Outsource the linear algebra c functions (inverse, solve) in…
Browse files Browse the repository at this point in the history
…to separate files and read them in as strings for the preamble of the respective loopy Callable.
  • Loading branch information
sv2518 committed Sep 4, 2020
1 parent 1ff2577 commit 2c71321
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 60 deletions.
29 changes: 29 additions & 0 deletions pyop2/codegen/c/inverse.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include <petscsys.h>
#include <petscblaslapack.h>

#ifndef PYOP2_WORK_ARRAYS
#define PYOP2_WORK_ARRAYS
#define BUF_SIZE 30
static PetscBLASInt ipiv_buffer[BUF_SIZE];
static PetscScalar work_buffer[BUF_SIZE*BUF_SIZE];
#endif

static void inverse(PetscScalar* __restrict__ Aout, const PetscScalar* __restrict__ A, PetscBLASInt N)
{
PetscBLASInt info;
PetscBLASInt *ipiv = N <= BUF_SIZE ? ipiv_buffer : malloc(N*sizeof(*ipiv));
PetscScalar *Awork = N <= BUF_SIZE ? work_buffer : malloc(N*N*sizeof(*Awork));
memcpy(Aout, A, N*N*sizeof(PetscScalar));
LAPACKgetrf_(&N, &N, Aout, &N, ipiv, &info);
if(info == 0){
LAPACKgetri_(&N, Aout, &N, ipiv, Awork, &N, &info);
}
if(info != 0){
fprintf(stderr, "Getri throws nonzero info.");
abort();
}
if ( N > BUF_SIZE ) {
free(Awork);
free(ipiv);
}
}
33 changes: 33 additions & 0 deletions pyop2/codegen/c/solve.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include <petscsys.h>
#include <petscblaslapack.h>

#ifndef PYOP2_WORK_ARRAYS
#define PYOP2_WORK_ARRAYS
#define BUF_SIZE 30
static PetscBLASInt ipiv_buffer[BUF_SIZE];
static PetscScalar work_buffer[BUF_SIZE*BUF_SIZE];
#endif

static void solve(PetscScalar* __restrict__ out, const PetscScalar* __restrict__ A, const PetscScalar* __restrict__ B, PetscBLASInt N)
{
PetscBLASInt info;
PetscBLASInt *ipiv = N <= BUF_SIZE ? ipiv_buffer : malloc(N*sizeof(*ipiv));
memcpy(out,B,N*sizeof(PetscScalar));
PetscScalar *Awork = N <= BUF_SIZE ? work_buffer : malloc(N*N*sizeof(*Awork));
memcpy(Awork,A,N*N*sizeof(PetscScalar));
PetscBLASInt NRHS = 1;
const char T = 'T';
LAPACKgetrf_(&N, &N, Awork, &N, ipiv, &info);
if(info == 0){
LAPACKgetrs_(&T, &N, &NRHS, Awork, &N, ipiv, out, &N, &info);
}
if(info != 0){
fprintf(stderr, "Gesv throws nonzero info.");
abort();
}

if ( N > BUF_SIZE ) {
free(ipiv);
free(Awork);
}
}
67 changes: 8 additions & 59 deletions pyop2/codegen/rep2loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,33 +160,10 @@ class INVCallable(LACallable):
"""
def generate_preambles(self, target):
assert isinstance(target, loopy.CTarget)
inverse_preamble = """
#define Inverse_HPP
#define BUF_SIZE 30
static PetscBLASInt ipiv_buffer[BUF_SIZE];
static PetscScalar work_buffer[BUF_SIZE*BUF_SIZE];
static void inverse(PetscScalar* __restrict__ Aout, const PetscScalar* __restrict__ A, PetscBLASInt N)
{
PetscBLASInt info;
PetscBLASInt *ipiv = N <= BUF_SIZE ? ipiv_buffer : malloc(N*sizeof(*ipiv));
PetscScalar *Awork = N <= BUF_SIZE ? work_buffer : malloc(N*N*sizeof(*Awork));
memcpy(Aout, A, N*N*sizeof(PetscScalar));
LAPACKgetrf_(&N, &N, Aout, &N, ipiv, &info);
if(info == 0){
LAPACKgetri_(&N, Aout, &N, ipiv, Awork, &N, &info);
}
if(info != 0){
fprintf(stderr, \"Getri throws nonzero info.\");
abort();
}
if ( N > BUF_SIZE ) {
free(Awork);
free(ipiv);
}
}
"""
yield ("inverse", "#include <petscsys.h>\n#include <petscblaslapack.h>\n" + inverse_preamble)
import os
with open(os.path.dirname(__file__)+"/c/inverse.c", "r") as myfile:
inverse_preamble = myfile.read()
yield ("inverse", inverse_preamble)
return


Expand All @@ -204,38 +181,10 @@ class SolveCallable(LACallable):
"""
def generate_preambles(self, target):
assert isinstance(target, loopy.CTarget)
code = """
#define Solve_HPP
#define BUF_SIZE 30
static PetscBLASInt ipiv_buffer[BUF_SIZE];
static PetscScalar work_buffer[BUF_SIZE*BUF_SIZE];
static void solve(PetscScalar* __restrict__ out, const PetscScalar* __restrict__ A, const PetscScalar* __restrict__ B, PetscBLASInt N)
{
PetscBLASInt info;
PetscBLASInt *ipiv = N <= BUF_SIZE ? ipiv_buffer : malloc(N*sizeof(*ipiv));
memcpy(out,B,N*sizeof(PetscScalar));
PetscScalar *Awork = N <= BUF_SIZE ? work_buffer : malloc(N*N*sizeof(*Awork));
memcpy(Awork,A,N*N*sizeof(PetscScalar));
PetscBLASInt NRHS = 1;
const char T = 'T';
LAPACKgetrf_(&N, &N, Awork, &N, ipiv, &info);
if(info == 0){
LAPACKgetrs_(&T, &N, &NRHS, Awork, &N, ipiv, out, &N, &info);
}
if(info != 0){
fprintf(stderr, \"Gesv throws nonzero info.\");
abort();
}
if ( N > BUF_SIZE ) {
free(ipiv);
free(Awork);
}
}
"""

yield ("solve", "#include <petscsys.h>\n#include <petscblaslapack.h>\n" + code)
import os
with open(os.path.dirname(__file__)+"/c/solve.c", "r") as myfile:
solve_preamble = myfile.read()
yield ("solve", solve_preamble)
return


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def run(self):
test_requires=test_requires,
packages=['pyop2', 'pyop2.codegen'],
package_data={
'pyop2': ['assets/*', '*.h', '*.pxd', '*.pyx']},
'pyop2': ['assets/*', '*.h', '*.pxd', '*.pyx', 'codegen/c/*.c']},
scripts=glob('scripts/*'),
cmdclass=cmdclass,
ext_modules=[Extension('pyop2.sparsity', sparsity_sources,
Expand Down

0 comments on commit 2c71321

Please sign in to comment.