Skip to content

Commit

Permalink
feat: generate enum default as newtype (#233)
Browse files Browse the repository at this point in the history
  • Loading branch information
Millione authored Mar 28, 2024
1 parent ce01c86 commit 7f088f7
Show file tree
Hide file tree
Showing 23 changed files with 198 additions and 526 deletions.
38 changes: 2 additions & 36 deletions pilota-build/src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ use crate::{
rir,
},
symbol::{DefId, EnumRepr, FileId},
tags::EnumMode,
Context, Symbol,
};

Expand Down Expand Up @@ -258,31 +257,11 @@ where
}

pub fn write_enum(&self, def_id: DefId, stream: &mut String, e: &middle::rir::Enum) {
if self
.node_tags(def_id)
.unwrap()
.get::<EnumMode>()
.filter(|s| **s == EnumMode::NewType)
.is_some()
{
if e.repr.is_some() {
return self.write_enum_as_new_type(def_id, stream, e);
}
let name = self.rust_name(def_id);

let mut repr = if e.variants.is_empty() {
quote! {}
} else {
match e.repr {
Some(EnumRepr::I32) => quote! {
#[repr(i32)]
},
None => quote! {},
}
};

if e.repr.is_some() {
repr.extend(quote! { #[derive(Copy)] })
}
let mut keep = true;
let mut variants = e
.variants
Expand All @@ -306,21 +285,9 @@ where
format!("({fields})")
};

let discr = v
.discr
.map(|x| {
let x = isize::try_from(x).unwrap();
let x = match e.repr {
Some(EnumRepr::I32) => x as i32,
None => panic!(),
};
format!("={x}")
})
.unwrap_or_default();

format!(
r#"{attrs}
{name} {fields_stream} {discr},"#
{name} {fields_stream},"#
)
})
})
Expand All @@ -332,7 +299,6 @@ where
stream.push_str(&format! {
r#"
#[derive(Clone, PartialEq)]
{repr}
pub enum {name} {{
{variants}
}}
Expand Down
14 changes: 2 additions & 12 deletions pilota-build/src/codegen/thrift/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
},
rir::EnumVariant,
symbol::{DefId, EnumRepr, Symbol},
tags::{thrift::EntryMessage, EnumMode},
tags::thrift::EntryMessage,
ty::TyKind,
};

