Skip to content

Commit

Permalink
Support skipping fields in Queryable derive (falls back to Default impl)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ten0 committed Oct 27, 2021
1 parent 9cb852b commit e1eed5b
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 19 deletions.
18 changes: 13 additions & 5 deletions diesel_derives/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,19 @@ impl Field {
}
}

pub fn ty_for_deserialize(&self) -> Result<Cow<syn::Type>, Diagnostic> {
if let Some(meta) = self.flags.nested_item("deserialize_as")? {
meta.ty_value().map(Cow::Owned)
} else {
Ok(Cow::Borrowed(&self.ty))
/// Returns None if has `default` flag
pub fn ty_for_deserialize(&self) -> Result<Option<Cow<syn::Type>>, Diagnostic> {
match (
self.flags.nested_item("deserialize_as")?,
self.has_flag("skip"),
) {
(Some(meta), false) => Ok(Some(Cow::Owned(meta.ty_value()?))),
(None, false) => Ok(Some(Cow::Borrowed(&self.ty))),
(None, true) => Ok(None),
(Some(_), true) => Err(self
.flags
.span()
.error("Cannot have both `deserialize_as` and `skip` attributes")),
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions diesel_derives/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,9 @@ pub fn derive_query_id(input: TokenStream) -> TokenStream {
/// Then `Type` is converted via
/// [`.try_into`](https://doc.rust-lang.org/stable/std/convert/trait.TryInto.html#tymethod.try_into)
/// into the field type. By default this derive will deserialize directly into the field type
/// * `[diesel(skip)]`, instead of deserializing anything into that type, deserialization will
/// behave as if this field wasn't here, and it will simply be set using the `Default` impl
/// of the type of the field.
///
/// # Examples
///
Expand Down
22 changes: 15 additions & 7 deletions diesel_derives/src/queryable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,22 @@ pub fn derive(item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Diagno
let field_ty = model
.fields()
.iter()
.map(Field::ty_for_deserialize)
.filter_map(|f| Field::ty_for_deserialize(f).transpose())
.collect::<Result<Vec<_>, _>>()?;
let field_ty = &field_ty;
let build_expr = model.fields().iter().enumerate().map(|(i, f)| {
let i = syn::Index::from(i);
f.name.assign(parse_quote!(row.#i.try_into()?))
});
let sql_type = (0..model.fields().len())
let build_expr = {
let mut non_default_field_idx = 0;
model.fields().iter().map(move |f| {
f.name.assign(if f.has_flag("skip") {
parse_quote!(std::default::Default::default())
} else {
let i = syn::Index::from(non_default_field_idx);
non_default_field_idx += 1;
parse_quote!(row.#i.try_into()?)
})
})
};
let sql_type = (0..field_ty.len())
.map(|i| {
let i = syn::Ident::new(&format!("__ST{}", i), proc_macro2::Span::call_site());
quote!(#i)
Expand All @@ -32,7 +40,7 @@ pub fn derive(item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Diagno
generics
.params
.push(parse_quote!(__DB: diesel::backend::Backend));
for id in 0..model.fields().len() {
for id in 0..field_ty.len() {
let ident = syn::Ident::new(&format!("__ST{}", id), proc_macro2::Span::call_site());
generics.params.push(parse_quote!(#ident));
}
Expand Down
46 changes: 39 additions & 7 deletions diesel_derives/tests/queryable.rs
Original file line number Diff line number Diff line change
@@ -1,38 +1,70 @@
use diesel::dsl::sql;
use diesel::sql_types::Integer;
use diesel::sql_types::{Int4, Int8};
use diesel::*;

use helpers::connection;

#[test]
fn named_struct_definition() {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Queryable)]
#[derive(Debug, PartialEq, Eq, Queryable)]
struct MyStruct {
foo: i32,
bar: i32,
}

let conn = &mut connection();
let data = select(sql::<(Integer, Integer)>("1, 2")).get_result(conn);
let data = select(sql::<(Int4, Int4)>("1, 2")).get_result(conn);
assert_eq!(Ok(MyStruct { foo: 1, bar: 2 }), data);
}

#[test]
fn tuple_struct() {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Queryable)]
#[derive(Debug, PartialEq, Eq, Queryable)]
struct MyStruct(#[column_name = "foo"] i32, #[column_name = "bar"] i32);

let conn = &mut connection();
let data = select(sql::<(Integer, Integer)>("1, 2")).get_result(conn);
let data = select(sql::<(Int4, Int4)>("1, 2")).get_result(conn);
assert_eq!(Ok(MyStruct(1, 2)), data);
}

#[test]
fn tuple_struct_without_column_name_annotations() {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Queryable)]
#[derive(Debug, PartialEq, Eq, Queryable)]
struct MyStruct(i32, i32);

let conn = &mut connection();
let data = select(sql::<(Integer, Integer)>("1, 2")).get_result(conn);
let data = select(sql::<(Int4, Int4)>("1, 2")).get_result(conn);
assert_eq!(Ok(MyStruct(1, 2)), data);
}

#[test]
fn named_struct_definition_with_skip() {
#[derive(Debug, PartialEq, Eq, Queryable)]
struct MyStruct {
foo: i32,
#[diesel(skip)]
should_be_default: Vec<i32>,
bar: i64,
}

let conn = &mut connection();
let data = select(sql::<(Int4, Int8)>("1, 2::int8")).get_result(conn);
assert_eq!(
Ok(MyStruct {
foo: 1,
should_be_default: Vec::default(),
bar: 2,
}),
data
);
}

#[test]
fn tuple_struct_with_skip() {
#[derive(Debug, PartialEq, Eq, Queryable)]
struct MyStruct(i32, #[diesel(skip)] Option<i32>, i64);

let conn = &mut connection();
let data = select(sql::<(Int4, Int8)>("1, 2::int8")).get_result(conn);
assert_eq!(Ok(MyStruct(1, None, 2)), data);
}

0 comments on commit e1eed5b

Please sign in to comment.