Skip to content

Commit

Permalink
Auto merge of rust-lang#117703 - compiler-errors:recursive-async, r=<…
Browse files Browse the repository at this point in the history
…try>

Support async recursive calls (as long as they have indirection)

Before rust-lang#101692, we stored coroutine witness types directly inside of the coroutine. That means that a coroutine could not contain itself (as a witness field) without creating a cycle in the type representation of the coroutine, which we detected with the `OpaqueTypeExpander`, which is used to detect cycles when expanding opaque types after that are inferred to contain themselves.

After `-Zdrop-tracking-mir` was stabilized, we no longer store these generator witness fields directly, but instead behind a def-id based query. That means there is no technical obstacle in the compiler preventing coroutines from containing themselves per se, other than the fact that for a coroutine to have a non-infinite layout, it must contain itself wrapped in a layer of allocation indirection (like a `Box`).

This means that it should be valid for this code to work:

```
async fn async_fibonacci(i: u32) -> u32 {
    if i == 0 || i == 1 {
        i
    } else {
        Box::pin(async_fibonacci(i - 1)).await
          + Box::pin(async_fibonacci(i - 2)).await
    }
}
```

Whereas previously, you'd need to coerce the future to `Pin<Box<dyn Future<Output = ...>>` before `await`ing it, to prevent the async's desugared coroutine from containing itself across as await point.

This PR does two things:
1. Remove the behavior from `OpaqueTypeExpander` where it intentionally fetches and walks through the coroutine's witness fields. This was kept around after `-Zdrop-tracking-mir` was stabilized so we would not be introducing new stable behavior, and to preserve the much better diagnostics of async recursion compared to a layout error.
2. Reworks the way we report layout errors having to do with coroutines, to make up for the diagnostic regressions introduced by (1.). We actually do even better now, pointing out the call sites of the recursion!
  • Loading branch information
bors committed Nov 12, 2023
2 parents 2b603f9 + 8b4b867 commit ab5373f
Show file tree
Hide file tree
Showing 44 changed files with 291 additions and 196 deletions.
7 changes: 5 additions & 2 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,8 +792,11 @@ impl<'hir> LoweringContext<'_, 'hir> {
// debuggers and debugger extensions expect it to be called `__awaitee`. They use
// this name to identify what is being awaited by a suspended async functions.
let awaitee_ident = Ident::with_dummy_span(sym::__awaitee);
let (awaitee_pat, awaitee_pat_hid) =
self.pat_ident_binding_mode(span, awaitee_ident, hir::BindingAnnotation::MUT);
let (awaitee_pat, awaitee_pat_hid) = self.pat_ident_binding_mode(
gen_future_span,
awaitee_ident,
hir::BindingAnnotation::MUT,
);

let task_context_ident = Ident::with_dummy_span(sym::_task_context);

Expand Down
9 changes: 9 additions & 0 deletions compiler/rustc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1529,6 +1529,15 @@ pub enum CoroutineKind {
Coroutine,
}

impl CoroutineKind {
pub fn is_fn_like(self) -> bool {
matches!(
self,
CoroutineKind::Async(CoroutineSource::Fn) | CoroutineKind::Gen(CoroutineSource::Fn)
)
}
}

impl fmt::Display for CoroutineKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Expand Down
25 changes: 7 additions & 18 deletions compiler/rustc_hir_analysis/src/check/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,12 @@ fn check_opaque(tcx: TyCtxt<'_>, id: hir::ItemId) {
return;
}

let args = GenericArgs::identity_for_item(tcx, item.owner_id);
let span = tcx.def_span(item.owner_id.def_id);

