Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Actually use the inferred ClosureKind from signature inference in coroutine-closures #123350

Merged
merged 2 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,18 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
kind: TypeVariableOriginKind::ClosureSynthetic,
span: expr_span,
});
let closure_kind_ty = self.next_ty_var(TypeVariableOrigin {
// FIXME(eddyb) distinguish closure kind inference variables from the rest.
kind: TypeVariableOriginKind::ClosureSynthetic,
span: expr_span,
});

let closure_kind_ty = match expected_kind {
Some(kind) => Ty::from_closure_kind(tcx, kind),

// Create a type variable (for now) to represent the closure kind.
// It will be unified during the upvar inference phase (`upvar.rs`)
None => self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::ClosureSynthetic,
span: expr_span,
}),
};

let coroutine_captures_by_ref_ty = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::ClosureSynthetic,
span: expr_span,
Expand Down Expand Up @@ -262,10 +269,17 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
},
);

let coroutine_kind_ty = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::ClosureSynthetic,
span: expr_span,
});
let coroutine_kind_ty = match expected_kind {
Some(kind) => Ty::from_coroutine_closure_kind(tcx, kind),

// Create a type variable (for now) to represent the closure kind.
// It will be unified during the upvar inference phase (`upvar.rs`)
None => self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::ClosureSynthetic,
span: expr_span,
}),
};

let coroutine_upvars_ty = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::ClosureSynthetic,
span: expr_span,
Expand Down
50 changes: 39 additions & 11 deletions compiler/rustc_hir_typeck/src/upvar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
span: Span,
body_id: hir::BodyId,
body: &'tcx hir::Body<'tcx>,
capture_clause: hir::CaptureBy,
mut capture_clause: hir::CaptureBy,
) {
// Extract the type of the closure.
let ty = self.node_ty(closure_hir_id);
Expand Down Expand Up @@ -259,6 +259,28 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
)
.consume_body(body);

// If a coroutine is comes from a coroutine-closure that is `move`, but
// the coroutine-closure was inferred to be `FnOnce` during signature
// inference, then it's still possible that we try to borrow upvars from
// the coroutine-closure because they are not used by the coroutine body
// in a way that forces a move.
//
// This would lead to an impossible to satisfy situation, since `AsyncFnOnce`
// coroutine bodies can't borrow from their parent closure. To fix this,
// we force the inner coroutine to also be `move`. This only matters for
// coroutine-closures that are `move` since otherwise they themselves will
// be borrowing from the outer environment, so there's no self-borrows occuring.
if let UpvarArgs::Coroutine(..) = args
&& let hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure) =
self.tcx.coroutine_kind(closure_def_id).expect("coroutine should have kind")
&& let parent_hir_id =
self.tcx.local_def_id_to_hir_id(self.tcx.local_parent(closure_def_id))
&& let parent_ty = self.node_ty(parent_hir_id)
&& let Some(ty::ClosureKind::FnOnce) = self.closure_kind(parent_ty)
{
capture_clause = self.tcx.hir_node(parent_hir_id).expect_closure().capture_clause;
}

debug!(
"For closure={:?}, capture_information={:#?}",
closure_def_id, delegate.capture_information
Expand Down Expand Up @@ -399,16 +421,22 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
);

// Additionally, we can now constrain the coroutine's kind type.
let ty::Coroutine(_, coroutine_args) =
*self.typeck_results.borrow().expr_ty(body.value).kind()
else {
bug!();
};
self.demand_eqtype(
span,
coroutine_args.as_coroutine().kind_ty(),
Ty::from_coroutine_closure_kind(self.tcx, closure_kind),
);
//
// We only do this if `infer_kind`, because if we have constrained
// the kind from closure signature inference, the kind inferred
// for the inner coroutine may actually be more restrictive.
if infer_kind {
let ty::Coroutine(_, coroutine_args) =
*self.typeck_results.borrow().expr_ty(body.value).kind()
else {
bug!();
};
self.demand_eqtype(
span,
coroutine_args.as_coroutine().kind_ty(),
Ty::from_coroutine_closure_kind(self.tcx, closure_kind),
);
}
}

self.log_closure_min_capture_info(closure_def_id, span);
Expand Down
20 changes: 14 additions & 6 deletions compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,17 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
return;
}

let ty::Coroutine(_, coroutine_args) = *coroutine_ty.kind() else { bug!("{body:#?}") };
// We don't need to generate a by-move coroutine if the kind of the coroutine is
// already `FnOnce` -- that means that any upvars that the closure consumes have
// already been taken by-value.
let coroutine_kind = coroutine_args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap();
if coroutine_kind == ty::ClosureKind::FnOnce {
// We don't need to generate a by-move coroutine if the coroutine body was
// produced by the `CoroutineKindShim`, since it's already by-move.
if matches!(body.source.instance, ty::InstanceDef::CoroutineKindShim { .. }) {
return;
}

let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") };
let args = args.as_coroutine();

let coroutine_kind = args.kind_ty().to_opt_closure_kind().unwrap();

let parent_def_id = tcx.local_parent(coroutine_def_id);
let ty::CoroutineClosure(_, parent_args) =
*tcx.type_of(parent_def_id).instantiate_identity().kind()
Expand Down Expand Up @@ -128,6 +130,12 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
// the outer closure body -- we need to change the coroutine to take the
// upvar by value.
if coroutine_capture.is_by_ref() && !parent_capture.is_by_ref() {
assert_ne!(
coroutine_kind,
ty::ClosureKind::FnOnce,
"`FnOnce` coroutine-closures return coroutines that capture from \
their body; it will always result in a borrowck error!"
);
by_ref_fields.insert(FieldIdx::from_usize(num_args + idx));
}

Expand Down
34 changes: 34 additions & 0 deletions src/tools/miri/tests/pass/async-closure-captures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,38 @@ async fn async_main() {
};
call_once(c).await;
}

