Skip to content

Commit

Permalink
wgsl-in, glsl-in: Short circuit || and &&
Browse files Browse the repository at this point in the history
  • Loading branch information
adeline-sparks committed Aug 18, 2022
1 parent 48e7938 commit 55c898d
Show file tree
Hide file tree
Showing 19 changed files with 1,414 additions and 493 deletions.
114 changes: 114 additions & 0 deletions src/front/glsl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,120 @@ impl Context {
self.add_expression(Expression::Constant(constant), meta, body)
}
HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => {
// Logical operators must short circuit by emitting an if statement.
// Handle them as a special case.
if let BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr = op {
// Lower the lhs, then emit all expressions lowered so far.
let (mut left, left_meta) =
self.lower_expect_inner(stmt, parser, left, ExprPos::Rhs, body)?;
self.emit_restart(body);

// Lower the rhs into a special body for use in the if statement.
let mut right_body = Block::new();
let (mut right, right_meta) = self.lower_expect_inner(
stmt,
parser,
right,
ExprPos::Rhs,
&mut right_body,
)?;
self.emit_restart(&mut right_body);

// Type check and emit a conversion if necessary.
parser.typifier_grow(self, left, left_meta)?;
parser.typifier_grow(self, right, right_meta)?;
self.binary_implicit_conversion(
parser, &mut left, left_meta, &mut right, right_meta,
)?;
self.emit_end(body);

// Create a temporary local to hold the result of the operator.
let bool_ty = parser.module.types.insert(
Type {
name: None,
inner: TypeInner::Scalar {
kind: ScalarKind::Bool,
width: crate::BOOL_WIDTH,
},
},
Default::default(),
);
let local = self.locals.append(
LocalVariable {
name: None,
ty: bool_ty,
init: None,
},
meta,
);
let local_expr = self
.expressions
.append(Expression::LocalVariable(local), meta);

// Store the result of the RHS to the local.
right_body.push(
Statement::Store {
pointer: local_expr,
value: right,
},
right_meta,
);

// Create a value representing the result of the operator if it short circuits and does not evaluate the RHS.
let short_circuit_value = match op {
BinaryOperator::LogicalAnd => false,
BinaryOperator::LogicalOr => true,
_ => unreachable!(),
};
let short_circuit_constant = parser.module.constants.fetch_or_append(
Constant {
name: None,
specialization: None,
inner: crate::ConstantInner::boolean(short_circuit_value),
},
Default::default(),
);
let short_circuit_expr = self
.expressions
.append(Expression::Constant(short_circuit_constant), meta);

// Create a short circuit body, which assigns the short circuit value to the local.
let mut short_circuit_body = Block::new();
short_circuit_body.push(
Statement::Store {
pointer: local_expr,
value: short_circuit_expr,
},
meta,
);

// Add an if statement which either evaluates the RHS block or the short circuit block.
let (accept, reject) = match op {
BinaryOperator::LogicalAnd => (right_body, short_circuit_body),
BinaryOperator::LogicalOr => (short_circuit_body, right_body),
_ => unimplemented!(),
};
body.push(
Statement::If {
condition: left,
accept,
reject,
},
meta,
);

// The result of lowering this operator is just the local in which the result is stored.
self.emit_start();
let load_local_expr = self.expressions.append(
Expression::Load {
pointer: local_expr,
},
meta,
);

return Ok((Some(load_local_expr), meta));
};

let (mut left, left_meta) =
self.lower_expect_inner(stmt, parser, left, ExprPos::Rhs, body)?;
let (mut right, right_meta) =
Expand Down
124 changes: 121 additions & 3 deletions src/front/wgsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ struct ExpressionContext<'input, 'temp, 'out> {
types: &'out mut UniqueArena<crate::Type>,
constants: &'out mut Arena<crate::Constant>,
global_vars: &'out Arena<crate::GlobalVariable>,
local_vars: &'out Arena<crate::LocalVariable>,
local_vars: &'out mut Arena<crate::LocalVariable>,
arguments: &'out [crate::FunctionArgument],
functions: &'out Arena<crate::Function>,
block: &'temp mut crate::Block,
Expand Down Expand Up @@ -897,6 +897,124 @@ impl<'a> ExpressionContext<'a, '_, '_> {
Ok(accumulator)
}

fn parse_binary_short_circuit_op(
&mut self,
lexer: &mut Lexer<'a>,
classifier: impl Fn(Token<'a>) -> Option<crate::BinaryOperator>,
mut parser: impl FnMut(
&mut Lexer<'a>,
ExpressionContext<'a, '_, '_>,
) -> Result<TypedExpression, Error<'a>>,
) -> Result<TypedExpression, Error<'a>> {
let start = lexer.current_byte_offset() as u32;
let mut accumulator = parser(lexer, self.reborrow())?;
while let Some(op) = classifier(lexer.peek().0) {
let _ = lexer.next();

// Apply load rule to the lhs.
let left = self.apply_load_rule(accumulator);

// Emit all previous expressions, and prepare a body for the rhs.
self.block.extend(self.emitter.finish(self.expressions));
let mut rhs_body = crate::Block::new();
let mut rhs_ctx = self.reborrow();
rhs_ctx.block = &mut rhs_body;
rhs_ctx.emitter.start(rhs_ctx.expressions);

// Parse the rhs using the rhs body.
let unloaded_right = parser(lexer, rhs_ctx)?;
let end = lexer.current_byte_offset() as u32;
let span = NagaSpan::new(start, end);
let right = self.apply_load_rule(unloaded_right);
rhs_body.extend(self.emitter.finish(self.expressions));

// Create a temporary local to store the result of the operator.
let local_ty = self.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
},
},
Default::default(),
);
let local = self.local_vars.append(
crate::LocalVariable {
name: None,
ty: local_ty,
init: None,
},
span,
);

