Skip to content

Commit

Permalink
Add countLeadingZeros (#2226)
Browse files Browse the repository at this point in the history
* Add countLeadingZeros

* [glsl-out] Bake countLeadingZeros

* [hlsl-out] Bake countLeadingZeros

* [hlsl-out] Update Baked expressions

* Remove unnecessary bake for sints

* [glsl-out] CountLeadingZeros without findMSB

* Don't check negatives when uint

* Perform the type conv after mix

* use log2

* fix clippy lints

---------

Co-authored-by: teoxoy <[email protected]>
  • Loading branch information
evahop and teoxoy authored Jan 31, 2023
1 parent a2b39e4 commit 6be394d
Show file tree
Hide file tree
Showing 17 changed files with 311 additions and 66 deletions.
96 changes: 72 additions & 24 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1114,33 +1114,33 @@ impl<'a, W: Write> Writer<'a, W> {
fn update_expressions_to_bake(&mut self, func: &crate::Function, info: &valid::FunctionInfo) {
use crate::Expression;
self.need_bake_expressions.clear();
for expr in func.expressions.iter() {
let expr_info = &info[expr.0];
let min_ref_count = func.expressions[expr.0].bake_ref_count();
for (fun_handle, expr) in func.expressions.iter() {
let expr_info = &info[fun_handle];
let min_ref_count = func.expressions[fun_handle].bake_ref_count();
if min_ref_count <= expr_info.ref_count {
self.need_bake_expressions.insert(expr.0);
}
// if the expression is a Dot product with integer arguments,
// then the args needs baking as well
if let (
fun_handle,
&Expression::Math {
fun: crate::MathFunction::Dot,
arg,
arg1,
..
},
) = expr
{
let inner = info[fun_handle].ty.inner_with(&self.module.types);
if let TypeInner::Scalar { kind, .. } = *inner {
match kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
self.need_bake_expressions.insert(fun_handle);
}

if let Expression::Math { fun, arg, arg1, .. } = *expr {
match fun {
crate::MathFunction::Dot => {
// if the expression is a Dot product with integer arguments,
// then the args needs baking as well
let inner = info[fun_handle].ty.inner_with(&self.module.types);
if let TypeInner::Scalar { kind, .. } = *inner {
match kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
}
_ => {}
}
}
_ => {}
}
crate::MathFunction::CountLeadingZeros => {
self.need_bake_expressions.insert(arg);
}
_ => {}
}
}
}
Expand Down Expand Up @@ -2928,6 +2928,54 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
// bits
Mf::CountLeadingZeros => {
match *ctx.info[arg].ty.inner_with(&self.module.types) {
crate::TypeInner::Vector { size, kind, .. } => {
let s = back::vector_size_str(size);

if let crate::ScalarKind::Uint = kind {
write!(self.out, "uvec{s}(")?;
} else {
write!(self.out, "ivec{s}(")?;
}

write!(self.out, "mix(vec{s}(31.0) - floor(log2(vec{s}(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ") + 0.5)), ")?;

if let crate::ScalarKind::Uint = kind {
write!(self.out, "vec{s}(32.0), lessThanEqual(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ", uvec{s}(0u))))")?;
} else {
write!(self.out, "mix(vec{s}(0.0), vec{s}(32.0), equal(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ", ivec{s}(0))), lessThanEqual(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ", ivec{s}(0))))")?;
}
}
crate::TypeInner::Scalar { kind, .. } => {
write!(self.out, "(")?;
self.write_expr(arg, ctx)?;

if let crate::ScalarKind::Uint = kind {
write!(self.out, " == 0u ? 32u : uint(")?;
} else {
write!(self.out, " <= 0 ? (")?;
self.write_expr(arg, ctx)?;
write!(self.out, " == 0 ? 32 : 0) : int(")?;
}

write!(self.out, "31.0 - floor(log2(float(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ") + 0.5))))")?;
}
_ => unreachable!(),
};

return Ok(());
}
Mf::CountOneBits => "bitCount",
Mf::ReverseBits => "bitfieldReverse",
Mf::ExtractBits => "bitfieldExtract",
Expand Down
3 changes: 2 additions & 1 deletion src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ mod writer;
use std::fmt::Error as FmtError;
use thiserror::Error;

