From 47225e8700d8e10848b4f8df8e551fc7e438240f Mon Sep 17 00:00:00 2001 From: Nicholas Nethercote Date: Tue, 7 Mar 2023 15:42:50 +1100 Subject: [PATCH 1/3] Introduce `DeepRejectCtxt::substs_refs_may_unify`. It factors out a repeated code pattern. --- compiler/rustc_middle/src/ty/fast_reject.rs | 15 +++++++++++---- .../src/solve/project_goals.rs | 5 +---- .../src/solve/trait_goals.rs | 9 ++++----- .../rustc_trait_selection/src/traits/coherence.rs | 5 +++-- .../src/traits/select/mod.rs | 6 ++++-- 5 files changed, 23 insertions(+), 17 deletions(-) diff --git a/compiler/rustc_middle/src/ty/fast_reject.rs b/compiler/rustc_middle/src/ty/fast_reject.rs index 669d50a7fda67..6fe91b0cfa0a6 100644 --- a/compiler/rustc_middle/src/ty/fast_reject.rs +++ b/compiler/rustc_middle/src/ty/fast_reject.rs @@ -1,6 +1,6 @@ use crate::mir::Mutability; use crate::ty::subst::GenericArgKind; -use crate::ty::{self, Ty, TyCtxt, TypeVisitableExt}; +use crate::ty::{self, SubstsRef, Ty, TyCtxt, TypeVisitableExt}; use rustc_hir::def_id::DefId; use std::fmt::Debug; use std::hash::Hash; @@ -188,6 +188,15 @@ pub struct DeepRejectCtxt { } impl DeepRejectCtxt { + pub fn substs_refs_may_unify<'tcx>( + self, + obligation_substs: SubstsRef<'tcx>, + impl_substs: SubstsRef<'tcx>, + ) -> bool { + iter::zip(obligation_substs, impl_substs) + .all(|(obl, imp)| self.generic_args_may_unify(obl, imp)) + } + pub fn generic_args_may_unify<'tcx>( self, obligation_arg: ty::GenericArg<'tcx>, @@ -258,9 +267,7 @@ impl DeepRejectCtxt { }, ty::Adt(obl_def, obl_substs) => match k { &ty::Adt(impl_def, impl_substs) => { - obl_def == impl_def - && iter::zip(obl_substs, impl_substs) - .all(|(obl, imp)| self.generic_args_may_unify(obl, imp)) + obl_def == impl_def && self.substs_refs_may_unify(obl_substs, impl_substs) } _ => false, }, diff --git a/compiler/rustc_trait_selection/src/solve/project_goals.rs b/compiler/rustc_trait_selection/src/solve/project_goals.rs index 2b104703aab2f..91b56fe3522ca 100644 --- a/compiler/rustc_trait_selection/src/solve/project_goals.rs +++ b/compiler/rustc_trait_selection/src/solve/project_goals.rs @@ -17,7 +17,6 @@ use rustc_middle::ty::ProjectionPredicate; use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_middle::ty::{ToPredicate, TypeVisitableExt}; use rustc_span::{sym, DUMMY_SP}; -use std::iter; impl<'tcx> EvalCtxt<'_, 'tcx> { #[instrument(level = "debug", skip(self), ret)] @@ -144,9 +143,7 @@ impl<'tcx> assembly::GoalKind<'tcx> for ProjectionPredicate<'tcx> { let goal_trait_ref = goal.predicate.projection_ty.trait_ref(tcx); let impl_trait_ref = tcx.impl_trait_ref(impl_def_id).unwrap(); let drcx = DeepRejectCtxt { treat_obligation_params: TreatParams::ForLookup }; - if iter::zip(goal_trait_ref.substs, impl_trait_ref.skip_binder().substs) - .any(|(goal, imp)| !drcx.generic_args_may_unify(goal, imp)) - { + if !drcx.substs_refs_may_unify(goal_trait_ref.substs, impl_trait_ref.skip_binder().substs) { return Err(NoSolution); } diff --git a/compiler/rustc_trait_selection/src/solve/trait_goals.rs b/compiler/rustc_trait_selection/src/solve/trait_goals.rs index 718c82c8f1f28..f522a8f7e65d1 100644 --- a/compiler/rustc_trait_selection/src/solve/trait_goals.rs +++ b/compiler/rustc_trait_selection/src/solve/trait_goals.rs @@ -1,7 +1,5 @@ //! Dealing with trait goals, i.e. `T: Trait<'a, U>`. -use std::iter; - use super::{assembly, EvalCtxt, SolverMode}; use rustc_hir::def_id::DefId; use rustc_hir::LangItem; @@ -41,9 +39,10 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> { let impl_trait_ref = tcx.impl_trait_ref(impl_def_id).unwrap(); let drcx = DeepRejectCtxt { treat_obligation_params: TreatParams::ForLookup }; - if iter::zip(goal.predicate.trait_ref.substs, impl_trait_ref.skip_binder().substs) - .any(|(goal, imp)| !drcx.generic_args_may_unify(goal, imp)) - { + if !drcx.substs_refs_may_unify( + goal.predicate.trait_ref.substs, + impl_trait_ref.skip_binder().substs, + ) { return Err(NoSolution); } diff --git a/compiler/rustc_trait_selection/src/traits/coherence.rs b/compiler/rustc_trait_selection/src/traits/coherence.rs index 4e5e756dc4af4..d360158fdf818 100644 --- a/compiler/rustc_trait_selection/src/traits/coherence.rs +++ b/compiler/rustc_trait_selection/src/traits/coherence.rs @@ -79,8 +79,9 @@ pub fn overlapping_impls( let impl1_ref = tcx.impl_trait_ref(impl1_def_id); let impl2_ref = tcx.impl_trait_ref(impl2_def_id); let may_overlap = match (impl1_ref, impl2_ref) { - (Some(a), Some(b)) => iter::zip(a.skip_binder().substs, b.skip_binder().substs) - .all(|(arg1, arg2)| drcx.generic_args_may_unify(arg1, arg2)), + (Some(a), Some(b)) => { + drcx.substs_refs_may_unify(a.skip_binder().substs, b.skip_binder().substs) + } (None, None) => { let self_ty1 = tcx.type_of(impl1_def_id).skip_binder(); let self_ty2 = tcx.type_of(impl2_def_id).skip_binder(); diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs index 98c3e7c13ac6e..827c107c8b121 100644 --- a/compiler/rustc_trait_selection/src/traits/select/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs @@ -2542,8 +2542,10 @@ impl<'tcx> SelectionContext<'_, 'tcx> { // substitution if we find that any of the input types, when // simplified, do not match. let drcx = DeepRejectCtxt { treat_obligation_params: TreatParams::ForLookup }; - iter::zip(obligation.predicate.skip_binder().trait_ref.substs, impl_trait_ref.substs) - .any(|(obl, imp)| !drcx.generic_args_may_unify(obl, imp)) + !drcx.substs_refs_may_unify( + obligation.predicate.skip_binder().trait_ref.substs, + impl_trait_ref.substs, + ) } /// Normalize `where_clause_trait_ref` and try to match it against From 9fa69473fd34d9d8974ebe722e21fd6feae6c986 Mon Sep 17 00:00:00 2001 From: Nicholas Nethercote Date: Tue, 7 Mar 2023 16:02:01 +1100 Subject: [PATCH 2/3] Inline and remove `generic_args_may_unify`. It has a single callsite. --- compiler/rustc_middle/src/ty/fast_reject.rs | 31 ++++++++------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/compiler/rustc_middle/src/ty/fast_reject.rs b/compiler/rustc_middle/src/ty/fast_reject.rs index 6fe91b0cfa0a6..0a6e94248e6e6 100644 --- a/compiler/rustc_middle/src/ty/fast_reject.rs +++ b/compiler/rustc_middle/src/ty/fast_reject.rs @@ -193,26 +193,19 @@ impl DeepRejectCtxt { obligation_substs: SubstsRef<'tcx>, impl_substs: SubstsRef<'tcx>, ) -> bool { - iter::zip(obligation_substs, impl_substs) - .all(|(obl, imp)| self.generic_args_may_unify(obl, imp)) - } - - pub fn generic_args_may_unify<'tcx>( - self, - obligation_arg: ty::GenericArg<'tcx>, - impl_arg: ty::GenericArg<'tcx>, - ) -> bool { - match (obligation_arg.unpack(), impl_arg.unpack()) { - // We don't fast reject based on regions for now. - (GenericArgKind::Lifetime(_), GenericArgKind::Lifetime(_)) => true, - (GenericArgKind::Type(obl), GenericArgKind::Type(imp)) => { - self.types_may_unify(obl, imp) - } - (GenericArgKind::Const(obl), GenericArgKind::Const(imp)) => { - self.consts_may_unify(obl, imp) + iter::zip(obligation_substs, impl_substs).all(|(obl, imp)| { + match (obl.unpack(), imp.unpack()) { + // We don't fast reject based on regions for now. + (GenericArgKind::Lifetime(_), GenericArgKind::Lifetime(_)) => true, + (GenericArgKind::Type(obl), GenericArgKind::Type(imp)) => { + self.types_may_unify(obl, imp) + } + (GenericArgKind::Const(obl), GenericArgKind::Const(imp)) => { + self.consts_may_unify(obl, imp) + } + _ => bug!("kind mismatch: {obl} {imp}"), } - _ => bug!("kind mismatch: {obligation_arg} {impl_arg}"), - } + }) } pub fn types_may_unify<'tcx>(self, obligation_ty: Ty<'tcx>, impl_ty: Ty<'tcx>) -> bool { From 03923661afd565efbd9d000c1bc2eaa98206e2c8 Mon Sep 17 00:00:00 2001 From: Nicholas Nethercote Date: Thu, 9 Mar 2023 11:00:05 +1100 Subject: [PATCH 3/3] Inline and remove `SelectionContext::fast_reject_trait_refs`. Because it has a single call site, and it lets us move a small amount of the work outside the loop. --- .../src/traits/select/candidate_assembly.rs | 6 ++++-- .../src/traits/select/mod.rs | 16 ---------------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs index 234d773d64d78..47a351590b1de 100644 --- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs +++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs @@ -11,7 +11,7 @@ use hir::LangItem; use rustc_hir as hir; use rustc_infer::traits::ObligationCause; use rustc_infer::traits::{Obligation, SelectionError, TraitObligation}; -use rustc_middle::ty::fast_reject::TreatProjections; +use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams, TreatProjections}; use rustc_middle::ty::{self, Ty, TypeVisitableExt}; use crate::traits; @@ -344,6 +344,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { return; } + let drcx = DeepRejectCtxt { treat_obligation_params: TreatParams::ForLookup }; + let obligation_substs = obligation.predicate.skip_binder().trait_ref.substs; self.tcx().for_each_relevant_impl( obligation.predicate.def_id(), obligation.predicate.skip_binder().trait_ref.self_ty(), @@ -352,7 +354,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { // consider a "quick reject". This avoids creating more types // and so forth that we need to. let impl_trait_ref = self.tcx().impl_trait_ref(impl_def_id).unwrap(); - if self.fast_reject_trait_refs(obligation, &impl_trait_ref.0) { + if !drcx.substs_refs_may_unify(obligation_substs, impl_trait_ref.0.substs) { return; } if self.reject_fn_ptr_impls( diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs index 827c107c8b121..3ed3dd2d20d84 100644 --- a/compiler/rustc_trait_selection/src/traits/select/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs @@ -45,7 +45,6 @@ use rustc_infer::traits::TraitEngineExt; use rustc_middle::dep_graph::{DepKind, DepNodeIndex}; use rustc_middle::mir::interpret::ErrorHandled; use rustc_middle::ty::abstract_const::NotConstEvaluatable; -use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams}; use rustc_middle::ty::fold::BottomUpFolder; use rustc_middle::ty::relate::TypeRelation; use rustc_middle::ty::SubstsRef; @@ -2533,21 +2532,6 @@ impl<'tcx> SelectionContext<'_, 'tcx> { Ok(Normalized { value: impl_substs, obligations: nested_obligations }) } - fn fast_reject_trait_refs( - &mut self, - obligation: &TraitObligation<'tcx>, - impl_trait_ref: &ty::TraitRef<'tcx>, - ) -> bool { - // We can avoid creating type variables and doing the full - // substitution if we find that any of the input types, when - // simplified, do not match. - let drcx = DeepRejectCtxt { treat_obligation_params: TreatParams::ForLookup }; - !drcx.substs_refs_may_unify( - obligation.predicate.skip_binder().trait_ref.substs, - impl_trait_ref.substs, - ) - } - /// Normalize `where_clause_trait_ref` and try to match it against /// `obligation`. If successful, return any predicates that /// result from the normalization.