Skip to content

Commit

Permalink
added src submodules
Browse files Browse the repository at this point in the history
  • Loading branch information
annayesy committed May 21, 2024
1 parent d777f67 commit 248acb1
Show file tree
Hide file tree
Showing 21 changed files with 174 additions and 171 deletions.
1 change: 1 addition & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# slablu_package/__init__.py
42 changes: 22 additions & 20 deletions argparse_driver.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import argparse
import torch
torch.set_default_dtype(torch.double)

import numpy as np
from time import time
torch.set_default_dtype(torch.double)

from domain_driver import *
from built_in_funcs import *
from src_solver.slablu_solver import Domain_Solver
import src_disc.built_in_funcs as built_in_funcs
import src_disc.pdo as pdo
import pickle
import os

Expand Down Expand Up @@ -91,19 +93,19 @@


if (args.pde == 'bfield_constant'):
bfield = bfield_constant
bfield = built_in_funcs.bfield_constant
elif (args.pde == 'bfield_bumpy'):
bfield = bfield_bumpy
bfield = built_in_funcs.bfield_bumpy
elif (args.pde == 'bfield_gaussian_bumps'):
bfield = bfield_gaussian_bumps
bfield = built_in_funcs.bfield_gaussian_bumps
elif (args.pde == 'bfield_cavity'):
bfield = bfield_cavity_scattering
bfield = built_in_funcs.bfield_cavity_scattering
elif (args.pde == 'bfield_crystal'):
bfield = bfield_crystal
bfield = built_in_funcs.bfield_crystal
elif (args.pde == 'bfield_crystal_waveguide'):
bfield = bfield_crystal_waveguide
bfield = built_in_funcs.bfield_crystal_waveguide
elif (args.pde == 'bfield_crystal_rhombus'):
bfield = bfield_crystal_rhombus
bfield = built_in_funcs.bfield_crystal_rhombus
else:
raise ValueError

