Skip to content

Commit

Permalink
refactor: update flags, add support for extra1
Browse files Browse the repository at this point in the history
  • Loading branch information
Fumuran committed Jun 21, 2023
1 parent dadfc24 commit f2a39c2
Showing 1 changed file with 83 additions and 68 deletions.
151 changes: 83 additions & 68 deletions constraints/miden-vm/stack.air
Original file line number Diff line number Diff line change
Expand Up @@ -4,89 +4,86 @@ mod StackAir

# Flags for the first bits (op_bits[6], op_bits[5], op_bits[4])

fn f_000(op_bits: vector[8]) -> scalar:
fn f_000(op_bits: vector[9]) -> scalar:
return !op_bits[6] & !op_bits[5] & !op_bits[4]

fn f_001(op_bits: vector[8]) -> scalar:
fn f_001(op_bits: vector[9]) -> scalar:
return !op_bits[6] & !op_bits[5] & op_bits[4]

fn f_010(op_bits: vector[8]) -> scalar:
fn f_010(op_bits: vector[9]) -> scalar:
return !op_bits[6] & op_bits[5] & !op_bits[4]

fn f_011(op_bits: vector[8]) -> scalar:
fn f_011(op_bits: vector[9]) -> scalar:
return !op_bits[6] & op_bits[5] & op_bits[4]

# This flag is equal to f_100
fn f_u32rc(op_bits: vector[8]) -> scalar:
fn f_u32rc(op_bits: vector[9]) -> scalar:
return op_bits[6] & !op_bits[5] & !op_bits[4]

fn f_101(op_bits: vector[8]) -> scalar:
return op_bits[6] & !op_bits[5] & op_bits[4]


# Flags for the four last bits (op_bits[3], op_bits[2], op_bits[1], op_bits[0])

fn f_x0000(op_bits: vector[8]) -> scalar:
fn f_x0000(op_bits: vector[9]) -> scalar:
return !op_bits[3] & !op_bits[2] & !op_bits[1] & !op_bits[0]

fn f_x0001(op_bits: vector[8]) -> scalar:
fn f_x0001(op_bits: vector[9]) -> scalar:
return !op_bits[3] & !op_bits[2] & !op_bits[1] & op_bits[0]

fn f_x0010(op_bits: vector[8]) -> scalar:
fn f_x0010(op_bits: vector[9]) -> scalar:
return !op_bits[3] & !op_bits[2] & op_bits[1] & !op_bits[0]

fn f_x0011(op_bits: vector[8]) -> scalar:
fn f_x0011(op_bits: vector[9]) -> scalar:
return !op_bits[3] & !op_bits[2] & op_bits[1] & op_bits[0]

fn f_x0100(op_bits: vector[8]) -> scalar:
fn f_x0100(op_bits: vector[9]) -> scalar:
return !op_bits[3] & op_bits[2] & !op_bits[1] & !op_bits[0]

fn f_x0101(op_bits: vector[8]) -> scalar:
fn f_x0101(op_bits: vector[9]) -> scalar:
return !op_bits[3] & op_bits[2] & !op_bits[1] & op_bits[0]

fn f_x0110(op_bits: vector[8]) -> scalar:
fn f_x0110(op_bits: vector[9]) -> scalar:
return !op_bits[3] & op_bits[2] & op_bits[1] & !op_bits[0]

fn f_x0111(op_bits: vector[8]) -> scalar:
fn f_x0111(op_bits: vector[9]) -> scalar:
return !op_bits[3] & op_bits[2] & op_bits[1] & op_bits[0]

fn f_x1000(op_bits: vector[8]) -> scalar:
fn f_x1000(op_bits: vector[9]) -> scalar:
return op_bits[3] & !op_bits[2] & !op_bits[1] & !op_bits[0]

fn f_x1001(op_bits: vector[8]) -> scalar:
fn f_x1001(op_bits: vector[9]) -> scalar:
return op_bits[3] & !op_bits[2] & !op_bits[1] & op_bits[0]

fn f_x1010(op_bits: vector[8]) -> scalar:
fn f_x1010(op_bits: vector[9]) -> scalar:
return op_bits[3] & !op_bits[2] & op_bits[1] & !op_bits[0]

fn f_x1011(op_bits: vector[8]) -> scalar:
fn f_x1011(op_bits: vector[9]) -> scalar:
return op_bits[3] & !op_bits[2] & op_bits[1] & op_bits[0]

