diff --git a/Cargo.toml b/Cargo.toml index 27d181a..ad1732f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,3 +29,4 @@ libc = "0.2" [dev-dependencies] rspirv = "0.11" +spirv-linker = "0.1" \ No newline at end of file diff --git a/examples/exports.hlsl b/examples/exports.hlsl new file mode 100644 index 0000000..0f78a95 --- /dev/null +++ b/examples/exports.hlsl @@ -0,0 +1,3 @@ +export int valueTwo() { + return 2; +} \ No newline at end of file diff --git a/examples/link.rs b/examples/link.rs new file mode 100644 index 0000000..81fef70 --- /dev/null +++ b/examples/link.rs @@ -0,0 +1,77 @@ +use hassle_rs::utils::DefaultIncludeHandler; +use hassle_rs::*; +use rspirv::binary::Disassemble; + +fn main() { + let dxc = Dxc::new(None).unwrap(); + let compiler = dxc.create_compiler().unwrap(); + let library = dxc.create_library().unwrap(); + let spirv = false; + + let args = &if spirv { + vec!["-spirv", "-fspv-target-env=vulkan1.1spirv1.4"] + } else { + vec![] + }; + + let exports = compiler.compile( + &library + .create_blob_with_encoding_from_str(include_str!("exports.hlsl")) + .unwrap(), + "exports.hlsl", + "", + "lib_6_6", + args, + Some(&mut DefaultIncludeHandler {}), + &[], + ); + let use_exports = compiler.compile( + &library + .create_blob_with_encoding_from_str(include_str!("use-export.hlsl")) + .unwrap(), + "use-exports.hlsl", + "", + "lib_6_6", + args, + Some(&mut DefaultIncludeHandler {}), + &[], + ); + + let exports = exports.ok().unwrap().get_result().unwrap(); + let use_exports = use_exports.ok().unwrap().get_result().unwrap(); + + if spirv { + let mut exports = rspirv::dr::load_bytes(exports).unwrap(); + let mut use_exports = rspirv::dr::load_bytes(use_exports).unwrap(); + let linked = + spirv_linker::link(&mut [&mut exports, &mut use_exports], &Default::default()).unwrap(); + println!("{}", exports.disassemble()); + println!("{}", use_exports.disassemble()); + println!("{}", linked.disassemble()); + } else { + let linker = dxc.create_linker().unwrap(); + + linker.register_library("exports", &exports).unwrap(); + linker.register_library("useExports", &use_exports).unwrap(); + + let binary = linker.link("copyCs", "cs_6_6", &["exports", "useExports"], &[]); + match binary { + Ok(spirv) => { + let spirv = spirv.get_result().unwrap().to_vec::(); + + println!("Outputting `linked.dxil`"); + println!("run `dxc -dumpbin linked.dxil` to disassemble"); + let _ = std::fs::write("./linked.dxil", spirv); + } + // Could very well happen that one needs to recompile or download a dxcompiler.dll + Err(result) => { + let error_blob = result.0.get_error_buffer().unwrap(); + + panic!( + "Failed to link to SPIR-V: {}", + library.get_blob_as_string(&error_blob.into()).unwrap() + ); + } + } + } +} diff --git a/examples/use-export.hlsl b/examples/use-export.hlsl new file mode 100644 index 0000000..3800894 --- /dev/null +++ b/examples/use-export.hlsl @@ -0,0 +1,10 @@ +RWTexture2D g_output : register(u0, space0); + +int valueTwo(); + +[numthreads(8, 8, 1)] +[shader("compute")] +void copyCs(uint3 dispatchThreadId : SV_DispatchThreadID) +{ + g_output[dispatchThreadId.xy] = valueTwo(); +} \ No newline at end of file diff --git a/src/utils.rs b/src/utils.rs index b86caab..9a1d0eb 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -35,7 +35,7 @@ pub(crate) fn from_lpstr(string: LPCSTR) -> String { .to_owned() } -struct DefaultIncludeHandler {} +pub struct DefaultIncludeHandler {} impl DxcIncludeHandler for DefaultIncludeHandler { fn load_source(&mut self, filename: String) -> Option { diff --git a/src/wrapper.rs b/src/wrapper.rs index 7ff8fee..672a99d 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -218,6 +218,89 @@ impl DxcIncludeHandlerWrapper { } } +pub struct DxcLinker { + inner: IDxcLinker, + library: DxcLibrary, +} + +impl DxcLinker { + fn new(inner: IDxcLinker, library: DxcLibrary) -> Self { + Self { inner, library } + } + + pub fn register_library(&self, lib_name: &str, lib: &DxcBlob) -> Result<(), HRESULT> { + let lib_name = to_wide(lib_name); + + let hr = unsafe { + self.inner + .register_library(lib_name.as_ptr(), lib.inner.clone()) + }; + + if hr.is_err() { + Err(hr) + } else { + Ok(()) + } + } + + pub fn link( + &self, + entry_point: &str, + target_profile: &str, + lib_names: &[&str], + arguments: &[&str], + ) -> Result { + let entry_point = to_wide(entry_point); + let target_profile = to_wide(target_profile); + + let mut lib_names_wide = vec![]; + let mut lib_names_ptr = vec![]; + + for lib_name in lib_names { + lib_names_wide.push(to_wide(lib_name)); + } + + for wide in &lib_names_wide { + lib_names_ptr.push(wide.as_ptr()) + } + + let mut arguments_wide = vec![]; + let mut arguments_ptr = vec![]; + for arg in arguments { + arguments_wide.push(to_wide(arg)); + } + + for wide in &arguments_wide { + arguments_ptr.push(wide.as_ptr()); + } + + let mut result = None; + + let hr = unsafe { + self.inner.link( + entry_point.as_ptr(), + target_profile.as_ptr(), + lib_names_ptr.as_ptr(), + lib_names_ptr.len() as u32, + arguments_ptr.as_ptr(), + arguments_ptr.len() as u32, + &mut result, + ) + }; + + let result = result.unwrap(); + + let mut linker_error = 0u32; + let status_hr = unsafe { result.get_status(&mut linker_error) }; + + if !hr.is_err() && !status_hr.is_err() && linker_error == 0 { + Ok(DxcOperationResult::new(result)) + } else { + Err((DxcOperationResult::new(result), hr)) + } + } +} + pub struct DxcCompiler { inner: IDxcCompiler2, library: DxcLibrary, @@ -532,6 +615,17 @@ impl Dxc { Ok(unsafe { self.dxc_lib.get(b"DxcCreateInstance\0")? }) } + pub fn create_linker(&self) -> Result { + let mut linker = None; + + self.get_dxc_create_instance()?(&CLSID_DxcLinker, &IDxcLinker::IID, &mut linker) + .result()?; + Ok(DxcLinker::new( + linker.unwrap(), + self.create_library().unwrap(), + )) + } + pub fn create_compiler(&self) -> Result { let mut compiler = None;