Skip to content

Commit

Permalink
Move peripherals to CSR bus. Update fixtures to accomodate amaranth_soc
Browse files Browse the repository at this point in the history
quirks.
  • Loading branch information
cr1901 committed Dec 6, 2023
1 parent 45203d8 commit a8abc00
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 110 deletions.
198 changes: 97 additions & 101 deletions examples/attosoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from bronzebeard.asm import assemble
from elftools.elf.elffile import ELFFile
from amaranth import Module, Memory, Signal, Cat, C
from amaranth.utils import log2_int
from amaranth_soc import wishbone
from amaranth_soc import csr
from amaranth_soc.csr.wishbone import WishboneCSRBridge
from amaranth_soc.memory import MemoryMap
from amaranth.lib.wiring import In, Out, Component, Elaboratable, connect, \
Signature, flipped
Expand Down Expand Up @@ -93,54 +96,45 @@ def signature(self):
return self._signature

def __init__(self):
bus_signature = wishbone.Signature(addr_width=25, data_width=8,
granularity=8)
bus_signature.memory_map = MemoryMap(addr_width=25, data_width=8,
name="leds")
self._signature = Signature({
"bus": In(bus_signature),
self.mux = csr.bus.Multiplexer(addr_width=2, data_width=8, name="gpio")
self._signature = self.mux.signature
self._signature.members += {
"leds": Out(8),
"gpio": In(Signature({
"i": In(1),
"o": Out(1),
"oe": Out(1)
})).array(8)
})
}

self.leds_reg = csr.Element(8, "w", path=("leds",))
self.inout_reg = csr.Element(8, "rw", path=("inout",))
self.oe_reg = csr.Element(8, "rw", path=("oe",))
self.mux.add(self.leds_reg, name="leds")
self.mux.add(self.inout_reg, name="inout")
self.mux.add(self.oe_reg, name="oe")

super().__init__()
bus_signature.memory_map.add_resource(self.leds,
name="leds", size=1)
bus_signature.memory_map.add_resource(((g.o for g in self.gpio),
(g.i for g in self.gpio)),
name="inout", size=1)
bus_signature.memory_map.add_resource((g.oe for g in self.gpio),
name="oe", size=1)

def elaborate(self, plat):
m = Module()
m.submodules.mux = self.mux

with m.If(self.bus.stb & self.bus.cyc & self.bus.ack & self.bus.we
& (self.bus.adr[0:2] == 0) & self.bus.sel[0]):
m.d.sync += self.leds.eq(self.bus.dat_w)
connect(m, flipped(self.bus), self.mux.bus)

with m.If(self.bus.stb & self.bus.cyc & ~self.bus.ack &
(self.bus.adr[0:2] == 1) & self.bus.sel[0]):
with m.If(~self.bus.we):
for i in range(8):
m.d.sync += self.bus.dat_r[i].eq(self.gpio[i].i)
with m.Else():
for i in range(8):
m.d.sync += self.gpio[i].o.eq(self.bus.dat_w[i])
with m.If(self.leds_reg.w_stb):
m.d.sync += self.leds.eq(self.leds_reg.w_data)

with m.If(self.bus.stb & self.bus.cyc & ~self.bus.ack & self.bus.we
& (self.bus.adr[0:2] == 2) & self.bus.sel[0]):
with m.If(self.inout_reg.w_stb):
for i in range(8):
m.d.sync += self.gpio[i].oe.eq(self.bus.dat_w[i])
m.d.sync += self.gpio[i].o.eq(self.leds_reg.w_data[i])

with m.If(self.bus.stb & self.bus.cyc & ~self.bus.ack):
m.d.sync += self.bus.ack.eq(1)
with m.Else():
m.d.sync += self.bus.ack.eq(0)
for i in range(8):
m.d.comb += self.inout_reg.r_data[i].eq(self.gpio[i].i)

with m.If(self.oe_reg.w_stb):
for i in range(8):
m.d.sync += self.gpio[i].oe.eq(self.oe_reg.w_data[i])

