Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔗 Expose the Dxc linker #71

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ libc = "0.2"

[dev-dependencies]
rspirv = "0.11"
spirv-linker = "0.1"
3 changes: 3 additions & 0 deletions examples/exports.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export int valueTwo() {
return 2;
}
77 changes: 77 additions & 0 deletions examples/link.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use hassle_rs::utils::DefaultIncludeHandler;
use hassle_rs::*;
use rspirv::binary::Disassemble;

fn main() {
let dxc = Dxc::new(None).unwrap();
let compiler = dxc.create_compiler().unwrap();
let library = dxc.create_library().unwrap();
let spirv = false;

let args = &if spirv {
vec!["-spirv", "-fspv-target-env=vulkan1.1spirv1.4"]
} else {
vec![]
};

let exports = compiler.compile(
&library
.create_blob_with_encoding_from_str(include_str!("exports.hlsl"))
.unwrap(),
"exports.hlsl",
"",
"lib_6_6",
args,
Some(&mut DefaultIncludeHandler {}),
&[],
);
let use_exports = compiler.compile(
&library
.create_blob_with_encoding_from_str(include_str!("use-export.hlsl"))
.unwrap(),
"use-exports.hlsl",
"",
"lib_6_6",
args,
Some(&mut DefaultIncludeHandler {}),
&[],
);

let exports = exports.ok().unwrap().get_result().unwrap();
let use_exports = use_exports.ok().unwrap().get_result().unwrap();
Comment on lines +40 to +41
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh it really is time for me to PR that error handling simplification, this Result<DxcOperationResult, (DxcOperationResult, HRESULT)> is so useless 😬

I think you even need to call get_status() etc :(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but I didn't feel like cleaning it up in this PR. There's a bit of an ambiguity there too because there are multiple possible failure operations making this quite complex. (e.g. link and compile can fail without an actual compiler or linker error)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah definitely not. I was more so thinking of the other way around: any of these might return an error code (including .compile() or .link()), but DxcOperationResult might still contain an error blob so that there's more context why the HRESULT was false. It just doesn't make sense that there's also get_status() giving you two more HRESULTs on top of the function return.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can probably replace all of this with a thiserror type enum that just represents all of these error cases, and potentially just unwraps the DxcOperationResult into this new enum.


if spirv {
let mut exports = rspirv::dr::load_bytes(exports).unwrap();
let mut use_exports = rspirv::dr::load_bytes(use_exports).unwrap();
let linked =
spirv_linker::link(&mut [&mut exports, &mut use_exports], &Default::default()).unwrap();
println!("{}", exports.disassemble());
println!("{}", use_exports.disassemble());
println!("{}", linked.disassemble());
} else {
let linker = dxc.create_linker().unwrap();

linker.register_library("exports", &exports).unwrap();
linker.register_library("useExports", &use_exports).unwrap();

let binary = linker.link("copyCs", "cs_6_6", &["exports", "useExports"], &[]);
match binary {
Ok(spirv) => {
let spirv = spirv.get_result().unwrap().to_vec::<u8>();

println!("Outputting `linked.dxil`");
println!("run `dxc -dumpbin linked.dxil` to disassemble");
let _ = std::fs::write("./linked.dxil", spirv);
}
// Could very well happen that one needs to recompile or download a dxcompiler.dll
Err(result) => {
let error_blob = result.0.get_error_buffer().unwrap();

panic!(
"Failed to link to SPIR-V: {}",
library.get_blob_as_string(&error_blob.into()).unwrap()
);
}
}
}
}
10 changes: 10 additions & 0 deletions examples/use-export.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
RWTexture2D<float4> g_output : register(u0, space0);

int valueTwo();

[numthreads(8, 8, 1)]
[shader("compute")]
void copyCs(uint3 dispatchThreadId : SV_DispatchThreadID)
{
g_output[dispatchThreadId.xy] = valueTwo();
}
2 changes: 1 addition & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub(crate) fn from_lpstr(string: LPCSTR) -> String {
.to_owned()
}

struct DefaultIncludeHandler {}
pub struct DefaultIncludeHandler {}

impl DxcIncludeHandler for DefaultIncludeHandler {
fn load_source(&mut self, filename: String) -> Option<String> {
Expand Down
94 changes: 94 additions & 0 deletions src/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,89 @@ impl DxcIncludeHandlerWrapper {
}
}

pub struct DxcLinker {
inner: IDxcLinker,
library: DxcLibrary,
}

impl DxcLinker {
fn new(inner: IDxcLinker, library: DxcLibrary) -> Self {
Self { inner, library }
}

pub fn register_library(&self, lib_name: &str, lib: &DxcBlob) -> Result<(), HRESULT> {
let lib_name = to_wide(lib_name);

let hr = unsafe {
self.inner
.register_library(lib_name.as_ptr(), lib.inner.clone())
};

if hr.is_err() {
Err(hr)
} else {
Ok(())
}
}

pub fn link(
&self,
entry_point: &str,
target_profile: &str,
lib_names: &[&str],
arguments: &[&str],
) -> Result<DxcOperationResult, (DxcOperationResult, HRESULT)> {
let entry_point = to_wide(entry_point);
let target_profile = to_wide(target_profile);

let mut lib_names_wide = vec![];
let mut lib_names_ptr = vec![];

for lib_name in lib_names {
lib_names_wide.push(to_wide(lib_name));
}

for wide in &lib_names_wide {
lib_names_ptr.push(wide.as_ptr())
}

let mut arguments_wide = vec![];
let mut arguments_ptr = vec![];
for arg in arguments {
arguments_wide.push(to_wide(arg));
}

for wide in &arguments_wide {
arguments_ptr.push(wide.as_ptr());
}

let mut result = None;

let hr = unsafe {
self.inner.link(
entry_point.as_ptr(),
target_profile.as_ptr(),
lib_names_ptr.as_ptr(),
lib_names_ptr.len() as u32,
arguments_ptr.as_ptr(),
arguments_ptr.len() as u32,
&mut result,
)
};

let result = result.unwrap();

let mut linker_error = 0u32;
let status_hr = unsafe { result.get_status(&mut linker_error) };

if !hr.is_err() && !status_hr.is_err() && linker_error == 0 {
Comment on lines +293 to +296
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like, how many errors do they want us to check, and at which point will DxcOperationResult::get_error_buffer() return sensible help text?

Ok(DxcOperationResult::new(result))
} else {
Err((DxcOperationResult::new(result), hr))
}
}
}

pub struct DxcCompiler {
inner: IDxcCompiler2,
library: DxcLibrary,
Expand Down Expand Up @@ -532,6 +615,17 @@ impl Dxc {
Ok(unsafe { self.dxc_lib.get(b"DxcCreateInstance\0")? })
}

pub fn create_linker(&self) -> Result<DxcLinker> {
let mut linker = None;

self.get_dxc_create_instance()?(&CLSID_DxcLinker, &IDxcLinker::IID, &mut linker)
.result()?;
Ok(DxcLinker::new(
linker.unwrap(),
self.create_library().unwrap(),
))
}

pub fn create_compiler(&self) -> Result<DxcCompiler> {
let mut compiler = None;

Expand Down