Skip to content

Commit

Permalink
[hlsl-out] Bake countLeadingZeros
Browse files Browse the repository at this point in the history
  • Loading branch information
evahop committed Jan 30, 2023
1 parent 36eafee commit f8eaaca
Show file tree
Hide file tree
Showing 24 changed files with 236 additions and 423 deletions.
9 changes: 4 additions & 5 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1114,14 +1114,13 @@ impl<'a, W: Write> Writer<'a, W> {
fn update_expressions_to_bake(&mut self, func: &crate::Function, info: &valid::FunctionInfo) {
use crate::Expression;
self.need_bake_expressions.clear();
for expr in func.expressions.iter() {
let expr_info = &info[expr.0];
let min_ref_count = func.expressions[expr.0].bake_ref_count();
for (fun_handle, expr) in func.expressions.iter() {
let expr_info = &info[fun_handle];
let min_ref_count = func.expressions[fun_handle].bake_ref_count();
if min_ref_count <= expr_info.ref_count {
self.need_bake_expressions.insert(expr.0);
self.need_bake_expressions.insert(fun_handle);
}

let (fun_handle, expr) = expr;
match *expr {
Expression::Math {
fun: crate::MathFunction::Dot,
Expand Down
3 changes: 2 additions & 1 deletion src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ mod writer;
use std::fmt::Error as FmtError;
use thiserror::Error;

use crate::proc;
use crate::{back, proc};

#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
Expand Down Expand Up @@ -280,4 +280,5 @@ pub struct Writer<'a, W> {
named_expressions: crate::NamedExpressions,
wrapped: Wrapped,
temp_access_chain: Vec<storage::SubAccess>,
need_bake_expressions: back::NeedBakeExpressions,
}
44 changes: 36 additions & 8 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
named_expressions: crate::NamedExpressions::default(),
wrapped: super::Wrapped::default(),
temp_access_chain: Vec::new(),
need_bake_expressions: Default::default(),
}
}

Expand All @@ -93,6 +94,34 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.entry_point_io.clear();
self.named_expressions.clear();
self.wrapped.clear();
self.need_bake_expressions.clear();
}

/// Helper method used to find which expressions of a given function require baking
///
/// # Notes
/// Clears `need_bake_expressions` set before adding to it
fn update_expressions_to_bake(&mut self, func: &crate::Function, info: &valid::FunctionInfo) {
use crate::Expression;
self.need_bake_expressions.clear();
for (fun_handle, expr) in func.expressions.iter() {
let expr_info = &info[fun_handle];
let min_ref_count = func.expressions[fun_handle].bake_ref_count();
if min_ref_count <= expr_info.ref_count {
self.need_bake_expressions.insert(fun_handle);
}

match *expr {
Expression::Math {
fun: crate::MathFunction::CountLeadingZeros,
arg,
..
} => {
self.need_bake_expressions.insert(arg);
}
_ => (),
}
}
}

pub fn write(
Expand Down Expand Up @@ -244,7 +273,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// before writing all statements and expressions.
self.write_wrapped_functions(module, &ctx)?;

self.write_function(module, name.as_str(), function, &ctx)?;
self.write_function(module, name.as_str(), function, &ctx, info)?;

writeln!(self.out)?;
}
Expand Down Expand Up @@ -296,7 +325,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}

let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
self.write_function(module, &name, &ep.function, &ctx)?;
self.write_function(module, &name, &ep.function, &ctx, info)?;

if index < module.entry_points.len() - 1 {
writeln!(self.out)?;
Expand Down Expand Up @@ -1039,6 +1068,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
name: &str,
func: &crate::Function,
func_ctx: &back::FunctionCtx<'_>,
info: &valid::FunctionInfo,
) -> BackendResult {
// Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax

Expand Down Expand Up @@ -1217,6 +1247,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out, "}}")?;

self.named_expressions.clear();
self.update_expressions_to_bake(func, info);

Ok(())
}
Expand Down Expand Up @@ -1290,15 +1321,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// Otherwise, we could accidentally write variable name instead of full expression.
// Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
Some(self.namer.call(name))
} else if self.need_bake_expressions.contains(&handle) {
Some(format!("_expr{}", handle.index()))
} else if info.ref_count == 0 {
Some(self.namer.call(""))
} else {
let min_ref_count = func_ctx.expressions[handle].bake_ref_count();
if min_ref_count <= info.ref_count {
Some(format!("_expr{}", handle.index()))
} else {
None
}
None
};

