Skip to content

Commit

Permalink
interpret: add sanity check in dyn upcast to double-check what codege…
Browse files Browse the repository at this point in the history
…n does
  • Loading branch information
RalfJung committed Jul 17, 2024
1 parent 4cd8dc6 commit dc7e6f9
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 47 deletions.
41 changes: 36 additions & 5 deletions compiler/rustc_const_eval/src/interpret/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,46 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
}
(ty::Dynamic(data_a, _, ty::Dyn), ty::Dynamic(data_b, _, ty::Dyn)) => {
let val = self.read_immediate(src)?;
if data_a.principal() == data_b.principal() {
// A NOP cast that doesn't actually change anything, should be allowed even with mismatching vtables.
// (But currently mismatching vtables violate the validity invariant so UB is triggered anyway.)
return self.write_immediate(*val, dest);
}
// Take apart the old pointer, and find the dynamic type.
let (old_data, old_vptr) = val.to_scalar_pair();
let old_data = old_data.to_pointer(self)?;
let old_vptr = old_vptr.to_pointer(self)?;
let ty = self.get_ptr_vtable_ty(old_vptr, Some(data_a))?;

// Sanity-check that `supertrait_vtable_slot` in this type's vtable indeed produces
// our destination trait.
if cfg!(debug_assertions) {
let vptr_entry_idx =
self.tcx.supertrait_vtable_slot((src_pointee_ty, dest_pointee_ty));
let vtable_entries = self.vtable_entries(data_a.principal(), ty);
if let Some(entry_idx) = vptr_entry_idx {
let Some(&ty::VtblEntry::TraitVPtr(upcast_trait_ref)) =
vtable_entries.get(entry_idx)
else {
span_bug!(
self.cur_span(),
"invalid vtable entry index in {} -> {} upcast",
src_pointee_ty,
dest_pointee_ty
);
};
let erased_trait_ref = upcast_trait_ref
.map_bound(|r| ty::ExistentialTraitRef::erase_self_ty(*self.tcx, r));
assert!(
data_b
.principal()
.is_some_and(|b| self.eq_in_param_env(erased_trait_ref, b))
);
} else {
// In this case codegen would keep using the old vtable. We don't want to do
// that as it has the wrong trait. The reason codegen can do this is that
// one vtable is a prefix of the other, so we double-check that.
let vtable_entries_b = self.vtable_entries(data_b.principal(), ty);
assert!(&vtable_entries[..vtable_entries_b.len()] == vtable_entries_b);
};
}

// Get the destination trait vtable and return that.
let new_vptr = self.get_vtable_ptr(ty, data_b.principal())?;
self.write_immediate(Immediate::new_dyn_trait(old_data, new_vptr, self), dest)
}
Expand Down
30 changes: 30 additions & 0 deletions compiler/rustc_const_eval/src/interpret/eval_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ use std::cell::Cell;
use std::{fmt, mem};

use either::{Either, Left, Right};
use rustc_infer::infer::at::ToTrace;
use rustc_infer::traits::ObligationCause;
use rustc_trait_selection::traits::ObligationCtxt;
use tracing::{debug, info, info_span, instrument, trace};

use rustc_errors::DiagCtxtHandle;
use rustc_hir::{self as hir, def_id::DefId, definitions::DefPathData};
use rustc_index::IndexVec;
use rustc_infer::infer::TyCtxtInferExt;
use rustc_middle::mir;
use rustc_middle::mir::interpret::{
CtfeProvenance, ErrorHandled, InvalidMetaKind, ReportedErrorInfo,
Expand Down Expand Up @@ -640,6 +644,32 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
}
}

/// Check if the two things are equal in the current param_env, using an infctx to get proper
/// equality checks.
pub(super) fn eq_in_param_env<T>(&self, a: T, b: T) -> bool
where
T: PartialEq + TypeFoldable<TyCtxt<'tcx>> + ToTrace<'tcx>,
{
// Fast path: compare directly.
if a == b {
return true;
}
// Slow path: spin up an inference context to check if these traits are sufficiently equal.
let infcx = self.tcx.infer_ctxt().build();
let ocx = ObligationCtxt::new(&infcx);
let cause = ObligationCause::dummy_with_span(self.cur_span());
// equate the two trait refs after normalization
let a = ocx.normalize(&cause, self.param_env, a);
let b = ocx.normalize(&cause, self.param_env, b);
if ocx.eq(&cause, self.param_env, a, b).is_ok() {
if ocx.select_all_or_error().is_empty() {
// All good.
return true;
}
}
return false;
}

/// Walks up the callstack from the intrinsic's callsite, searching for the first callsite in a
/// frame which is not `#[track_caller]`. This matches the `caller_location` intrinsic,
/// and is primarily intended for the panic machinery.
Expand Down
17 changes: 6 additions & 11 deletions compiler/rustc_const_eval/src/interpret/terminator.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::borrow::Cow;

use either::Either;
use rustc_middle::ty::TyCtxt;
use tracing::trace;

use rustc_middle::{
Expand Down Expand Up @@ -867,7 +866,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
};

// Obtain the underlying trait we are working on, and the adjusted receiver argument.
let (dyn_trait, dyn_ty, adjusted_recv) = if let ty::Dynamic(data, _, ty::DynStar) =
let (trait_, dyn_ty, adjusted_recv) = if let ty::Dynamic(data, _, ty::DynStar) =
receiver_place.layout.ty.kind()
{
let recv = self.unpack_dyn_star(&receiver_place, data)?;
Expand Down Expand Up @@ -898,20 +897,16 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
(receiver_trait.principal(), dyn_ty, receiver_place.ptr())
};