fn f_x1100(op_bits: vector[8]) -> scalar:
fn f_x1100(op_bits: vector[9]) -> scalar:
return op_bits[3] & op_bits[2] & !op_bits[1] & !op_bits[0]

fn f_x1101(op_bits: vector[8]) -> scalar:
fn f_x1101(op_bits: vector[9]) -> scalar:
return op_bits[3] & op_bits[2] & !op_bits[1] & op_bits[0]

fn f_x1110(op_bits: vector[8]) -> scalar:
fn f_x1110(op_bits: vector[9]) -> scalar:
return op_bits[3] & op_bits[2] & op_bits[1] & !op_bits[0]

fn f_x1111(op_bits: vector[8]) -> scalar:
fn f_x1111(op_bits: vector[9]) -> scalar:
return op_bits[3] & op_bits[2] & op_bits[1] & op_bits[0]


# Composite flags

fn f_shr(op_bits: vector[8]) -> scalar:
fn f_shr(op_bits: vector[9]) -> scalar:
return !op_bits[6] & op_bits[5] & op_bits[4] + f_u32split(op_bits) + f_push(op_bits)

# hasher[5] = op_helpers[3], where hahser[] are decoder columns, which are the same as helper[] -- columns from the stack
fn f_shl(op_bits: vector[8], op_helpers: vector[6]) -> scalar:
fn f_shl(op_bits: vector[9], op_helpers: vector[6]) -> scalar:
let f_add3_mad = op_bits[6] & !op_bits[5] & !op_bits[4] & op_bits[3] & op_bits[2]
let f_split_loop = op_bits[6] & !op_bits[5] & op_bits[4] & op_bits[3] & op_bits[2]
return !op_bits[6] & op_bits[5] & !op_bits[4] + f_add3_mad + f_split_loop + f_repeat(op_bits) + f_end(op_bits) * op_helpers[3]

fn f_ctrl(op_bits: vector[8]) -> scalar:
fn f_ctrl(op_bits: vector[9]) -> scalar:
# flag for SPAN, JOIN, SPLIT, LOOP
let f_sjsl = op_bits[6] & !op_bits[5] & op_bits[4] & op_bits[3]

Expand All @@ -96,7 +93,7 @@ fn f_ctrl(op_bits: vector[8]) -> scalar:
return f_sjsl + f_errh + f_call(op_bits) + f_syscall(op_bits)


fn compute_op_flags(op_bits: vector[8]) -> vector[88]:
fn compute_op_flags(op_bits: vector[9]) -> vector[88]:
return [
# No stack shift operations
f_000(op_bits) & f_x0000(op_bits), # noop
Expand Down Expand Up @@ -178,25 +175,27 @@ fn compute_op_flags(op_bits: vector[8]) -> vector[88]:
f_u32rc(op_bits) & op_bits[3] & op_bits[2] & !op_bits[1], # u32add3
f_u32rc(op_bits) & op_bits[3] & op_bits[2] & op_bits[1], # u32madd


# High-degree operations
f_101(op_bits) & !op_bits[3] & !op_bits[2] & !op_bits[1], # hperm
f_101(op_bits) & !op_bits[3] & !op_bits[2] & op_bits[1], # mpverify
f_101(op_bits) & !op_bits[3] & op_bits[2] & !op_bits[1], # pipe
f_101(op_bits) & !op_bits[3] & op_bits[2] & op_bits[1], # mstream
f_101(op_bits) & op_bits[3] & !op_bits[2] & !op_bits[1], # span
f_101(op_bits) & op_bits[3] & !op_bits[2] & op_bits[1], # join
f_101(op_bits) & op_bits[3] & op_bits[2] & !op_bits[1], # split
f_101(op_bits) & op_bits[3] & op_bits[2] & op_bits[1], # loop
op_bits[7] & f_x0000(op_bits), # hperm
op_bits[7] & f_x0001(op_bits), # mpverify
op_bits[7] & f_x0010(op_bits), # pipe
op_bits[7] & f_x0011(op_bits), # mstream
op_bits[7] & f_x0100(op_bits), # split
op_bits[7] & f_x0101(op_bits), # loop
op_bits[7] & f_x0110(op_bits), # span
op_bits[7] & f_x0111(op_bits), # join


# Very high-degree operations
op_bits[7] & !op_bits[4] & !op_bits[3] & !op_bits[2], # mrupdate
op_bits[7] & !op_bits[4] & !op_bits[3] & op_bits[2], # push
op_bits[7] & !op_bits[4] & op_bits[3] & !op_bits[2], # syscall
op_bits[7] & !op_bits[4] & op_bits[3] & op_bits[2], # call
op_bits[7] & op_bits[4] & !op_bits[3] & !op_bits[2], # end
op_bits[7] & op_bits[4] & !op_bits[3] & op_bits[2], # repeat
op_bits[7] & op_bits[4] & op_bits[3] & !op_bits[2], # respan
op_bits[7] & op_bits[4] & op_bits[3] & op_bits[2], # halt
op_bits[8] & !op_bits[4] & !op_bits[3] & !op_bits[2], # mrupdate
op_bits[8] & !op_bits[4] & !op_bits[3] & op_bits[2], # push
op_bits[8] & !op_bits[4] & op_bits[3] & !op_bits[2], # syscall
op_bits[8] & !op_bits[4] & op_bits[3] & op_bits[2], # call
op_bits[8] & op_bits[4] & !op_bits[3] & !op_bits[2], # end
op_bits[8] & op_bits[4] & !op_bits[3] & op_bits[2], # repeat
op_bits[8] & op_bits[4] & op_bits[3] & !op_bits[2], # respan
op_bits[8] & op_bits[4] & op_bits[3] & op_bits[2], # halt
]