Expand Down Expand Up @@ -146,7 +148,7 @@ def c(xx):
if (args.buf_constant is None):
args.buf_constant = 0.6
h = 1/n;
dom = Domain_Driver(box_geom,op,\
dom = Domain_Solver(box_geom,op,\
kh,h,buf_constant=args.buf_constant)
N = dom.fd.ns[0] * dom.fd.ns[1]
elif (disc=='hps'):
Expand All @@ -156,7 +158,7 @@ def c(xx):
args.buf_constant = 1.0
p = args.p
npan = n / (p-2); a = 1/(2*npan)
dom = Domain_Driver(box_geom,op,\
dom = Domain_Solver(box_geom,op,\
kh,a,p=p,buf_constant=args.buf_constant,periodic_bc = args.periodic_bc)
N = (p-2) * (p*dom.hps.n[0]*dom.hps.n[1] + dom.hps.n[0] + dom.hps.n[1])
else:
Expand Down Expand Up @@ -195,17 +197,17 @@ def c(xx):
ff_body = None; known_sol = True

if (not curved_domain):
uu_dir = lambda xx: uu_dir_func_greens(xx,kh)
uu_dir = lambda xx: built_in_funcs.uu_dir_func_greens(xx,kh)
else:
uu_dir = lambda xx: uu_dir_func_greens(param_map(xx),kh)
uu_dir = lambda xx: built_in_funcs.uu_dir_func_greens(param_map(xx),kh)

elif (args.bc == 'pulse'):
ff_body = None; known_sol = False

if (not curved_domain):
uu_dir = lambda xx: uu_dir_pulse(xx,kh)
uu_dir = lambda xx: built_in_funcs.uu_dir_pulse(xx,kh)
else:
uu_dir = lambda xx: uu_dir_pulse(param_map(xx),kh)
uu_dir = lambda xx: built_in_funcs.uu_dir_pulse(param_map(xx),kh)

elif (args.bc == 'ones'):
ff_body = None; known_sol = False
Expand All @@ -223,8 +225,8 @@ def c(xx):
assert (not curved_domain)

Lx = 4*np.pi; Ly = 1
uu_dir = lambda xx: uu_dir_func_mms(xx,Lx,Ly)
ff_body = lambda xx: ff_body_func_mms(xx,Lx,Ly)
uu_dir = lambda xx: built_in_funcs.uu_dir_func_mms(xx,Lx,Ly)
ff_body = lambda xx: built_in_funcs.ff_body_func_mms(xx,Lx,Ly)
known_sol = True
else:
raise ValueError
Expand All @@ -235,7 +237,7 @@ def c(xx):
assert kh == 0
assert (not curved_domain)

uu_dir = lambda xx: uu_dir_func_greens(xx,kh)
uu_dir = lambda xx: built_in_funcs.uu_dir_func_greens(xx,kh)
ff_body = None
known_sol = True
else:
Expand Down Expand Up @@ -294,4 +296,4 @@ def c(xx):
f = open(file_loc,"wb+")
pickle.dump(build_info,f)
pickle.dump(solve_info,f)
f.close()
f.close()
Binary file modified figures/picture_annulus.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/picture_crystal.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/picture_curvy_annulus.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 5 additions & 5 deletions generate_pictures.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import sys
import plotting_utils
import src_disc.plotting_utils as plotting_utils
import matplotlib.pyplot as plt

if not os.path.exists('pickle_files'):
Expand Down Expand Up @@ -47,7 +47,7 @@ def get_driver_command(nwaves, pickle_loc,\
title_fontsize=50)
plotting_utils.plot_solution_from_pickle(pickle_file_freq1,fig,ax[2],plot_pad=0,axes_labeled=False,\
title_fontsize=50)
plt.savefig('picture_crystal.pdf',bbox_inches='tight')
plt.savefig('figures/picture_crystal.png',bbox_inches='tight')

###########################################################################################
### generate plots for annulus
Expand All @@ -65,7 +65,7 @@ def get_driver_command(nwaves, pickle_loc,\
pickle_file_freq1 = 'pickle_files/annulus_picture1'

plotting_utils.plot_solution_from_pickle(pickle_file_freq1,fig,ax,title_fontsize=50)
plt.savefig('picture_annulus.pdf',bbox_inches='tight')
plt.savefig('figures/picture_annulus.png',bbox_inches='tight')

###########################################################################################
### generate plots for curvy annulus
Expand All @@ -74,7 +74,7 @@ def get_driver_command(nwaves, pickle_loc,\

args_curvy = n,p,'bfield_constant', 'curvy_annulus', 6.0, 1.0, True

nwaves1 = 10.0; nwaves2 = 40.0
nwaves1 = 10.0; nwaves2 = 30.0
os.system(get_driver_command(nwaves1,'curvy_annulus_picture1',*args_curvy))
os.system(get_driver_command(nwaves2,'curvy_annulus_picture2',*args_curvy))

Expand All @@ -89,4 +89,4 @@ def get_driver_command(nwaves, pickle_loc,\
plotting_utils.plot_solution_from_pickle(pickle_file_freq0,fig,ax[0],poly_pad=0.02,title_fontsize=60)
plotting_utils.plot_solution_from_pickle(pickle_file_freq1,fig,ax[1],poly_pad=0.04,title_fontsize=60,\
axes_labeled=False)
plt.savefig('picture_curvy_annulus.pdf',bbox_inches='tight')
plt.savefig('figures/picture_curvy_annulus.png',bbox_inches='tight')
1 change: 1 addition & 0 deletions src_disc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# src_disc/__init__.py
2 changes: 1 addition & 1 deletion built_in_funcs.py → src_disc/built_in_funcs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import numpy as np
torch.set_default_dtype(torch.double)
import pdo

from src_disc import pdo
from scipy.special import hankel1

##################################### FUNCTIONS FOR DIRICHLET DATA AND BODY LOAD #############################
Expand Down
File renamed without changes.
8 changes: 4 additions & 4 deletions hps_multidomain_disc.py → src_disc/hps_multidomain_disc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
torch.set_default_dtype(torch.double)
import hps_subdomain_disc
import hps_parallel_leaf_ops as leaf_ops
import numpy as np
import scipy.sparse as sp

from src_disc import hps_subdomain_disc, hps_parallel_leaf_ops as leaf_ops
from scipy.sparse import coo_matrix
from time import time

def batched_meshgrid(b,npoints,I,J):
Expand Down Expand Up @@ -90,7 +90,7 @@ def sparse_mat(self,device,verbose=False):
toc_flatten = time() - tic

tic = time()
sp_mat = sp.coo_matrix(( np.array(data),\
sp_mat = coo_matrix(( np.array(data),\
(np.array(row_data,dtype=int),np.array(col_data,dtype=int)))).tocsr()
sp_mat = sp_mat.tocsr()
toc_csr_scipy = time() - tic;
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
21 changes: 10 additions & 11 deletions plotting_utils.py → src_disc/plotting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
torch.set_default_dtype(torch.double)
import numpy as np
from matplotlib.patches import Polygon
import pdo
from built_in_funcs import *
from src_disc import pdo,built_in_funcs
import pickle

from scipy import interpolate
Expand Down Expand Up @@ -41,7 +40,7 @@ def plot_solution(XX,box_geom,geom,kh,uu_sol,fig,ax,\
axes_labeled=True,colorbar=True,resolution=1000):

op, param_map, inv_param_map = pdo.get_param_map_and_pdo(geom, \
bfield_constant,kh)
built_in_funcs.bfield_constant,kh)

solution = uu_sol.reshape(XX.shape[0],)
max_sol = torch.max(solution)
Expand Down Expand Up @@ -142,19 +141,19 @@ def plot_bfield_from_pickle(pickle_file,fig,ax,title=None,\
geom = build_dict['domain']

if (bfield_str == 'bfield_constant'):
bfield = bfield_constant
bfield = built_in_funcs.bfield_constant
elif (bfield_str == 'bfield_bumpy'):
bfield = bfield_bumpy
bfield = built_in_funcs.bfield_bumpy
elif (bfield_str == 'bfield_gaussian_bumps'):
bfield = bfield_gaussian_bumps
bfield = built_in_funcs.bfield_gaussian_bumps
elif (bfield_str == 'bfield_cavity'):
bfield = bfield_cavity_scattering
bfield = built_in_funcs.bfield_cavity_scattering
elif (bfield_str == 'bfield_crystal'):
bfield = bfield_crystal
bfield = built_in_funcs.bfield_crystal
elif (bfield_str == 'bfield_crystal_waveguide'):
bfield = bfield_crystal_waveguide
bfield = built_in_funcs.bfield_crystal_waveguide
elif (bfield_str == 'bfield_crystal_rhombus'):
bfield = bfield_crystal_rhombus
bfield = built_in_funcs.bfield_crystal_rhombus
else:
raise ValueError

Expand All @@ -175,7 +174,7 @@ def plot_solution_from_pickle(pickle_file,fig,ax,title=None,title_fontsize=45,\
box_geom = build_dict['box_geom']

kh = build_dict['kh']
bfield_str = bfield_constant
bfield_str = built_in_funcs.bfield_constant
uu_sol = solve_dict['sol']

if ('domain' not in build_dict):
Expand Down
1 change: 1 addition & 0 deletions src_solver/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# src_solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch
torch.set_default_dtype(torch.double)
from time import time
import sys

#################################### STATIC METHODS ##################################

Expand Down
Loading

0 comments on commit 248acb1

Please sign in to comment.