Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix capture analysis for by-move closure bodies #123349

Merged
merged 2 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 118 additions & 31 deletions compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,66 @@
//! A MIR pass which duplicates a coroutine's body and removes any derefs which
//! would be present for upvars that are taken by-ref. The result of which will
//! be a coroutine body that takes all of its upvars by-move, and which we stash
//! into the `CoroutineInfo` for all coroutines returned by coroutine-closures.
//! This pass constructs a second coroutine body sufficient for return from
//! `FnOnce`/`AsyncFnOnce` implementations for coroutine-closures (e.g. async closures).
//!
//! Consider an async closure like:
//! ```rust
//! #![feature(async_closure)]
//!
//! let x = vec![1, 2, 3];
//!
//! let closure = async move || {
//! println!("{x:#?}");
//! };
//! ```
//!
//! This desugars to something like:
//! ```rust,ignore (invalid-borrowck)
//! let x = vec![1, 2, 3];
//!
//! let closure = move || {
//! async {
//! println!("{x:#?}");
//! }
//! };
//! ```
//!
//! Important to note here is that while the outer closure *moves* `x: Vec<i32>`
//! into its upvars, the inner `async` coroutine simply captures a ref of `x`.
//! This is the "magic" of async closures -- the futures that they return are
//! allowed to borrow from their parent closure's upvars.
//!
//! However, what happens when we call `closure` with `AsyncFnOnce` (or `FnOnce`,
//! since all async closures implement that too)? Well, recall the signature:
//! ```
//! use std::future::Future;
//! pub trait AsyncFnOnce<Args>
//! {
//! type CallOnceFuture: Future<Output = Self::Output>;
//! type Output;
//! fn async_call_once(
//! self,
//! args: Args
//! ) -> Self::CallOnceFuture;
//! }
//! ```
//!
//! This signature *consumes* the async closure (`self`) and returns a `CallOnceFuture`.
//! How do we deal with the fact that the coroutine is supposed to take a reference
//! to the captured `x` from the parent closure, when that parent closure has been
//! destroyed?
//!
//! This is the second piece of magic of async closures. We can simply create a
//! *second* `async` coroutine body where that `x` that was previously captured
//! by reference is now captured by value. This means that we consume the outer
//! closure and return a new coroutine that will hold onto all of these captures,
//! and drop them when it is finished (i.e. after it has been `.await`ed).
//!
//! We do this with the analysis below, which detects the captures that come from
//! borrowing from the outer closure, and we simply peel off a `deref` projection
//! from them. This second body is stored alongside the first body, and optimized
//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`,
//! we use this "by move" body instead.

use itertools::Itertools;

use rustc_data_structures::unord::UnordSet;
use rustc_hir as hir;
Expand All @@ -14,6 +73,8 @@ pub struct ByMoveBody;

impl<'tcx> MirPass<'tcx> for ByMoveBody {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) {
// We only need to generate by-move coroutine bodies for coroutines that come
// from coroutine-closures.
let Some(coroutine_def_id) = body.source.def_id().as_local() else {
return;
};
Expand All @@ -22,44 +83,70 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
else {
return;
};

// Also, let's skip processing any bodies with errors, since there's no guarantee
// the MIR body will be constructed well.
let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
if coroutine_ty.references_error() {
return;
}
let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") };

let coroutine_kind = args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap();
let ty::Coroutine(_, coroutine_args) = *coroutine_ty.kind() else { bug!("{body:#?}") };
// We don't need to generate a by-move coroutine if the kind of the coroutine is
// already `FnOnce` -- that means that any upvars that the closure consumes have
// already been taken by-value.
let coroutine_kind = coroutine_args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap();
if coroutine_kind == ty::ClosureKind::FnOnce {
return;
}

let parent_def_id = tcx.local_parent(coroutine_def_id);
let ty::CoroutineClosure(_, parent_args) =
*tcx.type_of(parent_def_id).instantiate_identity().kind()
else {
bug!();
};
let parent_closure_args = parent_args.as_coroutine_closure();
let num_args = parent_closure_args
.coroutine_closure_sig()
.skip_binder()
.tupled_inputs_ty
.tuple_fields()
.len();

let mut by_ref_fields = UnordSet::default();
let by_move_upvars = Ty::new_tup_from_iter(
tcx,
tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| {
if capture.is_by_ref() {
by_ref_fields.insert(FieldIdx::from_usize(idx));
}
capture.place.ty()
}),
);
let by_move_coroutine_ty = Ty::new_coroutine(
tcx,
coroutine_def_id.to_def_id(),
ty::CoroutineArgs::new(
for (idx, (coroutine_capture, parent_capture)) in tcx
.closure_captures(coroutine_def_id)
.iter()
// By construction we capture all the args first.
.skip(num_args)
.zip_eq(tcx.closure_captures(parent_def_id))
.enumerate()
{
// This upvar is captured by-move from the parent closure, but by-ref
// from the inner async block. That means that it's being borrowed from
// the outer closure body -- we need to change the coroutine to take the
// upvar by value.
if coroutine_capture.is_by_ref() && !parent_capture.is_by_ref() {
by_ref_fields.insert(FieldIdx::from_usize(num_args + idx));
}

// Make sure we're actually talking about the same capture.
// FIXME(async_closures): We could look at the `hir::Upvar` instead?
assert_eq!(coroutine_capture.place.ty(), parent_capture.place.ty());
}

let by_move_coroutine_ty = tcx
.instantiate_bound_regions_with_erased(parent_closure_args.coroutine_closure_sig())
.to_coroutine_given_kind_and_upvars(
tcx,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're already getting the parent CoroutineClosure, there's no need to manually construct a Coroutine from its parts. Just use the helper function that we already have for computing the FnOnce/by-move coroutine.

ty::CoroutineArgsParts {
parent_args: args.as_coroutine().parent_args(),
kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
resume_ty: args.as_coroutine().resume_ty(),
yield_ty: args.as_coroutine().yield_ty(),
return_ty: args.as_coroutine().return_ty(),
witness: args.as_coroutine().witness(),
tupled_upvars_ty: by_move_upvars,
},
)
.args,
);
parent_closure_args.parent_args(),
coroutine_def_id.to_def_id(),
ty::ClosureKind::FnOnce,
tcx.lifetimes.re_erased,
parent_closure_args.tupled_upvars_ty(),
parent_closure_args.coroutine_captures_by_ref_ty(),
);

