Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement jump threading MIR opt #107009

Merged
merged 16 commits into from
Oct 23, 2023
Merged
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4242,6 +4242,7 @@ dependencies = [
"coverage_test_macros",
"either",
"itertools",
"rustc_arena",
"rustc_ast",
"rustc_attr",
"rustc_const_eval",
Expand Down
9 changes: 9 additions & 0 deletions compiler/rustc_middle/src/mir/terminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ impl SwitchTargets {
Self { values: smallvec![value], targets: smallvec![then, else_] }
}

/// Inverse of `SwitchTargets::static_if`.
pub fn as_static_if(&self) -> Option<(u128, BasicBlock, BasicBlock)> {
if let &[value] = &self.values[..] && let &[then, else_] = &self.targets[..] {
Some((value, then, else_))
} else {
None
}
}

/// Returns the fallback target that is jumped to when none of the values match the operand.
pub fn otherwise(&self) -> BasicBlock {
*self.targets.last().unwrap()
Expand Down
149 changes: 124 additions & 25 deletions compiler/rustc_mir_dataflow/src/value_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,19 @@ impl<V: Clone> Clone for State<V> {
}
}

impl<V: Clone + HasTop + HasBottom> State<V> {
impl<V: Clone> State<V> {
pub fn new(init: V, map: &Map) -> State<V> {
let values = IndexVec::from_elem_n(init, map.value_count);
State(StateData::Reachable(values))
}

pub fn all(&self, f: impl Fn(&V) -> bool) -> bool {
match self.0 {
StateData::Unreachable => true,
StateData::Reachable(ref values) => values.iter().all(f),
}
}

pub fn is_reachable(&self) -> bool {
matches!(&self.0, StateData::Reachable(_))
}
Expand All @@ -472,7 +484,10 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
self.0 = StateData::Unreachable;
}

pub fn flood_all(&mut self) {
pub fn flood_all(&mut self)
where
V: HasTop,
{
self.flood_all_with(V::TOP)
}

Expand All @@ -481,28 +496,52 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
values.raw.fill(value);
}

/// Assign `value` to all places that are contained in `place` or may alias one.
pub fn flood_with(&mut self, place: PlaceRef<'_>, map: &Map, value: V) {
let StateData::Reachable(values) = &mut self.0 else { return };
map.for_each_aliasing_place(place, None, &mut |vi| {
values[vi] = value.clone();
});
self.flood_with_tail_elem(place, None, map, value)
}

pub fn flood(&mut self, place: PlaceRef<'_>, map: &Map) {
/// Assign `TOP` to all places that are contained in `place` or may alias one.
pub fn flood(&mut self, place: PlaceRef<'_>, map: &Map)
where
V: HasTop,
{
self.flood_with(place, map, V::TOP)
}

/// Assign `value` to the discriminant of `place` and all places that may alias it.
pub fn flood_discr_with(&mut self, place: PlaceRef<'_>, map: &Map, value: V) {
let StateData::Reachable(values) = &mut self.0 else { return };
map.for_each_aliasing_place(place, Some(TrackElem::Discriminant), &mut |vi| {
values[vi] = value.clone();
});
self.flood_with_tail_elem(place, Some(TrackElem::Discriminant), map, value)
}

pub fn flood_discr(&mut self, place: PlaceRef<'_>, map: &Map) {
/// Assign `TOP` to the discriminant of `place` and all places that may alias it.
pub fn flood_discr(&mut self, place: PlaceRef<'_>, map: &Map)
where
V: HasTop,
{
self.flood_discr_with(place, map, V::TOP)
}

/// This method is the most general version of the `flood_*` method.
///
/// Assign `value` on the given place and all places that may alias it. In particular, when
/// the given place has a variant downcast, we invoke the function on all the other variants.
///
/// `tail_elem` allows to support discriminants that are not a place in MIR, but that we track
/// as such.
pub fn flood_with_tail_elem(
&mut self,
place: PlaceRef<'_>,
tail_elem: Option<TrackElem>,
map: &Map,
value: V,
) {
let StateData::Reachable(values) = &mut self.0 else { return };
map.for_each_aliasing_place(place, tail_elem, &mut |vi| {
values[vi] = value.clone();
});
}

/// Low-level method that assigns to a place.
/// This does nothing if the place is not tracked.
///
Expand Down Expand Up @@ -553,44 +592,104 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
}

/// Helper method to interpret `target = result`.
pub fn assign(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map) {
pub fn assign(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map)
where
V: HasTop,
{
self.flood(target, map);
if let Some(target) = map.find(target) {
self.insert_idx(target, result, map);
}
}

/// Helper method for assignments to a discriminant.
pub fn assign_discr(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map) {
pub fn assign_discr(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map)
where
V: HasTop,
{
self.flood_discr(target, map);
if let Some(target) = map.find_discr(target) {
self.insert_idx(target, result, map);
}
}

/// Retrieve the value stored for a place, or `None` if it is not tracked.
pub fn try_get(&self, place: PlaceRef<'_>, map: &Map) -> Option<V> {
let place = map.find(place)?;
self.try_get_idx(place, map)
}

/// Retrieve the discriminant stored for a place, or `None` if it is not tracked.
pub fn try_get_discr(&self, place: PlaceRef<'_>, map: &Map) -> Option<V> {
let place = map.find_discr(place)?;
self.try_get_idx(place, map)
}

/// Retrieve the slice length stored for a place, or `None` if it is not tracked.
pub fn try_get_len(&self, place: PlaceRef<'_>, map: &Map) -> Option<V> {
let place = map.find_len(place)?;
self.try_get_idx(place, map)
}

/// Retrieve the value stored for a place index, or `None` if it is not tracked.
pub fn try_get_idx(&self, place: PlaceIndex, map: &Map) -> Option<V> {
match &self.0 {
StateData::Reachable(values) => {
map.places[place].value_index.map(|v| values[v].clone())
}
StateData::Unreachable => None,
}
}

/// Retrieve the value stored for a place, or ⊤ if it is not tracked.
pub fn get(&self, place: PlaceRef<'_>, map: &Map) -> V {
map.find(place).map(|place| self.get_idx(place, map)).unwrap_or(V::TOP)
///
/// This method returns ⊥ if the place is tracked and the state is unreachable.
pub fn get(&self, place: PlaceRef<'_>, map: &Map) -> V
where
V: HasBottom + HasTop,
{
match &self.0 {
StateData::Reachable(_) => self.try_get(place, map).unwrap_or(V::TOP),
// Because this is unreachable, we can return any value we want.
StateData::Unreachable => V::BOTTOM,
}
}

/// Retrieve the value stored for a place, or ⊤ if it is not tracked.
pub fn get_discr(&self, place: PlaceRef<'_>, map: &Map) -> V {
match map.find_discr(place) {
Some(place) => self.get_idx(place, map),
None => V::TOP,
///
/// This method returns ⊥ the current state is unreachable.
pub fn get_discr(&self, place: PlaceRef<'_>, map: &Map) -> V
where
V: HasBottom + HasTop,
{
match &self.0 {
StateData::Reachable(_) => self.try_get_discr(place, map).unwrap_or(V::TOP),
// Because this is unreachable, we can return any value we want.
StateData::Unreachable => V::BOTTOM,
}
}

/// Retrieve the value stored for a place, or ⊤ if it is not tracked.
pub fn get_len(&self, place: PlaceRef<'_>, map: &Map) -> V {
match map.find_len(place) {
Some(place) => self.get_idx(place, map),
None => V::TOP,
///
/// This method returns ⊥ the current state is unreachable.
pub fn get_len(&self, place: PlaceRef<'_>, map: &Map) -> V
where
V: HasBottom + HasTop,
{
match &self.0 {
StateData::Reachable(_) => self.try_get_len(place, map).unwrap_or(V::TOP),
// Because this is unreachable, we can return any value we want.
StateData::Unreachable => V::BOTTOM,
}
}

/// Retrieve the value stored for a place index, or ⊤ if it is not tracked.
pub fn get_idx(&self, place: PlaceIndex, map: &Map) -> V {
///
/// This method returns ⊥ the current state is unreachable.
pub fn get_idx(&self, place: PlaceIndex, map: &Map) -> V
where
V: HasBottom + HasTop,
{
match &self.0 {
StateData::Reachable(values) => {
map.places[place].value_index.map(|v| values[v].clone()).unwrap_or(V::TOP)
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_mir_transform/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ smallvec = { version = "1.8.1", features = ["union", "may_dangle"] }
tracing = "0.1"
either = "1"
rustc_ast = { path = "../rustc_ast" }
rustc_arena = { path = "../rustc_arena" }
rustc_attr = { path = "../rustc_attr" }
rustc_data_structures = { path = "../rustc_data_structures" }
rustc_errors = { path = "../rustc_errors" }
Expand Down
98 changes: 98 additions & 0 deletions compiler/rustc_mir_transform/src/cost_checker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt};

const INSTR_COST: usize = 5;
const CALL_PENALTY: usize = 25;
const LANDINGPAD_PENALTY: usize = 50;
const RESUME_PENALTY: usize = 45;

/// Verify that the callee body is compatible with the caller.
#[derive(Clone)]
pub(crate) struct CostChecker<'b, 'tcx> {
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
cost: usize,
callee_body: &'b Body<'tcx>,
instance: Option<ty::Instance<'tcx>>,
}

impl<'b, 'tcx> CostChecker<'b, 'tcx> {
pub fn new(
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
instance: Option<ty::Instance<'tcx>>,
callee_body: &'b Body<'tcx>,
) -> CostChecker<'b, 'tcx> {
CostChecker { tcx, param_env, callee_body, instance, cost: 0 }
}

pub fn cost(&self) -> usize {
self.cost
}

fn instantiate_ty(&self, v: Ty<'tcx>) -> Ty<'tcx> {
if let Some(instance) = self.instance {
instance.instantiate_mir(self.tcx, ty::EarlyBinder::bind(&v))
} else {
v
}
}
}

impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) {
// Don't count StorageLive/StorageDead in the inlining cost.
match statement.kind {
StatementKind::StorageLive(_)
| StatementKind::StorageDead(_)
| StatementKind::Deinit(_)
| StatementKind::Nop => {}
_ => self.cost += INSTR_COST,
}
}

fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, _: Location) {
let tcx = self.tcx;
match terminator.kind {
TerminatorKind::Drop { ref place, unwind, .. } => {
// If the place doesn't actually need dropping, treat it like a regular goto.
let ty = self.instantiate_ty(place.ty(self.callee_body, tcx).ty);
if ty.needs_drop(tcx, self.param_env) {
self.cost += CALL_PENALTY;
if let UnwindAction::Cleanup(_) = unwind {
self.cost += LANDINGPAD_PENALTY;
}
} else {
self.cost += INSTR_COST;
}
}
TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => {
let fn_ty = self.instantiate_ty(f.const_.ty());
self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.is_intrinsic(def_id) {
// Don't give intrinsics the extra penalty for calls
INSTR_COST
} else {
CALL_PENALTY
};
if let UnwindAction::Cleanup(_) = unwind {
self.cost += LANDINGPAD_PENALTY;
}
}
TerminatorKind::Assert { unwind, .. } => {
self.cost += CALL_PENALTY;
if let UnwindAction::Cleanup(_) = unwind {
self.cost += LANDINGPAD_PENALTY;
}
}
TerminatorKind::UnwindResume => self.cost += RESUME_PENALTY,
TerminatorKind::InlineAsm { unwind, .. } => {
self.cost += INSTR_COST;
if let UnwindAction::Cleanup(_) = unwind {
self.cost += LANDINGPAD_PENALTY;
}
}
_ => self.cost += INSTR_COST,
}
}
}
Loading