From 0c6bcc6563e913b74eb07f08d553057ee21df011 Mon Sep 17 00:00:00 2001 From: James Sturtevant Date: Tue, 17 Dec 2024 17:01:27 -0800 Subject: [PATCH] Improve WitException by generating custom exception Types Signed-off-by: James Sturtevant --- crates/csharp/src/lib.rs | 61 ++++++++++++++++++++++++++++++++--- tests/runtime/results/wasm.cs | 47 ++++++++------------------- 2 files changed, 70 insertions(+), 38 deletions(-) diff --git a/crates/csharp/src/lib.rs b/crates/csharp/src/lib.rs index bae63c208..3065dfa8e 100644 --- a/crates/csharp/src/lib.rs +++ b/crates/csharp/src/lib.rs @@ -175,6 +175,7 @@ impl CSharp { direction, usings: HashSet::::new(), interop_usings: HashSet::::new(), + custom_exception_types: HashSet::new(), } } @@ -253,6 +254,8 @@ impl WorldGenerator for CSharp { } } + gen.add_custom_exception_types(resolve); + // for anonymous types gen.define_interface_types(id); @@ -916,6 +919,7 @@ struct InterfaceGenerator<'a> { direction: Direction, usings: HashSet, interop_usings: HashSet, + custom_exception_types: HashSet, } impl InterfaceGenerator<'_> { @@ -1377,6 +1381,46 @@ impl InterfaceGenerator<'_> { } } + fn use_custom_exception(&mut self, ty: Type) -> Option { + match ty { + Type::Id(id) => { + let inner_ty = &self.resolve.types[id]; + match &inner_ty.kind { + TypeDefKind::Enum(_) | TypeDefKind::Record(_) | TypeDefKind::Variant(_) => { + self.custom_exception_types.insert(id.clone()); + Some(ty) + } + _ => None, + } + } + _ => None, + } + } + + fn add_custom_exception_types(&mut self, resolve: &Resolve) { + let access = self.gen.access_modifier(); + + for custom_type in self.custom_exception_types.iter() { + let ty = &resolve.types[*custom_type]; + let name = ty.name.as_ref().unwrap().to_upper_camel_case(); + match ty.kind { + TypeDefKind::Enum(_) | TypeDefKind::Record(_) | TypeDefKind::Variant(_) => { + uwrite!( + self.src, + " + {access} class {name}Exception : WitException {{ + {access} {name} {name}Value {{get {{return ({name})this.Value; }} }} + + {access} {name}Exception({name} v, uint level) : base(v, level){{}} + }} + " + ); + } + _ => {} + } + } + } + fn type_name_with_qualifier(&mut self, ty: &Type, qualifier: bool) -> String { match ty { Type::Bool => "bool".to_owned(), @@ -2952,10 +2996,9 @@ impl Bindgen for FunctionBindgen<'_, '_> { 1 => { let mut payload_is_void = false; let mut previous = operands[0].clone(); - let mut vars = Vec::with_capacity(self.results.len()); + let mut vars: Vec::<(String, Option)> = Vec::<(String, Option)>::with_capacity(self.results.len()); if let Direction::Import = self.gen.direction { for ty in &self.results { - vars.push(previous.clone()); let tmp = self.locals.tmp("tmp"); uwrite!( self.src, @@ -2964,21 +3007,31 @@ impl Bindgen for FunctionBindgen<'_, '_> { var {tmp} = {previous}.AsOk; " ); - previous = tmp; + let TypeDefKind::Result(result) = &self.gen.resolve.types[*ty].kind else { unreachable!(); }; + let exception_name = result.err + .and_then(|ty| self.gen.use_custom_exception(ty)) + .map(|ty| self.gen.type_name_with_qualifier(&ty, true)); + vars.push((previous.clone(), exception_name)); payload_is_void = result.ok.is_none(); + previous = tmp; } } uwriteln!(self.src, "return {};", if payload_is_void { "" } else { &previous }); for (level, var) in vars.iter().enumerate().rev() { self.gen.gen.needs_wit_exception = true; + let (var_name, exception_name) = var; + let exception_name = match exception_name { + Some(name) => &format!("{}Exception",name), + None => "WitException", + }; uwrite!( self.src, "\ }} else {{ - throw new WitException({var}.AsErr!, {level}); + throw new {exception_name}({var_name}.AsErr!, {level}); }} " ); diff --git a/tests/runtime/results/wasm.cs b/tests/runtime/results/wasm.cs index 81ca3296b..4ef825835 100644 --- a/tests/runtime/results/wasm.cs +++ b/tests/runtime/results/wasm.cs @@ -13,17 +13,8 @@ public static float EnumError(float a) { try { return ResultsWorld.wit.imports.test.results.TestInterop.EnumError(a); - } catch (WitException e) { - switch ((ResultsWorld.wit.imports.test.results.ITest.E) e.Value) { - case ResultsWorld.wit.imports.test.results.ITest.E.A: - throw new WitException(ITest.E.A, 0); - case ResultsWorld.wit.imports.test.results.ITest.E.B: - throw new WitException(ITest.E.B, 0); - case ResultsWorld.wit.imports.test.results.ITest.E.C: - throw new WitException(ITest.E.C, 0); - default: - throw new Exception("unreachable"); - } + } catch (ResultsWorld.wit.imports.test.results.ITest.EException e) { + throw new WitException(e.EValue, 0); } } @@ -31,9 +22,8 @@ public static float RecordError(float a) { try { return ResultsWorld.wit.imports.test.results.TestInterop.RecordError(a); - } catch (WitException e) { - var value = (ResultsWorld.wit.imports.test.results.ITest.E2) e.Value; - throw new WitException(new ITest.E2(value.line, value.column), 0); + } catch (ResultsWorld.wit.imports.test.results.ITest.E2Exception e) { + throw new WitException(new ITest.E2(e.E2Value.line, e.E2Value.column), 0); } } @@ -41,26 +31,15 @@ public static float VariantError(float a) { try { return ResultsWorld.wit.imports.test.results.TestInterop.VariantError(a); - } catch (WitException e) { - var value = (ResultsWorld.wit.imports.test.results.ITest.E3) e.Value; - switch (value.Tag) { - case ResultsWorld.wit.imports.test.results.ITest.E3.Tags.E1: - switch (value.AsE1) { - case ResultsWorld.wit.imports.test.results.ITest.E.A: - throw new WitException(ITest.E3.E1(ITest.E.A), 0); - case ResultsWorld.wit.imports.test.results.ITest.E.B: - throw new WitException(ITest.E3.E1(ITest.E.B), 0); - case ResultsWorld.wit.imports.test.results.ITest.E.C: - throw new WitException(ITest.E3.E1(ITest.E.C), 0); - default: - throw new Exception("unreachable"); - } - case ResultsWorld.wit.imports.test.results.ITest.E3.Tags.E2: { - throw new WitException(ITest.E3.E2(new ITest.E2(value.AsE2.line, value.AsE2.column)), 0); - } - default: - throw new Exception("unreachable"); - } + } catch (ResultsWorld.wit.imports.test.results.ITest.E3Exception e) + when (e.E3Value.Tag == ResultsWorld.wit.imports.test.results.ITest.E3.Tags.E1) { + throw new WitException(ITest.E3.E1((ITest.E)Enum.Parse(typeof(ITest.E), e.E3Value.AsE1.ToString())), 0); + } catch (ResultsWorld.wit.imports.test.results.ITest.E3Exception e) + when (e.E3Value.Tag == ResultsWorld.wit.imports.test.results.ITest.E3.Tags.E2) { + throw new WitException(ITest.E3.E2(new ITest.E2(e.E3Value.AsE2.line, e.E3Value.AsE2.column)), 0); + } + catch { + throw new Exception("unreachable"); } }