Skip to content

Commit

Permalink
[wgsl-in] Handle modf and frexp
Browse files Browse the repository at this point in the history
  • Loading branch information
fornwall committed Aug 22, 2023
1 parent 3da9355 commit f203dbf
Show file tree
Hide file tree
Showing 32 changed files with 727 additions and 154 deletions.
5 changes: 5 additions & 0 deletions src/back/glsl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,4 +477,9 @@ pub const RESERVED_KEYWORDS: &[&str] = &[
// entry point name (should not be shadowed)
//
"main",
// Naga utilities:
super::MODF_FUNCTION,
super::MODF_STRUCT,
super::FREXP_FUNCTION,
super::FREXP_STRUCT,
];
57 changes: 55 additions & 2 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ pub const SUPPORTED_ES_VERSIONS: &[u16] = &[300, 310, 320];
/// of detail for bounds checking in `ImageLoad`
const CLAMPED_LOD_SUFFIX: &str = "_clamped_lod";

pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
pub(crate) const FREXP_STRUCT: &str = "naga_frexp_struct";
pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const MODF_STRUCT: &str = "naga_modf_struct";

/// Mapping between resources and bindings.
pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, u8>;

Expand Down Expand Up @@ -604,6 +609,38 @@ impl<'a, W: Write> Writer<'a, W> {
}
}

if self.module.special_types.frexp_result.is_some() {
writeln!(
self.out,
"struct {FREXP_STRUCT} {{
float fract;
int exp;
}};
{FREXP_STRUCT} {FREXP_FUNCTION}(float arg) {{
int exp;
float fract = frexp(arg, exp);
return {FREXP_STRUCT}(fract, exp);
}}"
)?;
}

if self.module.special_types.modf_result.is_some() {
writeln!(
self.out,
"struct {MODF_STRUCT} {{
float fract;
float whole;
}};
{MODF_STRUCT} {MODF_FUNCTION}(float arg) {{
float whole;
float fract = modf(arg, whole);
return {MODF_STRUCT}(fract, whole);
}}"
)?;
}

// Write struct types.
//
// This are always ordered because the IR is structured in a way that
Expand Down Expand Up @@ -860,6 +897,8 @@ impl<'a, W: Write> Writer<'a, W> {
| TypeInner::Sampler { .. }
| TypeInner::AccelerationStructure
| TypeInner::RayQuery
| TypeInner::ModfResult
| TypeInner::FrexpResult
| TypeInner::BindingArray { .. } => {
return Err(Error::Custom(format!("Unable to write type {inner:?}")))
}
Expand All @@ -885,6 +924,14 @@ impl<'a, W: Write> Writer<'a, W> {
}
// glsl array has the size separated from the base type
TypeInner::Array { base, .. } => self.write_type(base),
TypeInner::FrexpResult => {
write!(self.out, "{FREXP_STRUCT}")?;
Ok(())
}
TypeInner::ModfResult => {
write!(self.out, "{MODF_STRUCT}")?;
Ok(())
}
ref other => self.write_value_type(other),
}
}
Expand Down Expand Up @@ -2325,6 +2372,12 @@ impl<'a, W: Write> Writer<'a, W> {
&self.names[&NameKey::StructMember(ty, index)]
)?
}
TypeInner::FrexpResult => {
write!(self.out, ".{}", if index == 0 { "fract" } else { "exp" })?
}
TypeInner::ModfResult => {
write!(self.out, ".{}", if index == 0 { "fract" } else { "whole" })?
}
ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
}
}
Expand Down Expand Up @@ -2985,8 +3038,8 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Round => "roundEven",
Mf::Fract => "fract",
Mf::Trunc => "trunc",
Mf::Modf => "modf",
Mf::Frexp => "frexp",
Mf::Modf => MODF_FUNCTION,
Mf::Frexp => FREXP_FUNCTION,
Mf::Ldexp => "ldexp",
// exponent
Mf::Exp => "exp",
Expand Down
48 changes: 48 additions & 0 deletions src/back/hlsl/help.rs
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,54 @@ impl<'a, W: Write> super::Writer<'a, W> {
Ok(())
}

