Skip to content

Commit

Permalink
Support pointwise addition for arrays and tuples
Browse files Browse the repository at this point in the history
Fixes JelteF#342

We want to support examples like these:

```rust
struct StructRecursive {
    a: i32,
    b: [i32; 2],
    c: [[i32; 2]; 3],
    d: (i32, i32),
    e: ((u8, [i32; 3]), i32),
    f: ((u8, i32), (u8, ((i32, u64, ((u8, u8), u16)), u8))),
    g: i32,
}

struct TupleRecursive((i32, u8), [(i32, u8); 10]);
```

Supporting arrays and tuples inside of enums would also be useful, but
that's not in this PR.
  • Loading branch information
matthiasgoergens committed Mar 20, 2024
1 parent 2a001d6 commit 9250753
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 20 deletions.
79 changes: 60 additions & 19 deletions impl/src/add_helpers.rs
@@ -1,28 +1,69 @@
use proc_macro2::TokenStream;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{Field, Ident, Index};
use syn::{Field, Ident, Index, Type, TypeArray, TypeTuple};

pub fn tuple_exprs(fields: &[&Field], method_ident: &Ident) -> Vec<TokenStream> {
let mut exprs = vec![];

for i in 0..fields.len() {
let i = Index::from(i);
// generates `self.0.add(rhs.0)`
let expr = quote! { self.#i.#method_ident(rhs.#i) };
exprs.push(expr);
}
exprs
let fields: Vec<&Type> = fields.iter().map(|field| &field.ty).collect::<Vec<_>>();
inner_tuple_exprs(0, &quote! {}, &fields, method_ident)
}

pub fn struct_exprs(fields: &[&Field], method_ident: &Ident) -> Vec<TokenStream> {
let mut exprs = vec![];
fields
.iter()
.map(|field| {
// It's safe to unwrap because struct fields always have an identifier
let field_path = field.ident.as_ref().unwrap();
elem_content(0, &quote! { .#field_path }, &field.ty, method_ident)
})
.collect()
}

pub fn inner_tuple_exprs(
// `depth` is needed for `index_var` generation for nested arrays
depth: usize,
field_path: &TokenStream,
fields: &[&Type],
method_ident: &Ident,
) -> Vec<TokenStream> {
fields
.iter()
.enumerate()
.map(|(i, ty)| {
let i = Index::from(i);
elem_content(depth + 1, &quote! { #field_path.#i }, ty, method_ident)
})
.collect()
}

pub fn elem_content(
depth: usize,
field_path: &TokenStream,
ty: &Type,
method_ident: &Ident,
) -> TokenStream {
match ty {
Type::Array(TypeArray { elem, .. }) => {
let index_var = Ident::new(&format!("i{}", depth), Span::call_site());
let fn_body = elem_content(
depth + 1,
&quote! { #field_path[#index_var] },
elem.as_ref(),
method_ident,
);

for field in fields {
// It's safe to unwrap because struct fields always have an identifier
let field_id = field.ident.as_ref().unwrap();
// generates `x: self.x.add(rhs.x)`
let expr = quote! { self.#field_id.#method_ident(rhs.#field_id) };
exprs.push(expr)
// generates `core::array::from_fn(|i0| self.x[i0].add(rhs.x[i0]))`
quote! { core::array::from_fn(|#index_var| #fn_body) }
}
Type::Tuple(TypeTuple { elems, .. }) => {
let exprs = inner_tuple_exprs(
depth + 1,
field_path,
&elems.iter().collect::<Vec<_>>(),
method_ident,
);
quote! { (#(#exprs),*) }
}
// generates `self.x.add(rhs.x)`
_ => quote! { self #field_path.#method_ident(rhs #field_path) },
}
exprs
}
3 changes: 2 additions & 1 deletion impl/src/add_like.rs
Expand Up @@ -18,6 +18,7 @@ pub fn expand(input: &DeriveInput, trait_name: &str) -> TokenStream {
let generics = add_extra_type_param_bound_op_output(&input.generics, &trait_ident);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

// TODO(Matthias): do we add support for arrays here?
let (output_type, block) = match input.data {
Data::Struct(ref data_struct) => match data_struct.fields {
Fields::Unnamed(ref fields) => (
Expand Down Expand Up @@ -53,7 +54,7 @@ pub fn expand(input: &DeriveInput, trait_name: &str) -> TokenStream {
}
}

fn tuple_content<T: ToTokens>(
pub(crate) fn tuple_content<T: ToTokens>(
input_type: &T,
fields: &[&Field],
method_ident: &Ident,
Expand Down
25 changes: 25 additions & 0 deletions tests/add.rs
Expand Up @@ -22,3 +22,28 @@ enum MixedInts {
UnsignedTwo(u32),
Unit,
}

#[derive(Add)]
#[derive(Default)]
struct StructRecursive {
a: i32,
b: [i32; 2],
c: [[i32; 2]; 3],
d: (i32, i32),
e: ((u8, [i32; 3]), i32),
f: ((u8, i32), (u8, ((i32, u64, ((u8, u8), u16)), u8))),
g: i32,
}

#[test]
fn test_sanity() {
let mut a: StructRecursive = Default::default();
let mut b: StructRecursive = Default::default();
a.c[0][1] = 1;
b.c[0][1] = 2;
let c = a + b;
assert_eq!(c.c[0][1], 3);
}

#[derive(Add)]
struct TupleRecursive((i32, u8), [(i32, u8); 10]);

0 comments on commit 9250753

Please sign in to comment.