Skip to content

Commit

Permalink
Added input and output variables and matrix stride and row_majorness
Browse files Browse the repository at this point in the history
  • Loading branch information
Rob2309 committed Jan 26, 2022
1 parent 75cdb08 commit ce8e3b3
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 25 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "spirv-layout"
version = "1.0.0"
version = "0.3.0"
authors = [ "Robin Quint" ]
edition = "2021"
description = "SPIRV reflection utility for deriving Vulkan DescriptorSetLayouts"
Expand Down
53 changes: 43 additions & 10 deletions examples/reflect-shader/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use core::slice;

use spirv_layout::{Module, PushConstantVariable, Type, Variable};
use spirv_layout::{LocationVariable, Module, PushConstantVariable, Type, UniformVariable};

const PATH: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
Expand All @@ -13,9 +13,19 @@ fn main() {
let words = unsafe { slice::from_raw_parts(bytes.as_ptr() as *const u32, bytes.len() / 4) };
let module = Module::from_words(words).unwrap();

println!("=== INPUTS ===");
for var in module.get_inputs() {
print_location_var(&module, var);
}

println!("=== OUTPUTS ===");
for var in module.get_outputs() {
print_location_var(&module, var);
}

println!("=== UNIFORMS ===");
for var in module.get_uniforms() {
print_var(&module, var);
print_uniform_var(&module, var);
}

println!("=== PUSH CONSTANTS ===");
Expand All @@ -25,12 +35,8 @@ fn main() {
}
}

