From 5b65f11e3cb491707f9cdbb78388972367feb52d Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Wed, 18 Oct 2023 15:35:22 +0200 Subject: [PATCH] [refactor] make use of `resolve_type` wherever possible --- src/back/glsl/mod.rs | 46 ++++++++++----------- src/back/hlsl/help.rs | 11 +++-- src/back/hlsl/storage.rs | 2 +- src/back/hlsl/writer.rs | 86 +++++++++++++++------------------------- src/back/mod.rs | 2 +- src/back/msl/writer.rs | 6 +-- 6 files changed, 61 insertions(+), 92 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index c6f211183a..60431e986e 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -1828,8 +1828,7 @@ impl<'a, W: Write> Writer<'a, W> { // This is where we can generate intermediate constants for some expression types. Statement::Emit(ref range) => { for handle in range.clone() { - let info = &ctx.info[handle]; - let ptr_class = info.ty.inner_with(&self.module.types).pointer_space(); + let ptr_class = ctx.resolve_type(handle, &self.module.types).pointer_space(); let expr_name = if ptr_class.is_some() { // GLSL can't save a pointer-valued expression in a variable, // but we shouldn't ever need to: they should never be named expressions, @@ -1859,7 +1858,7 @@ impl<'a, W: Write> Writer<'a, W> { if let TypeInner::Image { class: crate::ImageClass::Sampled { .. }, .. - } = *ctx.info[image].ty.inner_with(&self.module.types) + } = *ctx.resolve_type(image, &self.module.types) { if let proc::BoundsCheckPolicy::Restrict = self.policies.image_load { write!(self.out, "{level}")?; @@ -2225,7 +2224,7 @@ impl<'a, W: Write> Writer<'a, W> { } => { write!(self.out, "{level}")?; let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); - let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + let res_ty = ctx.resolve_type(result, &self.module.types); self.write_value_type(res_ty)?; write!(self.out, " {res_name} = ")?; self.named_expressions.insert(result, res_name); @@ -2484,7 +2483,7 @@ impl<'a, W: Write> Writer<'a, W> { level, depth_ref, } => { - let dim = match *ctx.info[image].ty.inner_with(&self.module.types) { + let dim = match *ctx.resolve_type(image, &self.module.types) { TypeInner::Image { dim, .. } => dim, _ => unreachable!(), }; @@ -2545,7 +2544,7 @@ impl<'a, W: Write> Writer<'a, W> { // We need to get the coordinates vector size to later build a vector that's `size + 1` // if `depth_ref` is some, if it isn't a vector we panic as that's not a valid expression - let mut coord_dim = match *ctx.info[coordinate].ty.inner_with(&self.module.types) { + let mut coord_dim = match *ctx.resolve_type(coordinate, &self.module.types) { TypeInner::Vector { size, .. } => size as u8, TypeInner::Scalar { .. } => 1, _ => unreachable!(), @@ -2672,7 +2671,7 @@ impl<'a, W: Write> Writer<'a, W> { use crate::ImageClass; // This will only panic if the module is invalid - let (dim, class) = match *ctx.info[image].ty.inner_with(&self.module.types) { + let (dim, class) = match *ctx.resolve_type(image, &self.module.types) { TypeInner::Image { dim, arrayed: _, @@ -2704,7 +2703,7 @@ impl<'a, W: Write> Writer<'a, W> { self.write_expr(image, ctx)?; if let Some(expr) = level { let cast_to_int = matches!( - *ctx.info[expr].ty.inner_with(&self.module.types), + *ctx.resolve_type(expr, &self.module.types), crate::TypeInner::Scalar { kind: crate::ScalarKind::Uint, .. @@ -2779,7 +2778,7 @@ impl<'a, W: Write> Writer<'a, W> { let operator_or_fn = match op { crate::UnaryOperator::Negate => "-", crate::UnaryOperator::LogicalNot => { - match *ctx.info[expr].ty.inner_with(&self.module.types) { + match *ctx.resolve_type(expr, &self.module.types) { TypeInner::Vector { .. } => "not", _ => "!", } @@ -2805,8 +2804,8 @@ impl<'a, W: Write> Writer<'a, W> { // implemented as a function call use crate::{BinaryOperator as Bo, ScalarKind as Sk, TypeInner as Ti}; - let left_inner = ctx.info[left].ty.inner_with(&self.module.types); - let right_inner = ctx.info[right].ty.inner_with(&self.module.types); + let left_inner = ctx.resolve_type(left, &self.module.types); + let right_inner = ctx.resolve_type(right, &self.module.types); let function = match (left_inner, right_inner) { (&Ti::Vector { kind, .. }, &Ti::Vector { .. }) => match op { @@ -2935,7 +2934,7 @@ impl<'a, W: Write> Writer<'a, W> { accept, reject, } => { - let cond_ty = ctx.info[condition].ty.inner_with(&self.module.types); + let cond_ty = ctx.resolve_type(condition, &self.module.types); let vec_select = if let TypeInner::Vector { .. } = *cond_ty { true } else { @@ -3025,7 +3024,7 @@ impl<'a, W: Write> Writer<'a, W> { self.write_expr(arg, ctx)?; - match *ctx.info[arg].ty.inner_with(&self.module.types) { + match *ctx.resolve_type(arg, &self.module.types) { crate::TypeInner::Vector { size, .. } => write!( self.out, ", vec{}(0.0), vec{0}(1.0)", @@ -3072,7 +3071,7 @@ impl<'a, W: Write> Writer<'a, W> { Mf::Log2 => "log2", Mf::Pow => "pow", // geometry - Mf::Dot => match *ctx.info[arg].ty.inner_with(&self.module.types) { + Mf::Dot => match *ctx.resolve_type(arg, &self.module.types) { crate::TypeInner::Vector { kind: crate::ScalarKind::Float, .. @@ -3128,7 +3127,7 @@ impl<'a, W: Write> Writer<'a, W> { Mf::Determinant => "determinant", // bits Mf::CountTrailingZeros => { - match *ctx.info[arg].ty.inner_with(&self.module.types) { + match *ctx.resolve_type(arg, &self.module.types) { crate::TypeInner::Vector { size, kind, .. } => { let s = back::vector_size_str(size); if let crate::ScalarKind::Uint = kind { @@ -3158,7 +3157,7 @@ impl<'a, W: Write> Writer<'a, W> { } Mf::CountLeadingZeros => { if self.options.version.supports_integer_functions() { - match *ctx.info[arg].ty.inner_with(&self.module.types) { + match *ctx.resolve_type(arg, &self.module.types) { crate::TypeInner::Vector { size, kind, .. } => { let s = back::vector_size_str(size); @@ -3189,7 +3188,7 @@ impl<'a, W: Write> Writer<'a, W> { _ => unreachable!(), }; } else { - match *ctx.info[arg].ty.inner_with(&self.module.types) { + match *ctx.resolve_type(arg, &self.module.types) { crate::TypeInner::Vector { size, kind, .. } => { let s = back::vector_size_str(size); @@ -3262,7 +3261,7 @@ impl<'a, W: Write> Writer<'a, W> { // Check if the argument is an unsigned integer and return the vector size // in case it's a vector - let maybe_uint_size = match *ctx.info[arg].ty.inner_with(&self.module.types) { + let maybe_uint_size = match *ctx.resolve_type(arg, &self.module.types) { crate::TypeInner::Scalar { kind: crate::ScalarKind::Uint, .. @@ -3349,7 +3348,7 @@ impl<'a, W: Write> Writer<'a, W> { kind: target_kind, convert, } => { - let inner = ctx.info[expr].ty.inner_with(&self.module.types); + let inner = ctx.resolve_type(expr, &self.module.types); match convert { Some(width) => { // this is similar to `write_type`, but with the target kind @@ -3515,7 +3514,7 @@ impl<'a, W: Write> Writer<'a, W> { } // Otherwise write just the expression (and the 1D hack if needed) None => { - let uvec_size = match *ctx.info[coordinate].ty.inner_with(&self.module.types) { + let uvec_size = match *ctx.resolve_type(coordinate, &self.module.types) { TypeInner::Scalar { kind: crate::ScalarKind::Uint, .. @@ -3563,7 +3562,7 @@ impl<'a, W: Write> Writer<'a, W> { // so we don't need to generate bounds checks (OpenGL 4.2 Core ยง3.9.20) // This will only panic if the module is invalid - let dim = match *ctx.info[image].ty.inner_with(&self.module.types) { + let dim = match *ctx.resolve_type(image, &self.module.types) { TypeInner::Image { dim, .. } => dim, _ => unreachable!(), }; @@ -3626,7 +3625,7 @@ impl<'a, W: Write> Writer<'a, W> { // in bounds (`ReadZeroSkipWrite`) or make them a valid texel (`Restrict`). // This will only panic if the module is invalid - let (dim, class) = match *ctx.info[image].ty.inner_with(&self.module.types) { + let (dim, class) = match *ctx.resolve_type(image, &self.module.types) { TypeInner::Image { dim, arrayed: _, @@ -3891,8 +3890,7 @@ impl<'a, W: Write> Writer<'a, W> { } } - let base_ty_res = &ctx.info[named].ty; - let resolved = base_ty_res.inner_with(&self.module.types); + let resolved = ctx.resolve_type(named, &self.module.types); write!(self.out, " {name}")?; if let TypeInner::Array { base, size, .. } = *resolved { diff --git a/src/back/hlsl/help.rs b/src/back/hlsl/help.rs index 2d725514b2..fcb9949fe1 100644 --- a/src/back/hlsl/help.rs +++ b/src/back/hlsl/help.rs @@ -244,7 +244,7 @@ impl<'a, W: Write> super::Writer<'a, W> { const MIP_LEVEL_PARAM: &str = "mip_level"; // Write function return type and name - let ret_ty = func_ctx.info[expr_handle].ty.inner_with(&module.types); + let ret_ty = func_ctx.resolve_type(expr_handle, &module.types); self.write_value_type(module, ret_ty)?; write!(self.out, " ")?; self.write_wrapped_image_query_function_name(wiq)?; @@ -891,7 +891,7 @@ impl<'a, W: Write> super::Writer<'a, W> { } } crate::Expression::ImageQuery { image, query } => { - let wiq = match *func_ctx.info[image].ty.inner_with(&module.types) { + let wiq = match *func_ctx.resolve_type(image, &module.types) { crate::TypeInner::Image { dim, arrayed, @@ -912,9 +912,8 @@ impl<'a, W: Write> super::Writer<'a, W> { // Write `WrappedConstructor` for structs that are loaded from `AddressSpace::Storage` // since they will later be used by the fn `write_storage_load` crate::Expression::Load { pointer } => { - let pointer_space = func_ctx.info[pointer] - .ty - .inner_with(&module.types) + let pointer_space = func_ctx + .resolve_type(pointer, &module.types) .pointer_space(); if let Some(crate::AddressSpace::Storage { .. }) = pointer_space { @@ -1016,7 +1015,7 @@ impl<'a, W: Write> super::Writer<'a, W> { if extra == 0 { self.write_expr(module, coordinate, func_ctx)?; } else { - let num_coords = match *func_ctx.info[coordinate].ty.inner_with(&module.types) { + let num_coords = match *func_ctx.resolve_type(coordinate, &module.types) { crate::TypeInner::Scalar { .. } => 1, crate::TypeInner::Vector { size, .. } => size as usize, _ => unreachable!(), diff --git a/src/back/hlsl/storage.rs b/src/back/hlsl/storage.rs index d4eeefe3e1..1e02e9e502 100644 --- a/src/back/hlsl/storage.rs +++ b/src/back/hlsl/storage.rs @@ -466,7 +466,7 @@ impl super::Writer<'_, W> { } }; - let parent = match *func_ctx.info[next_expr].ty.inner_with(&module.types) { + let parent = match *func_ctx.resolve_type(next_expr, &module.types) { crate::TypeInner::Pointer { base, .. } => match module.types[base].inner { crate::TypeInner::Struct { ref members, .. } => Parent::Struct(members), crate::TypeInner::Array { stride, .. } => Parent::Array { stride }, diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 5286eb6ddb..f26604476a 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1320,8 +1320,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { match *stmt { Statement::Emit(ref range) => { for handle in range.clone() { - let info = &func_ctx.info[handle]; - let ptr_class = info.ty.inner_with(&module.types).pointer_space(); + let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space(); let expr_name = if ptr_class.is_some() { // HLSL can't save a pointer-valued expression in a variable, // but we shouldn't ever need to: they should never be named expressions, @@ -1441,7 +1440,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } Statement::Store { pointer, value } => { - let ty_inner = func_ctx.info[pointer].ty.inner_with(&module.types); + let ty_inner = func_ctx.resolve_type(pointer, &module.types); if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() { let var_handle = self.fill_access_chain(module, pointer, func_ctx)?; self.write_storage_store( @@ -1467,8 +1466,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } let get_members = |expr: Handle| { - let base_ty_res = &func_ctx.info[expr].ty; - let resolved = base_ty_res.inner_with(&module.types); + let resolved = func_ctx.resolve_type(expr, &module.types); match *resolved { TypeInner::Pointer { base, .. } => match module.types[base].inner { TypeInner::Struct { ref members, .. } => Some(members), @@ -1484,7 +1482,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { let mut current_expr = pointer; for _ in 0..3 { - let resolved = func_ctx.info[current_expr].ty.inner_with(&module.types); + let resolved = func_ctx.resolve_type(current_expr, &module.types); match (resolved, &func_ctx.expressions[current_expr]) { ( @@ -1634,7 +1632,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { let mut current_expr = pointer; for _ in 0..3 { - let resolved = func_ctx.info[current_expr].ty.inner_with(&module.types); + let resolved = func_ctx.resolve_type(current_expr, &module.types); match (resolved, &func_ctx.expressions[current_expr]) { ( &TypeInner::ValuePointer { @@ -1726,8 +1724,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { }) = get_inner_matrix_of_struct_array_member( module, pointer, func_ctx, false, ) { - let mut resolved = - func_ctx.info[pointer].ty.inner_with(&module.types); + let mut resolved = func_ctx.resolve_type(pointer, &module.types); if let TypeInner::Pointer { base, .. } = *resolved { resolved = &module.types[base].inner; } @@ -1854,9 +1851,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { }; // Validation ensures that `pointer` has a `Pointer` type. - let pointer_space = func_ctx.info[pointer] - .ty - .inner_with(&module.types) + let pointer_space = func_ctx + .resolve_type(pointer, &module.types) .pointer_space() .unwrap(); @@ -2163,11 +2159,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { op: crate::BinaryOperator::Multiply, left, right, - } if func_ctx.info[left].ty.inner_with(&module.types).is_matrix() - || func_ctx.info[right] - .ty - .inner_with(&module.types) - .is_matrix() => + } if func_ctx.resolve_type(left, &module.types).is_matrix() + || func_ctx.resolve_type(right, &module.types).is_matrix() => { // We intentionally flip the order of multiplication as our matrices are implicitly transposed. write!(self.out, "mul(")?; @@ -2196,10 +2189,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { op: crate::BinaryOperator::Modulo, left, right, - } if func_ctx.info[left] - .ty - .inner_with(&module.types) - .scalar_kind() + } if func_ctx.resolve_type(left, &module.types).scalar_kind() == Some(crate::ScalarKind::Float) => { write!(self.out, "fmod(")?; @@ -2216,10 +2206,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, ")")?; } Expression::Access { base, index } => { - if let Some(crate::AddressSpace::Storage { .. }) = func_ctx.info[expr] - .ty - .inner_with(&module.types) - .pointer_space() + if let Some(crate::AddressSpace::Storage { .. }) = + func_ctx.resolve_type(expr, &module.types).pointer_space() { // do nothing, the chain is written on `Load`/`Store` } else { @@ -2243,8 +2231,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { return Ok(()); } - let base_ty_res = &func_ctx.info[base].ty; - let resolved = base_ty_res.inner_with(&module.types); + let resolved = func_ctx.resolve_type(base, &module.types); let non_uniform_qualifier = match *resolved { TypeInner::BindingArray { .. } => { @@ -2268,10 +2255,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } Expression::AccessIndex { base, index } => { - if let Some(crate::AddressSpace::Storage { .. }) = func_ctx.info[expr] - .ty - .inner_with(&module.types) - .pointer_space() + if let Some(crate::AddressSpace::Storage { .. }) = + func_ctx.resolve_type(expr, &module.types).pointer_space() { // do nothing, the chain is written on `Load`/`Store` } else { @@ -2450,7 +2435,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { dim, arrayed, class, - } = *func_ctx.info[image].ty.inner_with(&module.types) + } = *func_ctx.resolve_type(image, &module.types) { let wrapped_image_query = WrappedImageQuery { dim, @@ -2499,8 +2484,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, ")")?; // return x component if return type is scalar - if let TypeInner::Scalar { .. } = *func_ctx.info[expr].ty.inner_with(&module.types) - { + if let TypeInner::Scalar { .. } = *func_ctx.resolve_type(expr, &module.types) { write!(self.out, ".x")?; } } @@ -2515,9 +2499,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])? } Expression::Load { pointer } => { - match func_ctx.info[pointer] - .ty - .inner_with(&module.types) + match func_ctx + .resolve_type(pointer, &module.types) .pointer_space() { Some(crate::AddressSpace::Storage { .. }) => { @@ -2541,7 +2524,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { ) .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx)) { - let mut resolved = func_ctx.info[pointer].ty.inner_with(&module.types); + let mut resolved = func_ctx.resolve_type(pointer, &module.types); if let TypeInner::Pointer { base, .. } = *resolved { resolved = &module.types[base].inner; } @@ -2581,7 +2564,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { kind, convert, } => { - let inner = func_ctx.info[expr].ty.inner_with(&module.types); + let inner = func_ctx.resolve_type(expr, &module.types); match convert { Some(dst_width) => { match *inner { @@ -2956,11 +2939,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, ")")? } Function::MissingIntOverload(fun_name) => { - let scalar_kind = &func_ctx.info[arg] - .ty - .inner_with(&module.types) - .scalar_kind(); - if let Some(ScalarKind::Sint) = *scalar_kind { + let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind(); + if let Some(ScalarKind::Sint) = scalar_kind { write!(self.out, "asint({fun_name}(asuint(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, ")))")?; @@ -2971,11 +2951,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } Function::MissingIntReturnType(fun_name) => { - let scalar_kind = &func_ctx.info[arg] - .ty - .inner_with(&module.types) - .scalar_kind(); - if let Some(ScalarKind::Sint) = *scalar_kind { + let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind(); + if let Some(ScalarKind::Sint) = scalar_kind { write!(self.out, "asint({fun_name}(")?; self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; @@ -2986,7 +2963,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } Function::CountTrailingZeros => { - match *func_ctx.info[arg].ty.inner_with(&module.types) { + match *func_ctx.resolve_type(arg, &module.types) { TypeInner::Vector { size, kind, .. } => { let s = match size { crate::VectorSize::Bi => ".xx", @@ -3021,7 +2998,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { return Ok(()); } Function::CountLeadingZeros => { - match *func_ctx.info[arg].ty.inner_with(&module.types) { + match *func_ctx.resolve_type(arg, &module.types) { TypeInner::Vector { size, kind, .. } => { let s = match size { crate::VectorSize::Bi => ".xx", @@ -3209,8 +3186,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } - let base_ty_res = &ctx.info[named].ty; - let resolved = base_ty_res.inner_with(&module.types); + let resolved = ctx.resolve_type(named, &module.types); write!(self.out, " {name}")?; // If rhs is a array type, we should write array size @@ -3287,7 +3263,7 @@ pub(super) fn get_inner_matrix_of_struct_array_member( let mut current_base = base; loop { - let mut resolved = func_ctx.info[current_base].ty.inner_with(&module.types); + let mut resolved = func_ctx.resolve_type(current_base, &module.types); if let TypeInner::Pointer { base, .. } = *resolved { resolved = &module.types[base].inner; }; @@ -3344,7 +3320,7 @@ fn get_inner_matrix_of_global_uniform( let mut current_base = base; loop { - let mut resolved = func_ctx.info[current_base].ty.inner_with(&module.types); + let mut resolved = func_ctx.resolve_type(current_base, &module.types); if let TypeInner::Pointer { base, .. } = *resolved { resolved = &module.types[base].inner; }; diff --git a/src/back/mod.rs b/src/back/mod.rs index 1314e14d77..8100b930e9 100644 --- a/src/back/mod.rs +++ b/src/back/mod.rs @@ -135,7 +135,7 @@ impl FunctionCtx<'_> { }; } crate::Expression::AccessIndex { base, index } => { - match *self.info[base].ty.inner_with(&module.types) { + match *self.resolve_type(base, &module.types) { crate::TypeInner::Struct { ref members, .. } => { if let Some(crate::Binding::BuiltIn(bi)) = members[index as usize].binding diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 91315eb5b7..09f7b1c73f 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -2624,11 +2624,7 @@ impl Writer { )?; } - let info = &context.expression.info[handle]; - let ptr_class = info - .ty - .inner_with(&context.expression.module.types) - .pointer_space(); + let ptr_class = context.expression.resolve_type(handle).pointer_space(); let expr_name = if ptr_class.is_some() { None // don't bake pointer expressions (just yet) } else if let Some(name) =