From c7d02151f08d6285683795289b5725b827d836d1 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Wed, 25 Jan 2023 18:00:11 +0100 Subject: [PATCH] add support for zero-initializing workgroup memory --- benches/criterion.rs | 1 + src/back/glsl/mod.rs | 49 +- src/back/hlsl/mod.rs | 3 + src/back/hlsl/writer.rs | 57 ++ src/back/mod.rs | 11 + src/back/msl/mod.rs | 3 + src/back/msl/writer.rs | 237 +++++++- src/back/spv/mod.rs | 14 + src/back/spv/writer.rs | 197 ++++++- tests/in/access.param.ron | 1 + tests/in/binding-arrays.param.ron | 2 + tests/in/bitcast.params.ron | 1 + tests/in/bits.param.ron | 1 + tests/in/boids.param.ron | 1 + .../in/bounds-check-image-restrict.param.ron | 1 + tests/in/bounds-check-image-rzsw.param.ron | 1 + tests/in/extra.param.ron | 1 + tests/in/functions-webgl.param.ron | 1 + tests/in/interface.param.ron | 2 + tests/in/interpolate.param.ron | 1 + tests/in/multiview_webgl.param.ron | 1 + tests/in/padding.param.ron | 1 + tests/in/push-constants.param.ron | 2 + tests/in/quad.param.ron | 1 + tests/in/skybox.param.ron | 3 + tests/in/workgroup-var-init.param.ron | 22 + tests/in/workgroup-var-init.wgsl | 15 + .../access.assign_through_ptr.Compute.glsl | 5 + tests/out/glsl/globals.main.Compute.glsl | 6 + .../glsl/workgroup-var-init.main.Compute.glsl | 28 + tests/out/hlsl/access.hlsl | 6 +- tests/out/hlsl/globals.hlsl | 7 +- tests/out/hlsl/interface.hlsl | 6 +- tests/out/hlsl/workgroup-var-init.hlsl | 535 ++++++++++++++++++ tests/out/hlsl/workgroup-var-init.hlsl.config | 3 + tests/out/msl/access.msl | 7 +- tests/out/msl/globals.msl | 8 +- tests/out/msl/interface.msl | 4 + tests/out/msl/workgroup-var-init.msl | 40 ++ tests/out/spv/access.spvasm | 26 +- tests/out/spv/globals.spvasm | 136 +++-- tests/out/spv/interface.compute.spvasm | 41 +- tests/out/spv/workgroup-var-init.spvasm | 78 +++ tests/out/wgsl/workgroup-var-init.wgsl | 16 + tests/snapshots.rs | 5 + 45 files changed, 1502 insertions(+), 85 deletions(-) create mode 100644 tests/in/workgroup-var-init.param.ron create mode 100644 tests/in/workgroup-var-init.wgsl create mode 100644 tests/out/glsl/workgroup-var-init.main.Compute.glsl create mode 100644 tests/out/hlsl/workgroup-var-init.hlsl create mode 100644 tests/out/hlsl/workgroup-var-init.hlsl.config create mode 100644 tests/out/msl/workgroup-var-init.msl create mode 100644 tests/out/spv/workgroup-var-init.spvasm create mode 100644 tests/out/wgsl/workgroup-var-init.wgsl diff --git a/benches/criterion.rs b/benches/criterion.rs index 81c55c70cd..0bf2478993 100644 --- a/benches/criterion.rs +++ b/benches/criterion.rs @@ -242,6 +242,7 @@ fn backends(c: &mut Criterion) { version: naga::back::glsl::Version::new_gles(320), writer_flags: naga::back::glsl::WriterFlags::empty(), binding_map: Default::default(), + zero_initialize_workgroup_memory: true, }; for &(ref module, ref info) in inputs.iter() { for ep in module.entry_points.iter() { diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index ede4bf8481..9b4bf0cf23 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -232,6 +232,8 @@ pub struct Options { pub writer_flags: WriterFlags, /// Map of resources association to binding locations. pub binding_map: BindingMap, + /// Should workgroup variables be zero initialized (by polyfilling)? + pub zero_initialize_workgroup_memory: bool, } impl Default for Options { @@ -240,6 +242,7 @@ impl Default for Options { version: Version::new_gles(310), writer_flags: WriterFlags::ADJUST_COORDINATE_SPACE, binding_map: BindingMap::default(), + zero_initialize_workgroup_memory: true, } } } @@ -1432,6 +1435,12 @@ impl<'a, W: Write> Writer<'a, W> { // Close the parentheses and open braces to start the function body writeln!(self.out, ") {{")?; + if self.options.zero_initialize_workgroup_memory + && ctx.ty.is_compute_entry_point(self.module) + { + self.write_workgroup_variables_initialization(&ctx)?; + } + // Compose the function arguments from globals, in case of an entry point. if let back::FunctionType::EntryPoint(ep_index) = ctx.ty { let stage = self.module.entry_points[ep_index as usize].stage; @@ -1520,6 +1529,42 @@ impl<'a, W: Write> Writer<'a, W> { Ok(()) } + fn write_workgroup_variables_initialization( + &mut self, + ctx: &back::FunctionCtx, + ) -> BackendResult { + let mut vars = self + .module + .global_variables + .iter() + .filter(|&(handle, var)| { + !ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }) + .peekable(); + + if vars.peek().is_some() { + let level = back::Level(1); + + writeln!( + self.out, + "{}if (gl_GlobalInvocationID == uvec3(0u)) {{", + level + )?; + + for (handle, var) in vars { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, "{}{} = ", level.next(), name)?; + self.write_zero_init_value(var.ty)?; + writeln!(self.out, ";")?; + } + + writeln!(self.out, "{}}}", level)?; + self.write_barrier(crate::Barrier::WORK_GROUP, level)?; + } + + Ok(()) + } + /// Helper method that writes a list of comma separated `T` with a writer function `F` /// /// The writer function `F` receives a mutable reference to `self` that if needed won't cause @@ -3548,7 +3593,7 @@ impl<'a, W: Write> Writer<'a, W> { fn write_zero_init_value(&mut self, ty: Handle) -> BackendResult { let inner = &self.module.types[ty].inner; match *inner { - TypeInner::Scalar { kind, .. } => { + TypeInner::Scalar { kind, .. } | TypeInner::Atomic { kind, .. } => { self.write_zero_init_scalar(kind)?; } TypeInner::Vector { kind, .. } => { @@ -3593,7 +3638,7 @@ impl<'a, W: Write> Writer<'a, W> { } write!(self.out, ")")?; } - _ => {} // TODO: + _ => unreachable!(), } Ok(()) diff --git a/src/back/hlsl/mod.rs b/src/back/hlsl/mod.rs index 333ea2cf1a..1031270390 100644 --- a/src/back/hlsl/mod.rs +++ b/src/back/hlsl/mod.rs @@ -191,6 +191,8 @@ pub struct Options { pub special_constants_binding: Option, /// Bind target of the push constant buffer pub push_constants_target: Option, + /// Should workgroup variables be zero initialized (by polyfilling)? + pub zero_initialize_workgroup_memory: bool, } impl Default for Options { @@ -201,6 +203,7 @@ impl Default for Options { fake_missing_bindings: true, special_constants_binding: None, push_constants_target: None, + zero_initialize_workgroup_memory: true, } } } diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 59050455a0..664b326a16 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1077,6 +1077,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { // Write function name write!(self.out, " {}(", name)?; + let need_workgroup_variables_initialization = + self.need_workgroup_variables_initialization(func_ctx, module); + // Write function arguments for non entry point functions match func_ctx.ty { back::FunctionType::Function(handle) => { @@ -1129,6 +1132,16 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_semantic(binding, Some((stage, Io::Input)))?; } } + + if need_workgroup_variables_initialization { + if !func.arguments.is_empty() { + write!(self.out, ", ")?; + } + write!( + self.out, + "uint3 __global_invocation_id : SV_DispatchThreadID" + )?; + } } } } @@ -1151,6 +1164,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out)?; writeln!(self.out, "{{")?; + if need_workgroup_variables_initialization { + self.write_workgroup_variables_initialization(func_ctx, module)?; + } + if let back::FunctionType::EntryPoint(index) = func_ctx.ty { self.write_ep_arguments_initialization(module, func, index)?; } @@ -1204,6 +1221,46 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Ok(()) } + fn need_workgroup_variables_initialization( + &mut self, + func_ctx: &back::FunctionCtx, + module: &Module, + ) -> bool { + self.options.zero_initialize_workgroup_memory + && func_ctx.ty.is_compute_entry_point(module) + && module.global_variables.iter().any(|(handle, var)| { + !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }) + } + + fn write_workgroup_variables_initialization( + &mut self, + func_ctx: &back::FunctionCtx, + module: &Module, + ) -> BackendResult { + let level = back::Level(1); + + writeln!( + self.out, + "{}if (all(__global_invocation_id == uint3(0u, 0u, 0u))) {{", + level + )?; + + let vars = module.global_variables.iter().filter(|&(handle, var)| { + !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }); + + for (handle, var) in vars { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, "{}{} = ", level.next(), name)?; + self.write_default_init(module, var.ty)?; + writeln!(self.out, ";")?; + } + + writeln!(self.out, "{}}}", level)?; + self.write_barrier(crate::Barrier::WORK_GROUP, level) + } + /// Helper method used to write statements /// /// # Notes diff --git a/src/back/mod.rs b/src/back/mod.rs index d8e016c008..56223ac2bb 100644 --- a/src/back/mod.rs +++ b/src/back/mod.rs @@ -47,6 +47,17 @@ enum FunctionType { EntryPoint(crate::proc::EntryPointIndex), } +impl FunctionType { + fn is_compute_entry_point(&self, module: &crate::Module) -> bool { + match *self { + FunctionType::EntryPoint(index) => { + module.entry_points[index as usize].stage == crate::ShaderStage::Compute + } + _ => false, + } + } +} + /// Helper structure that stores data needed when writing the function struct FunctionCtx<'a> { /// The current function being written diff --git a/src/back/msl/mod.rs b/src/back/msl/mod.rs index 4c0177173e..679e5ed29b 100644 --- a/src/back/msl/mod.rs +++ b/src/back/msl/mod.rs @@ -209,6 +209,8 @@ pub struct Options { /// Bounds checking policies. #[cfg_attr(feature = "deserialize", serde(default))] pub bounds_check_policies: index::BoundsCheckPolicies, + /// Should workgroup variables be zero initialized (by polyfilling)? + pub zero_initialize_workgroup_memory: bool, } impl Default for Options { @@ -220,6 +222,7 @@ impl Default for Options { spirv_cross_compatibility: false, fake_missing_bindings: true, bounds_check_policies: index::BoundsCheckPolicies::default(), + zero_initialize_workgroup_memory: true, } } } diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 331f7a2db5..5b505b358a 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -3618,6 +3618,8 @@ impl Writer { is_first_argument = false; } + let mut global_invocation_id = None; + // Then pass the remaining arguments not included in the varyings // struct. // @@ -3629,7 +3631,7 @@ impl Writer { let mut flattened_member_names = FastHashMap::default(); for &(ref name_key, ty, binding) in flattened_arguments.iter() { let binding = match binding { - Some(ref binding @ &crate::Binding::BuiltIn { .. }) => binding, + Some(binding @ &crate::Binding::BuiltIn { .. }) => binding, _ => continue, }; let name = if let NameKey::StructMember(ty, index) = *name_key { @@ -3642,6 +3644,11 @@ impl Writer { } else { &self.names[name_key] }; + + if binding == &crate::Binding::BuiltIn(crate::BuiltIn::GlobalInvocationId) { + global_invocation_id = Some(name_key); + } + let ty_name = TypeContext { handle: ty, module, @@ -3662,6 +3669,23 @@ impl Writer { writeln!(self.out)?; } + let need_workgroup_variables_initialization = + self.need_workgroup_variables_initialization(options, ep, module, fun_info); + + if need_workgroup_variables_initialization && global_invocation_id.is_none() { + let separator = if is_first_argument { + is_first_argument = false; + ' ' + } else { + ',' + }; + writeln!( + self.out, + "{} {}::uint3 __global_invocation_id [[thread_position_in_grid]]", + separator, NAMESPACE + )?; + } + // Those global variables used by this entry point and its callees // get passed as arguments. `Private` globals are an exception, they // don't outlive this invocation, so we declare them below as locals @@ -3744,6 +3768,15 @@ impl Writer { // end of the entry point argument list writeln!(self.out, ") {{")?; + if need_workgroup_variables_initialization { + self.write_workgroup_variables_initialization( + module, + mod_info, + fun_info, + global_invocation_id, + )?; + } + // Metal doesn't support private mutable variables outside of functions, // so we put them here, just like the locals. for (handle, var) in module.global_variables.iter() { @@ -3939,6 +3972,208 @@ impl Writer { } } +/// Initializing workgroup variables is more tricky for Metal because we have to deal +/// with atomics at the type-level (which don't have a copy constructor). +mod workgroup_mem_init { + use crate::EntryPoint; + + use super::*; + + enum Access { + GlobalVariable(Handle), + StructMember(Handle, u32), + Array(usize), + } + + impl Access { + fn write( + &self, + writer: &mut W, + names: &FastHashMap, + ) -> Result<(), core::fmt::Error> { + match *self { + Access::GlobalVariable(handle) => { + write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)]) + } + Access::StructMember(handle, index) => { + write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)]) + } + Access::Array(depth) => write!(writer, ".{}[__i{}]", WRAPPED_ARRAY_FIELD, depth), + } + } + } + + struct AccessStack { + stack: Vec, + array_depth: usize, + } + + impl AccessStack { + const fn new() -> Self { + Self { + stack: Vec::new(), + array_depth: 0, + } + } + + fn enter_array(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R { + let array_depth = self.array_depth; + self.stack.push(Access::Array(array_depth)); + self.array_depth += 1; + let res = cb(self, array_depth); + self.stack.pop(); + self.array_depth -= 1; + res + } + + fn enter(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R { + self.stack.push(new); + let res = cb(self); + self.stack.pop(); + res + } + + fn write( + &self, + writer: &mut W, + names: &FastHashMap, + ) -> Result<(), core::fmt::Error> { + for next in self.stack.iter() { + next.write(writer, names)?; + } + Ok(()) + } + } + + impl Writer { + pub(super) fn need_workgroup_variables_initialization( + &mut self, + options: &Options, + ep: &EntryPoint, + module: &crate::Module, + fun_info: &valid::FunctionInfo, + ) -> bool { + options.zero_initialize_workgroup_memory + && ep.stage == crate::ShaderStage::Compute + && module.global_variables.iter().any(|(handle, var)| { + !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }) + } + + pub(super) fn write_workgroup_variables_initialization( + &mut self, + module: &crate::Module, + module_info: &valid::ModuleInfo, + fun_info: &valid::FunctionInfo, + global_invocation_id: Option<&NameKey>, + ) -> BackendResult { + let level = back::Level(1); + + writeln!( + self.out, + "{}if ({}::all({} == {}::uint3(0u))) {{", + level, + NAMESPACE, + global_invocation_id + .map(|name_key| self.names[name_key].as_str()) + .unwrap_or("__global_invocation_id"), + NAMESPACE, + )?; + + let mut access_stack = AccessStack::new(); + + let vars = module.global_variables.iter().filter(|&(handle, var)| { + !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }); + + for (handle, var) in vars { + access_stack.enter(Access::GlobalVariable(handle), |access_stack| { + self.write_workgroup_variable_initialization( + module, + module_info, + var.ty, + access_stack, + level.next(), + ) + })?; + } + + writeln!(self.out, "{}}}", level)?; + self.write_barrier(crate::Barrier::WORK_GROUP, level) + } + + fn write_workgroup_variable_initialization( + &mut self, + module: &crate::Module, + module_info: &valid::ModuleInfo, + ty: Handle, + access_stack: &mut AccessStack, + level: back::Level, + ) -> BackendResult { + if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) { + write!(self.out, "{}", level)?; + access_stack.write(&mut self.out, &self.names)?; + writeln!(self.out, " = {{}};")?; + } else { + match module.types[ty].inner { + crate::TypeInner::Atomic { .. } => { + write!( + self.out, + "{}{}::atomic_store_explicit({}", + level, NAMESPACE, ATOMIC_REFERENCE + )?; + access_stack.write(&mut self.out, &self.names)?; + writeln!(self.out, ", 0, {}::memory_order_relaxed);", NAMESPACE)?; + } + crate::TypeInner::Array { base, size, .. } => { + let count = match size.to_indexable_length(module).expect("Bad array size") + { + proc::IndexableLength::Known(count) => count, + proc::IndexableLength::Dynamic => unreachable!(), + }; + + access_stack.enter_array(|access_stack, array_depth| { + writeln!( + self.out, + "{}for (int __i{} = 0; __i{} < {}; __i{}++) {{", + level, array_depth, array_depth, count, array_depth + )?; + self.write_workgroup_variable_initialization( + module, + module_info, + base, + access_stack, + level.next(), + )?; + writeln!(self.out, "{}}}", level)?; + BackendResult::Ok(()) + })?; + } + crate::TypeInner::Struct { ref members, .. } => { + for (index, member) in members.iter().enumerate() { + access_stack.enter( + Access::StructMember(ty, index as u32), + |access_stack| { + self.write_workgroup_variable_initialization( + module, + module_info, + member.ty, + access_stack, + level, + ) + }, + )?; + } + } + _ => unreachable!(), + } + } + + Ok(()) + } + } +} + #[test] fn test_stack_size() { use crate::valid::{Capabilities, ValidationFlags}; diff --git a/src/back/spv/mod.rs b/src/back/spv/mod.rs index 544f5ca4f5..e19d825a05 100644 --- a/src/back/spv/mod.rs +++ b/src/back/spv/mod.rs @@ -582,6 +582,7 @@ pub struct Writer { annotations: Vec, flags: WriterFlags, bounds_check_policies: BoundsCheckPolicies, + zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode, void_type: Word, //TODO: convert most of these into vectors, addressable by handle indices lookup_type: crate::FastHashMap, @@ -630,6 +631,15 @@ pub struct BindingInfo { // Using `BTreeMap` instead of `HashMap` so that we can hash itself. pub type BindingMap = std::collections::BTreeMap; +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ZeroInitializeWorkgroupMemoryMode { + /// Via `VK_KHR_zero_initialize_workgroup_memory` or Vulkan 1.3 + Native, + /// Via assignments + barrier + Polyfill, + None, +} + #[derive(Debug, Clone)] pub struct Options { /// (Major, Minor) target version of the SPIR-V. @@ -650,6 +660,9 @@ pub struct Options { /// How should generate code handle array, vector, matrix, or image texel /// indices that are out of range? pub bounds_check_policies: BoundsCheckPolicies, + + /// Dictates the way workgroup variables should be zero initialized + pub zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode, } impl Default for Options { @@ -666,6 +679,7 @@ impl Default for Options { binding_map: BindingMap::default(), capabilities: None, bounds_check_policies: crate::proc::BoundsCheckPolicies::default(), + zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode::Polyfill, } } } diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index d61ae08bfe..f19d83e454 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -65,6 +65,7 @@ impl Writer { annotations: vec![], flags: options.flags, bounds_check_policies: options.bounds_check_policies, + zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory, void_type, lookup_type: crate::FastHashMap::default(), lookup_function: crate::FastHashMap::default(), @@ -102,6 +103,7 @@ impl Writer { // Copied from the old Writer: flags: self.flags, bounds_check_policies: self.bounds_check_policies, + zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory, capabilities_available: take(&mut self.capabilities_available), binding_map: take(&mut self.binding_map), @@ -248,6 +250,16 @@ impl Writer { self.get_type_id(local_type.into()) } + pub(super) fn get_uint3_type_id(&mut self) -> Word { + let local_type = LocalType::Value { + vector_size: Some(crate::VectorSize::Tri), + kind: crate::ScalarKind::Uint, + width: 4, + pointer_space: None, + }; + self.get_type_id(local_type.into()) + } + pub(super) fn get_float_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word { let lookup_type = LookupType::Local(LocalType::Value { vector_size: None, @@ -267,6 +279,25 @@ impl Writer { } } + pub(super) fn get_uint3_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word { + let lookup_type = LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Tri), + kind: crate::ScalarKind::Uint, + width: 4, + pointer_space: Some(class), + }); + if let Some(&id) = self.lookup_type.get(&lookup_type) { + id + } else { + let id = self.id_gen.next(); + let ty_id = self.get_uint3_type_id(); + let instruction = Instruction::type_pointer(id, class, ty_id); + instruction.to_words(&mut self.logical_layout.declarations); + self.lookup_type.insert(lookup_type, id); + id + } + } + pub(super) fn get_bool_type_id(&mut self) -> Word { let local_type = LocalType::Value { vector_size: None, @@ -277,6 +308,16 @@ impl Writer { self.get_type_id(local_type.into()) } + pub(super) fn get_bool3_type_id(&mut self) -> Word { + let local_type = LocalType::Value { + vector_size: Some(crate::VectorSize::Tri), + kind: crate::ScalarKind::Bool, + width: 1, + pointer_space: None, + }; + self.get_type_id(local_type.into()) + } + pub(super) fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) { self.annotations .push(Instruction::decorate(id, decoration, operands)); @@ -326,6 +367,8 @@ impl Writer { results: Vec::new(), }; + let mut global_invocation_id = None; + let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len()); for argument in ir_function.arguments.iter() { let class = spirv::StorageClass::Input; @@ -356,6 +399,11 @@ impl Writer { prelude .body .push(Instruction::load(argument_type_id, id, varying_id, None)); + + if binding == &crate::Binding::BuiltIn(crate::BuiltIn::GlobalInvocationId) { + global_invocation_id = Some(id); + } + id } else if let crate::TypeInner::Struct { ref members, .. } = ir_module.types[argument.ty].inner @@ -380,6 +428,10 @@ impl Writer { .body .push(Instruction::load(type_id, id, varying_id, None)); constituent_ids.push(id); + + if binding == &crate::Binding::BuiltIn(crate::BuiltIn::GlobalInvocationId) { + global_invocation_id = Some(id); + } } prelude.body.push(Instruction::composite_construct( argument_type_id, @@ -595,10 +647,39 @@ impl Writer { } } - let main_id = context.gen_id(); + let next_id = context.gen_id(); + context .function - .consume(prelude, Instruction::branch(main_id)); + .consume(prelude, Instruction::branch(next_id)); + + let workgroup_vars_init_exit_block_id = + match (context.writer.zero_initialize_workgroup_memory, interface) { + ( + super::ZeroInitializeWorkgroupMemoryMode::Polyfill, + Some( + ref mut interface @ FunctionInterface { + stage: crate::ShaderStage::Compute, + .. + }, + ), + ) => context.writer.generate_workgroup_vars_init_block( + next_id, + ir_module, + info, + global_invocation_id, + interface, + context.function, + ), + _ => None, + }; + + let main_id = if let Some(exit_id) = workgroup_vars_init_exit_block_id { + exit_id + } else { + next_id + }; + context.write_block( main_id, &ir_function.body, @@ -1143,6 +1224,113 @@ impl Writer { )); } + fn generate_workgroup_vars_init_block( + &mut self, + entry_id: Word, + ir_module: &crate::Module, + info: &FunctionInfo, + global_invocation_id: Option, + interface: &mut FunctionInterface, + function: &mut Function, + ) -> Option { + let body = ir_module + .global_variables + .iter() + .filter(|&(handle, var)| { + !info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }) + .map(|(handle, var)| { + // It's safe to use `var_id` here, not `access_id`, because only + // variables in the `Uniform` and `StorageBuffer` address spaces + // get wrapped, and we're initializing `WorkGroup` variables. + let var_id = self.global_variables[handle.index()].var_id; + let var_type_id = self.get_type_id(LookupType::Handle(var.ty)); + let init_word = self.write_constant_null(var_type_id); + Instruction::store(var_id, init_word, None) + }) + .collect::>(); + + if body.is_empty() { + return None; + } + + let uint3_type_id = self.get_uint3_type_id(); + + let mut pre_if_block = Block::new(entry_id); + + let global_invocation_id = if let Some(global_invocation_id) = global_invocation_id { + global_invocation_id + } else { + let varying_id = self.id_gen.next(); + let class = spirv::StorageClass::Input; + let pointer_type_id = self.get_uint3_pointer_type_id(class); + + Instruction::variable(pointer_type_id, varying_id, class, None) + .to_words(&mut self.logical_layout.declarations); + + self.decorate( + varying_id, + spirv::Decoration::BuiltIn, + &[spirv::BuiltIn::GlobalInvocationId as u32], + ); + + interface.varying_ids.push(varying_id); + let id = self.id_gen.next(); + pre_if_block + .body + .push(Instruction::load(uint3_type_id, id, varying_id, None)); + + id + }; + + let zero_id = self.write_constant_null(uint3_type_id); + let bool3_type_id = self.get_bool3_type_id(); + + let eq_id = self.id_gen.next(); + pre_if_block.body.push(Instruction::binary( + spirv::Op::IEqual, + bool3_type_id, + eq_id, + global_invocation_id, + zero_id, + )); + + let condition_id = self.id_gen.next(); + let bool_type_id = self.get_bool_type_id(); + pre_if_block.body.push(Instruction::relational( + spirv::Op::All, + bool_type_id, + condition_id, + eq_id, + )); + + let merge_id = self.id_gen.next(); + pre_if_block.body.push(Instruction::selection_merge( + merge_id, + spirv::SelectionControl::NONE, + )); + + let accept_id = self.id_gen.next(); + function.consume( + pre_if_block, + Instruction::branch_conditional(condition_id, accept_id, merge_id), + ); + + let accept_block = Block { + label_id: accept_id, + body, + }; + function.consume(accept_block, Instruction::branch(merge_id)); + + let mut post_if_block = Block::new(merge_id); + + self.write_barrier(crate::Barrier::WORK_GROUP, &mut post_if_block); + + let next_id = self.id_gen.next(); + function.consume(post_if_block, Instruction::branch(next_id)); + Some(next_id) + } + /// Generate an `OpVariable` for one value in an [`EntryPoint`]'s IO interface. /// /// The [`Binding`]s of the arguments and result of an [`EntryPoint`]'s @@ -1414,8 +1602,9 @@ impl Writer { } }; - let init_word = match global_variable.space { - crate::AddressSpace::Private => { + let init_word = match (global_variable.space, self.zero_initialize_workgroup_memory) { + (crate::AddressSpace::Private, _) + | (crate::AddressSpace::WorkGroup, super::ZeroInitializeWorkgroupMemoryMode::Native) => { init_word.or_else(|| Some(self.write_constant_null(inner_type_id))) } _ => init_word, diff --git a/tests/in/access.param.ron b/tests/in/access.param.ron index 8408e4cb6b..af45e4f970 100644 --- a/tests/in/access.param.ron +++ b/tests/in/access.param.ron @@ -33,5 +33,6 @@ inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, ), ) diff --git a/tests/in/binding-arrays.param.ron b/tests/in/binding-arrays.param.ron index 2d3e15263b..ad5b0ee319 100644 --- a/tests/in/binding-arrays.param.ron +++ b/tests/in/binding-arrays.param.ron @@ -15,6 +15,7 @@ }, fake_missing_bindings: true, special_constants_binding: None, + zero_initialize_workgroup_memory: true, ), msl: ( lang_version: (2, 0), @@ -29,6 +30,7 @@ inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: true, + zero_initialize_workgroup_memory: true, ), spv: ( version: (1, 1), diff --git a/tests/in/bitcast.params.ron b/tests/in/bitcast.params.ron index fb45784e55..324cd4a518 100644 --- a/tests/in/bitcast.params.ron +++ b/tests/in/bitcast.params.ron @@ -11,5 +11,6 @@ inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, ), ) diff --git a/tests/in/bits.param.ron b/tests/in/bits.param.ron index fb45784e55..324cd4a518 100644 --- a/tests/in/bits.param.ron +++ b/tests/in/bits.param.ron @@ -11,5 +11,6 @@ inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, ), ) diff --git a/tests/in/boids.param.ron b/tests/in/boids.param.ron index 91181f7ddf..e6d752fca8 100644 --- a/tests/in/boids.param.ron +++ b/tests/in/boids.param.ron @@ -19,5 +19,6 @@ inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, ), ) diff --git a/tests/in/bounds-check-image-restrict.param.ron b/tests/in/bounds-check-image-restrict.param.ron index e91f7fc24d..fedcf8407d 100644 --- a/tests/in/bounds-check-image-restrict.param.ron +++ b/tests/in/bounds-check-image-restrict.param.ron @@ -10,5 +10,6 @@ version: Desktop(430), writer_flags: (bits: 0), binding_map: { }, + zero_initialize_workgroup_memory: true, ), ) diff --git a/tests/in/bounds-check-image-rzsw.param.ron b/tests/in/bounds-check-image-rzsw.param.ron index 72ccb27e9d..40974004c7 100644 --- a/tests/in/bounds-check-image-rzsw.param.ron +++ b/tests/in/bounds-check-image-rzsw.param.ron @@ -10,5 +10,6 @@ version: Desktop(430), writer_flags: (bits: 0), binding_map: { }, + zero_initialize_workgroup_memory: true, ), ) diff --git a/tests/in/extra.param.ron b/tests/in/extra.param.ron index ba9e087592..051dee6432 100644 --- a/tests/in/extra.param.ron +++ b/tests/in/extra.param.ron @@ -13,5 +13,6 @@ inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, ), ) diff --git a/tests/in/functions-webgl.param.ron b/tests/in/functions-webgl.param.ron index 8ac016ad21..862e6a5d03 100644 --- a/tests/in/functions-webgl.param.ron +++ b/tests/in/functions-webgl.param.ron @@ -6,5 +6,6 @@ ), writer_flags: (bits: 0), binding_map: {}, + zero_initialize_workgroup_memory: true, ), ) diff --git a/tests/in/interface.param.ron b/tests/in/interface.param.ron index eeb45b0d7d..5df3e3b9c3 100644 --- a/tests/in/interface.param.ron +++ b/tests/in/interface.param.ron @@ -12,6 +12,7 @@ binding_map: {}, fake_missing_bindings: false, special_constants_binding: Some((space: 1, register: 0)), + zero_initialize_workgroup_memory: true, ), wgsl: ( explicit_types: true, @@ -22,6 +23,7 @@ inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, ), msl_pipeline: ( allow_point_size: true, diff --git a/tests/in/interpolate.param.ron b/tests/in/interpolate.param.ron index cb08368c87..a885c9032f 100644 --- a/tests/in/interpolate.param.ron +++ b/tests/in/interpolate.param.ron @@ -10,5 +10,6 @@ version: Desktop(400), writer_flags: (bits: 0), binding_map: {}, + zero_initialize_workgroup_memory: true, ), ) diff --git a/tests/in/multiview_webgl.param.ron b/tests/in/multiview_webgl.param.ron index 720b51fbbc..a8bc096646 100644 --- a/tests/in/multiview_webgl.param.ron +++ b/tests/in/multiview_webgl.param.ron @@ -6,6 +6,7 @@ ), writer_flags: (bits: 0), binding_map: {}, + zero_initialize_workgroup_memory: true, ), glsl_multiview: Some(2), ) diff --git a/tests/in/padding.param.ron b/tests/in/padding.param.ron index 9e20176bac..41af2916a9 100644 --- a/tests/in/padding.param.ron +++ b/tests/in/padding.param.ron @@ -18,5 +18,6 @@ inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, ), ) diff --git a/tests/in/push-constants.param.ron b/tests/in/push-constants.param.ron index a5cfe142c4..46874a15af 100644 --- a/tests/in/push-constants.param.ron +++ b/tests/in/push-constants.param.ron @@ -7,6 +7,7 @@ ), writer_flags: (bits: 0), binding_map: {}, + zero_initialize_workgroup_memory: true, ), hlsl: ( shader_model: V5_1, @@ -14,5 +15,6 @@ fake_missing_bindings: true, special_constants_binding: Some((space: 1, register: 0)), push_constants_target: Some((space: 0, register: 0)), + zero_initialize_workgroup_memory: true, ), ) diff --git a/tests/in/quad.param.ron b/tests/in/quad.param.ron index 568b961898..db93369f6d 100644 --- a/tests/in/quad.param.ron +++ b/tests/in/quad.param.ron @@ -11,5 +11,6 @@ ), writer_flags: (bits: 0), binding_map: {}, + zero_initialize_workgroup_memory: true, ), ) diff --git a/tests/in/skybox.param.ron b/tests/in/skybox.param.ron index b90f2e5a25..905721c914 100644 --- a/tests/in/skybox.param.ron +++ b/tests/in/skybox.param.ron @@ -35,6 +35,7 @@ ], spirv_cross_compatibility: false, fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, ), glsl: ( version: Embedded( @@ -46,6 +47,7 @@ (group: 0, binding: 0): 0, (group: 0, binding: 1): 0, }, + zero_initialize_workgroup_memory: true, ), hlsl: ( shader_model: V5_1, @@ -56,5 +58,6 @@ }, fake_missing_bindings: false, special_constants_binding: Some((space: 0, register: 1)), + zero_initialize_workgroup_memory: true, ), ) diff --git a/tests/in/workgroup-var-init.param.ron b/tests/in/workgroup-var-init.param.ron new file mode 100644 index 0000000000..fd10e95d60 --- /dev/null +++ b/tests/in/workgroup-var-init.param.ron @@ -0,0 +1,22 @@ +( + spv: ( + version: (1, 1), + debug: true, + adjust_coordinate_space: false, + ), + msl: ( + lang_version: (2, 0), + per_stage_map: ( + cs: ( + resources: { + (group: 0, binding: 0): (buffer: Some(0), mutable: true), + }, + sizes_buffer: None, + ), + ), + inline_samplers: [], + spirv_cross_compatibility: false, + fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, + ), +) \ No newline at end of file diff --git a/tests/in/workgroup-var-init.wgsl b/tests/in/workgroup-var-init.wgsl new file mode 100644 index 0000000000..8df2bcf4a8 --- /dev/null +++ b/tests/in/workgroup-var-init.wgsl @@ -0,0 +1,15 @@ +struct WStruct { + arr: array, + atom: atomic, + atom_arr: array, 8>, 8>, +} + +var w_mem: WStruct; + +@group(0) @binding(0) +var output: array; + +@compute @workgroup_size(1) +fn main() { + output = w_mem.arr; +} \ No newline at end of file diff --git a/tests/out/glsl/access.assign_through_ptr.Compute.glsl b/tests/out/glsl/access.assign_through_ptr.Compute.glsl index aafc30b909..1306fe514d 100644 --- a/tests/out/glsl/access.assign_through_ptr.Compute.glsl +++ b/tests/out/glsl/access.assign_through_ptr.Compute.glsl @@ -37,6 +37,11 @@ void assign_through_ptr_fn(inout uint p) { } void main() { + if (gl_GlobalInvocationID == uvec3(0u)) { + val = 0u; + } + memoryBarrierShared(); + barrier(); assign_through_ptr_fn(val); return; } diff --git a/tests/out/glsl/globals.main.Compute.glsl b/tests/out/glsl/globals.main.Compute.glsl index e77585f0b9..b5b357105c 100644 --- a/tests/out/glsl/globals.main.Compute.glsl +++ b/tests/out/glsl/globals.main.Compute.glsl @@ -51,6 +51,12 @@ void test_msl_packed_vec3_() { } void main() { + if (gl_GlobalInvocationID == uvec3(0u)) { + wg = float[10](0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0); + at_1 = 0u; + } + memoryBarrierShared(); + barrier(); float Foo = 0.0; bool at = false; test_msl_packed_vec3_(); diff --git a/tests/out/glsl/workgroup-var-init.main.Compute.glsl b/tests/out/glsl/workgroup-var-init.main.Compute.glsl new file mode 100644 index 0000000000..da55fea958 --- /dev/null +++ b/tests/out/glsl/workgroup-var-init.main.Compute.glsl @@ -0,0 +1,28 @@ +#version 310 es + +precision highp float; +precision highp int; + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +struct WStruct { + uint arr[512]; + int atom; + int atom_arr[8][8]; +}; +shared WStruct w_mem; + +layout(std430) buffer type_1_block_0Compute { uint _group_0_binding_0_cs[512]; }; + + +void main() { + if (gl_GlobalInvocationID == uvec3(0u)) { + w_mem = WStruct(uint[512](0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u), 0, int[8][8](int[8](0, 0, 0, 0, 0, 0, 0, 0), int[8](0, 0, 0, 0, 0, 0, 0, 0), int[8](0, 0, 0, 0, 0, 0, 0, 0), int[8](0, 0, 0, 0, 0, 0, 0, 0), int[8](0, 0, 0, 0, 0, 0, 0, 0), int[8](0, 0, 0, 0, 0, 0, 0, 0), int[8](0, 0, 0, 0, 0, 0, 0, 0), int[8](0, 0, 0, 0, 0, 0, 0, 0))); + } + memoryBarrierShared(); + barrier(); + uint _e3[512] = w_mem.arr; + _group_0_binding_0_cs = _e3; + return; +} + diff --git a/tests/out/hlsl/access.hlsl b/tests/out/hlsl/access.hlsl index 68ba8ae28c..9f898b0e9a 100644 --- a/tests/out/hlsl/access.hlsl +++ b/tests/out/hlsl/access.hlsl @@ -326,8 +326,12 @@ void atomics() } [numthreads(1, 1, 1)] -void assign_through_ptr() +void assign_through_ptr(uint3 __global_invocation_id : SV_DispatchThreadID) { + if (all(__global_invocation_id == uint3(0u, 0u, 0u))) { + val = (uint)0; + } + GroupMemoryBarrierWithGroupSync(); assign_through_ptr_fn(val); return; } diff --git a/tests/out/hlsl/globals.hlsl b/tests/out/hlsl/globals.hlsl index d941f8e096..6047b8dca2 100644 --- a/tests/out/hlsl/globals.hlsl +++ b/tests/out/hlsl/globals.hlsl @@ -106,8 +106,13 @@ uint NagaBufferLength(ByteAddressBuffer buffer) } [numthreads(1, 1, 1)] -void main() +void main(uint3 __global_invocation_id : SV_DispatchThreadID) { + if (all(__global_invocation_id == uint3(0u, 0u, 0u))) { + wg = (float[10])0; + at_1 = (uint)0; + } + GroupMemoryBarrierWithGroupSync(); float Foo = (float)0; bool at = (bool)0; diff --git a/tests/out/hlsl/interface.hlsl b/tests/out/hlsl/interface.hlsl index 99d5ecccbc..34ef1ccb4a 100644 --- a/tests/out/hlsl/interface.hlsl +++ b/tests/out/hlsl/interface.hlsl @@ -75,8 +75,12 @@ FragmentOutput fragment(FragmentInput_fragment fragmentinput_fragment) } [numthreads(1, 1, 1)] -void compute(uint3 global_id : SV_DispatchThreadID, uint3 local_id : SV_GroupThreadID, uint local_index : SV_GroupIndex, uint3 wg_id : SV_GroupID, uint3 num_wgs : SV_GroupID) +void compute(uint3 global_id : SV_DispatchThreadID, uint3 local_id : SV_GroupThreadID, uint local_index : SV_GroupIndex, uint3 wg_id : SV_GroupID, uint3 num_wgs : SV_GroupID, uint3 __global_invocation_id : SV_DispatchThreadID) { + if (all(__global_invocation_id == uint3(0u, 0u, 0u))) { + output = (uint[1])0; + } + GroupMemoryBarrierWithGroupSync(); output[0] = ((((global_id.x + local_id.x) + local_index) + wg_id.x) + uint3(_NagaConstants.base_vertex, _NagaConstants.base_instance, _NagaConstants.other).x); return; } diff --git a/tests/out/hlsl/workgroup-var-init.hlsl b/tests/out/hlsl/workgroup-var-init.hlsl new file mode 100644 index 0000000000..db48d36c09 --- /dev/null +++ b/tests/out/hlsl/workgroup-var-init.hlsl @@ -0,0 +1,535 @@ + +struct WStruct { + uint arr[512]; + int atom; + int atom_arr[8][8]; +}; + +groupshared WStruct w_mem; +RWByteAddressBuffer output : register(u0); + +[numthreads(1, 1, 1)] +void main(uint3 __global_invocation_id : SV_DispatchThreadID) +{ + if (all(__global_invocation_id == uint3(0u, 0u, 0u))) { + w_mem = (WStruct)0; + } + GroupMemoryBarrierWithGroupSync(); + uint _expr3[512] = w_mem.arr; + { + uint _value2[512] = _expr3; + output.Store(0, asuint(_value2[0])); + output.Store(4, asuint(_value2[1])); + output.Store(8, asuint(_value2[2])); + output.Store(12, asuint(_value2[3])); + output.Store(16, asuint(_value2[4])); + output.Store(20, asuint(_value2[5])); + output.Store(24, asuint(_value2[6])); + output.Store(28, asuint(_value2[7])); + output.Store(32, asuint(_value2[8])); + output.Store(36, asuint(_value2[9])); + output.Store(40, asuint(_value2[10])); + output.Store(44, asuint(_value2[11])); + output.Store(48, asuint(_value2[12])); + output.Store(52, asuint(_value2[13])); + output.Store(56, asuint(_value2[14])); + output.Store(60, asuint(_value2[15])); + output.Store(64, asuint(_value2[16])); + output.Store(68, asuint(_value2[17])); + output.Store(72, asuint(_value2[18])); + output.Store(76, asuint(_value2[19])); + output.Store(80, asuint(_value2[20])); + output.Store(84, asuint(_value2[21])); + output.Store(88, asuint(_value2[22])); + output.Store(92, asuint(_value2[23])); + output.Store(96, asuint(_value2[24])); + output.Store(100, asuint(_value2[25])); + output.Store(104, asuint(_value2[26])); + output.Store(108, asuint(_value2[27])); + output.Store(112, asuint(_value2[28])); + output.Store(116, asuint(_value2[29])); + output.Store(120, asuint(_value2[30])); + output.Store(124, asuint(_value2[31])); + output.Store(128, asuint(_value2[32])); + output.Store(132, asuint(_value2[33])); + output.Store(136, asuint(_value2[34])); + output.Store(140, asuint(_value2[35])); + output.Store(144, asuint(_value2[36])); + output.Store(148, asuint(_value2[37])); + output.Store(152, asuint(_value2[38])); + output.Store(156, asuint(_value2[39])); + output.Store(160, asuint(_value2[40])); + output.Store(164, asuint(_value2[41])); + output.Store(168, asuint(_value2[42])); + output.Store(172, asuint(_value2[43])); + output.Store(176, asuint(_value2[44])); + output.Store(180, asuint(_value2[45])); + output.Store(184, asuint(_value2[46])); + output.Store(188, asuint(_value2[47])); + output.Store(192, asuint(_value2[48])); + output.Store(196, asuint(_value2[49])); + output.Store(200, asuint(_value2[50])); + output.Store(204, asuint(_value2[51])); + output.Store(208, asuint(_value2[52])); + output.Store(212, asuint(_value2[53])); + output.Store(216, asuint(_value2[54])); + output.Store(220, asuint(_value2[55])); + output.Store(224, asuint(_value2[56])); + output.Store(228, asuint(_value2[57])); + output.Store(232, asuint(_value2[58])); + output.Store(236, asuint(_value2[59])); + output.Store(240, asuint(_value2[60])); + output.Store(244, asuint(_value2[61])); + output.Store(248, asuint(_value2[62])); + output.Store(252, asuint(_value2[63])); + output.Store(256, asuint(_value2[64])); + output.Store(260, asuint(_value2[65])); + output.Store(264, asuint(_value2[66])); + output.Store(268, asuint(_value2[67])); + output.Store(272, asuint(_value2[68])); + output.Store(276, asuint(_value2[69])); + output.Store(280, asuint(_value2[70])); + output.Store(284, asuint(_value2[71])); + output.Store(288, asuint(_value2[72])); + output.Store(292, asuint(_value2[73])); + output.Store(296, asuint(_value2[74])); + output.Store(300, asuint(_value2[75])); + output.Store(304, asuint(_value2[76])); + output.Store(308, asuint(_value2[77])); + output.Store(312, asuint(_value2[78])); + output.Store(316, asuint(_value2[79])); + output.Store(320, asuint(_value2[80])); + output.Store(324, asuint(_value2[81])); + output.Store(328, asuint(_value2[82])); + output.Store(332, asuint(_value2[83])); + output.Store(336, asuint(_value2[84])); + output.Store(340, asuint(_value2[85])); + output.Store(344, asuint(_value2[86])); + output.Store(348, asuint(_value2[87])); + output.Store(352, asuint(_value2[88])); + output.Store(356, asuint(_value2[89])); + output.Store(360, asuint(_value2[90])); + output.Store(364, asuint(_value2[91])); + output.Store(368, asuint(_value2[92])); + output.Store(372, asuint(_value2[93])); + output.Store(376, asuint(_value2[94])); + output.Store(380, asuint(_value2[95])); + output.Store(384, asuint(_value2[96])); + output.Store(388, asuint(_value2[97])); + output.Store(392, asuint(_value2[98])); + output.Store(396, asuint(_value2[99])); + output.Store(400, asuint(_value2[100])); + output.Store(404, asuint(_value2[101])); + output.Store(408, asuint(_value2[102])); + output.Store(412, asuint(_value2[103])); + output.Store(416, asuint(_value2[104])); + output.Store(420, asuint(_value2[105])); + output.Store(424, asuint(_value2[106])); + output.Store(428, asuint(_value2[107])); + output.Store(432, asuint(_value2[108])); + output.Store(436, asuint(_value2[109])); + output.Store(440, asuint(_value2[110])); + output.Store(444, asuint(_value2[111])); + output.Store(448, asuint(_value2[112])); + output.Store(452, asuint(_value2[113])); + output.Store(456, asuint(_value2[114])); + output.Store(460, asuint(_value2[115])); + output.Store(464, asuint(_value2[116])); + output.Store(468, asuint(_value2[117])); + output.Store(472, asuint(_value2[118])); + output.Store(476, asuint(_value2[119])); + output.Store(480, asuint(_value2[120])); + output.Store(484, asuint(_value2[121])); + output.Store(488, asuint(_value2[122])); + output.Store(492, asuint(_value2[123])); + output.Store(496, asuint(_value2[124])); + output.Store(500, asuint(_value2[125])); + output.Store(504, asuint(_value2[126])); + output.Store(508, asuint(_value2[127])); + output.Store(512, asuint(_value2[128])); + output.Store(516, asuint(_value2[129])); + output.Store(520, asuint(_value2[130])); + output.Store(524, asuint(_value2[131])); + output.Store(528, asuint(_value2[132])); + output.Store(532, asuint(_value2[133])); + output.Store(536, asuint(_value2[134])); + output.Store(540, asuint(_value2[135])); + output.Store(544, asuint(_value2[136])); + output.Store(548, asuint(_value2[137])); + output.Store(552, asuint(_value2[138])); + output.Store(556, asuint(_value2[139])); + output.Store(560, asuint(_value2[140])); + output.Store(564, asuint(_value2[141])); + output.Store(568, asuint(_value2[142])); + output.Store(572, asuint(_value2[143])); + output.Store(576, asuint(_value2[144])); + output.Store(580, asuint(_value2[145])); + output.Store(584, asuint(_value2[146])); + output.Store(588, asuint(_value2[147])); + output.Store(592, asuint(_value2[148])); + output.Store(596, asuint(_value2[149])); + output.Store(600, asuint(_value2[150])); + output.Store(604, asuint(_value2[151])); + output.Store(608, asuint(_value2[152])); + output.Store(612, asuint(_value2[153])); + output.Store(616, asuint(_value2[154])); + output.Store(620, asuint(_value2[155])); + output.Store(624, asuint(_value2[156])); + output.Store(628, asuint(_value2[157])); + output.Store(632, asuint(_value2[158])); + output.Store(636, asuint(_value2[159])); + output.Store(640, asuint(_value2[160])); + output.Store(644, asuint(_value2[161])); + output.Store(648, asuint(_value2[162])); + output.Store(652, asuint(_value2[163])); + output.Store(656, asuint(_value2[164])); + output.Store(660, asuint(_value2[165])); + output.Store(664, asuint(_value2[166])); + output.Store(668, asuint(_value2[167])); + output.Store(672, asuint(_value2[168])); + output.Store(676, asuint(_value2[169])); + output.Store(680, asuint(_value2[170])); + output.Store(684, asuint(_value2[171])); + output.Store(688, asuint(_value2[172])); + output.Store(692, asuint(_value2[173])); + output.Store(696, asuint(_value2[174])); + output.Store(700, asuint(_value2[175])); + output.Store(704, asuint(_value2[176])); + output.Store(708, asuint(_value2[177])); + output.Store(712, asuint(_value2[178])); + output.Store(716, asuint(_value2[179])); + output.Store(720, asuint(_value2[180])); + output.Store(724, asuint(_value2[181])); + output.Store(728, asuint(_value2[182])); + output.Store(732, asuint(_value2[183])); + output.Store(736, asuint(_value2[184])); + output.Store(740, asuint(_value2[185])); + output.Store(744, asuint(_value2[186])); + output.Store(748, asuint(_value2[187])); + output.Store(752, asuint(_value2[188])); + output.Store(756, asuint(_value2[189])); + output.Store(760, asuint(_value2[190])); + output.Store(764, asuint(_value2[191])); + output.Store(768, asuint(_value2[192])); + output.Store(772, asuint(_value2[193])); + output.Store(776, asuint(_value2[194])); + output.Store(780, asuint(_value2[195])); + output.Store(784, asuint(_value2[196])); + output.Store(788, asuint(_value2[197])); + output.Store(792, asuint(_value2[198])); + output.Store(796, asuint(_value2[199])); + output.Store(800, asuint(_value2[200])); + output.Store(804, asuint(_value2[201])); + output.Store(808, asuint(_value2[202])); + output.Store(812, asuint(_value2[203])); + output.Store(816, asuint(_value2[204])); + output.Store(820, asuint(_value2[205])); + output.Store(824, asuint(_value2[206])); + output.Store(828, asuint(_value2[207])); + output.Store(832, asuint(_value2[208])); + output.Store(836, asuint(_value2[209])); + output.Store(840, asuint(_value2[210])); + output.Store(844, asuint(_value2[211])); + output.Store(848, asuint(_value2[212])); + output.Store(852, asuint(_value2[213])); + output.Store(856, asuint(_value2[214])); + output.Store(860, asuint(_value2[215])); + output.Store(864, asuint(_value2[216])); + output.Store(868, asuint(_value2[217])); + output.Store(872, asuint(_value2[218])); + output.Store(876, asuint(_value2[219])); + output.Store(880, asuint(_value2[220])); + output.Store(884, asuint(_value2[221])); + output.Store(888, asuint(_value2[222])); + output.Store(892, asuint(_value2[223])); + output.Store(896, asuint(_value2[224])); + output.Store(900, asuint(_value2[225])); + output.Store(904, asuint(_value2[226])); + output.Store(908, asuint(_value2[227])); + output.Store(912, asuint(_value2[228])); + output.Store(916, asuint(_value2[229])); + output.Store(920, asuint(_value2[230])); + output.Store(924, asuint(_value2[231])); + output.Store(928, asuint(_value2[232])); + output.Store(932, asuint(_value2[233])); + output.Store(936, asuint(_value2[234])); + output.Store(940, asuint(_value2[235])); + output.Store(944, asuint(_value2[236])); + output.Store(948, asuint(_value2[237])); + output.Store(952, asuint(_value2[238])); + output.Store(956, asuint(_value2[239])); + output.Store(960, asuint(_value2[240])); + output.Store(964, asuint(_value2[241])); + output.Store(968, asuint(_value2[242])); + output.Store(972, asuint(_value2[243])); + output.Store(976, asuint(_value2[244])); + output.Store(980, asuint(_value2[245])); + output.Store(984, asuint(_value2[246])); + output.Store(988, asuint(_value2[247])); + output.Store(992, asuint(_value2[248])); + output.Store(996, asuint(_value2[249])); + output.Store(1000, asuint(_value2[250])); + output.Store(1004, asuint(_value2[251])); + output.Store(1008, asuint(_value2[252])); + output.Store(1012, asuint(_value2[253])); + output.Store(1016, asuint(_value2[254])); + output.Store(1020, asuint(_value2[255])); + output.Store(1024, asuint(_value2[256])); + output.Store(1028, asuint(_value2[257])); + output.Store(1032, asuint(_value2[258])); + output.Store(1036, asuint(_value2[259])); + output.Store(1040, asuint(_value2[260])); + output.Store(1044, asuint(_value2[261])); + output.Store(1048, asuint(_value2[262])); + output.Store(1052, asuint(_value2[263])); + output.Store(1056, asuint(_value2[264])); + output.Store(1060, asuint(_value2[265])); + output.Store(1064, asuint(_value2[266])); + output.Store(1068, asuint(_value2[267])); + output.Store(1072, asuint(_value2[268])); + output.Store(1076, asuint(_value2[269])); + output.Store(1080, asuint(_value2[270])); + output.Store(1084, asuint(_value2[271])); + output.Store(1088, asuint(_value2[272])); + output.Store(1092, asuint(_value2[273])); + output.Store(1096, asuint(_value2[274])); + output.Store(1100, asuint(_value2[275])); + output.Store(1104, asuint(_value2[276])); + output.Store(1108, asuint(_value2[277])); + output.Store(1112, asuint(_value2[278])); + output.Store(1116, asuint(_value2[279])); + output.Store(1120, asuint(_value2[280])); + output.Store(1124, asuint(_value2[281])); + output.Store(1128, asuint(_value2[282])); + output.Store(1132, asuint(_value2[283])); + output.Store(1136, asuint(_value2[284])); + output.Store(1140, asuint(_value2[285])); + output.Store(1144, asuint(_value2[286])); + output.Store(1148, asuint(_value2[287])); + output.Store(1152, asuint(_value2[288])); + output.Store(1156, asuint(_value2[289])); + output.Store(1160, asuint(_value2[290])); + output.Store(1164, asuint(_value2[291])); + output.Store(1168, asuint(_value2[292])); + output.Store(1172, asuint(_value2[293])); + output.Store(1176, asuint(_value2[294])); + output.Store(1180, asuint(_value2[295])); + output.Store(1184, asuint(_value2[296])); + output.Store(1188, asuint(_value2[297])); + output.Store(1192, asuint(_value2[298])); + output.Store(1196, asuint(_value2[299])); + output.Store(1200, asuint(_value2[300])); + output.Store(1204, asuint(_value2[301])); + output.Store(1208, asuint(_value2[302])); + output.Store(1212, asuint(_value2[303])); + output.Store(1216, asuint(_value2[304])); + output.Store(1220, asuint(_value2[305])); + output.Store(1224, asuint(_value2[306])); + output.Store(1228, asuint(_value2[307])); + output.Store(1232, asuint(_value2[308])); + output.Store(1236, asuint(_value2[309])); + output.Store(1240, asuint(_value2[310])); + output.Store(1244, asuint(_value2[311])); + output.Store(1248, asuint(_value2[312])); + output.Store(1252, asuint(_value2[313])); + output.Store(1256, asuint(_value2[314])); + output.Store(1260, asuint(_value2[315])); + output.Store(1264, asuint(_value2[316])); + output.Store(1268, asuint(_value2[317])); + output.Store(1272, asuint(_value2[318])); + output.Store(1276, asuint(_value2[319])); + output.Store(1280, asuint(_value2[320])); + output.Store(1284, asuint(_value2[321])); + output.Store(1288, asuint(_value2[322])); + output.Store(1292, asuint(_value2[323])); + output.Store(1296, asuint(_value2[324])); + output.Store(1300, asuint(_value2[325])); + output.Store(1304, asuint(_value2[326])); + output.Store(1308, asuint(_value2[327])); + output.Store(1312, asuint(_value2[328])); + output.Store(1316, asuint(_value2[329])); + output.Store(1320, asuint(_value2[330])); + output.Store(1324, asuint(_value2[331])); + output.Store(1328, asuint(_value2[332])); + output.Store(1332, asuint(_value2[333])); + output.Store(1336, asuint(_value2[334])); + output.Store(1340, asuint(_value2[335])); + output.Store(1344, asuint(_value2[336])); + output.Store(1348, asuint(_value2[337])); + output.Store(1352, asuint(_value2[338])); + output.Store(1356, asuint(_value2[339])); + output.Store(1360, asuint(_value2[340])); + output.Store(1364, asuint(_value2[341])); + output.Store(1368, asuint(_value2[342])); + output.Store(1372, asuint(_value2[343])); + output.Store(1376, asuint(_value2[344])); + output.Store(1380, asuint(_value2[345])); + output.Store(1384, asuint(_value2[346])); + output.Store(1388, asuint(_value2[347])); + output.Store(1392, asuint(_value2[348])); + output.Store(1396, asuint(_value2[349])); + output.Store(1400, asuint(_value2[350])); + output.Store(1404, asuint(_value2[351])); + output.Store(1408, asuint(_value2[352])); + output.Store(1412, asuint(_value2[353])); + output.Store(1416, asuint(_value2[354])); + output.Store(1420, asuint(_value2[355])); + output.Store(1424, asuint(_value2[356])); + output.Store(1428, asuint(_value2[357])); + output.Store(1432, asuint(_value2[358])); + output.Store(1436, asuint(_value2[359])); + output.Store(1440, asuint(_value2[360])); + output.Store(1444, asuint(_value2[361])); + output.Store(1448, asuint(_value2[362])); + output.Store(1452, asuint(_value2[363])); + output.Store(1456, asuint(_value2[364])); + output.Store(1460, asuint(_value2[365])); + output.Store(1464, asuint(_value2[366])); + output.Store(1468, asuint(_value2[367])); + output.Store(1472, asuint(_value2[368])); + output.Store(1476, asuint(_value2[369])); + output.Store(1480, asuint(_value2[370])); + output.Store(1484, asuint(_value2[371])); + output.Store(1488, asuint(_value2[372])); + output.Store(1492, asuint(_value2[373])); + output.Store(1496, asuint(_value2[374])); + output.Store(1500, asuint(_value2[375])); + output.Store(1504, asuint(_value2[376])); + output.Store(1508, asuint(_value2[377])); + output.Store(1512, asuint(_value2[378])); + output.Store(1516, asuint(_value2[379])); + output.Store(1520, asuint(_value2[380])); + output.Store(1524, asuint(_value2[381])); + output.Store(1528, asuint(_value2[382])); + output.Store(1532, asuint(_value2[383])); + output.Store(1536, asuint(_value2[384])); + output.Store(1540, asuint(_value2[385])); + output.Store(1544, asuint(_value2[386])); + output.Store(1548, asuint(_value2[387])); + output.Store(1552, asuint(_value2[388])); + output.Store(1556, asuint(_value2[389])); + output.Store(1560, asuint(_value2[390])); + output.Store(1564, asuint(_value2[391])); + output.Store(1568, asuint(_value2[392])); + output.Store(1572, asuint(_value2[393])); + output.Store(1576, asuint(_value2[394])); + output.Store(1580, asuint(_value2[395])); + output.Store(1584, asuint(_value2[396])); + output.Store(1588, asuint(_value2[397])); + output.Store(1592, asuint(_value2[398])); + output.Store(1596, asuint(_value2[399])); + output.Store(1600, asuint(_value2[400])); + output.Store(1604, asuint(_value2[401])); + output.Store(1608, asuint(_value2[402])); + output.Store(1612, asuint(_value2[403])); + output.Store(1616, asuint(_value2[404])); + output.Store(1620, asuint(_value2[405])); + output.Store(1624, asuint(_value2[406])); + output.Store(1628, asuint(_value2[407])); + output.Store(1632, asuint(_value2[408])); + output.Store(1636, asuint(_value2[409])); + output.Store(1640, asuint(_value2[410])); + output.Store(1644, asuint(_value2[411])); + output.Store(1648, asuint(_value2[412])); + output.Store(1652, asuint(_value2[413])); + output.Store(1656, asuint(_value2[414])); + output.Store(1660, asuint(_value2[415])); + output.Store(1664, asuint(_value2[416])); + output.Store(1668, asuint(_value2[417])); + output.Store(1672, asuint(_value2[418])); + output.Store(1676, asuint(_value2[419])); + output.Store(1680, asuint(_value2[420])); + output.Store(1684, asuint(_value2[421])); + output.Store(1688, asuint(_value2[422])); + output.Store(1692, asuint(_value2[423])); + output.Store(1696, asuint(_value2[424])); + output.Store(1700, asuint(_value2[425])); + output.Store(1704, asuint(_value2[426])); + output.Store(1708, asuint(_value2[427])); + output.Store(1712, asuint(_value2[428])); + output.Store(1716, asuint(_value2[429])); + output.Store(1720, asuint(_value2[430])); + output.Store(1724, asuint(_value2[431])); + output.Store(1728, asuint(_value2[432])); + output.Store(1732, asuint(_value2[433])); + output.Store(1736, asuint(_value2[434])); + output.Store(1740, asuint(_value2[435])); + output.Store(1744, asuint(_value2[436])); + output.Store(1748, asuint(_value2[437])); + output.Store(1752, asuint(_value2[438])); + output.Store(1756, asuint(_value2[439])); + output.Store(1760, asuint(_value2[440])); + output.Store(1764, asuint(_value2[441])); + output.Store(1768, asuint(_value2[442])); + output.Store(1772, asuint(_value2[443])); + output.Store(1776, asuint(_value2[444])); + output.Store(1780, asuint(_value2[445])); + output.Store(1784, asuint(_value2[446])); + output.Store(1788, asuint(_value2[447])); + output.Store(1792, asuint(_value2[448])); + output.Store(1796, asuint(_value2[449])); + output.Store(1800, asuint(_value2[450])); + output.Store(1804, asuint(_value2[451])); + output.Store(1808, asuint(_value2[452])); + output.Store(1812, asuint(_value2[453])); + output.Store(1816, asuint(_value2[454])); + output.Store(1820, asuint(_value2[455])); + output.Store(1824, asuint(_value2[456])); + output.Store(1828, asuint(_value2[457])); + output.Store(1832, asuint(_value2[458])); + output.Store(1836, asuint(_value2[459])); + output.Store(1840, asuint(_value2[460])); + output.Store(1844, asuint(_value2[461])); + output.Store(1848, asuint(_value2[462])); + output.Store(1852, asuint(_value2[463])); + output.Store(1856, asuint(_value2[464])); + output.Store(1860, asuint(_value2[465])); + output.Store(1864, asuint(_value2[466])); + output.Store(1868, asuint(_value2[467])); + output.Store(1872, asuint(_value2[468])); + output.Store(1876, asuint(_value2[469])); + output.Store(1880, asuint(_value2[470])); + output.Store(1884, asuint(_value2[471])); + output.Store(1888, asuint(_value2[472])); + output.Store(1892, asuint(_value2[473])); + output.Store(1896, asuint(_value2[474])); + output.Store(1900, asuint(_value2[475])); + output.Store(1904, asuint(_value2[476])); + output.Store(1908, asuint(_value2[477])); + output.Store(1912, asuint(_value2[478])); + output.Store(1916, asuint(_value2[479])); + output.Store(1920, asuint(_value2[480])); + output.Store(1924, asuint(_value2[481])); + output.Store(1928, asuint(_value2[482])); + output.Store(1932, asuint(_value2[483])); + output.Store(1936, asuint(_value2[484])); + output.Store(1940, asuint(_value2[485])); + output.Store(1944, asuint(_value2[486])); + output.Store(1948, asuint(_value2[487])); + output.Store(1952, asuint(_value2[488])); + output.Store(1956, asuint(_value2[489])); + output.Store(1960, asuint(_value2[490])); + output.Store(1964, asuint(_value2[491])); + output.Store(1968, asuint(_value2[492])); + output.Store(1972, asuint(_value2[493])); + output.Store(1976, asuint(_value2[494])); + output.Store(1980, asuint(_value2[495])); + output.Store(1984, asuint(_value2[496])); + output.Store(1988, asuint(_value2[497])); + output.Store(1992, asuint(_value2[498])); + output.Store(1996, asuint(_value2[499])); + output.Store(2000, asuint(_value2[500])); + output.Store(2004, asuint(_value2[501])); + output.Store(2008, asuint(_value2[502])); + output.Store(2012, asuint(_value2[503])); + output.Store(2016, asuint(_value2[504])); + output.Store(2020, asuint(_value2[505])); + output.Store(2024, asuint(_value2[506])); + output.Store(2028, asuint(_value2[507])); + output.Store(2032, asuint(_value2[508])); + output.Store(2036, asuint(_value2[509])); + output.Store(2040, asuint(_value2[510])); + output.Store(2044, asuint(_value2[511])); + } + return; +} diff --git a/tests/out/hlsl/workgroup-var-init.hlsl.config b/tests/out/hlsl/workgroup-var-init.hlsl.config new file mode 100644 index 0000000000..246c485cf7 --- /dev/null +++ b/tests/out/hlsl/workgroup-var-init.hlsl.config @@ -0,0 +1,3 @@ +vertex=() +fragment=() +compute=(main:cs_5_1 ) diff --git a/tests/out/msl/access.msl b/tests/out/msl/access.msl index 02f922fc8b..30b2bdbc61 100644 --- a/tests/out/msl/access.msl +++ b/tests/out/msl/access.msl @@ -240,8 +240,13 @@ kernel void atomics( kernel void assign_through_ptr( - threadgroup uint& val + metal::uint3 __global_invocation_id [[thread_position_in_grid]] +, threadgroup uint& val ) { + if (metal::all(__global_invocation_id == metal::uint3(0u))) { + val = {}; + } + metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); assign_through_ptr_fn(val); return; } diff --git a/tests/out/msl/globals.msl b/tests/out/msl/globals.msl index b821cc9d5f..62dc1c4acc 100644 --- a/tests/out/msl/globals.msl +++ b/tests/out/msl/globals.msl @@ -62,7 +62,8 @@ void test_msl_packed_vec3_( } kernel void main_( - threadgroup type_2& wg + metal::uint3 __global_invocation_id [[thread_position_in_grid]] +, threadgroup type_2& wg , threadgroup metal::atomic_uint& at_1 , device FooStruct& alignment [[user(fake0)]] , device type_6 const& dummy [[user(fake0)]] @@ -73,6 +74,11 @@ kernel void main_( , constant type_15& global_nested_arrays_of_matrices_4x2_ [[user(fake0)]] , constant _mslBufferSizes& _buffer_sizes [[user(fake0)]] ) { + if (metal::all(__global_invocation_id == metal::uint3(0u))) { + wg = {}; + metal::atomic_store_explicit(&at_1, 0, metal::memory_order_relaxed); + } + metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); float Foo = {}; bool at = {}; test_msl_packed_vec3_(alignment); diff --git a/tests/out/msl/interface.msl b/tests/out/msl/interface.msl index 7eea0a4f00..368c8218cb 100644 --- a/tests/out/msl/interface.msl +++ b/tests/out/msl/interface.msl @@ -76,6 +76,10 @@ kernel void compute_( , metal::uint3 num_wgs [[threadgroups_per_grid]] , threadgroup type_4& output ) { + if (metal::all(global_id == metal::uint3(0u))) { + output = {}; + } + metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); output.inner[0] = (((global_id.x + local_id.x) + local_index) + wg_id.x) + num_wgs.x; return; } diff --git a/tests/out/msl/workgroup-var-init.msl b/tests/out/msl/workgroup-var-init.msl new file mode 100644 index 0000000000..8f8bb14f12 --- /dev/null +++ b/tests/out/msl/workgroup-var-init.msl @@ -0,0 +1,40 @@ +// language: metal2.0 +#include +#include + +using metal::uint; + +struct type_1 { + uint inner[512]; +}; +struct type_3 { + metal::atomic_int inner[8]; +}; +struct type_4 { + type_3 inner[8]; +}; +struct WStruct { + type_1 arr; + metal::atomic_int atom; + type_4 atom_arr; +}; + +kernel void main_( + metal::uint3 __global_invocation_id [[thread_position_in_grid]] +, threadgroup WStruct& w_mem +, device type_1& output [[buffer(0)]] +) { + if (metal::all(__global_invocation_id == metal::uint3(0u))) { + w_mem.arr = {}; + metal::atomic_store_explicit(&w_mem.atom, 0, metal::memory_order_relaxed); + for (int __i0 = 0; __i0 < 8; __i0++) { + for (int __i1 = 0; __i1 < 8; __i1++) { + metal::atomic_store_explicit(&w_mem.atom_arr.inner[__i0].inner[__i1], 0, metal::memory_order_relaxed); + } + } + } + metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); + type_1 _e3 = w_mem.arr; + for(int _i=0; _i<512; ++_i) output.inner[_i] = _e3.inner[_i]; + return; +} diff --git a/tests/out/spv/access.spvasm b/tests/out/spv/access.spvasm index 49d1eb4a64..128e0158b0 100644 --- a/tests/out/spv/access.spvasm +++ b/tests/out/spv/access.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 325 +; Bound: 338 OpCapability Shader OpExtension "SPV_KHR_storage_buffer_storage_class" %1 = OpExtInstImport "GLSL.std.450" @@ -9,7 +9,7 @@ OpMemoryModel Logical GLSL450 OpEntryPoint Vertex %237 "foo_vert" %232 %235 OpEntryPoint Fragment %279 "foo_frag" %278 OpEntryPoint GLCompute %298 "atomics" -OpEntryPoint GLCompute %322 "assign_through_ptr" +OpEntryPoint GLCompute %322 "assign_through_ptr" %325 OpExecutionMode %279 OriginUpperLeft OpExecutionMode %298 LocalSize 1 1 1 OpExecutionMode %322 LocalSize 1 1 1 @@ -103,6 +103,7 @@ OpMemberDecorate %81 0 Offset 0 OpDecorate %232 BuiltIn VertexIndex OpDecorate %235 BuiltIn Position OpDecorate %278 Location 0 +OpDecorate %325 BuiltIn GlobalInvocationId %2 = OpTypeVoid %4 = OpTypeInt 32 0 %3 = OpConstant %4 0 @@ -228,6 +229,13 @@ OpDecorate %278 Location 0 %296 = OpConstantNull %6 %300 = OpTypePointer StorageBuffer %6 %303 = OpConstant %4 64 +%324 = OpConstantNull %4 +%326 = OpTypePointer Input %35 +%325 = OpVariable %326 Input +%328 = OpConstantNull %35 +%330 = OpTypeBool +%329 = OpTypeVector %330 3 +%335 = OpConstant %4 264 %91 = OpFunction %2 None %92 %90 = OpLabel %84 = OpVariable %85 Function %86 @@ -494,6 +502,18 @@ OpFunctionEnd %321 = OpLabel OpBranch %323 %323 = OpLabel -%324 = OpFunctionCall %2 %223 %83 +%327 = OpLoad %35 %325 +%331 = OpIEqual %329 %327 %328 +%332 = OpAll %330 %331 +OpSelectionMerge %333 None +OpBranchConditional %332 %334 %333 +%334 = OpLabel +OpStore %83 %324 +OpBranch %333 +%333 = OpLabel +OpControlBarrier %30 %30 %335 +OpBranch %336 +%336 = OpLabel +%337 = OpFunctionCall %2 %223 %83 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/spv/globals.spvasm b/tests/out/spv/globals.spvasm index ba39fa05c5..0a6cf9d726 100644 --- a/tests/out/spv/globals.spvasm +++ b/tests/out/spv/globals.spvasm @@ -1,12 +1,12 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 172 +; Bound: 186 OpCapability Shader OpExtension "SPV_KHR_storage_buffer_storage_class" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %111 "main" +OpEntryPoint GLCompute %111 "main" %129 OpExecutionMode %111 LocalSize 1 1 1 OpDecorate %25 ArrayStride 4 OpMemberDecorate %27 0 Offset 0 @@ -48,6 +48,7 @@ OpDecorate %64 DescriptorSet 0 OpDecorate %64 Binding 7 OpDecorate %65 Block OpMemberDecorate %65 0 Offset 0 +OpDecorate %129 BuiltIn GlobalInvocationId %2 = OpTypeVoid %4 = OpTypeBool %3 = OpConstantTrue %4 @@ -131,23 +132,31 @@ OpMemberDecorate %65 0 Offset 0 %119 = OpTypePointer Uniform %32 %121 = OpTypePointer Uniform %35 %123 = OpTypePointer Uniform %38 -%127 = OpTypePointer Workgroup %11 -%128 = OpTypePointer Uniform %37 -%129 = OpTypePointer Uniform %36 -%132 = OpTypePointer Uniform %34 -%133 = OpTypePointer Uniform %33 -%134 = OpTypePointer Uniform %30 -%139 = OpConstant %6 7 -%145 = OpConstant %6 6 -%147 = OpTypePointer StorageBuffer %28 -%148 = OpConstant %6 1 -%151 = OpConstant %6 5 -%153 = OpTypePointer Uniform %30 -%154 = OpTypePointer Uniform %11 -%155 = OpConstant %6 3 -%158 = OpConstant %6 4 -%160 = OpTypePointer StorageBuffer %11 -%171 = OpConstant %6 256 +%126 = OpConstantNull %25 +%127 = OpConstantNull %6 +%128 = OpTypeVector %6 3 +%130 = OpTypePointer Input %128 +%129 = OpVariable %130 Input +%132 = OpConstantNull %128 +%133 = OpTypeVector %4 3 +%138 = OpConstant %6 264 +%141 = OpTypePointer Workgroup %11 +%142 = OpTypePointer Uniform %37 +%143 = OpTypePointer Uniform %36 +%146 = OpTypePointer Uniform %34 +%147 = OpTypePointer Uniform %33 +%148 = OpTypePointer Uniform %30 +%153 = OpConstant %6 7 +%159 = OpConstant %6 6 +%161 = OpTypePointer StorageBuffer %28 +%162 = OpConstant %6 1 +%165 = OpConstant %6 5 +%167 = OpTypePointer Uniform %30 +%168 = OpTypePointer Uniform %11 +%169 = OpConstant %6 3 +%172 = OpConstant %6 4 +%174 = OpTypePointer StorageBuffer %11 +%185 = OpConstant %6 256 %69 = OpFunction %2 None %70 %68 = OpFunctionParameter %26 %67 = OpLabel @@ -201,44 +210,57 @@ OpFunctionEnd %124 = OpAccessChain %123 %64 %79 OpBranch %125 %125 = OpLabel -%126 = OpFunctionCall %2 %76 -%130 = OpAccessChain %129 %124 %79 %79 -%131 = OpLoad %36 %130 -%135 = OpAccessChain %134 %122 %79 %79 %79 -%136 = OpLoad %30 %135 -%137 = OpMatrixTimesVector %28 %131 %136 -%138 = OpCompositeExtract %11 %137 0 -%140 = OpAccessChain %127 %42 %139 -OpStore %140 %138 -%141 = OpLoad %32 %120 -%142 = OpLoad %26 %118 -%143 = OpMatrixTimesVector %28 %141 %142 -%144 = OpCompositeExtract %11 %143 0 -%146 = OpAccessChain %127 %42 %145 -OpStore %146 %144 -%149 = OpAccessChain %85 %114 %148 %148 -%150 = OpLoad %11 %149 -%152 = OpAccessChain %127 %42 %151 -OpStore %152 %150 -%156 = OpAccessChain %154 %116 %79 %155 -%157 = OpLoad %11 %156 -%159 = OpAccessChain %127 %42 %158 -OpStore %159 %157 -%161 = OpAccessChain %160 %112 %148 -%162 = OpLoad %11 %161 -%163 = OpAccessChain %127 %42 %155 -OpStore %163 %162 -%164 = OpAccessChain %85 %112 %79 %79 -%165 = OpLoad %11 %164 -%166 = OpAccessChain %127 %42 %23 -OpStore %166 %165 -%167 = OpAccessChain %160 %112 %148 -OpStore %167 %22 -%168 = OpArrayLength %6 %49 0 -%169 = OpConvertUToF %11 %168 -%170 = OpAccessChain %127 %42 %148 -OpStore %170 %169 -OpAtomicStore %44 %9 %171 %23 +%131 = OpLoad %128 %129 +%134 = OpIEqual %133 %131 %132 +%135 = OpAll %4 %134 +OpSelectionMerge %136 None +OpBranchConditional %135 %137 %136 +%137 = OpLabel +OpStore %42 %126 +OpStore %44 %127 +OpBranch %136 +%136 = OpLabel +OpControlBarrier %23 %23 %138 +OpBranch %139 +%139 = OpLabel +%140 = OpFunctionCall %2 %76 +%144 = OpAccessChain %143 %124 %79 %79 +%145 = OpLoad %36 %144 +%149 = OpAccessChain %148 %122 %79 %79 %79 +%150 = OpLoad %30 %149 +%151 = OpMatrixTimesVector %28 %145 %150 +%152 = OpCompositeExtract %11 %151 0 +%154 = OpAccessChain %141 %42 %153 +OpStore %154 %152 +%155 = OpLoad %32 %120 +%156 = OpLoad %26 %118 +%157 = OpMatrixTimesVector %28 %155 %156 +%158 = OpCompositeExtract %11 %157 0 +%160 = OpAccessChain %141 %42 %159 +OpStore %160 %158 +%163 = OpAccessChain %85 %114 %162 %162 +%164 = OpLoad %11 %163 +%166 = OpAccessChain %141 %42 %165 +OpStore %166 %164 +%170 = OpAccessChain %168 %116 %79 %169 +%171 = OpLoad %11 %170 +%173 = OpAccessChain %141 %42 %172 +OpStore %173 %171 +%175 = OpAccessChain %174 %112 %162 +%176 = OpLoad %11 %175 +%177 = OpAccessChain %141 %42 %169 +OpStore %177 %176 +%178 = OpAccessChain %85 %112 %79 %79 +%179 = OpLoad %11 %178 +%180 = OpAccessChain %141 %42 %23 +OpStore %180 %179 +%181 = OpAccessChain %174 %112 %162 +OpStore %181 %22 +%182 = OpArrayLength %6 %49 0 +%183 = OpConvertUToF %11 %182 +%184 = OpAccessChain %141 %42 %162 +OpStore %184 %183 +OpAtomicStore %44 %9 %185 %23 OpStore %104 %10 OpStore %107 %24 OpReturn diff --git a/tests/out/spv/interface.compute.spvasm b/tests/out/spv/interface.compute.spvasm index 6c1dac372d..f99174536e 100644 --- a/tests/out/spv/interface.compute.spvasm +++ b/tests/out/spv/interface.compute.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.0 ; Generator: rspirv -; Bound: 49 +; Bound: 58 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -48,8 +48,12 @@ OpDecorate %33 BuiltIn NumWorkgroups %31 = OpVariable %24 Input %33 = OpVariable %24 Input %36 = OpTypeFunction %2 -%38 = OpTypePointer Workgroup %6 -%47 = OpConstant %6 0 +%38 = OpConstantNull %16 +%39 = OpConstantNull %17 +%40 = OpTypeVector %15 3 +%45 = OpConstant %6 264 +%47 = OpTypePointer Workgroup %6 +%56 = OpConstant %6 0 %35 = OpFunction %2 None %36 %22 = OpLabel %25 = OpLoad %17 %23 @@ -59,15 +63,26 @@ OpDecorate %33 BuiltIn NumWorkgroups %34 = OpLoad %17 %33 OpBranch %37 %37 = OpLabel -%39 = OpCompositeExtract %6 %25 0 -%40 = OpCompositeExtract %6 %27 0 -%41 = OpIAdd %6 %39 %40 -%42 = OpIAdd %6 %41 %30 -%43 = OpCompositeExtract %6 %32 0 -%44 = OpIAdd %6 %42 %43 -%45 = OpCompositeExtract %6 %34 0 -%46 = OpIAdd %6 %44 %45 -%48 = OpAccessChain %38 %20 %47 -OpStore %48 %46 +%41 = OpIEqual %40 %25 %39 +%42 = OpAll %15 %41 +OpSelectionMerge %43 None +OpBranchConditional %42 %44 %43 +%44 = OpLabel +OpStore %20 %38 +OpBranch %43 +%43 = OpLabel +OpControlBarrier %11 %11 %45 +OpBranch %46 +%46 = OpLabel +%48 = OpCompositeExtract %6 %25 0 +%49 = OpCompositeExtract %6 %27 0 +%50 = OpIAdd %6 %48 %49 +%51 = OpIAdd %6 %50 %30 +%52 = OpCompositeExtract %6 %32 0 +%53 = OpIAdd %6 %51 %52 +%54 = OpCompositeExtract %6 %34 0 +%55 = OpIAdd %6 %53 %54 +%57 = OpAccessChain %47 %20 %56 +OpStore %57 %55 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/spv/workgroup-var-init.spvasm b/tests/out/spv/workgroup-var-init.spvasm new file mode 100644 index 0000000000..1cccc889dc --- /dev/null +++ b/tests/out/spv/workgroup-var-init.spvasm @@ -0,0 +1,78 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 41 +OpCapability Shader +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %17 "main" %25 +OpExecutionMode %17 LocalSize 1 1 1 +OpSource GLSL 450 +OpMemberName %10 0 "arr" +OpMemberName %10 1 "atom" +OpMemberName %10 2 "atom_arr" +OpName %10 "WStruct" +OpName %11 "w_mem" +OpName %13 "output" +OpName %17 "main" +OpDecorate %7 ArrayStride 4 +OpDecorate %8 ArrayStride 4 +OpDecorate %9 ArrayStride 32 +OpMemberDecorate %10 0 Offset 0 +OpMemberDecorate %10 1 Offset 2048 +OpMemberDecorate %10 2 Offset 2052 +OpDecorate %13 DescriptorSet 0 +OpDecorate %13 Binding 0 +OpDecorate %14 Block +OpMemberDecorate %14 0 Offset 0 +OpDecorate %25 BuiltIn GlobalInvocationId +%2 = OpTypeVoid +%4 = OpTypeInt 32 1 +%3 = OpConstant %4 512 +%5 = OpConstant %4 8 +%6 = OpTypeInt 32 0 +%7 = OpTypeArray %6 %3 +%8 = OpTypeArray %4 %5 +%9 = OpTypeArray %8 %5 +%10 = OpTypeStruct %7 %4 %9 +%12 = OpTypePointer Workgroup %10 +%11 = OpVariable %12 Workgroup +%14 = OpTypeStruct %7 +%15 = OpTypePointer StorageBuffer %14 +%13 = OpVariable %15 StorageBuffer +%18 = OpTypeFunction %2 +%19 = OpTypePointer StorageBuffer %7 +%20 = OpConstant %6 0 +%23 = OpConstantNull %10 +%24 = OpTypeVector %6 3 +%26 = OpTypePointer Input %24 +%25 = OpVariable %26 Input +%28 = OpConstantNull %24 +%30 = OpTypeBool +%29 = OpTypeVector %30 3 +%35 = OpConstant %6 2 +%36 = OpConstant %6 264 +%38 = OpTypePointer Workgroup %7 +%17 = OpFunction %2 None %18 +%16 = OpLabel +%21 = OpAccessChain %19 %13 %20 +OpBranch %22 +%22 = OpLabel +%27 = OpLoad %24 %25 +%31 = OpIEqual %29 %27 %28 +%32 = OpAll %30 %31 +OpSelectionMerge %33 None +OpBranchConditional %32 %34 %33 +%34 = OpLabel +OpStore %11 %23 +OpBranch %33 +%33 = OpLabel +OpControlBarrier %35 %35 %36 +OpBranch %37 +%37 = OpLabel +%39 = OpAccessChain %38 %11 %20 +%40 = OpLoad %7 %39 +OpStore %21 %40 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/workgroup-var-init.wgsl b/tests/out/wgsl/workgroup-var-init.wgsl new file mode 100644 index 0000000000..6197108cb6 --- /dev/null +++ b/tests/out/wgsl/workgroup-var-init.wgsl @@ -0,0 +1,16 @@ +struct WStruct { + arr: array, + atom: atomic, + atom_arr: array,8>,8>, +} + +var w_mem: WStruct; +@group(0) @binding(0) +var output: array; + +@compute @workgroup_size(1, 1, 1) +fn main() { + let _e3 = w_mem.arr; + output = _e3; + return; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 02e77a6cf2..00c263946e 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -227,6 +227,7 @@ fn write_output_spv( }, bounds_check_policies, binding_map: params.binding_map.clone(), + zero_initialize_workgroup_memory: spv::ZeroInitializeWorkgroupMemoryMode::Polyfill, }; if params.separate_entry_points { @@ -556,6 +557,10 @@ fn convert_wgsl() { ("lexical-scopes", Targets::WGSL), ("type-alias", Targets::WGSL), ("module-scope", Targets::WGSL), + ( + "workgroup-var-init", + Targets::WGSL | Targets::GLSL | Targets::SPIRV | Targets::HLSL | Targets::METAL, + ), ]; for &(name, targets) in inputs.iter() {