Skip to content

Commit

Permalink
[hlsl-out] Bake countLeadingZeros
Browse files Browse the repository at this point in the history
  • Loading branch information
evahop committed Jan 30, 2023
1 parent 36eafee commit 8db7860
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 16 deletions.
9 changes: 4 additions & 5 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1114,14 +1114,13 @@ impl<'a, W: Write> Writer<'a, W> {
fn update_expressions_to_bake(&mut self, func: &crate::Function, info: &valid::FunctionInfo) {
use crate::Expression;
self.need_bake_expressions.clear();
for expr in func.expressions.iter() {
let expr_info = &info[expr.0];
let min_ref_count = func.expressions[expr.0].bake_ref_count();
for (fun_handle, expr) in func.expressions.iter() {
let expr_info = &info[fun_handle];
let min_ref_count = func.expressions[fun_handle].bake_ref_count();
if min_ref_count <= expr_info.ref_count {
self.need_bake_expressions.insert(expr.0);
self.need_bake_expressions.insert(fun_handle);
}

let (fun_handle, expr) = expr;
match *expr {
Expression::Math {
fun: crate::MathFunction::Dot,
Expand Down
3 changes: 2 additions & 1 deletion src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ mod writer;
use std::fmt::Error as FmtError;
use thiserror::Error;

use crate::proc;
use crate::{back, proc};

#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
Expand Down Expand Up @@ -280,4 +280,5 @@ pub struct Writer<'a, W> {
named_expressions: crate::NamedExpressions,
wrapped: Wrapped,
temp_access_chain: Vec<storage::SubAccess>,
need_bake_expressions: back::NeedBakeExpressions,
}
45 changes: 37 additions & 8 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
named_expressions: crate::NamedExpressions::default(),
wrapped: super::Wrapped::default(),
temp_access_chain: Vec::new(),
need_bake_expressions: Default::default(),
}
}

Expand All @@ -93,6 +94,34 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.entry_point_io.clear();
self.named_expressions.clear();
self.wrapped.clear();
self.need_bake_expressions.clear();
}

/// Helper method used to find which expressions of a given function require baking
///
/// # Notes
/// Clears `need_bake_expressions` set before adding to it
fn update_expressions_to_bake(&mut self, func: &crate::Function, info: &valid::FunctionInfo) {
use crate::Expression;
self.need_bake_expressions.clear();
for (fun_handle, expr) in func.expressions.iter() {
let expr_info = &info[fun_handle];
let min_ref_count = func.expressions[fun_handle].bake_ref_count();
if min_ref_count <= expr_info.ref_count {
self.need_bake_expressions.insert(fun_handle);
}

match *expr {
Expression::Math {
fun: crate::MathFunction::CountLeadingZeros,
arg,
..
} => {
self.need_bake_expressions.insert(arg);
}
_ => (),
}
}
}

pub fn write(
Expand Down Expand Up @@ -244,7 +273,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// before writing all statements and expressions.
self.write_wrapped_functions(module, &ctx)?;

self.write_function(module, name.as_str(), function, &ctx)?;
self.write_function(module, name.as_str(), function, &ctx, info)?;

writeln!(self.out)?;
}
Expand Down Expand Up @@ -296,7 +325,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}

let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
self.write_function(module, &name, &ep.function, &ctx)?;
self.write_function(module, &name, &ep.function, &ctx, info)?;

if index < module.entry_points.len() - 1 {
writeln!(self.out)?;
Expand Down Expand Up @@ -1039,9 +1068,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
name: &str,
func: &crate::Function,
func_ctx: &back::FunctionCtx<'_>,
info: &valid::FunctionInfo,
) -> BackendResult {
// Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax

self.update_expressions_to_bake(func, info);

// Write modifier
if let Some(crate::FunctionResult {
binding:
Expand Down Expand Up @@ -1290,15 +1322,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// Otherwise, we could accidentally write variable name instead of full expression.
// Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
Some(self.namer.call(name))
} else if self.need_bake_expressions.contains(&handle) {
Some(format!("_expr{}", handle.index()))
} else if info.ref_count == 0 {
Some(self.namer.call(""))
} else {
let min_ref_count = func_ctx.expressions[handle].bake_ref_count();
if min_ref_count <= info.ref_count {
Some(format!("_expr{}", handle.index()))
} else {
None
}
None
};

if let Some(name) = expr_name {
Expand Down
6 changes: 4 additions & 2 deletions tests/out/hlsl/math-functions.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ void main()
uint first_leading_bit_abs = firstbithigh(abs(0u));
int clz_a = (-1 < 0 ? 0 : 31 - firstbithigh(-1));
uint clz_b = asuint(31 - firstbithigh(1u));
int2 clz_c = ((-1).xx < (0).xx ? (0).xx : (31).xx - firstbithigh((-1).xx));
uint2 clz_d = asuint((31).xx - firstbithigh((1u).xx));
int2 _expr20 = (-1).xx;
int2 clz_c = (_expr20 < (0).xx ? (0).xx : (31).xx - firstbithigh(_expr20));
uint2 _expr23 = (1u).xx;
uint2 clz_d = asuint((31).xx - firstbithigh(_expr23));
}

0 comments on commit 8db7860

Please sign in to comment.