Skip to content

Commit

Permalink
[NDArray] Update order to match numpy (#1354)
Browse files Browse the repository at this point in the history
Addresses #1310
  • Loading branch information
leonardt authored Jan 30, 2024
1 parent dcd5b89 commit 8d68a6c
Show file tree
Hide file tree
Showing 8 changed files with 406 additions and 404 deletions.
30 changes: 16 additions & 14 deletions magma/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def __new__(cls, name, bases, namespace, info=(None, None, None), **kwargs):

return type_

def _make_ndarray(cls, Ns, T):
for N in reversed(Ns):
T = Array[N, T]
return T

def __getitem__(cls, index: tuple) -> 'ArrayMeta':
mcs = type(cls)

Expand Down Expand Up @@ -104,14 +109,11 @@ def __getitem__(cls, index: tuple) -> 'ArrayMeta':
if len(index[0]) == 0:
raise ValueError("Cannot create array with length 0 tuple "
"for N")
if len(index[0]) > 1:
T = index[1]
# ND Array
for N in index[0]:
T = Array[N, T]
return T
# len(index[0]) == 1, Treat as normal Array
index = index[0]
elif len(index[0]) > 1:
return cls._make_ndarray(*index)
else:
# len(index[0]) == 1, Treat as normal Array
index = index[0]

if (not isinstance(index[0], int) or index[0] <= 0):
raise TypeError(
Expand Down Expand Up @@ -771,19 +773,19 @@ def _ndarray_getitem(self, key: tuple):

if len(key) == 1:
return self[key[0]]
if not isinstance(key[-1], slice):
return self[key[-1]][key[:-1]]
if not self._is_whole_slice(key[-1]):
if not isinstance(key[0], slice):
return self[key[0]][key[1:]]
if not self._is_whole_slice(key[0]):
# If it's not a slice of the whole array, first slice the
# current array (self), then replace with a slice of the whole
# array (this is how we determine that we're ready to traverse
# into the children)
this_key = key[-1]
result = self[this_key][key[:-1] + (slice(None), )]
this_key = key[0]
result = self[this_key][(slice(None), ) + key[1:]]
return result
# Last index is selecting the whole array, recurse into the
# children and slice off the inner indices
inner_ts = [t[key[:-1]] for t in self.ts]
inner_ts = [t[key[1:]] for t in self.ts]
# Get the type from the children and return the final value
return type(self)[len(self), type(inner_ts[0])](inner_ts)

Expand Down
51 changes: 21 additions & 30 deletions tests/test_primitives/test_set_index.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import os
import pytest

from hwtypes import BitVector, Bit
import fault

import magma as m
from magma.config import config as magma_config


@pytest.fixture(autouse=True)
def test_dir():
magma_config.compile_dir = 'normal'
yield
magma_config.compile_dir = 'callee_file_dir'


def test_set_index():
Expand All @@ -23,11 +32,7 @@ class test_set_index(m.Circuit):
I[int(idx)] = val
tester.circuit.O.expect(I)

m.compile("build/test_set_index", test_set_index)
tester.compile_and_run("verilator", skip_compile=True,
directory=os.path.join(os.path.dirname(__file__),
"build"),
flags=["-Wno-unused"])
tester.compile_and_run("verilator", tmp_dir=True, flags=["-Wno-unused"])


def test_set_index_array():
Expand All @@ -47,45 +52,31 @@ class test_set_index_array(m.Circuit):
I[int(idx)] = val
tester.circuit.O.expect(I)

m.compile("build/test_set_index_array", test_set_index_array)
tester.compile_and_run("verilator", skip_compile=True,
directory=os.path.join(os.path.dirname(__file__),
"build"),
flags=["-Wno-unused"])
tester.compile_and_run("verilator", tmp_dir=True, flags=["-Wno-unused"])


def test_set_ndindex_array():
class test_set_ndindex_array(m.Circuit):
io = m.IO(I=m.In(m.Array[(2, 2, 2), m.Bits[4]]),
io = m.IO(I=m.In(m.Array[(2, 4, 8), m.Bits[4]]),
val=m.In(m.Bits[4]),
idx_z=m.In(m.UInt[1]),
idx_y=m.In(m.UInt[1]),
idx_x=m.In(m.UInt[1]),
O=m.Out(m.Array[(2, 2, 2), m.Bits[4]]))
idx_y=m.In(m.UInt[2]),
idx_x=m.In(m.UInt[3]),
O=m.Out(m.Array[(2, 4, 8), m.Bits[4]]))
io.O @= m.set_index(io.I, io.val, [io.idx_z, io.idx_y, io.idx_x])

tester = fault.Tester(test_set_ndindex_array)
for i in range(5):
tester.circuit.I = I = [
[
[BitVector.random(4), BitVector.random(4)],
[BitVector.random(4), BitVector.random(4)]
],
[
[BitVector.random(4), BitVector.random(4)],
[BitVector.random(4), BitVector.random(4)]
]
]
[[BitVector.random(4) for _ in range(8)] for _ in range(4)]
for _ in range(2)
]
tester.circuit.val = val = BitVector.random(4)
tester.circuit.idx_z = idx_z = BitVector.random(1)
tester.circuit.idx_y = idx_y = BitVector.random(1)
tester.circuit.idx_x = idx_x = BitVector.random(1)
tester.circuit.idx_y = idx_y = BitVector.random(2)
tester.circuit.idx_x = idx_x = BitVector.random(3)
I[int(idx_z)][int(idx_y)][int(idx_x)] = val
tester.eval()
tester.circuit.O.expect(I)

m.compile("build/test_set_ndindex_array", test_set_ndindex_array)
tester.compile_and_run("verilator", skip_compile=True,
directory=os.path.join(os.path.dirname(__file__),
"build"),
flags=["-Wno-unused"])
tester.compile_and_run("verilator", tmp_dir=True, flags=["-Wno-unused"])
12 changes: 7 additions & 5 deletions tests/test_type/gold/test_ndarray_basic.v
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
module Main (
input [2:0] I [4:0],
output [4:0] O [2:0]
input [4:0] I [2:0],
output [2:0] O [4:0]
);
assign O[2] = {I[4][2],I[3][2],I[2][2],I[1][2],I[0][2]};
assign O[1] = {I[4][1],I[3][1],I[2][1],I[1][1],I[0][1]};
assign O[0] = {I[4][0],I[3][0],I[2][0],I[1][0],I[0][0]};
assign O[4] = {I[2][4],I[1][4],I[0][4]};
assign O[3] = {I[2][3],I[1][3],I[0][3]};
assign O[2] = {I[2][2],I[1][2],I[0][2]};
assign O[1] = {I[2][1],I[1][1],I[0][1]};
assign O[0] = {I[2][0],I[1][0],I[0][0]};
endmodule

108 changes: 47 additions & 61 deletions tests/test_type/gold/test_ndarray_dynamic_getitem.v
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ module coreir_reg #(
endmodule

module Register (
input [1:0] I [3:0][2:0],
output [1:0] O [3:0][2:0],
input [2:0] I [3:0][1:0],
output [2:0] O [3:0][1:0],
input CLK
);
wire [23:0] reg_P24_inst0_out;
wire [23:0] reg_P24_inst0_in;
assign reg_P24_inst0_in = {I[3][2],I[3][1],I[3][0],I[2][2],I[2][1],I[2][0],I[1][2],I[1][1],I[1][0],I[0][2],I[0][1],I[0][0]};
assign reg_P24_inst0_in = {I[3][1],I[3][0],I[2][1],I[2][0],I[1][1],I[1][0],I[0][1],I[0][0]};
coreir_reg #(
.clk_posedge(1'b1),
.init(24'h000000),
Expand All @@ -33,97 +33,83 @@ coreir_reg #(
.in(reg_P24_inst0_in),
.out(reg_P24_inst0_out)
);
assign O[3][2] = {reg_P24_inst0_out[23],reg_P24_inst0_out[22]};
assign O[3][1] = {reg_P24_inst0_out[21],reg_P24_inst0_out[20]};
assign O[3][0] = {reg_P24_inst0_out[19],reg_P24_inst0_out[18]};
assign O[2][2] = {reg_P24_inst0_out[17],reg_P24_inst0_out[16]};
assign O[2][1] = {reg_P24_inst0_out[15],reg_P24_inst0_out[14]};
assign O[2][0] = {reg_P24_inst0_out[13],reg_P24_inst0_out[12]};
assign O[1][2] = {reg_P24_inst0_out[11],reg_P24_inst0_out[10]};
assign O[1][1] = {reg_P24_inst0_out[9],reg_P24_inst0_out[8]};
assign O[1][0] = {reg_P24_inst0_out[7],reg_P24_inst0_out[6]};
assign O[0][2] = {reg_P24_inst0_out[5],reg_P24_inst0_out[4]};
assign O[0][1] = {reg_P24_inst0_out[3],reg_P24_inst0_out[2]};
assign O[0][0] = {reg_P24_inst0_out[1],reg_P24_inst0_out[0]};
assign O[3][1] = {reg_P24_inst0_out[23],reg_P24_inst0_out[22],reg_P24_inst0_out[21]};
assign O[3][0] = {reg_P24_inst0_out[20],reg_P24_inst0_out[19],reg_P24_inst0_out[18]};
assign O[2][1] = {reg_P24_inst0_out[17],reg_P24_inst0_out[16],reg_P24_inst0_out[15]};
assign O[2][0] = {reg_P24_inst0_out[14],reg_P24_inst0_out[13],reg_P24_inst0_out[12]};
assign O[1][1] = {reg_P24_inst0_out[11],reg_P24_inst0_out[10],reg_P24_inst0_out[9]};
assign O[1][0] = {reg_P24_inst0_out[8],reg_P24_inst0_out[7],reg_P24_inst0_out[6]};
assign O[0][1] = {reg_P24_inst0_out[5],reg_P24_inst0_out[4],reg_P24_inst0_out[3]};
assign O[0][0] = {reg_P24_inst0_out[2],reg_P24_inst0_out[1],reg_P24_inst0_out[0]};
endmodule

module Mux4xArray3_Array2_Bit (
input [1:0] I0 [2:0],
input [1:0] I1 [2:0],
input [1:0] I2 [2:0],
input [1:0] I3 [2:0],
module Mux4xArray2_Array3_Bit (
input [2:0] I0 [1:0],
input [2:0] I1 [1:0],
input [2:0] I2 [1:0],
input [2:0] I3 [1:0],
input [1:0] S,
output [1:0] O [2:0]
output [2:0] O [1:0]
);
reg [5:0] coreir_commonlib_mux4x6_inst0_out;
always @(*) begin
if (S == 0) begin
coreir_commonlib_mux4x6_inst0_out = {I0[2],I0[1],I0[0]};
coreir_commonlib_mux4x6_inst0_out = {I0[1],I0[0]};
end else if (S == 1) begin
coreir_commonlib_mux4x6_inst0_out = {I1[2],I1[1],I1[0]};
coreir_commonlib_mux4x6_inst0_out = {I1[1],I1[0]};
end else if (S == 2) begin
coreir_commonlib_mux4x6_inst0_out = {I2[2],I2[1],I2[0]};
coreir_commonlib_mux4x6_inst0_out = {I2[1],I2[0]};
end else begin
coreir_commonlib_mux4x6_inst0_out = {I3[2],I3[1],I3[0]};
coreir_commonlib_mux4x6_inst0_out = {I3[1],I3[0]};
end
end

assign O[2] = {coreir_commonlib_mux4x6_inst0_out[5],coreir_commonlib_mux4x6_inst0_out[4]};
assign O[1] = {coreir_commonlib_mux4x6_inst0_out[3],coreir_commonlib_mux4x6_inst0_out[2]};
assign O[0] = {coreir_commonlib_mux4x6_inst0_out[1],coreir_commonlib_mux4x6_inst0_out[0]};
assign O[1] = {coreir_commonlib_mux4x6_inst0_out[5],coreir_commonlib_mux4x6_inst0_out[4],coreir_commonlib_mux4x6_inst0_out[3]};
assign O[0] = {coreir_commonlib_mux4x6_inst0_out[2],coreir_commonlib_mux4x6_inst0_out[1],coreir_commonlib_mux4x6_inst0_out[0]};
endmodule

module Main (
output [1:0] rdata [2:0],
output [2:0] rdata [1:0],
input [1:0] raddr,
input CLK
);
wire [1:0] Mux4xArray3_Array2_Bit_inst0_O [2:0];
wire [1:0] Register_inst0_O [3:0][2:0];
wire [1:0] Mux4xArray3_Array2_Bit_inst0_I0 [2:0];
assign Mux4xArray3_Array2_Bit_inst0_I0[2] = Register_inst0_O[0][2];
assign Mux4xArray3_Array2_Bit_inst0_I0[1] = Register_inst0_O[0][1];
assign Mux4xArray3_Array2_Bit_inst0_I0[0] = Register_inst0_O[0][0];
wire [1:0] Mux4xArray3_Array2_Bit_inst0_I1 [2:0];
assign Mux4xArray3_Array2_Bit_inst0_I1[2] = Register_inst0_O[1][2];
assign Mux4xArray3_Array2_Bit_inst0_I1[1] = Register_inst0_O[1][1];
assign Mux4xArray3_Array2_Bit_inst0_I1[0] = Register_inst0_O[1][0];
wire [1:0] Mux4xArray3_Array2_Bit_inst0_I2 [2:0];
assign Mux4xArray3_Array2_Bit_inst0_I2[2] = Register_inst0_O[2][2];
assign Mux4xArray3_Array2_Bit_inst0_I2[1] = Register_inst0_O[2][1];
assign Mux4xArray3_Array2_Bit_inst0_I2[0] = Register_inst0_O[2][0];
wire [1:0] Mux4xArray3_Array2_Bit_inst0_I3 [2:0];
assign Mux4xArray3_Array2_Bit_inst0_I3[2] = Register_inst0_O[3][2];
assign Mux4xArray3_Array2_Bit_inst0_I3[1] = Register_inst0_O[3][1];
assign Mux4xArray3_Array2_Bit_inst0_I3[0] = Register_inst0_O[3][0];
Mux4xArray3_Array2_Bit Mux4xArray3_Array2_Bit_inst0 (
.I0(Mux4xArray3_Array2_Bit_inst0_I0),
.I1(Mux4xArray3_Array2_Bit_inst0_I1),
.I2(Mux4xArray3_Array2_Bit_inst0_I2),
.I3(Mux4xArray3_Array2_Bit_inst0_I3),
wire [2:0] Mux4xArray2_Array3_Bit_inst0_O [1:0];
wire [2:0] Register_inst0_O [3:0][1:0];
wire [2:0] Mux4xArray2_Array3_Bit_inst0_I0 [1:0];
assign Mux4xArray2_Array3_Bit_inst0_I0[1] = Register_inst0_O[0][1];
assign Mux4xArray2_Array3_Bit_inst0_I0[0] = Register_inst0_O[0][0];
wire [2:0] Mux4xArray2_Array3_Bit_inst0_I1 [1:0];
assign Mux4xArray2_Array3_Bit_inst0_I1[1] = Register_inst0_O[1][1];
assign Mux4xArray2_Array3_Bit_inst0_I1[0] = Register_inst0_O[1][0];
wire [2:0] Mux4xArray2_Array3_Bit_inst0_I2 [1:0];
assign Mux4xArray2_Array3_Bit_inst0_I2[1] = Register_inst0_O[2][1];
assign Mux4xArray2_Array3_Bit_inst0_I2[0] = Register_inst0_O[2][0];
wire [2:0] Mux4xArray2_Array3_Bit_inst0_I3 [1:0];
assign Mux4xArray2_Array3_Bit_inst0_I3[1] = Register_inst0_O[3][1];
assign Mux4xArray2_Array3_Bit_inst0_I3[0] = Register_inst0_O[3][0];
Mux4xArray2_Array3_Bit Mux4xArray2_Array3_Bit_inst0 (
.I0(Mux4xArray2_Array3_Bit_inst0_I0),
.I1(Mux4xArray2_Array3_Bit_inst0_I1),
.I2(Mux4xArray2_Array3_Bit_inst0_I2),
.I3(Mux4xArray2_Array3_Bit_inst0_I3),
.S(raddr),
.O(Mux4xArray3_Array2_Bit_inst0_O)
.O(Mux4xArray2_Array3_Bit_inst0_O)
);
wire [1:0] Register_inst0_I [3:0][2:0];
assign Register_inst0_I[3][2] = Register_inst0_O[3][2];
wire [2:0] Register_inst0_I [3:0][1:0];
assign Register_inst0_I[3][1] = Register_inst0_O[3][1];
assign Register_inst0_I[3][0] = Register_inst0_O[3][0];
assign Register_inst0_I[2][2] = Register_inst0_O[2][2];
assign Register_inst0_I[2][1] = Register_inst0_O[2][1];
assign Register_inst0_I[2][0] = Register_inst0_O[2][0];
assign Register_inst0_I[1][2] = Register_inst0_O[1][2];
assign Register_inst0_I[1][1] = Register_inst0_O[1][1];
assign Register_inst0_I[1][0] = Register_inst0_O[1][0];
assign Register_inst0_I[0][2] = Register_inst0_O[0][2];
assign Register_inst0_I[0][1] = Register_inst0_O[0][1];
assign Register_inst0_I[0][0] = Register_inst0_O[0][0];
Register Register_inst0 (
.I(Register_inst0_I),
.O(Register_inst0_O),
.CLK(CLK)
);
assign rdata[2] = Mux4xArray3_Array2_Bit_inst0_O[2];
assign rdata[1] = Mux4xArray3_Array2_Bit_inst0_O[1];
assign rdata[0] = Mux4xArray3_Array2_Bit_inst0_O[0];
assign rdata[1] = Mux4xArray2_Array3_Bit_inst0_O[1];
assign rdata[0] = Mux4xArray2_Array3_Bit_inst0_O[0];
endmodule

Loading

0 comments on commit 8d68a6c

Please sign in to comment.