// Assign the rhs expression to the local variable inside the rhs body.
let local_expr = self
.expressions
.append(crate::Expression::LocalVariable(local), span);

rhs_body.push(
crate::Statement::Store {
pointer: local_expr,
value: right,
},
span,
);

// Make a short circuit body which assigns the local variable to the short circuit value.
let short_circuit_value = match op {
crate::BinaryOperator::LogicalAnd => false,
crate::BinaryOperator::LogicalOr => true,
_ => unreachable!(),
};

let short_circuit_constant = self.constants.fetch_or_append(
crate::Constant {
name: None,
specialization: None,
inner: ConstantInner::boolean(short_circuit_value),
},
Default::default(),
);
let short_circuit_expr = self
.expressions
.append(crate::Expression::Constant(short_circuit_constant), span);

let mut short_circuit_body = crate::Block::new();
short_circuit_body.push(
crate::Statement::Store {
pointer: local_expr,
value: short_circuit_expr,
},
span,
);

// Append an if statement. This implements the short circuiting behavior, by deciding whether to
// evaluate the rhs based on the value of the lhs.
let (accept, reject) = match op {
crate::BinaryOperator::LogicalAnd => (rhs_body, short_circuit_body),
crate::BinaryOperator::LogicalOr => (short_circuit_body, rhs_body),
_ => unreachable!(),
};
self.block.push(
crate::Statement::If {
condition: left,
accept,
reject,
},
span,
);

// Prepare to resume parsing expressions.
self.emitter.start(self.expressions);
accumulator = TypedExpression {
handle: local_expr,
is_reference: true,
};
}
Ok(accumulator)
}

fn parse_binary_splat_op(
&mut self,
lexer: &mut Lexer<'a>,
Expand Down Expand Up @@ -2683,15 +2801,15 @@ impl Parser {
) -> Result<(TypedExpression, Span), Error<'a>> {
self.push_scope(Scope::GeneralExpr, lexer);
// logical_or_expression
let handle = context.parse_binary_op(
let handle = context.parse_binary_short_circuit_op(
lexer,
|token| match token {
Token::LogicalOperation('|') => Some(crate::BinaryOperator::LogicalOr),
_ => None,
},
// logical_and_expression
|lexer, mut context| {
context.parse_binary_op(
context.parse_binary_short_circuit_op(
lexer,
|token| match token {
Token::LogicalOperation('&') => Some(crate::BinaryOperator::LogicalAnd),
Expand Down
12 changes: 12 additions & 0 deletions tests/in/glsl/short_circuit.frag
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#version 460 core

bool a() { return true; }
bool b() { return true; }
bool c() { return true; }
bool d() { return true; }

void main() {
bool out1 = a() || b() || c();
bool out2 = a() && b() && c();
bool out3 = (a() || b()) && (c() || d());
}
2 changes: 2 additions & 0 deletions tests/in/short_circuit.param.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
(
)
11 changes: 11 additions & 0 deletions tests/in/short_circuit.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
fn a() -> bool { return true; }
fn b() -> bool { return true; }
fn c() -> bool { return true; }
fn d() -> bool { return true; }

@compute @workgroup_size(1)
fn main() {
_ = a() || b() || c();
_ = a() && b() && c();
_ = (a() || b()) && (c() || d());
}
16 changes: 14 additions & 2 deletions tests/out/glsl/operators.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,22 @@ float constructors() {
}

void logical() {
bool local = false;
bool local_1 = false;
bool unnamed_11 = (!true);
bvec2 unnamed_12 = not(bvec2(true));
bool unnamed_13 = (true || false);
bool unnamed_14 = (true && false);
if (true) {
local = true;
} else {
local = false;
}
bool unnamed_13 = local;
if (true) {
local_1 = false;
} else {
local_1 = false;
}
bool unnamed_14 = local_1;
bool unnamed_15 = (true || false);
bvec3 unnamed_16 = bvec3(bvec3(true).x || bvec3(false).x, bvec3(true).y || bvec3(false).y, bvec3(true).z || bvec3(false).z);
bool unnamed_17 = (true && false);
Expand Down
Loading

0 comments on commit 55c898d

Please sign in to comment.