fn print_var(module: &Module, var: &Variable) {
if let Some(set) = var.set {
if let Some(binding) = var.binding {
print!("layout (set={}, binding={}) ", set, binding);
}
}
fn print_uniform_var(module: &Module, var: &UniformVariable) {
print!("layout (set={}, binding={}) ", var.set, var.binding);

print_type(module, module.get_type(var.type_id).unwrap());

Expand All @@ -57,6 +63,21 @@ fn print_pc_var(module: &Module, var: &PushConstantVariable) {
);
}

fn print_location_var(module: &Module, var: &LocationVariable) {
print!("layout (location={}) ", var.location);

print_type(module, module.get_type(var.type_id).unwrap());

println!(
"{};",
if let Some(name) = &var.name {
name
} else {
"<no-name>"
}
);
}

fn print_type(module: &Module, ty: &Type) {
match ty {
Type::Void => print!("void "),
Expand Down Expand Up @@ -86,10 +107,22 @@ fn print_type(module: &Module, ty: &Type) {
);

for elem in elements {
print!(" ");
print!(" layout(");
if let Some(offset) = elem.offset {
print!("layout(offset={}) ", offset);
print!("offset={}", offset);
}
if let Some(Type::Mat3 | Type::Mat4) = module.get_type(elem.type_id) {
print!(
", {}, stride={}",
if elem.row_major {
"row_major"
} else {
"col_major"
},
elem.stride
);
}
print!(") ");

print_type(module, module.get_type(elem.type_id).unwrap());

Expand Down
128 changes: 114 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@ pub struct Module {
/// Stores information about all type declarations that exist in the given SPIRV module.
types: HashMap<u32, Type>,
/// Stores information about all uniform variables that exist in the given SPIRV module.
uniforms: Vec<Variable>,
uniforms: Vec<UniformVariable>,
/// Stores information about all push constant variables that exist in the given SPIRV module.
push_constants: Vec<PushConstantVariable>,
/// Stores information about all input variables that exist in the given SPIRV module.
inputs: Vec<LocationVariable>,
/// Stores information about all output variables that exist in the given SPIRV module.
outputs: Vec<LocationVariable>,
}

impl Module {
Expand Down Expand Up @@ -84,9 +88,9 @@ impl Module {
pointed_type_id,
}) = types.get(&var.type_id)
{
Some(Variable {
set: var.set,
binding: var.binding,
Some(UniformVariable {
set: var.set?,
binding: var.binding?,
type_id: *pointed_type_id, // for convenience, we store the pointed-to type instead of the pointer, since every uniform is a pointer
name: var.name.clone(),
})
Expand Down Expand Up @@ -114,10 +118,50 @@ impl Module {
})
.collect();

let inputs = vars
.iter()
.filter_map(|(_id, var)| {
if let Some(Type::Pointer {
storage_class: StorageClass::Input,
pointed_type_id,
}) = types.get(&var.type_id)
{
Some(LocationVariable {
location: var.location?,
type_id: *pointed_type_id,
name: var.name.clone(),
})
} else {
None
}
})
.collect();

let outputs = vars
.iter()
.filter_map(|(_id, var)| {
if let Some(Type::Pointer {
storage_class: StorageClass::Output,
pointed_type_id,
}) = types.get(&var.type_id)
{
Some(LocationVariable {
location: var.location?,
type_id: *pointed_type_id,
name: var.name.clone(),
})
} else {
None
}
})
.collect();

Ok(Self {
types,
uniforms,
push_constants,
inputs,
outputs,
})
}

Expand All @@ -127,7 +171,7 @@ impl Module {
}

/// Returns all uniform variables declared in the given SPIR-V module.
pub fn get_uniforms(&self) -> &[Variable] {
pub fn get_uniforms(&self) -> &[UniformVariable] {
&self.uniforms
}

Expand All @@ -136,6 +180,16 @@ impl Module {
&self.push_constants
}

/// Returns all input variables declared in the given SPIR-V module.
pub fn get_inputs(&self) -> &[LocationVariable] {
&self.inputs
}

/// Returns all output variables declared in the given SPIR-V module.
pub fn get_outputs(&self) -> &[LocationVariable] {
&self.outputs
}

/// Calculates the size of a primitive type or Struct.
///
/// # Returns
Expand Down Expand Up @@ -170,7 +224,7 @@ impl Module {
fn collect_decorations_and_names(
ops: &[Op],
types: &mut HashMap<u32, Type>,
vars: &mut HashMap<u32, Variable>,
vars: &mut HashMap<u32, RawVariable>,
) {
for op in ops {
match op {
Expand Down Expand Up @@ -203,16 +257,35 @@ impl Module {
target.set = Some(*set);
}
}
ops::Decoration::Location { loc } => {
if let Some(target) = vars.get_mut(&target.0) {
target.location = Some(*loc);
}
}
_ => {}
},
Op::OpMemberDecorate {
target,
member_index,
decoration: ops::Decoration::Offset { offset },
decoration,
} => {
if let Some(Type::Struct { elements, .. }) = types.get_mut(&target.0) {
if elements.len() > *member_index as usize {
elements[*member_index as usize].offset = Some(*offset);
match decoration {
ops::Decoration::RowMajor {} => {
elements[*member_index as usize].row_major = true;
}
ops::Decoration::ColMajor {} => {
elements[*member_index as usize].row_major = false;
}
ops::Decoration::MatrixStride { stride } => {
elements[*member_index as usize].stride = *stride;
}
ops::Decoration::Offset { offset } => {
elements[*member_index as usize].offset = Some(*offset);
}
_ => {}
}
}
}
}
Expand All @@ -226,7 +299,7 @@ impl Module {
ops: &[Op],
types: &mut HashMap<u32, Type>,
constants: &mut HashMap<u32, u32>,
vars: &mut HashMap<u32, Variable>,
vars: &mut HashMap<u32, RawVariable>,
) -> SpirvResult<()> {
for op in ops {
match op {
Expand Down Expand Up @@ -381,6 +454,8 @@ impl Module {
name: None,
type_id: e.0,
offset: None,
row_major: true,
stride: 16,
})
.collect(),
},
Expand All @@ -399,6 +474,8 @@ impl Module {
ops::StorageClass::UniformConstant {}
| ops::StorageClass::Uniform {} => StorageClass::Uniform,
ops::StorageClass::PushConstant {} => StorageClass::PushConstant,
ops::StorageClass::Input {} => StorageClass::Input,
ops::StorageClass::Output {} => StorageClass::Output,
},
pointed_type_id: pointed_type.0,
},
Expand All @@ -423,9 +500,10 @@ impl Module {
} => {
vars.insert(
result.0,
Variable {
RawVariable {
set: None,
binding: None,
location: None,
type_id: result_type.0,
name: None,
},
Expand Down Expand Up @@ -511,6 +589,8 @@ pub struct StructMember {
pub name: Option<String>,
pub type_id: u32,
pub offset: Option<u32>,
pub row_major: bool,
pub stride: u32,
}

/// Describes what type of storage a pointer points to
Expand All @@ -524,15 +604,28 @@ pub enum StorageClass {
UniformConstant,
/// The pointer is a push constant
PushConstant,
/// The pointer is an input variable
Input,
/// The pointer is an output variable
Output,
}

#[derive(Debug, Clone)]
struct RawVariable {
set: Option<u32>,
binding: Option<u32>,
location: Option<u32>,
type_id: u32,
name: Option<String>,
}

/// Describes a variable declared in a SPIRV module
/// Describes a uniform variable declared in a SPIRV module
#[derive(Debug, Clone)]
pub struct Variable {
pub struct UniformVariable {
/// Which DescriptorSet the variable is contained in (if known)
pub set: Option<u32>,
pub set: u32,
/// Which DescriptorSet binding the variable is contained in (if known)
pub binding: Option<u32>,
pub binding: u32,
/// The type id of the variable's [`Type`]
pub type_id: u32,
/// The variables name (if known)
Expand All @@ -544,3 +637,10 @@ pub struct PushConstantVariable {
pub type_id: u32,
pub name: Option<String>,
}

#[derive(Debug, Clone)]
pub struct LocationVariable {
pub location: u32,
pub type_id: u32,
pub name: Option<String>,
}
4 changes: 4 additions & 0 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ enums!(
Decoration {
4 = RowMajor(),
5 = ColMajor(),
7 = MatrixStride(stride: u32),
30 = Location(loc: u32),
33 = Binding(binding: u32),
34 = DescriptorSet(set: u32),
35 = Offset(offset: u32),
Expand All @@ -230,7 +232,9 @@ enums!(

StorageClass {
0 = UniformConstant(),
1 = Input(),
2 = Uniform(),
3 = Output(),
9 = PushConstant(),
},
);

0 comments on commit ce8e3b3

Please sign in to comment.