Expand Down Expand Up @@ -503,17 +503,7 @@ impl CodegenBackend for ThriftBackend {
let keep = self.keep_unknown_fields.contains(&def_id);
let name = self.rust_name(def_id);
let is_entry_message = self.node_contains_tag::<EntryMessage>(def_id);
let v = match self
.cx
.node_tags(def_id)
.unwrap()
.get::<EnumMode>()
.copied()
.unwrap_or(EnumMode::Enum)
{
EnumMode::NewType => "self.inner()",
EnumMode::Enum => "*self as i32",
};
let v = "self.inner()";
match e.repr {
Some(EnumRepr::I32) => stream.push_str(&self.codegen_impl_message_with_helper(
def_id,
Expand Down
28 changes: 2 additions & 26 deletions pilota-build/src/codegen/thrift/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use crate::{
db::RirDatabase,
middle::{rir, ty, ty::Ty},
symbol::EnumRepr,
tags::EnumMode,
DefId,
};

Expand Down Expand Up @@ -177,19 +176,7 @@ impl ThriftBackend {
.into()
}
ty::Path(p) if self.is_i32_enum(p.did) => {
let v = match self
.cx
.node_tags(p.did)
.unwrap()
.get::<EnumMode>()
.copied()
.unwrap_or(EnumMode::Enum)
{
EnumMode::NewType => format!("({ident}).inner()"),
EnumMode::Enum => format!("(*{ident}).into()"),
};

format!("protocol.write_i32_field({id}, {v})?;").into()
format!("protocol.write_i32_field({id}, ({ident}).inner())?;").into()
}
ty::Path(p) => match self.cx.expect_item(p.did).as_ref() {
rir::Item::NewType(nt) => {
Expand Down Expand Up @@ -306,18 +293,7 @@ impl ThriftBackend {
format!("protocol.map_field_len(Some({id}), {k_ttype}, {v_ttype}, {ident}, |protocol, key| {{ {add_key} }}, |protocol, val| {{ {add_val} }})").into()
}
ty::Path(p) if self.is_i32_enum(p.did) => {
let v = match self
.cx
.node_tags(p.did)
.unwrap()
.get::<EnumMode>()
.copied()
.unwrap_or(EnumMode::Enum)
{
EnumMode::NewType => format!("({ident}).inner()"),
EnumMode::Enum => format!("(*{ident}).into()"),
};
format!("protocol.i32_field_len(Some({id}), {v})").into()
format!("protocol.i32_field_len(Some({id}), ({ident}).inner())").into()
}
ty::Path(_) => format!("protocol.struct_field_len(Some({id}), {ident})").into(),
ty::Arc(ty) => self.codegen_field_size(ty, id, ident),
Expand Down
7 changes: 1 addition & 6 deletions pilota-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@ pub use middle::{
rir, ty,
};
use parser::{protobuf::ProtobufParser, thrift::ThriftParser, ParseResult, Parser};
use plugin::{
AutoDerivePlugin, BoxedPlugin, EnumNumPlugin, ImplDefaultPlugin, PredicateResult,
WithAttrsPlugin,
};
use plugin::{AutoDerivePlugin, BoxedPlugin, ImplDefaultPlugin, PredicateResult, WithAttrsPlugin};
pub use plugin::{BoxClonePlugin, ClonePlugin, Plugin};
use resolve::{ResolveResult, Resolver};
use salsa::Durability;
Expand Down Expand Up @@ -105,7 +102,6 @@ impl Builder<MkThriftBackend, ThriftParser> {
plugins: vec![
Box::new(WithAttrsPlugin(Arc::from(["#[derive(Debug)]".into()]))),
Box::new(ImplDefaultPlugin),
Box::new(EnumNumPlugin),
],
touches: Vec::default(),
ignore_unused: true,
Expand All @@ -127,7 +123,6 @@ impl Builder<MkProtobufBackend, ProtobufParser> {
plugins: vec![
Box::new(WithAttrsPlugin(Arc::from(["#[derive(Debug)]".into()]))),
Box::new(ImplDefaultPlugin),
Box::new(EnumNumPlugin),
],
touches: Vec::default(),
ignore_unused: true,
Expand Down
24 changes: 11 additions & 13 deletions pilota-build/src/middle/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::{
db::{RirDatabase, RootDatabase},
rir::{self, Field, Item, ItemPath, Literal},
symbol::{DefId, FileId, IdentName, Symbol},
tags::{EnumMode, TagId, Tags},
tags::{TagId, Tags},
ty::{AdtDef, AdtKind, CodegenTy, Visitor},
Plugin, MAX_RESOLVE_DEPTH,
};
Expand Down Expand Up @@ -847,18 +847,16 @@ impl Context {
},
NodeKind::Variant(v) => {
let parent = self.node(def_id).unwrap().parent.unwrap();

if self
.node_tags(parent)
.unwrap()
.get::<EnumMode>()
.copied()
.unwrap_or(EnumMode::Enum)
== EnumMode::NewType
{
(&**v.name).shouty_snake_case(self.nonstandard_snake_case)
} else {
(&**v.name).variant_ident()
let item = self.expect_item(parent);
match &*item {
rir::Item::Enum(e) => {
if e.repr.is_some() {
(&**v.name).const_ident(self.nonstandard_snake_case)
} else {
(&**v.name).variant_ident()
}
}
_ => unreachable!(),
}
}
NodeKind::Field(f) => (&**f.name).field_ident(self.nonstandard_snake_case),
Expand Down
2 changes: 1 addition & 1 deletion pilota-build/src/parser/thrift/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ impl ThriftLower {
}

annotations.iter().for_each(
|annotation| with_tags!(annotation -> crate::tags::PilotaName | crate::tags::RustType | crate::tags::RustWrapperArc | crate::tags::SerdeAttribute | crate::tags::EnumMode),
|annotation| with_tags!(annotation -> crate::tags::PilotaName | crate::tags::RustType | crate::tags::RustWrapperArc | crate::tags::SerdeAttribute),
);

tags
Expand Down
78 changes: 1 addition & 77 deletions pilota-build/src/plugin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@ use std::{collections::HashSet, ops::DerefMut, sync::Arc};

use faststr::FastStr;
use itertools::Itertools;
use quote::quote;
use rustc_hash::FxHashMap;

use crate::{
db::RirDatabase,
middle::context::tls::CUR_ITEM,
rir::{EnumVariant, Field, Item, NodeKind},
symbol::{DefId, EnumRepr},
tags::EnumMode,
symbol::DefId,
ty::{self, Ty, Visitor},
Context,
};
Expand Down Expand Up @@ -412,77 +410,3 @@ impl Plugin for ImplDefaultPlugin {
walk_item(self, cx, def_id, item)
}
}

pub struct EnumNumPlugin;

impl Plugin for EnumNumPlugin {
fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
match &*item {
Item::Enum(e)
if e.repr.is_some()
&& cx
.node_tags(def_id)
.unwrap()
.get::<EnumMode>()
.copied()
.unwrap_or(EnumMode::Enum)
== EnumMode::Enum =>
{
let name_str = &*cx.rust_name(def_id);
let name = name_str;
let num_ty = match e.repr {
Some(EnumRepr::I32) => quote!(i32),
_ => return,
};
let variants = e
.variants
.iter()
.map(|v| {
let variant_name_str = cx.rust_name(v.did);
let variant_name = variant_name_str;
format!(
"{variant_name} => ::std::result::Result::Ok({name}::{variant_name}),\n"
)
})
.join("");

let nums = e
.variants
.iter()
.map(|v| {
let variant_name_str = cx.rust_name(v.did);
let variant_name = variant_name_str;
format!(
"const {variant_name}: {num_ty} = {name}::{variant_name} as {num_ty};"
)
})
.join("\n");

cx.with_adjust_mut(def_id, |adj| {
adj.add_nested_item(format!(r#"
impl ::std::convert::From<{name}> for {num_ty} {{
fn from(e: {name}) -> Self {{
e as _
}}
}}
impl ::std::convert::TryFrom<{num_ty}> for {name} {{
type Error = ::pilota::EnumConvertError<{num_ty}>;
#[allow(non_upper_case_globals)]
fn try_from(v: i32) -> ::std::result::Result<Self, ::pilota::EnumConvertError<{num_ty}>> {{
{nums}
match v {{
{variants}
_ => ::std::result::Result::Err(::pilota::EnumConvertError::InvalidNum(v, "{name_str}")),
}}
}}
}}"#).into(),
)
});
}
_ => {}
}
walk_item(self, cx, def_id, item)
}
}
19 changes: 10 additions & 9 deletions pilota-build/src/plugin/serde.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::tags::{EnumMode, SerdeAttribute};
use crate::tags::SerdeAttribute;

#[derive(Clone, Copy)]
pub struct SerdePlugin;
Expand Down Expand Up @@ -28,17 +28,18 @@ impl crate::Plugin for SerdePlugin {
adj.add_attrs(&[attr.into()]);
}
});

if cx.node_tags(def_id).unwrap().get::<EnumMode>().copied()
== Some(EnumMode::NewType)
{
cx.with_adjust_mut(def_id, |adj| {
adj.add_attrs(&["#[serde(transparent)]".into()]);
})
}
}
_ => {}
};

if let crate::rir::Item::Enum(e) = &*item {
if e.repr.is_some() {
cx.with_adjust_mut(def_id, |adj| {
adj.add_attrs(&["#[serde(transparent)]".into()]);
})
}
}

crate::plugin::walk_item(self, cx, def_id, item)
}

Expand Down
21 changes: 0 additions & 21 deletions pilota-build/src/tags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,27 +134,6 @@ impl Annotation for PilotaName {
#[derive(Debug)]
pub struct RustType(pub FastStr);

#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
pub enum EnumMode {
NewType,
Enum,
}

impl FromStr for EnumMode {
type Err = std::convert::Infallible;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
"new_type" => Self::NewType,
_ => Self::Enum,
})
}
}

impl Annotation for EnumMode {
const KEY: &'static str = "pilota.enum_mode";
}

impl PartialEq<str> for RustType {
fn eq(&self, other: &str) -> bool {
self.0 == other
Expand Down
Loading

0 comments on commit 7f088f7

Please sign in to comment.