Skip to content
This repository has been archived by the owner on Jan 29, 2025. It is now read-only.

[spv-in] Bubble up loop breaks, out of switch cases. #2323

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 181 additions & 17 deletions src/front/spv/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{

use super::{Error, Instruction, LookupExpression, LookupHelper as _};
use crate::front::Emitter;
use std::cell::Cell;

pub type BlockId = u32;

Expand Down Expand Up @@ -128,7 +129,7 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
expressions: &mut fun.expressions,
local_arena: &mut fun.local_variables,
const_arena: &mut module.constants,
type_arena: &module.types,
type_arena: &mut module.types,
global_arena: &module.global_variables,
arguments: &fun.arguments,
parameter_sampling: &mut parameters_sampling,
Expand Down Expand Up @@ -576,24 +577,54 @@ impl<'function> BlockContext<'function> {
}

/// Consumes the `BlockContext` producing a Ir [`Block`](crate::Block)
fn lower(mut self) -> crate::Block {
fn lower(self) -> crate::Block {
/// Smaller context type, with only the subset of `BlockContext`'s fields
/// that are needed for `lower_impl` below.
struct BlockLowerContext<'a, 'function> {
expressions: &'function mut Arena<crate::Expression>,
local_arena: &'function mut Arena<crate::LocalVariable>,
const_arena: &'function mut Arena<crate::Constant>,
type_arena: &'function mut crate::UniqueArena<crate::Type>,

blocks: crate::FastHashMap<spirv::Word, crate::Block>,
bodies: &'a [super::Body],
}

/// Helper type used for tracking the control-flow constructs which
/// support the [`Statement::Break`](crate::Statement::Break) syntax.
#[derive(Copy, Clone, Debug)]
enum Breakable<'a> {
Loop,
Switch {
/// This `Cell<Option<...>>` is set to `Some(loop_break_cond_var_ptr)`
/// when a loop `break` is found nested in the switch `case`s,
/// and `loop_break_cond_var_ptr` is a pointer to a `bool` local,
/// which (dynamically) tracks whether the nested loop `break`
/// was actually reached (and so the `switch` gets to handle
/// `break`-ing out of the `loop`, after the `switch` itself).
bubbled_up_loop_break_cond_var_ptr: &'a Cell<Option<Handle<crate::Expression>>>,
},
}

fn lower_impl(
blocks: &mut crate::FastHashMap<spirv::Word, crate::Block>,
bodies: &[super::Body],
lower_ctx: &mut BlockLowerContext<'_, '_>,
body_idx: BodyIndex,
innermost_breakable: Option<Breakable>,
) -> crate::Block {
let mut block = crate::Block::new();

for item in bodies[body_idx].data.iter() {
for item in lower_ctx.bodies[body_idx].data.iter() {
match *item {
super::BodyFragment::BlockId(id) => block.append(blocks.get_mut(&id).unwrap()),
super::BodyFragment::BlockId(id) => {
block.append(lower_ctx.blocks.get_mut(&id).unwrap())
}
super::BodyFragment::If {
condition,
accept,
reject,
} => {
let accept = lower_impl(blocks, bodies, accept);
let reject = lower_impl(blocks, bodies, reject);
let accept = lower_impl(lower_ctx, accept, innermost_breakable);
let reject = lower_impl(lower_ctx, reject, innermost_breakable);

block.push(
crate::Statement::If {
Expand All @@ -609,8 +640,11 @@ impl<'function> BlockContext<'function> {
continuing,
break_if,
} => {
let body = lower_impl(blocks, bodies, body);
let continuing = lower_impl(blocks, bodies, continuing);
let body = lower_impl(lower_ctx, body, Some(Breakable::Loop));
// NOTE(eddyb) the `continuing {...}` block cannot `break`,
// but this is checked in the validator, and so it's allowed
// here (where we could only panic, which is worse UX).
let continuing = lower_impl(lower_ctx, continuing, Some(Breakable::Loop));

block.push(
crate::Statement::Loop {
Expand All @@ -626,10 +660,17 @@ impl<'function> BlockContext<'function> {
ref cases,
default,
} => {
let bubbled_up_loop_break_cond_var_ptr = &Cell::new(None);
let mut ir_cases: Vec<_> = cases
.iter()
.map(|&(value, body_idx)| {
let body = lower_impl(blocks, bodies, body_idx);
let body = lower_impl(
lower_ctx,
body_idx,
Some(Breakable::Switch {
bubbled_up_loop_break_cond_var_ptr,
}),
);

// Handle simple cases that would make a fallthrough statement unreachable code
let fall_through = body.last().map_or(true, |s| !s.is_terminator());
Expand All @@ -643,7 +684,13 @@ impl<'function> BlockContext<'function> {
.collect();
ir_cases.push(crate::SwitchCase {
value: crate::SwitchValue::Default,
body: lower_impl(blocks, bodies, default),
body: lower_impl(
lower_ctx,
default,
Some(Breakable::Switch {
bubbled_up_loop_break_cond_var_ptr,
}),
),
fall_through: false,
});

Expand All @@ -653,20 +700,137 @@ impl<'function> BlockContext<'function> {
cases: ir_cases,
},
crate::Span::default(),
)
}
super::BodyFragment::Break => {
block.push(crate::Statement::Break, crate::Span::default())
);

if let Some(loop_break_cond_var_ptr) =
bubbled_up_loop_break_cond_var_ptr.get()
{
match innermost_breakable.expect("stray loop `break`") {
Breakable::Loop => {}
Breakable::Switch { .. } => {
unreachable!(
"loop `break` from multiple levels of nested switch"
)
}
}

let mut emitter = Emitter::default();
emitter.start(lower_ctx.expressions);
let loop_break_cond = lower_ctx.expressions.append(
crate::Expression::Load {
pointer: loop_break_cond_var_ptr,
},
crate::Span::default(),
);
block.extend(emitter.finish(lower_ctx.expressions));

block.push(
crate::Statement::If {
condition: loop_break_cond,
accept: crate::Block::from_vec(vec![crate::Statement::Break]),
reject: crate::Block::new(),
},
crate::Span::default(),
);
}
}
super::BodyFragment::Continue => {
block.push(crate::Statement::Continue, crate::Span::default())
}
super::BodyFragment::LoopBreak => {
match innermost_breakable.expect("stray loop `break`") {
Breakable::Loop => {}
Breakable::Switch {
bubbled_up_loop_break_cond_var_ptr,
} => {
let bool_ty = lower_ctx.type_arena.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
},
},
crate::Span::default(),
);
let mut bool_const = |value| {
lower_ctx.const_arena.fetch_or_append(
crate::Constant {
name: None,
specialization: None,
inner: crate::ConstantInner::boolean(value),
},
crate::Span::default(),
)
};
if bubbled_up_loop_break_cond_var_ptr.get().is_none() {
let local = lower_ctx.local_arena.append(
crate::LocalVariable {
name: None,
ty: bool_ty,
init: Some(bool_const(false)),
},
crate::Span::default(),
);
let local_ptr = lower_ctx.expressions.append(
crate::Expression::LocalVariable(local),
crate::Span::default(),
);
bubbled_up_loop_break_cond_var_ptr.set(Some(local_ptr));
}

// Store a `true` in the local variable that the
// parent `switch` will read and use to decide
// whether to `break` out of its parent `loop`.
block.push(
crate::Statement::Store {
pointer: bubbled_up_loop_break_cond_var_ptr.get().unwrap(),
value: lower_ctx.expressions.append(
crate::Expression::Constant(bool_const(true)),
crate::Span::default(),
),
},
crate::Span::default(),
);
}
}
block.push(crate::Statement::Break, crate::Span::default())
}
super::BodyFragment::SwitchBreak => {
match innermost_breakable.expect("stray switch `break`") {
Breakable::Loop => {
unreachable!("switch `break` from nested loop")
}
Breakable::Switch { .. } => {}
}
block.push(crate::Statement::Break, crate::Span::default())
}
}
}

block
}

lower_impl(&mut self.blocks, &self.bodies, 0)
let Self {
expressions,
local_arena,
const_arena,
type_arena,
blocks,
ref bodies,
..
} = self;
lower_impl(
&mut BlockLowerContext {
expressions,
local_arena,
const_arena,
type_arena,
blocks,
bodies,
},
0,
None,
)
}
}
10 changes: 5 additions & 5 deletions src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,9 @@ enum BodyFragment {
cases: Vec<(i32, BodyIndex)>,
default: BodyIndex,
},
Break,
Continue,
LoopBreak,
SwitchBreak,
}

/// An intermediate representation of a Naga [`Block`].
Expand Down Expand Up @@ -505,7 +506,7 @@ struct BlockContext<'function> {
/// Constants arena of the module being processed
const_arena: &'function mut Arena<crate::Constant>,
/// Type arena of the module being processed
type_arena: &'function UniqueArena<crate::Type>,
type_arena: &'function mut UniqueArena<crate::Type>,
/// Global arena of the module being processed
global_arena: &'function Arena<crate::GlobalVariable>,
/// Arguments of the function currently being processed
Expand Down Expand Up @@ -1251,9 +1252,8 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
fn merger(body: &mut Body, target: &MergeBlockInformation) {
body.data.push(match *target {
MergeBlockInformation::LoopContinue => BodyFragment::Continue,
MergeBlockInformation::LoopMerge | MergeBlockInformation::SwitchMerge => {
BodyFragment::Break
}
MergeBlockInformation::LoopMerge => BodyFragment::LoopBreak,
MergeBlockInformation::SwitchMerge => BodyFragment::SwitchBreak,

// Finishing a selection merge means just falling off the end of
// the `accept` or `reject` block of the `If` statement.
Expand Down
Binary file added tests/in/spv/loop-break-from-switch.spv
Binary file not shown.
65 changes: 65 additions & 0 deletions tests/in/spv/loop-break-from-switch.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
;; Ensure that `break`s out of `switch`es and out of `loop`s are correctly
;; distinguished, and more specifically that a loop break from inside a switch
;; `case` *does not* become a switch break, but actually breaks out of the loop.
;;
;; The SPIR-V below was made by converting this WGSL to SPIR-V with `naga`, and
;; then optimizing the result of that with `spirv-opt -O` (which is what turned
;; the `return 0;`, into a branch to the loop merge target i.e. a loop break):
;; ```wgsl
;; @fragment
;; fn main(@location(0) dyn_case: i32) -> @location(0) i32 {
;; loop {
;; switch(dyn_case) {
;; case 0: {
;; return 0;
;; }
;; default: {
;; break;
;; }
;; }
;; return -9;
;; }
;; return -9999;
;; }
;; ```

OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %13 "main" %8 %11
OpExecutionMode %13 OriginUpperLeft
OpDecorate %8 Location 0
OpDecorate %8 Flat
OpDecorate %11 Location 0
%void = OpTypeVoid
%int = OpTypeInt 32 1
%int_0 = OpConstant %int 0
%int_n9 = OpConstant %int -9
%_ptr_Input_int = OpTypePointer Input %int
%8 = OpVariable %_ptr_Input_int Input
%_ptr_Output_int = OpTypePointer Output %int
%11 = OpVariable %_ptr_Output_int Output
%14 = OpTypeFunction %void
%13 = OpFunction %void None %14
%7 = OpLabel
%10 = OpLoad %int %8
OpBranch %16
%16 = OpLabel
OpLoopMerge %17 %19 None
OpBranch %18
%18 = OpLabel
OpSelectionMerge %20 None
OpSwitch %10 %22 0 %21
%21 = OpLabel
OpStore %11 %int_0
OpBranch %17
%22 = OpLabel
OpBranch %20
%20 = OpLabel
OpStore %11 %int_n9
OpBranch %17
%19 = OpLabel
OpBranch %16
%17 = OpLabel
OpReturn
OpFunctionEnd
Loading