Expand All @@ -215,23 +214,31 @@ ev check_element_validity([op_helpers[6]]):

# Enforces that the last bit of the opcode (op_bits[0]) is always set to 0. This evaluator is used
# for u32 operations where the last bit of the opcode is not used in computation of the flag.
ev b0_is_zero([op_bits[8]]):
enf op_bits[6] & !op_bits[5] & op_bits[0] = 0
ev b0_is_zero([op_bits[9]]):
enf op_bits[6] & !op_bits[5] & !op_bits[4] & op_bits[0] = 0

# Enforces that the last two bits of the opcode (op_bits[0] and op_bits[1]) are always set to 0.
# This evaluator is used for very-high degree operations where the last two bits of the opcode are
# not used in computation of the flag.
ev b0_b1_is_zero([op_bits[8]]):
ev b0_b1_is_zero([op_bits[9]]):
enf op_bits[6] & op_bits[5] & op_bits[0] = 0
enf op_bits[6] & op_bits[5] & op_bits[1] = 0

# Enforces that register extra0 is set to 1 when high-degree operations are executed.
ev extra0([op_bits[9]]):
op_bits[7] = 1 when op_bits[6] & !op_bits[5] & op_bits[4]

# Enforces that register extra1 is set to 1 when very high-degree operations are executed.
ev extra1([op_bits[9]]):
op_bits[8] = 1 when op_bits[6] & op_bits[5]


### Stack Air Constraints #########################################################################

# Enforces the constraints on the stack.
# TODO: add docs for columns
# stack_helpers consists of [bookkeeping[0], bookkeeping[1], h0]
# op_bits consists of [op_bits[7], extra]
# op_bits consists of [op_bits[7], extra0, extra1]
ev stack_constraints([stack_top[16], stack_helpers[3], op_bits[8], op_helpers[6], clk, fmp]):
let op_flags = compute_op_flags(op_bits)

Expand Down Expand Up @@ -607,57 +614,57 @@ ev clk([s[16], clk]):

# u32 operations

ev u32add([s[16], op_bits[8], op_helpers[6]]):
ev u32add([s[16], op_bits[9], op_helpers[6]]):
enf s[0] + s[1] = 2^32 * op_helpers[2] + 2^16 * op_helpers[1] + op_helpers[0]
enf s[0]' = op_helpers[2]
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
enf s[i]' = s[i] for i in 2..16
enf b0_is_zero([op_bits])

