From bd05e520f30fea3b554e495bb3a92d60d0a08c97 Mon Sep 17 00:00:00 2001
From: Maciej Dudek <mdudek@antmicro.com>
Date: Thu, 5 Dec 2024 12:10:44 +0100
Subject: [PATCH] Change refresh counter to 64 bits

This commit changes refresh counter width from 32 to 64 bits.
It also adds tests to check that values exceeding 32 bits are
represented.

Signed-off-by: Maciej Dudek <mdudek@antmicro.com>
---
 rowhammer_tester/gateware/payload_executor.py |  6 +-
 tests/test_payload_executor.py                | 78 +++++++++++++++++++
 2 files changed, 82 insertions(+), 2 deletions(-)

diff --git a/rowhammer_tester/gateware/payload_executor.py b/rowhammer_tester/gateware/payload_executor.py
index 8ac43012d..de57e8b39 100644
--- a/rowhammer_tester/gateware/payload_executor.py
+++ b/rowhammer_tester/gateware/payload_executor.py
@@ -356,10 +356,12 @@ def __init__(self, with_refresh, dfii, refresher_reset, memtype=""):
         # master (any refresh issued, both by MC and PayloadExecutor)
         if dfii.master.with_sub_channels:
             self.submodules.refresh_counter = RefreshCounter(
-                dfii.master.phases[0].A_, memtype=memtype
+                dfii.master.phases[0].A_, width=64, memtype=memtype
             )
         else:
-            self.submodules.refresh_counter = RefreshCounter(dfii.master.phases[0], memtype=memtype)
+            self.submodules.refresh_counter = RefreshCounter(
+                dfii.master.phases[0], width=64, memtype=memtype
+            )
 
         # If non-zero, we must wait until exactly that refresh count
         # Refresh counter is updated 1 cycle after refresh, so add +1 in the test
diff --git a/tests/test_payload_executor.py b/tests/test_payload_executor.py
index 9524ecd1e..401397b0e 100644
--- a/tests/test_payload_executor.py
+++ b/tests/test_payload_executor.py
@@ -789,6 +789,48 @@ def generator(dut):
         dut.dfi_switch.add_csrs()
         run_simulation(dut, [generator(dut), *dut.get_generators()])
 
+    def test_refresh_counter_64_bit(self):
+        def generator(dut):
+            # Subtract some value from 2^32
+            value = 31
+            yield dut.dfi_switch.refresh_counter.counter.eq(2**32 - value)
+            # wait for `value` refreshes and land in the middle of the refresh period(10)
+            for _ in range(value * 10 + 5):
+                yield
+
+            # start execution, this should wait for the next refresh, then latch refresh count
+            yield dut.payload_executor.start.eq(1)
+            yield
+            yield dut.payload_executor.start.eq(0)
+            yield
+
+            while not (yield dut.payload_executor.ready):
+                yield
+
+            # read refresh count CSR twice
+            at_transition = yield from dut.dfi_switch._refresh_count.read()
+            yield from dut.dfi_switch._refresh_update.write(1)
+            yield
+            forced = yield from dut.dfi_switch._refresh_count.read()
+            yield
+
+            # refreshes during waiting time, +1 between start.eq(1) and actual transition
+            self.assertEqual(at_transition, 2**32 + 1)
+            self.assertEqual(forced, at_transition + 3)  # for payload
+
+        encoder = Encoder(bankbits=3)
+        payload = [
+            encoder.Instruction(OpCode.NOOP, timeslice=2),
+            encoder.Instruction(OpCode.REF, timeslice=8),
+            encoder.Instruction(OpCode.REF, timeslice=8),
+            encoder.Instruction(OpCode.REF, timeslice=8),
+            encoder.Instruction(OpCode.NOOP, timeslice=0),  # STOP
+        ]
+
+        dut = PayloadExecutorDUT(encoder(payload), refresh_delay=9)
+        dut.dfi_switch.add_csrs()
+        run_simulation(dut, [generator(dut), *dut.get_generators()])
+
     def test_switch_at_refresh(self):
         def generator(dut, switch_at):
             yield from dut.dfi_switch._at_refresh.write(switch_at)
@@ -820,6 +862,42 @@ def generator(dut, switch_at):
                 dut.dfi_switch.add_csrs()
                 run_simulation(dut, [generator(dut, switch_at), *dut.get_generators()])
 
+    def test_switch_at_refresh_64_bit(self):
+        def generator(dut, switch_at):
+            yield from dut.dfi_switch._at_refresh.write(switch_at)
+            # Assert that write to _at_refresh doesn't affect refresh_counter
+            self.assertEqual((yield dut.dfi_switch.refresh_counter.counter), 0)
+            # Set refresh_counter to 2**32 minus a small value to
+            # check that counter goes over 32 bits.
+            # Starting from 2*32 - 31 speeds up simulation
+            yield dut.dfi_switch.refresh_counter.counter.eq(2**32 - 31)
+
+            # start execution, this should wait for the next refresh, then latch refresh count
+            yield dut.payload_executor.start.eq(1)
+            yield
+            yield dut.payload_executor.start.eq(0)
+            yield
+
+            while not (yield dut.payload_executor.ready):
+                yield
+
+            # +1 for payload
+            self.assertEqual((yield dut.dfi_switch.refresh_counter.counter), switch_at + 1)
+
+        encoder = Encoder(bankbits=3)
+        payload = [
+            encoder.Instruction(OpCode.NOOP, timeslice=10),
+            encoder.Instruction(OpCode.REF, timeslice=8),
+            encoder.Instruction(OpCode.NOOP, timeslice=10),
+            encoder.Instruction(OpCode.NOOP, timeslice=0),  # STOP
+        ]
+
+        for switch_at in [2**32 + 1]:
+            with self.subTest(switch_at=switch_at):
+                dut = PayloadExecutorDUT(encoder(payload), refresh_delay=3)
+                dut.dfi_switch.add_csrs()
+                run_simulation(dut, [generator(dut, switch_at), *dut.get_generators()])
+
 
 class TestPayloadExecutorDDR5(unittest.TestCase):
     def run_payload(self, dut, **kwargs):