if let Some(name) = expr_name {
Expand Down
67 changes: 23 additions & 44 deletions tests/out/hlsl/access.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -142,35 +142,23 @@ void test_matrix_within_struct_accesses()
Baz t = (Baz)0;

idx = 1;
int _expr2 = idx;
idx = (_expr2 - 1);
idx = (idx - 1);
float3x2 unnamed = GetMatmOnBaz(baz);
float2 unnamed_1 = GetMatmOnBaz(baz)[0];
int _expr15 = idx;
float2 unnamed_2 = GetMatmOnBaz(baz)[_expr15];
float2 unnamed_2 = GetMatmOnBaz(baz)[idx];
float unnamed_3 = GetMatmOnBaz(baz)[0].y;
int _expr29 = idx;
float unnamed_4 = GetMatmOnBaz(baz)[0][_expr29];
int _expr34 = idx;
float unnamed_5 = GetMatmOnBaz(baz)[_expr34].y;
int _expr41 = idx;
int _expr43 = idx;
float unnamed_6 = GetMatmOnBaz(baz)[_expr41][_expr43];
float unnamed_4 = GetMatmOnBaz(baz)[0][idx];
float unnamed_5 = GetMatmOnBaz(baz)[idx].y;
float unnamed_6 = GetMatmOnBaz(baz)[idx][idx];
t = ConstructBaz(float3x2((1.0).xx, (2.0).xx, (3.0).xx));
int _expr55 = idx;
idx = (_expr55 + 1);
idx = (idx + 1);
SetMatmOnBaz(t, float3x2((6.0).xx, (5.0).xx, (4.0).xx));
t.m_0 = (9.0).xx;
int _expr72 = idx;
SetMatVecmOnBaz(t, (90.0).xx, _expr72);
SetMatVecmOnBaz(t, (90.0).xx, idx);
t.m_0[1] = 10.0;
int _expr85 = idx;
t.m_0[_expr85] = 20.0;
int _expr89 = idx;
SetMatScalarmOnBaz(t, 30.0, _expr89, 1);
int _expr95 = idx;
int _expr97 = idx;
SetMatScalarmOnBaz(t, 40.0, _expr95, _expr97);
t.m_0[idx] = 20.0;
SetMatScalarmOnBaz(t, 30.0, idx, 1);
SetMatScalarmOnBaz(t, 40.0, idx, idx);
return;
}

Expand All @@ -191,32 +179,22 @@ void test_matrix_within_array_within_struct_accesses()
float4x2 unnamed_7[2] = ((float4x2[2])nested_mat_cx2_.am);
float4x2 unnamed_8 = ((float4x2)nested_mat_cx2_.am[0]);
float2 unnamed_9 = nested_mat_cx2_.am[0]._0;
int _expr24 = idx_1;
float2 unnamed_10 = __get_col_of_mat4x2(nested_mat_cx2_.am[0], _expr24);
float2 unnamed_10 = __get_col_of_mat4x2(nested_mat_cx2_.am[0], idx_1);
float unnamed_11 = nested_mat_cx2_.am[0]._0.y;
int _expr42 = idx_1;
float unnamed_12 = nested_mat_cx2_.am[0]._0[_expr42];
int _expr49 = idx_1;
float unnamed_13 = __get_col_of_mat4x2(nested_mat_cx2_.am[0], _expr49).y;
int _expr58 = idx_1;
int _expr60 = idx_1;
float unnamed_14 = __get_col_of_mat4x2(nested_mat_cx2_.am[0], _expr58)[_expr60];
float unnamed_12 = nested_mat_cx2_.am[0]._0[idx_1];
float unnamed_13 = __get_col_of_mat4x2(nested_mat_cx2_.am[0], idx_1).y;
float unnamed_14 = __get_col_of_mat4x2(nested_mat_cx2_.am[0], idx_1)[idx_1];
t_1 = ConstructMatCx2InArray(Constructarray2_float4x2_(float4x2(float2(0.0, 0.0), float2(0.0, 0.0), float2(0.0, 0.0), float2(0.0, 0.0)), float4x2(float2(0.0, 0.0), float2(0.0, 0.0), float2(0.0, 0.0), float2(0.0, 0.0))));
int _expr66 = idx_1;
idx_1 = (_expr66 + 1);
idx_1 = (idx_1 + 1);
t_1.am = (__mat4x2[2])Constructarray2_float4x2_(float4x2(float2(0.0, 0.0), float2(0.0, 0.0), float2(0.0, 0.0), float2(0.0, 0.0)), float4x2(float2(0.0, 0.0), float2(0.0, 0.0), float2(0.0, 0.0), float2(0.0, 0.0)));
t_1.am[0] = (__mat4x2)float4x2((8.0).xx, (7.0).xx, (6.0).xx, (5.0).xx);
t_1.am[0]._0 = (9.0).xx;
int _expr93 = idx_1;
__set_col_of_mat4x2(t_1.am[0], _expr93, (90.0).xx);
float2 _expr89 = (9.0).xx;
t_1.am[0]._0 = _expr89;
__set_col_of_mat4x2(t_1.am[0], idx_1, (90.0).xx);
t_1.am[0]._0.y = 10.0;
int _expr110 = idx_1;
t_1.am[0]._0[_expr110] = 20.0;
int _expr116 = idx_1;
__set_el_of_mat4x2(t_1.am[0], _expr116, 1, 30.0);
int _expr124 = idx_1;
int _expr126 = idx_1;
__set_el_of_mat4x2(t_1.am[0], _expr124, _expr126, 40.0);
t_1.am[0]._0[idx_1] = 20.0;
__set_el_of_mat4x2(t_1.am[0], idx_1, 1, 30.0);
__set_el_of_mat4x2(t_1.am[0], idx_1, idx_1, 40.0);
return;
}

Expand Down Expand Up @@ -282,8 +260,9 @@ ret_Constructarray2_uint2_ Constructarray2_uint2_(uint2 arg0, uint2 arg1) {
float4 foo_frag() : SV_Target0
{
bar.Store(8+16+0, asuint(1.0));
float4x3 _expr16 = float4x3((0.0).xxx, (1.0).xxx, (2.0).xxx, (3.0).xxx);
{
float4x3 _value2 = float4x3((0.0).xxx, (1.0).xxx, (2.0).xxx, (3.0).xxx);
float4x3 _value2 = _expr16;
bar.Store3(0+0, asuint(_value2[0]));
bar.Store3(0+16, asuint(_value2[1]));
bar.Store3(0+32, asuint(_value2[2]));
Expand Down
Loading

0 comments on commit f8eaaca

Please sign in to comment.