return m

Expand All @@ -151,37 +145,35 @@ def signature(self):
return self._signature

def __init__(self):
bus_signature = wishbone.Signature(addr_width=30, data_width=8,
granularity=8)
bus_signature.memory_map = MemoryMap(addr_width=30, data_width=8,
name="timer")
self._signature = Signature({
"bus": In(bus_signature),
"irq": Out(1),
})
self.mux = csr.bus.Multiplexer(addr_width=1, data_width=8,
name="timer")
self._signature = self.mux.signature
self._signature.members += {
"irq": Out(1)
}

self.irq_reg = csr.Element(8, "r", path=("irq",))
self.mux.add(self.irq_reg, name="irq")

super().__init__()
bus_signature.memory_map.add_resource(self.irq, name="irq", size=1)

def elaborate(self, plat):
m = Module()
m.submodules.mux = self.mux

prescalar = Signal(15)
connect(m, flipped(self.bus), self.mux.bus)

m.d.sync += prescalar.eq(prescalar + 1)
m.d.comb += self.irq.eq(prescalar[14])
prescaler = Signal(15)

with m.If(self.bus.stb & self.bus.cyc & ~self.bus.we):
with m.If(self.bus.sel[0] == 1):
m.d.sync += self.bus.dat_r.eq(self.irq)

with m.If(self.irq):
m.d.sync += prescalar[14].eq(0)
m.d.sync += prescaler.eq(prescaler + 1)
m.d.comb += [
self.irq.eq(prescaler[14]),
self.irq_reg.r_data.eq(prescaler[14])
]

with m.If(self.bus.stb & self.bus.cyc & ~self.bus.ack):
m.d.sync += self.bus.ack.eq(1)
with m.Else():
m.d.sync += self.bus.ack.eq(0)
with m.If(self.irq_reg.r_stb):
with m.If(self.irq):
m.d.sync += prescaler[14].eq(0)

return m

Expand Down Expand Up @@ -280,26 +272,29 @@ def signature(self):
return self._signature

def __init__(self):
bus_signature = wishbone.Signature(addr_width=30, data_width=8,
granularity=8)
bus_signature.memory_map = MemoryMap(addr_width=30, data_width=8,
name="serial")
self._signature = Signature({
"bus": In(bus_signature),
self.mux = csr.bus.Multiplexer(addr_width=1, data_width=8,
name="serial")
self._signature = self.mux.signature
self._signature.members += {
"rx": In(1),
"tx": Out(1),
"irq": Out(1),
})
"irq": Out(1)
}

self.txrx_reg = csr.Element(8, "rw", path=("txrx",))
self.mux.add(self.txrx_reg, name="txrx")
self.irq_reg = csr.Element(8, "r", path=("irq",))
self.mux.add(self.irq_reg, name="irq")

super().__init__()
bus_signature.memory_map.add_resource((self.tx, self.rx), name="rxtx",
size=1)
bus_signature.memory_map.add_resource(self.irq, name="irq", size=1)
self.serial = UART(divisor=12000000 // 9600)

def elaborate(self, plat):
m = Module()
m.submodules.ser_internal = self.serial
m.submodules.mux = self.mux

connect(m, flipped(self.bus), self.mux.bus)

rx_rdy_irq = Signal()
rx_rdy_prev = Signal()
Expand All @@ -317,31 +312,23 @@ def elaborate(self, plat):
tx_ack_prev.eq(self.serial.tx_ack),
]

with m.If(self.bus.stb & self.bus.cyc & self.bus.sel[0] &
~self.bus.adr[0]):
m.d.sync += self.bus.dat_r.eq(self.serial.rx_data)
with m.If(~self.bus.we):
m.d.comb += self.serial.rx_ack.eq(1)
m.d.comb += [
self.serial.tx_data.eq(self.txrx_reg.w_data),
self.txrx_reg.r_data.eq(self.serial.rx_data),
self.irq_reg.r_data.eq(Cat(rx_rdy_irq, tx_ack_irq))
]