if tcx.type_of(item.owner_id.def_id).instantiate_identity().references_error() {
return;
}
if check_opaque_for_cycles(tcx, item.owner_id.def_id, args, span, &origin).is_err() {
if check_opaque_for_cycles(tcx, item.owner_id.def_id, span).is_err() {
return;
}

Expand All @@ -233,16 +232,16 @@ fn check_opaque(tcx: TyCtxt<'_>, id: hir::ItemId) {
pub(super) fn check_opaque_for_cycles<'tcx>(
tcx: TyCtxt<'tcx>,
def_id: LocalDefId,
args: GenericArgsRef<'tcx>,
span: Span,
origin: &hir::OpaqueTyOrigin,
) -> Result<(), ErrorGuaranteed> {
let args = GenericArgs::identity_for_item(tcx, def_id);
if tcx.try_expand_impl_trait_type(def_id.to_def_id(), args).is_err() {
let reported = match origin {
hir::OpaqueTyOrigin::AsyncFn(..) => async_opaque_type_cycle_error(tcx, span),
_ => opaque_type_cycle_error(tcx, def_id, span),
};
let reported = opaque_type_cycle_error(tcx, def_id, span);
Err(reported)
} else if let Err(&LayoutError::Cycle(guar)) =
tcx.layout_of(tcx.param_env(def_id).and(Ty::new_opaque(tcx, def_id.to_def_id(), args)))
{
Err(guar)
} else {
Ok(())
}
Expand Down Expand Up @@ -1324,16 +1323,6 @@ pub(super) fn check_mod_item_types(tcx: TyCtxt<'_>, module_def_id: LocalModDefId
}
}

fn async_opaque_type_cycle_error(tcx: TyCtxt<'_>, span: Span) -> ErrorGuaranteed {
struct_span_err!(tcx.sess, span, E0733, "recursion in an `async fn` requires boxing")
.span_label(span, "recursive `async fn`")
.note("a recursive `async fn` must be rewritten to return a boxed `dyn Future`")
.note(
"consider using the `async_recursion` crate: https://crates.io/crates/async_recursion",
)
.emit()
}

/// Emit an error for recursive opaque types.
///
/// If this is a return `impl Trait`, find the item's return expressions and point at them. For
Expand Down
13 changes: 9 additions & 4 deletions compiler/rustc_middle/src/query/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub trait Key: Sized {
None
}

fn ty_adt_id(&self) -> Option<DefId> {
fn ty_def_id(&self) -> Option<DefId> {
None
}
}
Expand Down Expand Up @@ -406,9 +406,10 @@ impl<'tcx> Key for Ty<'tcx> {
DUMMY_SP
}

fn ty_adt_id(&self) -> Option<DefId> {
match self.kind() {
fn ty_def_id(&self) -> Option<DefId> {
match *self.kind() {
ty::Adt(adt, _) => Some(adt.did()),
ty::Coroutine(def_id, ..) => Some(def_id),
_ => None,
}
}
Expand Down Expand Up @@ -452,6 +453,10 @@ impl<'tcx, T: Key> Key for ty::ParamEnvAnd<'tcx, T> {
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
self.value.default_span(tcx)
}

fn ty_def_id(&self) -> Option<DefId> {
self.value.ty_def_id()
}
}

impl Key for Symbol {
Expand Down Expand Up @@ -550,7 +555,7 @@ impl<'tcx> Key for (ValidityRequirement, ty::ParamEnvAnd<'tcx, Ty<'tcx>>) {
DUMMY_SP
}

fn ty_adt_id(&self) -> Option<DefId> {
fn ty_def_id(&self) -> Option<DefId> {
match self.1.value.kind() {
ty::Adt(adt, _) => Some(adt.did()),
_ => None,
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,8 @@ rustc_queries! {
) -> Result<ty::layout::TyAndLayout<'tcx>, &'tcx ty::layout::LayoutError<'tcx>> {
depth_limit
desc { "computing layout of `{}`", key.value }
// we emit our own error during query cycle handling
cycle_delay_bug
}

/// Compute a `FnAbi` suitable for indirect calls, i.e. to `fn` pointers.
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/query/plumbing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub struct DynamicQuery<'tcx, C: QueryCache> {
fn(tcx: TyCtxt<'tcx>, key: &C::Key, index: SerializedDepNodeIndex) -> bool,
pub hash_result: HashResult<C::Value>,
pub value_from_cycle_error:
fn(tcx: TyCtxt<'tcx>, cycle: &[QueryInfo], guar: ErrorGuaranteed) -> C::Value,
fn(tcx: TyCtxt<'tcx>, cycle_error: &CycleError, guar: ErrorGuaranteed) -> C::Value,
pub format_value: fn(&C::Value) -> String,
}

Expand Down
10 changes: 5 additions & 5 deletions compiler/rustc_middle/src/ty/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ pub enum LayoutError<'tcx> {
SizeOverflow(Ty<'tcx>),
NormalizationFailure(Ty<'tcx>, NormalizationError<'tcx>),
ReferencesError(ErrorGuaranteed),
Cycle,
Cycle(ErrorGuaranteed),
}

impl<'tcx> LayoutError<'tcx> {
Expand All @@ -226,7 +226,7 @@ impl<'tcx> LayoutError<'tcx> {
Unknown(_) => middle_unknown_layout,
SizeOverflow(_) => middle_values_too_big,
NormalizationFailure(_, _) => middle_cannot_be_normalized,
Cycle => middle_cycle,
Cycle(_) => middle_cycle,
ReferencesError(_) => middle_layout_references_error,
}
}
Expand All @@ -240,7 +240,7 @@ impl<'tcx> LayoutError<'tcx> {
NormalizationFailure(ty, e) => {
E::NormalizationFailure { ty, failure_ty: e.get_type_for_failure() }
}
Cycle => E::Cycle,
Cycle(_) => E::Cycle,
ReferencesError(_) => E::ReferencesError,
}
}
Expand All @@ -261,7 +261,7 @@ impl<'tcx> fmt::Display for LayoutError<'tcx> {
t,
e.get_type_for_failure()
),
LayoutError::Cycle => write!(f, "a cycle occurred during layout computation"),
LayoutError::Cycle(_) => write!(f, "a cycle occurred during layout computation"),
LayoutError::ReferencesError(_) => write!(f, "the type has an unknown layout"),
}
}
Expand Down Expand Up @@ -333,7 +333,7 @@ impl<'tcx> SizeSkeleton<'tcx> {
Err(err @ LayoutError::Unknown(_)) => err,
// We can't extract SizeSkeleton info from other layout errors
Err(
e @ LayoutError::Cycle
e @ LayoutError::Cycle(_)
| e @ LayoutError::SizeOverflow(_)
| e @ LayoutError::NormalizationFailure(..)
| e @ LayoutError::ReferencesError(_),
Expand Down
29 changes: 11 additions & 18 deletions compiler/rustc_middle/src/ty/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -747,9 +747,13 @@ impl<'tcx> TyCtxt<'tcx> {
match def_kind {
DefKind::AssocFn if self.associated_item(def_id).fn_has_self_parameter => "method",
DefKind::Coroutine => match self.coroutine_kind(def_id).unwrap() {
rustc_hir::CoroutineKind::Async(..) => "async closure",
rustc_hir::CoroutineKind::Coroutine => "coroutine",
rustc_hir::CoroutineKind::Gen(..) => "gen closure",
hir::CoroutineKind::Async(hir::CoroutineSource::Fn) => "async fn",
hir::CoroutineKind::Async(hir::CoroutineSource::Block) => "async block",
hir::CoroutineKind::Async(hir::CoroutineSource::Closure) => "async closure",
hir::CoroutineKind::Gen(hir::CoroutineSource::Fn) => "gen fn",
hir::CoroutineKind::Gen(hir::CoroutineSource::Block) => "gen block",
hir::CoroutineKind::Gen(hir::CoroutineSource::Closure) => "gen closure",
hir::CoroutineKind::Coroutine => "coroutine",
},
_ => def_kind.descr(def_id),
}
Expand All @@ -765,9 +769,9 @@ impl<'tcx> TyCtxt<'tcx> {
match def_kind {
DefKind::AssocFn if self.associated_item(def_id).fn_has_self_parameter => "a",
DefKind::Coroutine => match self.coroutine_kind(def_id).unwrap() {
rustc_hir::CoroutineKind::Async(..) => "an",
rustc_hir::CoroutineKind::Coroutine => "a",
rustc_hir::CoroutineKind::Gen(..) => "a",
hir::CoroutineKind::Async(..) => "an",
hir::CoroutineKind::Coroutine => "a",
hir::CoroutineKind::Gen(..) => "a",
},
_ => def_kind.article(),
}
Expand Down Expand Up @@ -850,18 +854,7 @@ impl<'tcx> OpaqueTypeExpander<'tcx> {
}
let args = args.fold_with(self);
if !self.check_recursion || self.seen_opaque_tys.insert(def_id) {
let expanded_ty = match self.expanded_cache.get(&(def_id, args)) {
Some(expanded_ty) => *expanded_ty,
None => {
for bty in self.tcx.coroutine_hidden_types(def_id) {
let hidden_ty = bty.instantiate(self.tcx, args);
self.fold_ty(hidden_ty);
}
let expanded_ty = Ty::new_coroutine_witness(self.tcx, def_id, args);
self.expanded_cache.insert((def_id, args), expanded_ty);
expanded_ty
}
};
let expanded_ty = Ty::new_coroutine_witness(self.tcx, def_id, args);
if self.check_recursion {
self.seen_opaque_tys.remove(&def_id);
}
Expand Down
Loading

0 comments on commit ab5373f

Please sign in to comment.