From e48429e25dc9938343e8a66ada1e145e59341fe6 Mon Sep 17 00:00:00 2001
From: Thibaut Lunet <thibaut.lunet@tuhh.de>
Date: Sun, 14 Jul 2024 19:52:23 +0200
Subject: [PATCH] TL: improved Z2N and N2N switch, added SDelta feature

---
 qmat/qcoeff/__init__.py                | 17 ++++++++------
 qmat/qdelta/__init__.py                | 31 +++++++++++++++++++-------
 tests/test_qcoeff/test_base.py         | 11 +++++----
 tests/test_qdelta/test_base.py         | 14 +++++++++---
 tests/test_qdelta/test_timestepping.py | 14 +++++++++---
 5 files changed, 60 insertions(+), 27 deletions(-)

diff --git a/qmat/qcoeff/__init__.py b/qmat/qcoeff/__init__.py
index 17a94b3..beac1f7 100644
--- a/qmat/qcoeff/__init__.py
+++ b/qmat/qcoeff/__init__.py
@@ -71,13 +71,16 @@ def hCoeffs(self):
         approx = LagrangeApproximation(self.nodes)
         return approx.getInterpolationMatrix([1]).ravel()
 
-    def genCoeffs(self, withS=False, hCoeffs=False, embedded=False):
-        out = [self.nodes, self.weights, self.Q]
-
+    def genCoeffs(self, form="Z2N", hCoeffs=False, embedded=False):
+        if form == "Z2N":
+            mat = self.Q
+        elif form == "N2N":
+            mat = self.S
+        else:
+            raise ValueError(f"form must be Z2N or N2N, not {form}")
+        out = [self.nodes, self.weights, mat]
         if embedded:
             out[1] = np.vstack([out[1], self.weightsEmbedded])
-        if withS:
-            out.append(self.S)
         if hCoeffs:
             out.append(self.hCoeffs)
         return out
@@ -141,13 +144,13 @@ def register(cls:QGenerator)->QGenerator:
     storeClass(cls, Q_GENERATORS)
     return cls
 
-def genQCoeffs(qType, withS=False, hCoeffs=False, embedded=False, **params):
+def genQCoeffs(qType, form="Z2N", hCoeffs=False, embedded=False, **params):
     try:
         Generator = Q_GENERATORS[qType]
     except KeyError:
         raise ValueError(f"{qType=!r} is not available")
     gen = Generator(**params)
-    return gen.genCoeffs(withS, hCoeffs, embedded)
+    return gen.genCoeffs(form, hCoeffs, embedded)
 
 
 # Import all local submodules
diff --git a/qmat/qdelta/__init__.py b/qmat/qdelta/__init__.py
index d87ba75..4d1effc 100644
--- a/qmat/qdelta/__init__.py
+++ b/qmat/qdelta/__init__.py
@@ -23,7 +23,7 @@ def size(self):
     def zeros(self):
         M = self.size
         return np.zeros((M, M), dtype=float)
-    
+
     def computeQDelta(self, k=None) -> np.ndarray:
         """Compute and returns the QDelta matrix"""
         raise NotImplementedError("mouahahah")
@@ -41,15 +41,28 @@ def getQDelta(self, k=None, copy=True):
                 raise Exception("some very weird bug happened ... did you do fishy stuff ?")
         return QDelta.copy() if copy else QDelta
 
+    def getSDelta(self, k=None):
+        QDelta = self.getQDelta(k)
+        M = QDelta.shape[0]
+        T = np.eye(M)
+        T[1:,:-1][np.diag_indices(M-1)] = -1
+        return T @ QDelta
+
     @property
     def dTau(self):
         return np.zeros(self.size, dtype=float)
 
-    def genCoeffs(self, k=None, dTau=False):
+    def genCoeffs(self, k=None, form="Z2N", dTau=False):
+        if form == "Z2N":
+            gen = lambda k, copy=False: self.getQDelta(k, copy)
+        elif form == "N2N":
+            gen = lambda k, copy=None: self.getSDelta(k)
+        else:
+            raise ValueError(f"form must be Z2N or N2N, not {form}")
         if isinstance(k, list):