with m.If(self.bus.ack & self.bus.we):
m.d.comb += [
self.serial.tx_data.eq(self.bus.dat_w),
self.serial.tx_rdy.eq(1)
]
with m.If(self.txrx_reg.w_stb):
m.d.comb += self.serial.tx_rdy.eq(1)
with m.If(self.txrx_reg.r_stb):
m.d.comb += self.serial.rx_ack.eq(1)

with m.If(self.bus.stb & self.bus.cyc & self.bus.sel[0] &
self.bus.adr[0] & ~self.bus.we & ~self.bus.ack):
with m.If(self.irq_reg.r_stb):
m.d.sync += [
self.bus.dat_r.eq(Cat(rx_rdy_irq, tx_ack_irq)),
rx_rdy_irq.eq(0),
tx_ack_irq.eq(0)
]

with m.If(self.bus.stb & self.bus.cyc & ~self.bus.ack):
m.d.sync += self.bus.ack.eq(1)
with m.Else():
m.d.sync += self.bus.ack.eq(0)

# Don't accidentally miss an IRQ
with m.If(self.serial.rx_rdy & ~rx_rdy_prev):
m.d.sync += rx_rdy_irq.eq(1)
Expand All @@ -359,6 +346,8 @@ def __init__(self, *, sim=False, num_bytes=0x400):
self.mem = WBMemory(sim=sim, num_bytes=num_bytes)
self.leds = Leds()
self.sim = sim
self.decoder = wishbone.Decoder(addr_width=30, data_width=32,
granularity=8, alignment=25)
if not self.sim:
self.timer = Timer()
self.serial = WBSerial()
Expand All @@ -384,13 +373,28 @@ def rom(self, source_or_list):
def elaborate(self, plat):
m = Module()

decoder = wishbone.Decoder(addr_width=30, data_width=32, granularity=8,
alignment=25)
# CSR (has to be done first other mem map "frozen" errors?)
periph_bus = csr.Decoder(addr_width=25, data_width=8, alignment=23,
name="periph")
periph_bus.add(self.leds.bus)
if not self.sim:
periph_bus.add(self.timer.bus)
periph_bus.add(self.serial.bus)

# Wishbone
periph_wb = WishboneCSRBridge(periph_bus.bus, data_width=32)
self.decoder.add(flipped(self.mem.bus))
self.decoder.add(flipped(periph_wb.wb_bus))

m.submodules.cpu = self.cpu
m.submodules.mem = self.mem
m.submodules.leds = self.leds
m.submodules.decoder = decoder
m.submodules.decoder = self.decoder
m.submodules.periph_bus = periph_bus
m.submodules.periph_wb = periph_wb
if not self.sim:
m.submodules.timer = self.timer
m.submodules.serial = self.serial

if plat:
for i in range(8):
Expand All @@ -399,7 +403,12 @@ def elaborate(self, plat):
except ResourceError:
break
m.d.comb += led.o.eq(self.leds.leds[i])

ser = plat.request("uart")
m.d.comb += [
self.serial.rx.eq(ser.rx.i),
ser.tx.o.eq(self.serial.tx)
]

for i in range(8):
try:
Expand All @@ -413,29 +422,16 @@ def elaborate(self, plat):
gpio.o.eq(self.leds.gpio[i].o)
]

decoder.add(flipped(self.mem.bus))
decoder.add(flipped(self.leds.bus), sparse=True)
if not self.sim:
m.submodules.timer = self.timer
decoder.add(flipped(self.timer.bus), sparse=True)

m.submodules.serial = self.serial
decoder.add(flipped(self.serial.bus), sparse=True)
m.d.comb += [
self.serial.rx.eq(ser.rx.i),
ser.tx.o.eq(self.serial.tx)
]

m.d.comb += self.cpu.irq.eq(self.timer.irq | self.serial.irq)

def destruct_res(res):
return ("/".join(res.path), res.start, res.end, res.width)