fn force_fnonce<T>(f: impl async FnOnce() -> T) -> impl async FnOnce() -> T {
f
}

// Capture something with `move`, but infer to `AsyncFnOnce`
{
let x = Hello(6);
let c = force_fnonce(async move || {
println!("{x:?}");
});
call_once(c).await;

let x = &Hello(7);
let c = force_fnonce(async move || {
println!("{x:?}");
});
call_once(c).await;
}

// Capture something by-ref, but infer to `AsyncFnOnce`
{
let x = Hello(8);
let c = force_fnonce(async || {
println!("{x:?}");
});
call_once(c).await;

let x = &Hello(9);
let c = force_fnonce(async || {
println!("{x:?}");
});
call_once(c).await;
}
}
4 changes: 4 additions & 0 deletions src/tools/miri/tests/pass/async-closure-captures.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ Hello(3)
Hello(4)
Hello(4)
Hello(5)
Hello(6)
Hello(7)
Hello(8)
Hello(9)
34 changes: 34 additions & 0 deletions tests/ui/async-await/async-closures/captures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,38 @@ async fn async_main() {
};
call_once(c).await;
}

fn force_fnonce<T>(f: impl async FnOnce() -> T) -> impl async FnOnce() -> T {
f
}

// Capture something with `move`, but infer to `AsyncFnOnce`
{
let x = Hello(6);
let c = force_fnonce(async move || {
println!("{x:?}");
});
call_once(c).await;

let x = &Hello(7);
let c = force_fnonce(async move || {
println!("{x:?}");
});
call_once(c).await;
}

// Capture something by-ref, but infer to `AsyncFnOnce`
{
let x = Hello(8);
let c = force_fnonce(async || {
println!("{x:?}");
});
call_once(c).await;

let x = &Hello(9);
let c = force_fnonce(async || {
println!("{x:?}");
});
call_once(c).await;
}
}
4 changes: 4 additions & 0 deletions tests/ui/async-await/async-closures/captures.run.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ Hello(3)
Hello(4)
Hello(4)
Hello(5)
Hello(6)
Hello(7)
Hello(8)
Hello(9)
10 changes: 7 additions & 3 deletions tests/ui/async-await/async-closures/wrong-fn-kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@

#![feature(async_closure)]

fn main() {
fn needs_async_fn(_: impl async Fn()) {}
fn needs_async_fn(_: impl async Fn()) {}

fn a() {
let mut x = 1;
needs_async_fn(async || {
//~^ ERROR expected a closure that implements the `async Fn` trait, but this closure only implements `async FnMut`
//~^ ERROR cannot borrow `x` as mutable, as it is a captured variable in a `Fn` closure
x += 1;
});
}

fn b() {
let x = String::new();
needs_async_fn(move || async move {
//~^ ERROR expected a closure that implements the `async Fn` trait, but this closure only implements `async FnOnce`
println!("{x}");
});
}

fn main() {}
49 changes: 23 additions & 26 deletions tests/ui/async-await/async-closures/wrong-fn-kind.stderr
Original file line number Diff line number Diff line change
@@ -1,26 +1,5 @@
error[E0525]: expected a closure that implements the `async Fn` trait, but this closure only implements `async FnMut`
--> $DIR/wrong-fn-kind.rs:9:20
|
LL | needs_async_fn(async || {
| -------------- -^^^^^^^
| | |
| _____|______________this closure implements `async FnMut`, not `async Fn`
| | |
| | required by a bound introduced by this call
LL | |
LL | | x += 1;
| | - closure is `async FnMut` because it mutates the variable `x` here
LL | | });
| |_____- the requirement to implement `async Fn` derives from here
|
note: required by a bound in `needs_async_fn`
--> $DIR/wrong-fn-kind.rs:6:31
|
LL | fn needs_async_fn(_: impl async Fn()) {}
| ^^^^^^^^^^ required by this bound in `needs_async_fn`

error[E0525]: expected a closure that implements the `async Fn` trait, but this closure only implements `async FnOnce`
--> $DIR/wrong-fn-kind.rs:15:20
--> $DIR/wrong-fn-kind.rs:17:20
|
LL | needs_async_fn(move || async move {
| -------------- -^^^^^^
Expand All @@ -35,11 +14,29 @@ LL | | });
| |_____- the requirement to implement `async Fn` derives from here
|
note: required by a bound in `needs_async_fn`
--> $DIR/wrong-fn-kind.rs:6:31
--> $DIR/wrong-fn-kind.rs:5:27
|
LL | fn needs_async_fn(_: impl async Fn()) {}
| ^^^^^^^^^^ required by this bound in `needs_async_fn`

error[E0596]: cannot borrow `x` as mutable, as it is a captured variable in a `Fn` closure
--> $DIR/wrong-fn-kind.rs:9:29
|
LL | fn needs_async_fn(_: impl async Fn()) {}
| ^^^^^^^^^^ required by this bound in `needs_async_fn`
LL | fn needs_async_fn(_: impl async Fn()) {}
| --------------- change this to accept `FnMut` instead of `Fn`
...
LL | needs_async_fn(async || {
| _____--------------_--------_^
| | | |
| | | in this closure
| | expects `Fn` instead of `FnMut`
LL | |
LL | | x += 1;
| | - mutable borrow occurs due to use of `x` in closure
LL | | });
| |_____^ cannot borrow as mutable

error: aborting due to 2 previous errors

For more information about this error, try `rustc --explain E0525`.
Some errors have detailed explanations: E0525, E0596.
For more information about an error, try `rustc --explain E0525`.
Loading