-            out = [np.array([self.getQDelta(_k, copy=False) for _k in k])]
+            out = [np.array([gen(_k, copy=False) for _k in k])]
         else:
-            out = [self.getQDelta(k)]
+            out = [gen(k)]
         if dTau:
             out += [self.dTau]
         return out if len(out) > 1 else out[0]
@@ -71,7 +84,8 @@ def register(cls:QDeltaGenerator)->QDeltaGenerator:
     storeClass(cls, QDELTA_GENERATORS)
     return cls
 
-def genQDeltaCoeffs(qDeltaType, nSweeps=None, dTau=False, **params):
+
+def genQDeltaCoeffs(qDeltaType, nSweeps=None, form="Z2N", dTau=False, **params):
 
     # Check arguments
     if isinstance(qDeltaType, str):
@@ -103,7 +117,7 @@ def genQDeltaCoeffs(qDeltaType, nSweeps=None, dTau=False, **params):
             raise ValueError(f"qDeltaType={qDeltaType} is not available")
 
         gen = Generator(**params)
-        return gen.genCoeffs(dTau=dTau)
+        return gen.genCoeffs(form=form, dTau=dTau)
 
     else:  # Multiple matrices return
         try:
@@ -113,12 +127,13 @@ def genQDeltaCoeffs(qDeltaType, nSweeps=None, dTau=False, **params):
 
         if len(qDeltaType) == 1:  # Single QDelta generator
             gen = Generators[0](**params)
-            return gen.genCoeffs(k=[k+1 for k in range(nSweeps)], dTau=dTau)
+            return gen.genCoeffs(
+                k=[k+1 for k in range(nSweeps)], form=form, dTau=dTau)
 
         else:  # Multiple QDelta generators
             gens = [Gen(**params) for Gen in Generators]
             out = [np.array(
-                [gen.getQDelta(k+1) for k, gen in enumerate(gens)]
+                [gen.genCoeffs(k+1, form) for k, gen in enumerate(gens)]
                 )]
             if dTau:
                 out += [gens[0].dTau]
diff --git a/tests/test_qcoeff/test_base.py b/tests/test_qcoeff/test_base.py
index 7087b0f..333e7b2 100644
--- a/tests/test_qcoeff/test_base.py
+++ b/tests/test_qcoeff/test_base.py
@@ -57,22 +57,21 @@ def testAdditionalCoeffs(name):
         f"hCoeffs for {name} has inconsistent size : {h1.size}"
 
     try:
-        _, _, _, S2, h2 = genQCoeffs(name, withS=True, hCoeffs=True)
+        _, _, S2, h2 = genQCoeffs(name, form="N2N", hCoeffs=True)
     except TypeError:
-        _, _, _, S2, h2 = genQCoeffs(name, withS=True, hCoeffs=True,
-                                **GENERATORS[name].DEFAULT_PARAMS)
+        _, _, S2, h2 = genQCoeffs(
+            name, form="N2N", hCoeffs=True, **GENERATORS[name].DEFAULT_PARAMS)
     assert np.allclose(S1, S2), \
         f"OOP S matrix {S1} and PP S matrix {S2} are not equals for {name}"
     assert np.allclose(h1, h2), \
         f"OOP hCoeffs {h1} and PP hCoeffs {h2} are not equals for {name}"
 
-
     try:
         try:
             _, b, _  = genQCoeffs(name, embedded=True)
         except TypeError:
-            _, b, _  = genQCoeffs(name, embedded=True, **GENERATORS[name].DEFAULT_PARAMS)
-
+            _, b, _  = genQCoeffs(
+                name, embedded=True, **GENERATORS[name].DEFAULT_PARAMS)
         assert type(b) == np.ndarray
         assert b.ndim == 2
     except NotImplementedError:
diff --git a/tests/test_qdelta/test_base.py b/tests/test_qdelta/test_base.py
index b789991..318711f 100644
--- a/tests/test_qdelta/test_base.py
+++ b/tests/test_qdelta/test_base.py
@@ -30,19 +30,27 @@ def testGeneration(name, nNodes):
     assert np.allclose(QD1, QD2), \
         f"OOP QDelta and PP QDelta are not equals for {name}"
 