use crate::proc;
use crate::{back, proc};

#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
Expand Down Expand Up @@ -280,4 +280,5 @@ pub struct Writer<'a, W> {
named_expressions: crate::NamedExpressions,
wrapped: Wrapped,
temp_access_chain: Vec<storage::SubAccess>,
need_bake_expressions: back::NeedBakeExpressions,
}
96 changes: 88 additions & 8 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
named_expressions: crate::NamedExpressions::default(),
wrapped: super::Wrapped::default(),
temp_access_chain: Vec::new(),
need_bake_expressions: Default::default(),
}
}

Expand All @@ -93,6 +94,46 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.entry_point_io.clear();
self.named_expressions.clear();
self.wrapped.clear();
self.need_bake_expressions.clear();
}

/// Helper method used to find which expressions of a given function require baking
///
/// # Notes
/// Clears `need_bake_expressions` set before adding to it
fn update_expressions_to_bake(
&mut self,
module: &Module,
func: &crate::Function,
info: &valid::FunctionInfo,
) {
use crate::Expression;
self.need_bake_expressions.clear();
for (fun_handle, expr) in func.expressions.iter() {
let expr_info = &info[fun_handle];
let min_ref_count = func.expressions[fun_handle].bake_ref_count();
if min_ref_count <= expr_info.ref_count {
self.need_bake_expressions.insert(fun_handle);
}

if let Expression::Math { fun, arg, .. } = *expr {
match fun {
crate::MathFunction::Asinh
| crate::MathFunction::Acosh
| crate::MathFunction::Atanh
| crate::MathFunction::Unpack2x16float => {
self.need_bake_expressions.insert(arg);
}
crate::MathFunction::CountLeadingZeros => {
let inner = info[fun_handle].ty.inner_with(&module.types);
if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
self.need_bake_expressions.insert(arg);
}
}
_ => {}
}
}
}
}

pub fn write(
Expand Down Expand Up @@ -244,7 +285,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// before writing all statements and expressions.
self.write_wrapped_functions(module, &ctx)?;

self.write_function(module, name.as_str(), function, &ctx)?;
self.write_function(module, name.as_str(), function, &ctx, info)?;

writeln!(self.out)?;
}
Expand Down Expand Up @@ -296,7 +337,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}

let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
self.write_function(module, &name, &ep.function, &ctx)?;
self.write_function(module, &name, &ep.function, &ctx, info)?;