pub(super) fn write_special_functions(&mut self, module: &crate::Module) -> BackendResult {
if module.special_types.frexp_result.is_some() {
let function_name = super::writer::FREXP_FUNCTION;
let struct_name = super::writer::FREXP_STRUCT;
writeln!(
self.out,
"struct {struct_name} {{
float fract;
int exp;
}};
{struct_name} {function_name}(in float arg) {{
float exp;
float fract = frexp(arg, exp);
{struct_name} result;
result.exp = exp;
result.fract = fract;
return result;
}}"
)?;
// Write extra new line
writeln!(self.out)?;
}
if module.special_types.modf_result.is_some() {
let function_name = super::writer::MODF_FUNCTION;
let struct_name = super::writer::MODF_STRUCT;
writeln!(
self.out,
"struct {struct_name} {{
float fract;
float whole;
}};
{struct_name} {function_name}(in float arg) {{
float whole;
float fract = modf(arg, whole);
{struct_name} result;
result.whole = whole;
result.fract = fract;
return result;
}}"
)?;
// Write extra new line
writeln!(self.out)?;
}
Ok(())
}

/// Helper function that writes compose wrapped functions
pub(super) fn write_wrapped_compose_functions(
&mut self,
Expand Down
5 changes: 5 additions & 0 deletions src/back/hlsl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,11 @@ pub const RESERVED: &[&str] = &[
"TextureBuffer",
"ConstantBuffer",
"RayQuery",
// Naga utilities
super::writer::FREXP_FUNCTION,
super::writer::FREXP_STRUCT,
super::writer::MODF_FUNCTION,
super::writer::MODF_STRUCT,
];

// DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254
Expand Down
25 changes: 23 additions & 2 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ const SPECIAL_BASE_VERTEX: &str = "base_vertex";
const SPECIAL_BASE_INSTANCE: &str = "base_instance";
const SPECIAL_OTHER: &str = "other";

pub(crate) const FREXP_FUNCTION: &str = "frexp_modf";
pub(crate) const FREXP_STRUCT: &str = "frexp_modf_result";
pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const MODF_STRUCT: &str = "naga_modf_result";

struct EpStructMember {
name: String,
ty: Handle<crate::Type>,
Expand Down Expand Up @@ -244,6 +249,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}

self.write_special_functions(module)?;

self.write_wrapped_compose_functions(module, &module.const_expressions)?;

// Write all named constants
Expand Down Expand Up @@ -1058,6 +1065,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
self.write_array_size(module, base, size)?;
}
TypeInner::FrexpResult => {
write!(self.out, "struct {FREXP_STRUCT}")?;
}
TypeInner::ModfResult => {
write!(self.out, "struct {MODF_STRUCT}")?;
}
_ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
}

Expand Down Expand Up @@ -2276,6 +2289,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
&writer.names[&NameKey::StructMember(ty, index)]
)?
}
TypeInner::FrexpResult => {
write!(writer.out, ".{}", if index == 0 { "fract" } else { "exp" })?
}
TypeInner::ModfResult => write!(
writer.out,
".{}",
if index == 0 { "fract" } else { "whole" }
)?,
ref other => {
return Err(Error::Custom(format!("Cannot index {other:?}")))
}
Expand Down Expand Up @@ -2665,8 +2686,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::Round => Function::Regular("round"),
Mf::Fract => Function::Regular("frac"),
Mf::Trunc => Function::Regular("trunc"),
Mf::Modf => Function::Regular("modf"),
Mf::Frexp => Function::Regular("frexp"),
Mf::Modf => Function::Regular(MODF_FUNCTION),
Mf::Frexp => Function::Regular(FREXP_FUNCTION),
Mf::Ldexp => Function::Regular("ldexp"),
// exponent
Mf::Exp => Function::Regular("exp"),
Expand Down
4 changes: 4 additions & 0 deletions src/back/msl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,8 @@ pub const RESERVED: &[&str] = &[
// Naga utilities
"DefaultConstructible",
"clamped_lod_e",
super::writer::FREXP_FUNCTION,
super::writer::FREXP_STRUCT,
super::writer::MODF_FUNCTION,
super::writer::MODF_STRUCT,
];
90 changes: 87 additions & 3 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
const RAY_QUERY_FIELD_READY: &str = "ready";
const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";

pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
pub(crate) const FREXP_STRUCT: &str = "naga_frexp_result";
pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const MODF_STRUCT: &str = "naga_modf_result";

/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
///
/// The `sizes` slice determines whether this function writes a
Expand Down Expand Up @@ -140,7 +145,15 @@ impl<'a> Display for TypeContext<'a> {
// so just print the element type here.
write!(out, "{sub}")
}
crate::TypeInner::Struct { .. } => unreachable!(),
crate::TypeInner::Struct { .. } => {
unreachable!()
}
crate::TypeInner::FrexpResult { .. } => {
write!(out, "{}", FREXP_STRUCT)
}
crate::TypeInner::ModfResult { .. } => {
write!(out, "{}", MODF_STRUCT)
}
crate::TypeInner::Image {
dim,
arrayed,
Expand Down Expand Up @@ -452,6 +465,8 @@ impl crate::Type {
| Ti::Sampler { .. }
| Ti::AccelerationStructure
| Ti::RayQuery
| Ti::FrexpResult
| Ti::ModfResult
| Ti::BindingArray { .. } => false,
}
}
Expand Down Expand Up @@ -1635,6 +1650,24 @@ impl<W: Write> Writer<W> {
write!(self.out, "{NAMESPACE}::{op}")?;
self.put_call_parameters(iter::once(argument), context)?;
}
crate::Expression::Math {
fun: crate::MathFunction::Frexp,
arg,
..
} => {
write!(self.out, "{FREXP_FUNCTION}(")?;
self.put_expression(arg, context, false)?;
write!(self.out, ")")?;
}
crate::Expression::Math {
fun: crate::MathFunction::Modf,
arg,
..
} => {
write!(self.out, "{MODF_FUNCTION}(")?;
self.put_expression(arg, context, false)?;
write!(self.out, ")")?;
}
crate::Expression::Math {
fun,
arg,
Expand All @@ -1644,7 +1677,7 @@ impl<W: Write> Writer<W> {
} => {
use crate::MathFunction as Mf;

let scalar_argument = match *context.resolve_type(arg) {
let scalar_argument: bool = match *context.resolve_type(arg) {
crate::TypeInner::Scalar { .. } => true,
_ => false,
};
Expand Down Expand Up @@ -2018,7 +2051,9 @@ impl<W: Write> Writer<W> {
base_inner = &context.module.types[base].inner;
}
match *base_inner {
crate::TypeInner::Struct { .. } => (base, None),
crate::TypeInner::Struct { .. }
| crate::TypeInner::FrexpResult
| crate::TypeInner::ModfResult { .. } => (base, None),
_ => (base, Some(index::GuardedIndex::Known(index))),
}
}
Expand Down Expand Up @@ -2133,6 +2168,20 @@ impl<W: Write> Writer<W> {
write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
}
}
crate::TypeInner::FrexpResult | crate::TypeInner::ModfResult => {
self.put_access_chain(base, policy, context)?;
write!(
self.out,
".{}",
if index == 0 {
"fract"
} else if *base_ty == crate::TypeInner::FrexpResult {
"exp"
} else {
"whole"
}
)?;
}
_ => {
self.put_subscripted_access_chain(
base,
Expand Down Expand Up @@ -3236,6 +3285,41 @@ impl<W: Write> Writer<W> {
}
}
}

if module.special_types.frexp_result.is_some() {
writeln!(
self.out,
"
struct {FREXP_STRUCT} {{
float fract;
int exp;
}};
struct {FREXP_STRUCT} {FREXP_FUNCTION}(float arg) {{
int exp;
float fract = {NAMESPACE}::frexp(arg, exp);
return {FREXP_STRUCT}{{ fract, exp }};
}};"
)?;
}

if module.special_types.modf_result.is_some() {
writeln!(
self.out,
"
struct {MODF_STRUCT} {{
float fract;
float whole;
}};
struct {MODF_STRUCT} {MODF_FUNCTION}(float arg) {{
float whole;
float fract = {NAMESPACE}::modf(arg, whole);
return {MODF_STRUCT}{{ fract, whole }};
}};"
)?;
}

Ok(())
}

Expand Down
Loading

0 comments on commit f203dbf

Please sign in to comment.