Skip to content

Commit

Permalink
TL: fix on Jumper implementation + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tlunet committed Jan 10, 2025
1 parent f5db01a commit c6c8046
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
5 changes: 3 additions & 2 deletions qmat/qdelta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ def genCoeffs(self, k=None, form="Z2N", dTau=False):
gen = lambda k, copy=None: self.getSDelta(k)
else:
raise ValueError(f"form must be Z2N or N2N, not {form}")
if isinstance(k, list):
try:
k = list(k)
out = [np.array([gen(_k, copy=False) for _k in k])]
else:
except TypeError:
out = [gen(k)]
if dTau:
out += [self.dTau]
Expand Down
22 changes: 22 additions & 0 deletions tests/test_qdelta/test_min.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,25 @@ def testFlex(nNodes, nodeType, quadType):
QDelta1 = gen.getQDelta(1)
assert np.allclose(QDelta0, QDelta1), \
"default QDelta is not equal to k=1"


@pytest.mark.parametrize("quadType", ["GAUSS", "RADAU-RIGHT"])
@pytest.mark.parametrize("nodeType", NODE_TYPES)
@pytest.mark.parametrize("nNodes", [2, 3, 4])
def testJumper(nNodes, nodeType, quadType):
coll = Collocation(nNodes=nNodes, nodeType=nodeType, quadType=quadType)
nodes, Q = coll.nodes, coll.Q
k = np.arange(nNodes)+1

gen = module.Jumper(nodes=nodes)
genFlex = module.MIN_SR_FLEX(coll=coll)
QDeltas = gen.genCoeffs(k)
QDeltasFlex = genFlex.genCoeffs(k)

assert np.allclose(QDeltas, QDeltasFlex/2)

gen2 = module.FlexJumper(nodes=nodes)
QDeltas2 = gen2.genCoeffs(k)

assert np.allclose(QDeltas2[0], QDeltasFlex[0])
assert np.allclose(QDeltas2[1:], QDeltas[:-1])

0 comments on commit c6c8046

Please sign in to comment.