diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/from_reflect.rs b/crates/bevy_reflect/bevy_reflect_derive/src/from_reflect.rs index 08342cccbe943..20523b8c5ba6b 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/from_reflect.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/from_reflect.rs @@ -3,7 +3,7 @@ use crate::derive_data::ReflectEnum; use crate::enum_utility::{get_variant_constructors, EnumVariantConstructors}; use crate::field_attributes::DefaultBehavior; use crate::fq_std::{FQAny, FQClone, FQDefault, FQOption}; -use crate::utility::ident_or_index; +use crate::utility::{extend_where_clause, ident_or_index, WhereClauseOptions}; use crate::{ReflectMeta, ReflectStruct}; use proc_macro::TokenStream; use proc_macro2::Span; @@ -49,8 +49,20 @@ pub(crate) fn impl_enum(reflect_enum: &ReflectEnum) -> TokenStream { let (impl_generics, ty_generics, where_clause) = reflect_enum.meta().generics().split_for_impl(); + + // Add FromReflect bound for each active field + let where_from_reflect_clause = extend_where_clause( + where_clause, + &WhereClauseOptions { + active_types: reflect_enum.active_types().into_boxed_slice(), + ignored_types: reflect_enum.ignored_types().into_boxed_slice(), + active_trait_bounds: quote!(#bevy_reflect_path::FromReflect), + ignored_trait_bounds: quote!(#FQDefault), + }, + ); + TokenStream::from(quote! { - impl #impl_generics #bevy_reflect_path::FromReflect for #type_name #ty_generics #where_clause { + impl #impl_generics #bevy_reflect_path::FromReflect for #type_name #ty_generics #where_from_reflect_clause { fn from_reflect(#ref_value: &dyn #bevy_reflect_path::Reflect) -> #FQOption { if let #bevy_reflect_path::ReflectRef::Enum(#ref_value) = #bevy_reflect_path::Reflect::reflect_ref(#ref_value) { match #bevy_reflect_path::Enum::variant_name(#ref_value) { @@ -89,11 +101,11 @@ fn impl_struct_internal(reflect_struct: &ReflectStruct, is_tuple: bool) -> Token Ident::new("Struct", Span::call_site()) }; - let field_types = reflect_struct.active_types(); let MemberValuePair(active_members, active_values) = get_active_fields(reflect_struct, &ref_struct, &ref_struct_type, is_tuple); - let constructor = if reflect_struct.meta().traits().contains(REFLECT_DEFAULT) { + let is_defaultable = reflect_struct.meta().traits().contains(REFLECT_DEFAULT); + let constructor = if is_defaultable { quote!( let mut __this: Self = #FQDefault::default(); #( @@ -120,16 +132,19 @@ fn impl_struct_internal(reflect_struct: &ReflectStruct, is_tuple: bool) -> Token let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); // Add FromReflect bound for each active field - let mut where_from_reflect_clause = if where_clause.is_some() { - quote! {#where_clause} - } else if !active_members.is_empty() { - quote! {where} - } else { - quote! {} - }; - where_from_reflect_clause.extend(quote! { - #(#field_types: #bevy_reflect_path::FromReflect,)* - }); + let where_from_reflect_clause = extend_where_clause( + where_clause, + &WhereClauseOptions { + active_types: reflect_struct.active_types().into_boxed_slice(), + ignored_types: reflect_struct.ignored_types().into_boxed_slice(), + active_trait_bounds: quote!(#bevy_reflect_path::FromReflect), + ignored_trait_bounds: if is_defaultable { + quote!() + } else { + quote!(#FQDefault) + }, + }, + ); TokenStream::from(quote! { impl #impl_generics #bevy_reflect_path::FromReflect for #struct_name #ty_generics #where_from_reflect_clause diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/utility.rs b/crates/bevy_reflect/bevy_reflect_derive/src/utility.rs index c0675d5363ea5..2ba4229426a47 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/utility.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/utility.rs @@ -122,8 +122,9 @@ pub(crate) fn extend_where_clause( let active_trait_bounds = &where_clause_options.active_trait_bounds; let ignored_trait_bounds = &where_clause_options.ignored_trait_bounds; - let mut generic_where_clause = if where_clause.is_some() { - quote! {#where_clause} + let mut generic_where_clause = if let Some(where_clause) = where_clause { + let predicates = where_clause.predicates.iter(); + quote! {where #(#predicates,)*} } else if !(active_types.is_empty() && ignored_types.is_empty()) { quote! {where} } else { diff --git a/crates/bevy_reflect_compile_fail_tests/tests/reflect_derive/bounds.pass.rs b/crates/bevy_reflect_compile_fail_tests/tests/reflect_derive/bounds.pass.rs new file mode 100644 index 0000000000000..2f7b3883b07f2 --- /dev/null +++ b/crates/bevy_reflect_compile_fail_tests/tests/reflect_derive/bounds.pass.rs @@ -0,0 +1,282 @@ +use bevy_reflect::prelude::*; + +fn main() {} + +#[derive(Default)] +struct NonReflect; + +struct NonReflectNonDefault; + +mod structs { + use super::*; + + #[derive(Reflect)] + struct ReflectGeneric { + foo: T, + #[reflect(ignore)] + _ignored: NonReflect, + } + + #[derive(Reflect, FromReflect)] + struct FromReflectGeneric { + foo: T, + #[reflect(ignore)] + _ignored: NonReflect, + } + + #[derive(Reflect, FromReflect)] + #[reflect(Default)] + struct DefaultGeneric { + foo: Option, + #[reflect(ignore)] + _ignored: NonReflectNonDefault, + } + + impl Default for DefaultGeneric { + fn default() -> Self { + Self { + foo: None, + _ignored: NonReflectNonDefault, + } + } + } + + #[derive(Reflect)] + struct ReflectBoundGeneric { + foo: T, + #[reflect(ignore)] + _ignored: NonReflect, + } + + #[derive(Reflect, FromReflect)] + struct FromReflectBoundGeneric { + foo: T, + #[reflect(ignore)] + _ignored: NonReflect, + } + + #[derive(Reflect, FromReflect)] + #[reflect(Default)] + struct DefaultBoundGeneric { + foo: Option, + #[reflect(ignore)] + _ignored: NonReflectNonDefault, + } + + impl Default for DefaultBoundGeneric { + fn default() -> Self { + Self { + foo: None, + _ignored: NonReflectNonDefault, + } + } + } + + #[derive(Reflect)] + struct ReflectGenericWithWhere + where + T: Clone, + { + foo: T, + #[reflect(ignore)] + _ignored: NonReflect, + } + + #[derive(Reflect, FromReflect)] + struct FromReflectGenericWithWhere + where + T: Clone, + { + foo: T, + #[reflect(ignore)] + _ignored: NonReflect, + } + + #[derive(Reflect, FromReflect)] + #[reflect(Default)] + struct DefaultGenericWithWhere + where + T: Clone, + { + foo: Option, + #[reflect(ignore)] + _ignored: NonReflectNonDefault, + } + + impl Default for DefaultGenericWithWhere + where + T: Clone, + { + fn default() -> Self { + Self { + foo: None, + _ignored: NonReflectNonDefault, + } + } + } + + #[derive(Reflect)] + #[rustfmt::skip] + struct ReflectGenericWithWhereNoTrailingComma + where + T: Clone + { + foo: T, + #[reflect(ignore)] + _ignored: NonReflect, + } + + #[derive(Reflect, FromReflect)] + #[rustfmt::skip] + struct FromReflectGenericWithWhereNoTrailingComma + where + T: Clone + { + foo: T, + #[reflect(ignore)] + _ignored: NonReflect, + } + + #[derive(Reflect, FromReflect)] + #[reflect(Default)] + #[rustfmt::skip] + struct DefaultGenericWithWhereNoTrailingComma + where + T: Clone + { + foo: Option, + #[reflect(ignore)] + _ignored: NonReflectNonDefault, + } + + impl Default for DefaultGenericWithWhereNoTrailingComma + where + T: Clone, + { + fn default() -> Self { + Self { + foo: None, + _ignored: NonReflectNonDefault, + } + } + } +} + +mod tuple_structs { + use super::*; + + #[derive(Reflect)] + struct ReflectGeneric(T, #[reflect(ignore)] NonReflect); + + #[derive(Reflect, FromReflect)] + struct FromReflectGeneric(T, #[reflect(ignore)] NonReflect); + + #[derive(Reflect, FromReflect)] + #[reflect(Default)] + struct DefaultGeneric(Option, #[reflect(ignore)] NonReflectNonDefault); + + impl Default for DefaultGeneric { + fn default() -> Self { + Self(None, NonReflectNonDefault) + } + } + + #[derive(Reflect)] + struct ReflectBoundGeneric(T, #[reflect(ignore)] NonReflect); + + #[derive(Reflect, FromReflect)] + struct FromReflectBoundGeneric(T, #[reflect(ignore)] NonReflect); + + #[derive(Reflect, FromReflect)] + #[reflect(Default)] + struct DefaultBoundGeneric(Option, #[reflect(ignore)] NonReflectNonDefault); + + impl Default for DefaultBoundGeneric { + fn default() -> Self { + Self(None, NonReflectNonDefault) + } + } + + #[derive(Reflect)] + struct ReflectGenericWithWhere(T, #[reflect(ignore)] NonReflect) + where + T: Clone; + + #[derive(Reflect, FromReflect)] + struct FromReflectGenericWithWhere(T, #[reflect(ignore)] NonReflect) + where + T: Clone; + + #[derive(Reflect, FromReflect)] + #[reflect(Default)] + struct DefaultGenericWithWhere(Option, #[reflect(ignore)] NonReflectNonDefault) + where + T: Clone; + + impl Default for DefaultGenericWithWhere + where + T: Clone, + { + fn default() -> Self { + Self(None, NonReflectNonDefault) + } + } +} + +mod enums { + use super::*; + + #[derive(Reflect)] + enum ReflectGeneric { + Foo(T, #[reflect(ignore)] NonReflect), + } + + #[derive(Reflect, FromReflect)] + enum FromReflectGeneric { + Foo(T, #[reflect(ignore)] NonReflect), + } + + #[derive(Reflect)] + enum ReflectBoundGeneric { + Foo(T, #[reflect(ignore)] NonReflect), + } + + #[derive(Reflect, FromReflect)] + enum FromReflectBoundGeneric { + Foo(T, #[reflect(ignore)] NonReflect), + } + + #[derive(Reflect)] + enum ReflectGenericWithWhere + where + T: Clone, + { + Foo(T, #[reflect(ignore)] NonReflect), + } + + #[derive(Reflect, FromReflect)] + enum FromReflectGenericWithWhere + where + T: Clone, + { + Foo(T, #[reflect(ignore)] NonReflect), + } + + #[derive(Reflect)] + #[rustfmt::skip] + enum ReflectGenericWithWhereNoTrailingComma + where + T: Clone + { + Foo(T, #[reflect(ignore)] NonReflect), + } + + #[derive(Reflect, FromReflect)] + #[rustfmt::skip] + enum FromReflectGenericWithWhereNoTrailingComma + where + T: Clone + { + Foo(T, #[reflect(ignore)] NonReflect), + } +}