-    _, dTau1 = gen.genCoeffs(dTau=True)
+    SD1, dTau1 = gen.genCoeffs(form="N2N", dTau=True)
     assert type(dTau1) == np.ndarray, \
         f"dTau for {name} is not np.ndarray but {type(dTau1)}"
     assert dTau1.ndim == 1, \
         f"dTau for {name} is not 1D : {dTau1}"
     assert dTau1.size == nNodes, \
         f"dTau for {name} has not the correct size : {dTau1}"
-
-    _, dTau2 = genQDeltaCoeffs(name, Q=Q, dTau=True)
+    assert SD1.ndim == 2, \
+        f"SDelta for {name} is not 2D : {SD1}"
+    assert SD1.shape == QD2.shape, \
+        f"SDelta for {name} has not the correct shape : {SD1}"
+
+    SD2, dTau2 = genQDeltaCoeffs(name, Q=Q, form="N2N", dTau=True)
+    assert np.allclose(SD1, SD2), \
+        f"OOP SDelta and PP SDelta are not equals for {name}"
     assert np.allclose(dTau1, dTau2), \
         f"OOP dTau and PP dTau are not equals for {name}"
 
 
+
+
 nNodes = 4
 @pytest.mark.parametrize("nSweeps", [1, 2, 3])
 @pytest.mark.parametrize("name", GENERATORS.keys())
diff --git a/tests/test_qdelta/test_timestepping.py b/tests/test_qdelta/test_timestepping.py
index 7b9ba32..c0a1122 100644
--- a/tests/test_qdelta/test_timestepping.py
+++ b/tests/test_qdelta/test_timestepping.py
@@ -17,13 +17,18 @@
 def testBE(nNodes, nodeType, quadType):
     coll = Collocation(nNodes, nodeType, quadType)
     nodes = coll.nodes
-    QDelta = module.BE(nodes).getQDelta()
+    gen =  module.BE(nodes)
+    QDelta = gen.getQDelta()
 
     assert np.allclose(np.tril(QDelta), QDelta), \
         "QDelta is not lower triangular"
     assert np.allclose(QDelta.sum(axis=1), nodes), \
         "sum over the columns is not equal to nodes"
 
+    SDelta = gen.getSDelta()
+    assert np.allclose(np.diag(np.diag(SDelta)), SDelta), \
+        "SDelta is not diagonal"
+
 
 @pytest.mark.parametrize("quadType", QUAD_TYPES)
 @pytest.mark.parametrize("nodeType", NODE_TYPES)
@@ -31,7 +36,8 @@ def testBE(nNodes, nodeType, quadType):
 def testFE(nNodes, nodeType, quadType):
     coll = Collocation(nNodes, nodeType, quadType)
     nodes = coll.nodes
-    QDelta = module.FE(nodes).getQDelta()
+    gen = module.FE(nodes)
+    QDelta = gen.getQDelta()
 
     assert np.allclose(np.tril(QDelta), QDelta), \
         "QDelta is not lower triangular"
@@ -40,7 +46,7 @@ def testFE(nNodes, nodeType, quadType):
     assert np.allclose(QDelta.sum(axis=1)[1:], np.cumsum(np.diff(coll.nodes))), \
         "sum over the columns is not equal to cumsum of node differences"
 
-    _, dTau = module.FE(nodes).genCoeffs(dTau=True)
+    SDelta, dTau = module.FE(nodes).genCoeffs(form="N2N", dTau=True)
     assert type(dTau) == np.ndarray, \
         f"dTau is not np.ndarray but {type(dTau)}"
     assert dTau.ndim == 1, \
@@ -49,6 +55,8 @@ def testFE(nNodes, nodeType, quadType):
         f"dTau has not the correct size : {dTau}"
     assert np.allclose(dTau, coll.nodes[0]), \
         "dTau is not equal to nodes[0]"
+    assert np.allclose(np.diag(np.diag(SDelta, k=-1), k=-1), SDelta), \
+        "SDelta is not strictly lower diagonal"
 
 
 @pytest.mark.parametrize("quadType", QUAD_TYPES)