if index < module.entry_points.len() - 1 {
writeln!(self.out)?;
Expand Down Expand Up @@ -1034,9 +1075,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
name: &str,
func: &crate::Function,
func_ctx: &back::FunctionCtx<'_>,
info: &valid::FunctionInfo,
) -> BackendResult {
// Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax

self.update_expressions_to_bake(module, func, info);

// Write modifier
if let Some(crate::FunctionResult {
binding:
Expand Down Expand Up @@ -1284,15 +1328,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// Otherwise, we could accidentally write variable name instead of full expression.
// Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
Some(self.namer.call(name))
} else if self.need_bake_expressions.contains(&handle) {
Some(format!("_expr{}", handle.index()))
} else if info.ref_count == 0 {
Some(self.namer.call(""))
} else {
let min_ref_count = func_ctx.expressions[handle].bake_ref_count();
if min_ref_count <= info.ref_count {
Some(format!("_expr{}", handle.index()))
} else {
None
}
None
};

if let Some(name) = expr_name {
Expand Down Expand Up @@ -2510,6 +2551,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Unpack2x16float,
Regular(&'static str),
MissingIntOverload(&'static str),
CountLeadingZeros,
}

let fun = match fun {
Expand Down Expand Up @@ -2572,6 +2614,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::Transpose => Function::Regular("transpose"),
Mf::Determinant => Function::Regular("determinant"),
// bits
Mf::CountLeadingZeros => Function::CountLeadingZeros,
Mf::CountOneBits => Function::MissingIntOverload("countbits"),
Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
Mf::FindLsb => Function::Regular("firstbitlow"),
Expand Down Expand Up @@ -2639,6 +2682,43 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, ")")?;
}
}
Function::CountLeadingZeros => {
match *func_ctx.info[arg].ty.inner_with(&module.types) {
TypeInner::Vector { size, kind, .. } => {
let s = match size {
crate::VectorSize::Bi => ".xx",
crate::VectorSize::Tri => ".xxx",
crate::VectorSize::Quad => ".xxxx",
};

if let ScalarKind::Uint = kind {
write!(self.out, "asuint((31){s} - firstbithigh(")?;
} else {
write!(self.out, "(")?;
self.write_expr(module, arg, func_ctx)?;
write!(
self.out,
" < (0){s} ? (0){s} : (31){s} - firstbithigh("
)?;
}
}
TypeInner::Scalar { kind, .. } => {
if let ScalarKind::Uint = kind {
write!(self.out, "asuint(31 - firstbithigh(")?;
} else {
write!(self.out, "(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " < 0 ? 0 : 31 - firstbithigh(")?;
}
}
_ => unreachable!(),
}

self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;

return Ok(());
}
}
}
Expression::Swizzle {
Expand Down
1 change: 1 addition & 0 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1689,6 +1689,7 @@ impl<W: Write> Writer<W> {
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
// bits
Mf::CountLeadingZeros => "clz",
Mf::CountOneBits => "popcount",
Mf::ReverseBits => "reverse_bits",
Mf::ExtractBits => "extract_bits",
Expand Down
69 changes: 67 additions & 2 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ impl<'w> BlockContext<'w> {
self.temp_list.resize(size as _, arg1_id);

let id = self.gen_id();
block.body.push(Instruction::composite_construct(
block.body.push(Instruction::constant_composite(
result_type_id,
id,
&self.temp_list,
Expand All @@ -735,7 +735,7 @@ impl<'w> BlockContext<'w> {
self.temp_list.resize(size as _, arg2_id);

let id = self.gen_id();
block.body.push(Instruction::composite_construct(
block.body.push(Instruction::constant_composite(
result_type_id,
id,
&self.temp_list,
Expand Down Expand Up @@ -888,6 +888,71 @@ impl<'w> BlockContext<'w> {
id,
arg0_id,
)),
Mf::CountLeadingZeros => {
let int = crate::ScalarValue::Sint(31);

let (int_type_id, int_id) = match *arg_ty {
crate::TypeInner::Vector { size, width, .. } => {
let ty = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(size),
kind: crate::ScalarKind::Sint,
width,
pointer_space: None,
}));

self.temp_list.clear();
self.temp_list
.resize(size as _, self.writer.get_constant_scalar(int, width));

let id = self.gen_id();
block.body.push(Instruction::constant_composite(
ty,
id,
&self.temp_list,
));

(ty, id)
}
crate::TypeInner::Scalar { width, .. } => (
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
kind: crate::ScalarKind::Sint,
width,
pointer_space: None,
})),
self.writer.get_constant_scalar(int, width),
),
_ => unreachable!(),
};

block.body.push(Instruction::ext_inst(
self.writer.gl450_ext_inst_id,
spirv::GLOp::FindUMsb,
int_type_id,
id,
&[arg0_id],
));

let sub_id = self.gen_id();
block.body.push(Instruction::binary(
spirv::Op::ISub,
int_type_id,
sub_id,
int_id,
id,
));

if let Some(crate::ScalarKind::Uint) = arg_scalar_kind {
block.body.push(Instruction::unary(
spirv::Op::Bitcast,
result_type_id,
self.gen_id(),
sub_id,
));
}

return Ok(());
}
Mf::CountOneBits => MathOp::Custom(Instruction::unary(
spirv::Op::BitCount,
result_type_id,
Expand Down
1 change: 1 addition & 0 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1578,6 +1578,7 @@ impl<W: Write> Writer<W> {
Mf::Transpose => Function::Regular("transpose"),
Mf::Determinant => Function::Regular("determinant"),
// bits
Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"),
Mf::CountOneBits => Function::Regular("countOneBits"),
Mf::ReverseBits => Function::Regular("reverseBits"),
Mf::ExtractBits => Function::Regular("extractBits"),
Expand Down
Loading

0 comments on commit 6be394d

Please sign in to comment.