print(tabulate(map(destruct_res,
decoder.bus.memory_map.all_resources()),
self.decoder.bus.memory_map.all_resources()),
intfmt=("", "#010x", "#010x", ""),
headers=["name", "start", "end", "width"]))
connect(m, self.cpu.bus, decoder.bus)
connect(m, self.cpu.bus, self.decoder.bus)

return m

Expand Down
10 changes: 5 additions & 5 deletions sentinel-rt/examples/attosoc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,23 @@ static mut TX_CONS: MaybeUninit<Consumer<'static, u8, 64>> = MaybeUninit::uninit
// proven that we have exclusive access or have opted into unsafety previously.
// These are all valid I/O port addresses.
fn read_timer_int(_cs: CriticalSection) -> u8 {
unsafe { read_volatile(0x40000000 as *const u8) }
unsafe { read_volatile(0x02800000 as *const u8) }
}

fn read_serial_int(_cs: CriticalSection) -> u8 {
unsafe { read_volatile(0x80000004 as *const u8) }
unsafe { read_volatile(0x03000001 as *const u8) }
}

fn read_serial_rx(_cs: CriticalSection) -> u8 {
unsafe { read_volatile(0x80000000 as *const u8) }
unsafe { read_volatile(0x03000000 as *const u8) }
}

fn write_serial_tx(_cs: CriticalSection, val: u8) {
unsafe { write_volatile(0x80000000 as *mut u8, val) }
unsafe { write_volatile(0x03000000 as *mut u8, val) }
}

fn read_inp_port(_cs: CriticalSection,) -> u8 {
unsafe { read_volatile(0x02000004 as *const u8) }
unsafe { read_volatile(0x02000001 as *const u8) }
}

fn write_leds(_cs: CriticalSection, val: u8) {
Expand Down
18 changes: 17 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import pytest

from amaranth import Value
Expand Down Expand Up @@ -29,7 +30,22 @@ def pytest_collection_modifyitems(config, items):

class SimulatorFixture:
def __init__(self, req, cfg):
self.mod = req.node.get_closest_marker("module").args[0]
mod = req.node.get_closest_marker("module").args[0]
# FIXME: Depending on module contents, some amaranth code, such as
# amaranth_soc.csr classes don't interact well with elaborating
# the same object multiple times. This happens during parametrized
# tests. Therefore, provide an escape hatch to create a fresh object
# for all arguments of a parameterized test.
#
# Ideally, I should figure out the exact conditions under where it's
# safe to reuse an already-elaborated object (if ever); the tests
# didn't break until I started using amaranth_soc.csr. But this will
# do for now.
if isinstance(mod, functools.partial):
self.mod = mod()
else:
self.mod = mod

self.name = req.node.name
self.vcds = cfg.getoption("vcds")
self.clks = req.node.get_closest_marker("clks").args[0]
Expand Down
6 changes: 3 additions & 3 deletions tests/upstream/test_upstream.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

import functools
import pytest

from enum import Enum, auto
Expand Down Expand Up @@ -89,7 +89,7 @@ def wait_for_host_write():
]


@pytest.mark.module(AttoSoC(sim=True, num_bytes=4096))
@pytest.mark.module(functools.partial(AttoSoC, sim=True, num_bytes=4096))
@pytest.mark.clks((1.0 / 12e6,))
@pytest.mark.parametrize("test_bin", RV32UI_TESTS, indirect=True)
def test_rv32ui(sim_mod, ucode_panic, test_bin, wait_for_host_write):
Expand All @@ -106,7 +106,7 @@ def test_rv32ui(sim_mod, ucode_panic, test_bin, wait_for_host_write):
]


@pytest.mark.module(AttoSoC(sim=True, num_bytes=4096))
@pytest.mark.module(functools.partial(AttoSoC, sim=True, num_bytes=4096))
@pytest.mark.clks((1.0 / 12e6,))
@pytest.mark.parametrize("test_bin", RV32MI_TESTS, indirect=True)
def test_rv32mi(sim_mod, ucode_panic, test_bin, wait_for_host_write):
Expand Down

0 comments on commit a8abc00

Please sign in to comment.