exhaust_macros/
lib.rs

1use std::iter;
2
3use itertools::izip;
4use proc_macro::TokenStream;
5use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
6use quote::{quote, ToTokens as _};
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::{parse_macro_input, parse_quote, DeriveInput};
10
11mod common;
12use common::ExhaustContext;
13
14mod fields;
15use fields::{exhaust_iter_fields, ExhaustFields};
16
17use crate::common::ConstructorSyntax;
18
19// Note: documentation is on the reexport so that it can have working links.
20#[proc_macro_derive(Exhaust)]
21pub fn derive_exhaust(input: TokenStream) -> TokenStream {
22    let input = parse_macro_input!(input as DeriveInput);
23    derive_impl(input)
24        .unwrap_or_else(|err| err.to_compile_error())
25        .into()
26}
27
28/// Generate an impl of Exhaust for a built-in tuple type.
29/// This macro is only useful within the `exhaust` crate.
30#[proc_macro]
31#[doc(hidden)]
32pub fn impl_exhaust_for_tuples(input: TokenStream) -> TokenStream {
33    let input = parse_macro_input!(input as syn::LitInt);
34    tuple_impls_up_to(input.base10_parse().unwrap())
35        .unwrap_or_else(|err| err.to_compile_error())
36        .into()
37}
38
39fn derive_impl(input: DeriveInput) -> Result<TokenStream2, syn::Error> {
40    let DeriveInput {
41        ident: item_type_name,
42        attrs: _,
43        vis,
44        generics,
45        data,
46    } = input;
47
48    let item_type_name_str = &item_type_name.to_string();
49    let factory_type_name = common::generated_type_name(item_type_name_str, "Factory");
50    let iterator_type_name = common::generated_type_name(item_type_name_str, "Iter");
51
52    let ctx = ExhaustContext {
53        vis,
54        generics,
55        iterator_type_name,
56        item_type: ConstructorSyntax::Braced(item_type_name.to_token_stream()),
57        factory_type: ConstructorSyntax::Braced(factory_type_name.to_token_stream()),
58        exhaust_crate_path: syn::parse_quote! { ::exhaust },
59    };
60    let ExhaustContext {
61        iterator_type_name,
62        exhaust_crate_path,
63        ..
64    } = &ctx;
65
66    let (iterator_and_factory_decl, from_factory_body) = match data {
67        syn::Data::Struct(s) => exhaust_iter_struct(s, &ctx),
68        syn::Data::Enum(e) => exhaust_iter_enum(e, &ctx),
69        syn::Data::Union(syn::DataUnion { union_token, .. }) => Err(syn::Error::new(
70            union_token.span,
71            "derive(Exhaust) does not support unions",
72        )),
73    }?;
74
75    let (impl_generics, ty_generics, augmented_where_predicates) =
76        ctx.generics_with_bounds(syn::parse_quote! {});
77
78    Ok(quote! {
79        // rust-analyzer (but not rustc) sometimes produces lints on macro generated code it
80        // shouldn't. We don't expect to actually hit this case normally, but in general,
81        // we don't want to *ever* bother our users with unfixable warnings about weird names.
82        #[allow(nonstandard_style)]
83        // This anonymous constant allows us to make all our generated types be public-in-private,
84        // without altering the meaning of any paths they use.
85        const _: () = {
86            impl #impl_generics #exhaust_crate_path::Exhaust for #item_type_name #ty_generics
87            where #augmented_where_predicates {
88                type Iter = #iterator_type_name #ty_generics;
89                type Factory = #factory_type_name #ty_generics;
90                fn exhaust_factories() -> Self::Iter {
91                    ::core::default::Default::default()
92                }
93                fn from_factory(factory: Self::Factory) -> Self {
94                    #from_factory_body
95                }
96            }
97
98            #iterator_and_factory_decl
99        };
100    })
101}
102
103fn tuple_impls_up_to(size: u64) -> Result<TokenStream2, syn::Error> {
104    (2..=size).map(tuple_impl).collect()
105}
106
107/// Generate an impl of Exhaust for a built-in tuple type.
108///
109/// This is almost but not quite identical to [`exhaust_iter_struct`], due to the syntax
110/// of tuples and due to it being used from the same crate (so that access is via
111/// crate::Exhaust instead of ::exhaust::Exhaust).
112fn tuple_impl(size: u64) -> Result<TokenStream2, syn::Error> {
113    if size < 2 {
114        return Err(syn::Error::new(
115            Span::call_site(),
116            "tuple type of size less than 2 not supported",
117        ));
118    }
119
120    let value_type_vars: Vec<Ident> = (0..size)
121        .map(|i| Ident::new(&format!("T{i}"), Span::mixed_site()))
122        .collect();
123    let factory_value_vars: Vec<Ident> = (0..size)
124        .map(|i| Ident::new(&format!("factory{i}"), Span::mixed_site()))
125        .collect();
126    let synthetic_fields: syn::Fields = syn::Fields::Unnamed(syn::FieldsUnnamed {
127        paren_token: syn::token::Paren(Span::mixed_site()),
128        unnamed: value_type_vars
129            .iter()
130            .map(|type_var| syn::Field {
131                attrs: vec![],
132                vis: parse_quote! { pub },
133                mutability: syn::FieldMutability::None,
134                ident: None,
135                colon_token: None,
136                ty: syn::Type::Verbatim(type_var.to_token_stream()),
137            })
138            .collect(),
139    });
140
141    // Synthesize a good-enough context to use the derive tools.
142    let ctx: ExhaustContext = ExhaustContext {
143        vis: parse_quote! { pub },
144        generics: syn::Generics {
145            lt_token: None,
146            params: value_type_vars
147                .iter()
148                .map(|var| {
149                    syn::GenericParam::Type(syn::TypeParam {
150                        attrs: vec![],
151                        ident: var.clone(),
152                        colon_token: None,
153                        bounds: Punctuated::default(),
154                        eq_token: None,
155                        default: None,
156                    })
157                })
158                .collect(),
159            gt_token: None,
160            where_clause: None,
161        },
162        item_type: ConstructorSyntax::Tuple,
163        factory_type: ConstructorSyntax::Tuple,
164        iterator_type_name: common::generated_type_name("Tuple", "Iter"),
165        exhaust_crate_path: parse_quote! { crate },
166    };
167
168    let iterator_type_name = &ctx.iterator_type_name;
169
170    // Generate the field-exhausting iteration logic
171    let ExhaustFields {
172        state_field_decls,
173        factory_field_decls: _, // unused because we use tuples instead
174        initializers,
175        cloners,
176        field_pats,
177        advance,
178    } = exhaust_iter_fields(
179        &ctx,
180        &synthetic_fields,
181        &quote! {},
182        &ConstructorSyntax::Tuple,
183    );
184
185    let iterator_impls = ctx.impl_iterator_and_factory_traits(
186        quote! {
187            match self {
188                Self { #field_pats } => {
189                    #advance
190                }
191            }
192        },
193        quote! { Self { #initializers } },
194        quote! {
195            let Self { #field_pats } = self;
196            Self { #cloners }
197        },
198    );
199
200    let iterator_doc = ctx.iterator_doc();
201
202    Ok(quote! {
203        const _: () = {
204            impl<#( #value_type_vars , )*> crate::Exhaust for ( #( #value_type_vars , )* )
205            where #( #value_type_vars : crate::Exhaust, )*
206            {
207                type Iter = #iterator_type_name <#( #value_type_vars , )*>;
208                type Factory = (#(
209                    <#value_type_vars as crate::Exhaust>::Factory,
210                )*);
211                fn exhaust_factories() -> Self::Iter {
212                    ::core::default::Default::default()
213                }
214                fn from_factory(factory: Self::Factory) -> Self {
215                    let (#( #factory_value_vars , )*) = factory;
216                    (#(
217                        <#value_type_vars as crate::Exhaust>::from_factory(#factory_value_vars),
218                    )*)
219                }
220            }
221
222            #[doc = #iterator_doc]
223            pub struct #iterator_type_name <#( #value_type_vars , )*>
224            where #( #value_type_vars : crate::Exhaust, )*
225            {
226                #state_field_decls
227            }
228
229            #iterator_impls
230        };
231    })
232}
233
234fn exhaust_iter_struct(
235    s: syn::DataStruct,
236    ctx: &ExhaustContext,
237) -> Result<(TokenStream2, TokenStream2), syn::Error> {
238    let vis = &ctx.vis;
239    let exhaust_crate_path = &ctx.exhaust_crate_path;
240    let (impl_generics, ty_generics, augmented_where_predicates) =
241        ctx.generics_with_bounds(syn::parse_quote! {});
242    let iterator_type_name = &ctx.iterator_type_name;
243    let factory_type_name = &ctx.factory_type.path()?;
244    let factory_type = &ctx.factory_type.parameterized(&ctx.generics);
245
246    let factory_state_struct_type = ctx.generated_type_name("FactoryState")?;
247    let factory_state_ctor = ConstructorSyntax::Braced(factory_state_struct_type.to_token_stream());
248
249    let ExhaustFields {
250        state_field_decls,
251        factory_field_decls,
252        initializers,
253        cloners,
254        field_pats,
255        advance,
256    } = if s.fields.is_empty() {
257        let empty_state_expr = factory_state_ctor.value_expr([].iter(), [].iter());
258        ExhaustFields {
259            state_field_decls: quote! { done: bool, },
260            factory_field_decls: syn::Fields::Unit,
261            initializers: quote! { done: false, },
262            cloners: quote! { done: *done, },
263            field_pats: quote! { done, },
264            advance: quote! {
265                if *done {
266                    ::core::option::Option::None
267                } else {
268                    *done = true;
269                    ::core::option::Option::Some(#factory_type_name(#empty_state_expr))
270                }
271            },
272        }
273    } else {
274        exhaust_iter_fields(
275            ctx,
276            &s.fields,
277            ctx.factory_type.path()?,
278            &factory_state_ctor,
279        )
280    };
281
282    // Note: The iterator must have trait bounds because its fields, being of type
283    // `<SomeOtherTy as Exhaust>::Iter`, require that `SomeOtherTy: Exhaust`.
284
285    let impls = ctx.impl_iterator_and_factory_traits(
286        quote! {
287            match self {
288                Self { #field_pats } => {
289                    #advance
290                }
291            }
292        },
293        quote! { Self { #initializers } },
294        quote! {
295            let Self { #field_pats } = self;
296            Self { #cloners }
297        },
298    );
299
300    let factory_struct_clone_arm = common::clone_like_struct_conversion(
301        &s.fields,
302        factory_state_ctor.path()?,
303        factory_state_ctor.path()?,
304        &quote! { ref },
305        |expr| quote! { ::core::clone::Clone::clone(#expr) },
306    );
307
308    let factory_to_self_transform = common::clone_like_struct_conversion(
309        &s.fields,
310        factory_state_ctor.path()?,
311        ctx.item_type.path()?,
312        &quote! {},
313        |expr| quote! { #exhaust_crate_path::Exhaust::from_factory(#expr) },
314    );
315
316    // Generate factory state struct with the same syntax type as the original
317    // (for elegance, not because it matters functionally).
318    // This struct is always wrapped in a newtype struct to hide implementation details reliably.
319    let factory_state_struct_decl = match &factory_field_decls {
320        syn::Fields::Unit | syn::Fields::Unnamed(_) => quote! {
321            #vis struct #factory_state_struct_type #ty_generics #factory_field_decls
322            where #augmented_where_predicates;
323
324        },
325
326        syn::Fields::Named(_) => quote! {
327            #vis struct #factory_state_struct_type #ty_generics
328            where #augmented_where_predicates
329            #factory_field_decls
330        },
331    };
332
333    Ok((
334        quote! {
335            // Struct that is exposed as the `<Self as Exhaust>::Iter` type.
336            // A wrapper struct is not needed because it always has at least one private field.
337            #vis struct #iterator_type_name #ty_generics
338            where #augmented_where_predicates {
339                #state_field_decls
340            }
341
342            // Struct that is exposed as the `<Self as Exhaust>::Factory` type.
343            #vis struct #factory_type_name #ty_generics (#factory_state_struct_type #ty_generics)
344            where #augmented_where_predicates;
345
346            #impls
347
348            #factory_state_struct_decl
349
350            // A manual impl of Clone is required to *not* have a `Clone` bound on the generics.
351            impl #impl_generics ::core::clone::Clone for #factory_type
352            where #augmented_where_predicates {
353                fn clone(&self) -> Self {
354                    Self(match self.0 {
355                        #factory_struct_clone_arm
356                    })
357                }
358            }
359
360        },
361        quote! {
362            match factory.0 {
363                #factory_to_self_transform
364            }
365        },
366    ))
367}
368
369fn exhaust_iter_enum(
370    e: syn::DataEnum,
371    ctx: &ExhaustContext,
372) -> Result<(TokenStream2, TokenStream2), syn::Error> {
373    let vis = &ctx.vis;
374    let exhaust_crate_path = &ctx.exhaust_crate_path;
375    let iterator_type_name = &ctx.iterator_type_name;
376    let factory_outer_type_path = &ctx.factory_type.path()?;
377    let factory_type = &ctx.factory_type.parameterized(&ctx.generics);
378
379    // These enum types are both wrapped in structs,
380    // so that the user of the macro cannot depend on its implementation details.
381    let iter_state_enum_type = ctx.generated_type_name("IterState")?;
382    let factory_state_enum_type = ctx.generated_type_name("FactoryState")?.to_token_stream();
383    let factory_state_ctor = ConstructorSyntax::Braced(factory_state_enum_type.clone());
384
385    // One ident per variant of the original enum.
386    let state_enum_progress_variants: Vec<Ident> = e
387        .variants
388        .iter()
389        .map(|v| {
390            // Renaming the variant serves two purposes: less confusing error/debug text,
391            // and disambiguating from the “Done” variant.
392            Ident::new(&format!("Exhaust{}", v.ident), v.span())
393        })
394        .collect();
395
396    // TODO: ensure no name conflict, perhaps by renaming the others
397    let done_variant = Ident::new("Done", Span::mixed_site());
398
399    // All variants of our generated enum, which are equal to the original enum
400    // plus a "done" variant.
401    #[allow(clippy::type_complexity)]
402    let (
403        state_enum_variant_decls,
404        state_enum_variant_initializers,
405        state_enum_variant_cloners,
406        state_enum_field_pats,
407        state_enum_variant_advancers,
408        mut factory_variant_decls,
409    ): (
410        Vec<TokenStream2>,
411        Vec<TokenStream2>,
412        Vec<TokenStream2>,
413        Vec<TokenStream2>,
414        Vec<TokenStream2>,
415        Vec<TokenStream2>,
416    ) = itertools::multiunzip(e
417        .variants
418        .iter()
419        .zip(state_enum_progress_variants.iter())
420        .map(|(target_variant, state_ident)| {
421            let target_variant_ident = &target_variant.ident;
422            let fields::ExhaustFields {
423                state_field_decls,
424                factory_field_decls,
425                initializers: state_fields_init,
426                cloners: state_fields_clone,
427                field_pats,
428                advance,
429            } = if target_variant.fields.is_empty() {
430                // TODO: don't even construct this dummy value (needs refactoring)
431                fields::ExhaustFields {
432                    state_field_decls: quote! {},
433                    factory_field_decls: syn::Fields::Unit,
434                    initializers: quote! {},
435                    cloners: quote! {},
436                    field_pats: quote! {},
437                    advance: quote! {
438                        compile_error!("can't happen: fieldless ExhaustFields not used")
439                    },
440                }
441            } else {
442                fields::exhaust_iter_fields(
443                    ctx,
444                    &target_variant.fields,
445                    factory_outer_type_path,
446                    &factory_state_ctor.with_variant(target_variant_ident),
447                )
448            };
449
450            (
451                quote! {
452                    #state_ident {
453                        #state_field_decls
454                    }
455                },
456                quote! {
457                    #iter_state_enum_type :: #state_ident { #state_fields_init }
458                },
459                quote! {
460                    #iter_state_enum_type :: #state_ident { #field_pats } =>
461                        #iter_state_enum_type :: #state_ident { #state_fields_clone }
462                },
463                field_pats,
464                advance,
465                quote! {
466                    #target_variant_ident #factory_field_decls
467                },
468            )
469        })
470        .chain(iter::once((
471            done_variant.to_token_stream(),
472            quote! {
473                // iterator construction
474                #iter_state_enum_type :: #done_variant {}
475            },
476            quote! {
477                // clone() match arm
478                #iter_state_enum_type :: #done_variant {} => #iter_state_enum_type :: #done_variant {}
479            },
480            quote! {},
481            quote! { compile_error!("done advancer not used") },
482            quote! { compile_error!("done factory variant not used") },
483        ))));
484
485    factory_variant_decls.pop(); // no Done arm in the factory enum
486
487    let first_state_variant_initializer = &state_enum_variant_initializers[0];
488
489    // Match arms to advance the iterator.
490    let variant_next_arms = izip!(
491        e.variants.iter(),
492        state_enum_progress_variants.iter(),
493        state_enum_field_pats.iter(),
494        state_enum_variant_initializers.iter().skip(1),
495        state_enum_variant_advancers.iter(),
496    )
497    .map(
498        |(target_enum_variant, state_ident, pats, next_state_initializer, field_advancer)| {
499            let target_variant_ident = &target_enum_variant.ident;
500            let advancer = if target_enum_variant.fields.is_empty() {
501                let factory_state_expr = factory_state_ctor
502                    .with_variant(target_variant_ident)
503                    .value_expr([].iter(), [].iter());
504                quote! {
505                    self.0 = #next_state_initializer;
506                    ::core::option::Option::Some(#factory_outer_type_path(#factory_state_expr))
507                }
508            } else {
509                quote! {
510                    let maybe_variant = { #field_advancer };
511                    match maybe_variant {
512                        ::core::option::Option::Some(v) => ::core::option::Option::Some(v),
513                        ::core::option::Option::None => {
514                            self.0 = #next_state_initializer;
515                            // TODO: recursion is a kludge here; rewrite as loop{}
516                            ::core::iter::Iterator::next(self)
517                        }
518                    }
519                }
520            };
521            quote! {
522                #iter_state_enum_type::#state_ident { #pats } => {
523                    #advancer
524                }
525            }
526        },
527    );
528
529    let factory_enum_variant_clone_arms: Vec<TokenStream2> = common::clone_like_match_arms(
530        &e.variants,
531        &factory_state_enum_type,
532        &factory_state_enum_type,
533        &quote! { ref },
534        |expr| quote! { ::core::clone::Clone::clone(#expr) },
535    );
536    let factory_to_self_transform = common::clone_like_match_arms(
537        &e.variants,
538        &factory_state_enum_type,
539        ctx.item_type.path()?,
540        &quote! {},
541        |expr| quote! { #exhaust_crate_path::Exhaust::from_factory(#expr) },
542    );
543
544    let (impl_generics, ty_generics, augmented_where_predicates) =
545        ctx.generics_with_bounds(syn::parse_quote! {});
546
547    let impls = ctx.impl_iterator_and_factory_traits(
548        quote! {
549            match &mut self.0 {
550                #( #variant_next_arms , )*
551                #iter_state_enum_type::#done_variant => ::core::option::Option::None,
552            }
553        },
554        quote! {
555            Self(#first_state_variant_initializer)
556        },
557        quote! {
558            Self(match &self.0 {
559                #( #state_enum_variant_cloners , )*
560            })
561        },
562    );
563
564    let iterator_decl = quote! {
565        // Struct that is exposed as the `<Self as Exhaust>::Iter` type.
566        #vis struct #iterator_type_name #ty_generics
567        (#iter_state_enum_type #ty_generics)
568        where #augmented_where_predicates;
569
570        // Struct that is exposed as the `<Self as Exhaust>::Factory` type.
571        #vis struct #factory_outer_type_path #ty_generics (#factory_state_enum_type #ty_generics)
572        where #augmented_where_predicates;
573
574        #impls
575
576        // Enum wrapped in #factory_type_name with the actual data.
577        enum #factory_state_enum_type #ty_generics
578        where #augmented_where_predicates { #( #factory_variant_decls ,)* }
579
580        // A manual impl of Clone is required to *not* have a `Clone` bound on the generics.
581        impl #impl_generics ::core::clone::Clone for #factory_type
582        where #augmented_where_predicates {
583            fn clone(&self) -> Self {
584                #![allow(unreachable_code)] // in case of empty enum
585                Self(match self.0 {
586                    #( #factory_enum_variant_clone_arms , )*
587                })
588            }
589        }
590
591        enum #iter_state_enum_type #ty_generics
592        where #augmented_where_predicates
593        {
594            #( #state_enum_variant_decls , )*
595        }
596    };
597
598    let from_factory_body = quote! {
599        match factory.0 {
600            #( #factory_to_self_transform , )*
601        }
602    };
603
604    Ok((iterator_decl, from_factory_body))
605}