let mut by_move_body = body.clone();
MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
Expand Down
91 changes: 91 additions & 0 deletions src/tools/miri/tests/pass/async-closure-captures.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Same as rustc's `tests/ui/async-await/async-closures/captures.rs`, keep in sync

#![feature(async_closure, noop_waker)]

use std::future::Future;
use std::pin::pin;
use std::task::*;

pub fn block_on<T>(fut: impl Future<Output = T>) -> T {
let mut fut = pin!(fut);
let ctx = &mut Context::from_waker(Waker::noop());

loop {
match fut.as_mut().poll(ctx) {
Poll::Pending => {}
Poll::Ready(t) => break t,
}
}
}

fn main() {
block_on(async_main());
}

async fn call<T>(f: &impl async Fn() -> T) -> T {
f().await
}

async fn call_once<T>(f: impl async FnOnce() -> T) -> T {
f().await
}

#[derive(Debug)]
#[allow(unused)]
struct Hello(i32);

async fn async_main() {
// Capture something by-ref
{
let x = Hello(0);
let c = async || {
println!("{x:?}");
};
call(&c).await;
call_once(c).await;

let x = &Hello(1);
let c = async || {
println!("{x:?}");
};
call(&c).await;
call_once(c).await;
}

// Capture something and consume it (force to `AsyncFnOnce`)
{
let x = Hello(2);
let c = async || {
println!("{x:?}");
drop(x);
};
call_once(c).await;
}

// Capture something with `move`, don't consume it
{
let x = Hello(3);
let c = async move || {
println!("{x:?}");
};
call(&c).await;
call_once(c).await;

let x = &Hello(4);
let c = async move || {
println!("{x:?}");
};
call(&c).await;
call_once(c).await;
}

// Capture something with `move`, also consume it (so `AsyncFnOnce`)
{
let x = Hello(5);
let c = async move || {
println!("{x:?}");
drop(x);
};
call_once(c).await;
}
}
10 changes: 10 additions & 0 deletions src/tools/miri/tests/pass/async-closure-captures.stdout
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Hello(0)
Hello(0)
Hello(1)
Hello(1)
Hello(2)
Hello(3)
Hello(3)
Hello(4)
Hello(4)
Hello(5)
82 changes: 82 additions & 0 deletions tests/ui/async-await/async-closures/captures.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
//@ aux-build:block-on.rs
//@ edition:2021
//@ run-pass
//@ check-run-results

// Same as miri's `tests/pass/async-closure-captures.rs`, keep in sync

#![feature(async_closure)]

extern crate block_on;

fn main() {
block_on::block_on(async_main());
}

async fn call<T>(f: &impl async Fn() -> T) -> T {
f().await
}

async fn call_once<T>(f: impl async FnOnce() -> T) -> T {
f().await
}

#[derive(Debug)]
#[allow(unused)]
struct Hello(i32);

async fn async_main() {
// Capture something by-ref
{
let x = Hello(0);
let c = async || {
println!("{x:?}");
};
call(&c).await;
call_once(c).await;

let x = &Hello(1);
let c = async || {
println!("{x:?}");
};
call(&c).await;
call_once(c).await;
}

// Capture something and consume it (force to `AsyncFnOnce`)
{
let x = Hello(2);
let c = async || {
println!("{x:?}");
drop(x);
};
call_once(c).await;
}

// Capture something with `move`, don't consume it
{
let x = Hello(3);
let c = async move || {
println!("{x:?}");
};
call(&c).await;
call_once(c).await;

let x = &Hello(4);
let c = async move || {
println!("{x:?}");
};
call(&c).await;
call_once(c).await;
}

// Capture something with `move`, also consume it (so `AsyncFnOnce`)
{
let x = Hello(5);
let c = async move || {
println!("{x:?}");
drop(x);
};
call_once(c).await;
}
}
10 changes: 10 additions & 0 deletions tests/ui/async-await/async-closures/captures.run.stdout
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Hello(0)
Hello(0)
Hello(1)
Hello(1)
Hello(2)
Hello(3)
Hello(3)
Hello(4)
Hello(4)
Hello(5)
Loading