// Now determine the actual method to call. We can do that in two different ways and
// compare them to ensure everything fits.
let vtable_entries = if let Some(dyn_trait) = dyn_trait {
let trait_ref = dyn_trait.with_self_ty(*self.tcx, dyn_ty);
let trait_ref = self.tcx.erase_regions(trait_ref);
self.tcx.vtable_entries(trait_ref)
} else {
TyCtxt::COMMON_VTABLE_ENTRIES
};
// Now determine the actual method to call. Usually we use the easy way of just
// looking up the method at index `idx`.
let vtable_entries = self.vtable_entries(trait_, dyn_ty);
let Some(ty::VtblEntry::Method(fn_inst)) = vtable_entries.get(idx).copied() else {
// FIXME(fee1-dead) these could be variants of the UB info enum instead of this
throw_ub_custom!(fluent::const_eval_dyn_call_not_a_method);
};
trace!("Virtual call dispatches to {fn_inst:#?}");
// We can also do the lookup based on `def_id`` and `dyn_ty`, and check that that

Check failure on line 908 in compiler/rustc_const_eval/src/interpret/terminator.rs

View workflow job for this annotation

GitHub Actions / PR - mingw-check-tidy

2-line comment block with odd number of backticks
// produces the same result.
if cfg!(debug_assertions) {
let tcx = *self.tcx;

Expand Down
48 changes: 23 additions & 25 deletions compiler/rustc_const_eval/src/interpret/traits.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use rustc_infer::infer::TyCtxtInferExt;
use rustc_infer::traits::ObligationCause;
use rustc_middle::mir::interpret::{InterpResult, Pointer};
use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{self, Ty};
use rustc_middle::ty::{self, Ty, TyCtxt, VtblEntry};
use rustc_target::abi::{Align, Size};
use rustc_trait_selection::traits::ObligationCtxt;
use tracing::trace;

use super::util::ensure_monomorphic_enough;
Expand Down Expand Up @@ -47,35 +44,36 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
Ok((layout.size, layout.align.abi))
}

pub(super) fn vtable_entries(
&self,
trait_: Option<ty::PolyExistentialTraitRef<'tcx>>,
dyn_ty: Ty<'tcx>,
) -> &'tcx [VtblEntry<'tcx>] {
if let Some(trait_) = trait_ {
let trait_ref = trait_.with_self_ty(*self.tcx, dyn_ty);
let trait_ref = self.tcx.erase_regions(trait_ref);
self.tcx.vtable_entries(trait_ref)
} else {
TyCtxt::COMMON_VTABLE_ENTRIES
}
}

/// Check that the given vtable trait is valid for a pointer/reference/place with the given
/// expected trait type.
pub(super) fn check_vtable_for_type(
&self,
vtable_trait: Option<ty::PolyExistentialTraitRef<'tcx>>,
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
) -> InterpResult<'tcx> {
// Fast path: if they are equal, it's all fine.
if expected_trait.principal() == vtable_trait {
return Ok(());
}
if let (Some(expected_trait), Some(vtable_trait)) =
(expected_trait.principal(), vtable_trait)
{
// Slow path: spin up an inference context to check if these traits are sufficiently equal.
let infcx = self.tcx.infer_ctxt().build();
let ocx = ObligationCtxt::new(&infcx);
let cause = ObligationCause::dummy_with_span(self.cur_span());
// equate the two trait refs after normalization
let expected_trait = ocx.normalize(&cause, self.param_env, expected_trait);
let vtable_trait = ocx.normalize(&cause, self.param_env, vtable_trait);
if ocx.eq(&cause, self.param_env, expected_trait, vtable_trait).is_ok() {
if ocx.select_all_or_error().is_empty() {
// All good.
return Ok(());
}
}
let eq = match (expected_trait.principal(), vtable_trait) {
(Some(a), Some(b)) => self.eq_in_param_env(a, b),
(None, None) => true,
_ => false,
};
if !eq {
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
}
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
Ok(())
}

/// Turn a place with a `dyn Trait` type into a place with the actual dynamic type.
Expand Down
10 changes: 6 additions & 4 deletions src/tools/miri/tests/fail/dyn-upcast-trait-mismatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ impl Baz for i32 {
}

fn main() {
let baz: &dyn Baz = &1;
let baz_fake: *const dyn Bar = unsafe { std::mem::transmute(baz) };
let _err = baz_fake as *const dyn Foo;
//~^ERROR: using vtable for trait `Baz` but trait `Bar` was expected
unsafe {
let baz: &dyn Baz = &1;
let baz_fake: *const dyn Bar = std::mem::transmute(baz);
let _err = baz_fake as *const dyn Foo;
//~^ERROR: using vtable for trait `Baz` but trait `Bar` was expected
}
}
4 changes: 2 additions & 2 deletions src/tools/miri/tests/fail/dyn-upcast-trait-mismatch.stderr
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
error: Undefined Behavior: using vtable for trait `Baz` but trait `Bar` was expected
--> $DIR/dyn-upcast-trait-mismatch.rs:LL:CC
|
LL | let _err = baz_fake as *const dyn Foo;
| ^^^^^^^^ using vtable for trait `Baz` but trait `Bar` was expected
LL | let _err = baz_fake as *const dyn Foo;
| ^^^^^^^^ using vtable for trait `Baz` but trait `Bar` was expected
|
= help: this indicates a bug in the program: it performed an invalid operation, and caused Undefined Behavior
= help: see https://doc.rust-lang.org/nightly/reference/behavior-considered-undefined.html for further information
Expand Down

0 comments on commit dc7e6f9

Please sign in to comment.