ev u32sub([s[16], op_bits[8], op_helpers[6]]):
ev u32sub([s[16], op_bits[9], op_helpers[6]]):
enf s[1] = s[0] + s[1]' + 2^32 * s[0]'
enf (s[0]')^2 - s[0]' = 0
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
enf s[i]' = s[i] for i in 2..16
enf b0_is_zero([op_bits])

ev u32mul([s[16], op_bits[8], op_helpers[6]]):
ev u32mul([s[16], op_bits[9], op_helpers[6]]):
enf s[0] * s[1] = 2^48 * op_helpers[3] + 2^32 * op_helpers[2] + 2^16 * op_helpers[1] + op_helpers[0]
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
enf s[0]' = 2^16 * op_helpers[3] + op_helpers[2]
enf check_element_validity([op_helpers])
enf s[i]' = s[i] for i in 2..16
enf b0_is_zero([op_bits])

ev u32div([s[16], op_bits[8], op_helpers[6]]):
ev u32div([s[16], op_bits[9], op_helpers[6]]):
enf s[1] = s[0] * s[1]' + s[0]'
enf s[1] - s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
enf s[0] - s[0]' - 1 = 2^16 * op_helpers[2] + op_helpers[3]
enf s[i]' = s[i] for i in 2..16
enf b0_is_zero([op_bits])

ev u32split([s[16], op_bits[8], op_helpers[6]]):
ev u32split([s[16], op_bits[9], op_helpers[6]]):
enf s[0] = 2^48 * op_helpers[3] + 2^32 * op_helpers[2] + 2^16 * op_helpers[1] + op_helpers[0]
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
enf s[0]' = 2^16 * op_helpers[3] + op_helpers[2]
enf check_element_validity([op_helpers])
enf s[i + 1]' = s[i] for i in 1..15
enf b0_is_zero([op_bits])

ev u32assert2([s[16], op_bits[8], op_helpers[6]]):
ev u32assert2([s[16], op_bits[9], op_helpers[6]]):
enf s[0]' = 2^16 * op_helpers[3] + op_helpers[2]
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
enf s[i]' = s[i] for i in 0..16
enf b0_is_zero([op_bits])

ev u32add3([s[16], op_bits[8], op_helpers[6]]):
ev u32add3([s[16], op_bits[9], op_helpers[6]]):
enf s[0] + s[1] + s[2] = 2^32 * op_helpers[2] + 2^16 * op_helpers[1] + op_helpers[0]
enf s[0]' = op_helpers[2]
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
enf s[i]' = s[i + 1] for i in 2..15
enf b0_is_zero([op_bits])

ev u32madd([s[16], op_bits[8], op_helpers[6]]):
ev u32madd([s[16], op_bits[9], op_helpers[6]]):
enf s[0] * s[1] + s[2] = 2^48 * op_helpers[3] + 2^32 * op_helpers[2] + 2^16 * op_helpers[1] + op_helpers[0]
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
enf s[0]' = 2^16 * op_helpers[3] + op_helpers[2]
Expand All @@ -668,28 +675,31 @@ ev u32madd([s[16], op_bits[8], op_helpers[6]]):

# High-degree operations

# Bus constraint is implemented in a separate file
ev hperm([s[16], op_bits[8], op_helpers[6]]):
ev hperm([s[16], op_bits[9], op_helpers[6]]):
enf s[i]' = s[i] for i in 12..16
enf b0_is_zero([op_bits])
enf extra0(op_bits)
# Bus constraint is implemented in a separate file

# Bus constraint is implemented in a separate file
ev mpverify([s[16], op_bits[8], op_helpers[6]]):
ev mpverify([s[16], op_bits[9], op_helpers[6]]):
enf s[i]' = s[i] for i in 0..16
enf b0_is_zero([op_bits])
enf extra0(op_bits)
# Bus constraint is implemented in a separate file


# TODO: add constraints
ev pipe([s[16], op_bits[8], op_helpers[6]]):
ev pipe([s[16], op_bits[9], op_helpers[6]]):


# Bus constraint is implemented in a separate file
ev mstream([s[16], op_bits[8], op_helpers[6]]):
ev mstream([s[16], op_bits[9], op_helpers[6]]):
enf s[12]' = s[12] + 2
enf s[i]' = s[i] for i in 8..12
enf s[i]' = s[i] for i in 13..16
enf extra0(op_bits)
# Bus constraint is implemented in a separate file


# TODO: add constraints
ev span([s[16], op_bits[8], op_helpers[6]])
ev span([s[16], op_bits[9], op_helpers[6]])


# TODO: add constraints
Expand All @@ -706,13 +716,18 @@ ev loop()

# Very high-degree operations

# Bus constraint is implemented in a separate file
ev mrupdate([s[16], op_bits[8], op_helpers[6]]):
ev mrupdate([s[16], op_bits[9], op_helpers[6]]):
enf s[i]' = s[i] for i in 4..16
enf b0_b1_is_zero([op_bits])
enf extra1(op_bits)
# Bus constraint is implemented in a separate file


ev push([s[16]]):
enf s[i + 1]' = s[i] for i in 0..15
enf b0_b1_is_zero([op_bits])
enf extra1(op_bits)


# TODO: add constraints
ev syscall():
Expand Down

0 comments on commit f2a39c2

Please sign in to comment.