Skip to content

Commit

Permalink
Prototype using const generic for simd_shuffle IDX array
Browse files Browse the repository at this point in the history
  • Loading branch information
oli-obk committed Sep 18, 2023
1 parent 078eb11 commit 1a01e57
Show file tree
Hide file tree
Showing 10 changed files with 293 additions and 45 deletions.
50 changes: 49 additions & 1 deletion compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fn report_simd_type_validation_error(
pub(super) fn codegen_simd_intrinsic_call<'tcx>(
fx: &mut FunctionCx<'_, '_, 'tcx>,
intrinsic: Symbol,
_args: GenericArgsRef<'tcx>,
generic_args: GenericArgsRef<'tcx>,
args: &[mir::Operand<'tcx>],
ret: CPlace<'tcx>,
target: BasicBlock,
Expand Down Expand Up @@ -117,6 +117,54 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
});
}

// simd_shuffle_generic<T, U, const I: &[u32]>(x: T, y: T) -> U
sym::simd_shuffle_generic => {
let [x, y] = args else {
bug!("wrong number of args for intrinsic {intrinsic}");
};
let x = codegen_operand(fx, x);
let y = codegen_operand(fx, y);

if !x.layout().ty.is_simd() {
report_simd_type_validation_error(fx, intrinsic, span, x.layout().ty);
return;
}

let idx = generic_args[2]
.expect_const()
.eval(fx.tcx, ty::ParamEnv::reveal_all(), Some(span))
.unwrap()
.unwrap_branch();

assert_eq!(x.layout(), y.layout());
let layout = x.layout();

let (lane_count, lane_ty) = layout.ty.simd_size_and_type(fx.tcx);
let (ret_lane_count, ret_lane_ty) = ret.layout().ty.simd_size_and_type(fx.tcx);

assert_eq!(lane_ty, ret_lane_ty);
assert_eq!(idx.len() as u64, ret_lane_count);

let total_len = lane_count * 2;

let indexes =
idx.iter().map(|idx| idx.unwrap_leaf().try_to_u16().unwrap()).collect::<Vec<u16>>();

for &idx in &indexes {
assert!(u64::from(idx) < total_len, "idx {} out of range 0..{}", idx, total_len);
}

for (out_idx, in_idx) in indexes.into_iter().enumerate() {
let in_lane = if u64::from(in_idx) < lane_count {
x.value_lane(fx, in_idx.into())
} else {
y.value_lane(fx, u64::from(in_idx) - lane_count)
};
let out_lane = ret.place_lane(fx, u64::try_from(out_idx).unwrap());
out_lane.write_cvalue(fx, in_lane);
}
}

// simd_shuffle<T, I, U>(x: T, y: T, idx: I) -> U
sym::simd_shuffle => {
let (x, y, idx) = match args {
Expand Down
57 changes: 55 additions & 2 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use rustc_codegen_ssa::mir::place::PlaceRef;
use rustc_codegen_ssa::traits::*;
use rustc_hir as hir;
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, LayoutOf};
use rustc_middle::ty::{self, Ty};
use rustc_middle::ty::{self, GenericArgsRef, Ty};
use rustc_middle::{bug, span_bug};
use rustc_span::{sym, symbol::kw, Span, Symbol};
use rustc_target::abi::{self, Align, HasDataLayout, Primitive};
Expand Down Expand Up @@ -376,7 +376,9 @@ impl<'ll, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'_, 'll, 'tcx> {
}

_ if name.as_str().starts_with("simd_") => {
match generic_simd_intrinsic(self, name, callee_ty, args, ret_ty, llret_ty, span) {
match generic_simd_intrinsic(
self, name, callee_ty, fn_args, args, ret_ty, llret_ty, span,
) {
Ok(llval) => llval,
Err(()) => return,
}
Expand Down Expand Up @@ -911,6 +913,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
name: Symbol,
callee_ty: Ty<'tcx>,
fn_args: GenericArgsRef<'tcx>,
args: &[OperandRef<'tcx, &'ll Value>],
ret_ty: Ty<'tcx>,
llret_ty: &'ll Type,
Expand Down Expand Up @@ -1030,6 +1033,56 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
));
}

if name == sym::simd_shuffle_generic {
let idx = fn_args[2]
.expect_const()
.eval(tcx, ty::ParamEnv::reveal_all(), Some(span))
.unwrap()
.unwrap_branch();
let n = idx.len() as u64;

require_simd!(ret_ty, InvalidMonomorphization::SimdReturn { span, name, ty: ret_ty });
let (out_len, out_ty) = ret_ty.simd_size_and_type(bx.tcx());
require!(
out_len == n,
InvalidMonomorphization::ReturnLength { span, name, in_len: n, ret_ty, out_len }
);
require!(
in_elem == out_ty,
InvalidMonomorphization::ReturnElement { span, name, in_elem, in_ty, ret_ty, out_ty }
);

let total_len = in_len * 2;

let indices: Option<Vec<_>> = idx
.iter()
.enumerate()
.map(|(arg_idx, val)| {
let idx = val.unwrap_leaf().try_to_i32().unwrap();
if idx >= i32::try_from(total_len).unwrap() {
bx.sess().emit_err(InvalidMonomorphization::ShuffleIndexOutOfBounds {
span,
name,
arg_idx: arg_idx as u64,
total_len: total_len.into(),
});
None
} else {
Some(bx.const_i32(idx))
}
})
.collect();
let Some(indices) = indices else {
return Ok(bx.const_null(llret_ty));
};

return Ok(bx.shuffle_vector(
args[0].immediate(),
args[1].immediate(),
bx.const_vector(&indices),
));
}

if name == sym::simd_shuffle {
// Make sure this is actually an array, since typeck only checks the length-suffixed
// version of this intrinsic.
Expand Down
44 changes: 23 additions & 21 deletions compiler/rustc_hir_analysis/src/check/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ fn equate_intrinsic_type<'tcx>(
it: &hir::ForeignItem<'_>,
n_tps: usize,
n_lts: usize,
n_cts: usize,
sig: ty::PolyFnSig<'tcx>,
) {
let (own_counts, span) = match &it.kind {
Expand Down Expand Up @@ -51,7 +52,7 @@ fn equate_intrinsic_type<'tcx>(

if gen_count_ok(own_counts.lifetimes, n_lts, "lifetime")
&& gen_count_ok(own_counts.types, n_tps, "type")
&& gen_count_ok(own_counts.consts, 0, "const")
&& gen_count_ok(own_counts.consts, n_cts, "const")
{
let fty = Ty::new_fn_ptr(tcx, sig);
let it_def_id = it.owner_id.def_id;
Expand Down Expand Up @@ -492,7 +493,7 @@ pub fn check_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>) {
};
let sig = tcx.mk_fn_sig(inputs, output, false, unsafety, Abi::RustIntrinsic);
let sig = ty::Binder::bind_with_vars(sig, bound_vars);
equate_intrinsic_type(tcx, it, n_tps, n_lts, sig)
equate_intrinsic_type(tcx, it, n_tps, n_lts, 0, sig)
}

/// Type-check `extern "platform-intrinsic" { ... }` functions.
Expand All @@ -504,9 +505,9 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)

let name = it.ident.name;

let (n_tps, inputs, output) = match name {
let (n_tps, n_cts, inputs, output) = match name {
sym::simd_eq | sym::simd_ne | sym::simd_lt | sym::simd_le | sym::simd_gt | sym::simd_ge => {
(2, vec![param(0), param(0)], param(1))
(2, 0, vec![param(0), param(0)], param(1))
}
sym::simd_add
| sym::simd_sub
Expand All @@ -522,8 +523,8 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
| sym::simd_fmax
| sym::simd_fpow
| sym::simd_saturating_add
| sym::simd_saturating_sub => (1, vec![param(0), param(0)], param(0)),
sym::simd_arith_offset => (2, vec![param(0), param(1)], param(0)),
| sym::simd_saturating_sub => (1, 0, vec![param(0), param(0)], param(0)),
sym::simd_arith_offset => (2, 0, vec![param(0), param(1)], param(0)),
sym::simd_neg
| sym::simd_bswap
| sym::simd_bitreverse
Expand All @@ -541,25 +542,25 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
| sym::simd_ceil
| sym::simd_floor
| sym::simd_round
| sym::simd_trunc => (1, vec![param(0)], param(0)),
sym::simd_fpowi => (1, vec![param(0), tcx.types.i32], param(0)),
sym::simd_fma => (1, vec![param(0), param(0), param(0)], param(0)),
sym::simd_gather => (3, vec![param(0), param(1), param(2)], param(0)),
sym::simd_scatter => (3, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
sym::simd_insert => (2, vec![param(0), tcx.types.u32, param(1)], param(0)),
sym::simd_extract => (2, vec![param(0), tcx.types.u32], param(1)),
| sym::simd_trunc => (1, 0, vec![param(0)], param(0)),
sym::simd_fpowi => (1, 0, vec![param(0), tcx.types.i32], param(0)),
sym::simd_fma => (1, 0, vec![param(0), param(0), param(0)], param(0)),
sym::simd_gather => (3, 0, vec![param(0), param(1), param(2)], param(0)),
sym::simd_scatter => (3, 0, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
sym::simd_insert => (2, 0, vec![param(0), tcx.types.u32, param(1)], param(0)),
sym::simd_extract => (2, 0, vec![param(0), tcx.types.u32], param(1)),
sym::simd_cast
| sym::simd_as
| sym::simd_cast_ptr
| sym::simd_expose_addr
| sym::simd_from_exposed_addr => (2, vec![param(0)], param(1)),
sym::simd_bitmask => (2, vec![param(0)], param(1)),
| sym::simd_from_exposed_addr => (2, 0, vec![param(0)], param(1)),
sym::simd_bitmask => (2, 0, vec![param(0)], param(1)),
sym::simd_select | sym::simd_select_bitmask => {
(2, vec![param(0), param(1), param(1)], param(1))
(2, 0, vec![param(0), param(1), param(1)], param(1))
}
sym::simd_reduce_all | sym::simd_reduce_any => (1, vec![param(0)], tcx.types.bool),
sym::simd_reduce_all | sym::simd_reduce_any => (1, 0, vec![param(0)], tcx.types.bool),
sym::simd_reduce_add_ordered | sym::simd_reduce_mul_ordered => {
(2, vec![param(0), param(1)], param(1))
(2, 0, vec![param(0), param(1)], param(1))
}
sym::simd_reduce_add_unordered
| sym::simd_reduce_mul_unordered
Expand All @@ -569,8 +570,9 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
| sym::simd_reduce_min
| sym::simd_reduce_max
| sym::simd_reduce_min_nanless
| sym::simd_reduce_max_nanless => (2, vec![param(0)], param(1)),
sym::simd_shuffle => (3, vec![param(0), param(0), param(1)], param(2)),
| sym::simd_reduce_max_nanless => (2, 0, vec![param(0)], param(1)),
sym::simd_shuffle => (3, 0, vec![param(0), param(0), param(1)], param(2)),
sym::simd_shuffle_generic => (2, 1, vec![param(0), param(0)], param(1)),
_ => {
let msg = format!("unrecognized platform-specific intrinsic function: `{name}`");
tcx.sess.struct_span_err(it.span, msg).emit();
Expand All @@ -580,5 +582,5 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)

let sig = tcx.mk_fn_sig(inputs, output, false, hir::Unsafety::Unsafe, Abi::PlatformIntrinsic);
let sig = ty::Binder::dummy(sig);
equate_intrinsic_type(tcx, it, n_tps, 0, sig)
equate_intrinsic_type(tcx, it, n_tps, 0, n_cts, sig)
}
1 change: 1 addition & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,7 @@ symbols! {
simd_shl,
simd_shr,
simd_shuffle,
simd_shuffle_generic,
simd_sub,
simd_trunc,
simd_xor,
Expand Down
5 changes: 3 additions & 2 deletions src/tools/miri/src/shims/intrinsics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
}

// The rest jumps to `ret` immediately.
this.emulate_intrinsic_by_name(intrinsic_name, args, dest)?;
this.emulate_intrinsic_by_name(intrinsic_name, instance.args, args, dest)?;

trace!("{:?}", this.dump_place(dest));
this.go_to_block(ret);
Expand All @@ -71,6 +71,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
fn emulate_intrinsic_by_name(
&mut self,
intrinsic_name: &str,
generic_args: ty::GenericArgsRef<'tcx>,
args: &[OpTy<'tcx, Provenance>],
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx> {
Expand All @@ -80,7 +81,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
return this.emulate_atomic_intrinsic(name, args, dest);
}
if let Some(name) = intrinsic_name.strip_prefix("simd_") {
return this.emulate_simd_intrinsic(name, args, dest);
return this.emulate_simd_intrinsic(name, generic_args, args, dest);
}

match intrinsic_name {
Expand Down
33 changes: 33 additions & 0 deletions src/tools/miri/src/shims/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
fn emulate_simd_intrinsic(
&mut self,
intrinsic_name: &str,
generic_args: ty::GenericArgsRef<'tcx>,
args: &[OpTy<'tcx, Provenance>],
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx> {
Expand Down Expand Up @@ -490,6 +491,38 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
this.write_immediate(val, &dest)?;
}
}
"shuffle_generic" => {
let [left, right] = check_arg_count(args)?;
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

let index = generic_args[2].expect_const().eval(*this.tcx, this.param_env(), Some(this.tcx.span)).unwrap().unwrap_branch();
let index_len = index.len();

assert_eq!(left_len, right_len);
assert_eq!(index_len as u64, dest_len);

for i in 0..dest_len {
let src_index: u64 = index[i as usize].unwrap_leaf()
.try_to_u32().unwrap()
.into();
let dest = this.project_index(&dest, i)?;

let val = if src_index < left_len {
this.read_immediate(&this.project_index(&left, src_index)?)?
} else if src_index < left_len.checked_add(right_len).unwrap() {
let right_idx = src_index.checked_sub(left_len).unwrap();
this.read_immediate(&this.project_index(&right, right_idx)?)?
} else {
span_bug!(
this.cur_span(),
"simd_shuffle index {src_index} is out of bounds for 2 vectors of size {left_len}",
);
};
this.write_immediate(*val, &dest)?;
}
}
"shuffle" => {
let [left, right, index] = check_arg_count(args)?;
let (left, left_len) = this.operand_to_simd(left)?;
Expand Down
28 changes: 27 additions & 1 deletion tests/ui/simd/intrinsic/generic-elements.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// build-fail

#![feature(repr_simd, platform_intrinsics, rustc_attrs)]
#![feature(repr_simd, platform_intrinsics, rustc_attrs, adt_const_params)]
#![allow(incomplete_features)]

#[repr(simd)]
#[derive(Copy, Clone)]
Expand Down Expand Up @@ -35,6 +36,7 @@ extern "platform-intrinsic" {
fn simd_extract<T, E>(x: T, idx: u32) -> E;

fn simd_shuffle<T, I, U>(x: T, y: T, idx: I) -> U;
fn simd_shuffle_generic<T, U, const IDX: &'static [u32]>(x: T, y: T) -> U;
}

fn main() {
Expand Down Expand Up @@ -71,5 +73,29 @@ fn main() {
//~^ ERROR expected return type of length 4, found `i32x8` with length 8
simd_shuffle::<_, _, i32x2>(x, x, IDX8);
//~^ ERROR expected return type of length 8, found `i32x2` with length 2

const I2: &[u32] = &[0; 2];
simd_shuffle_generic::<i32, i32, I2>(0, 0);
//~^ ERROR expected SIMD input type, found non-SIMD `i32`
const I4: &[u32] = &[0; 4];
simd_shuffle_generic::<i32, i32, I4>(0, 0);
//~^ ERROR expected SIMD input type, found non-SIMD `i32`
const I8: &[u32] = &[0; 8];
simd_shuffle_generic::<i32, i32, I8>(0, 0);
//~^ ERROR expected SIMD input type, found non-SIMD `i32`

simd_shuffle_generic::<_, f32x2, I2>(x, x);
//~^ ERROR element type `i32` (element of input `i32x4`), found `f32x2` with element type `f32`
simd_shuffle_generic::<_, f32x4, I4>(x, x);
//~^ ERROR element type `i32` (element of input `i32x4`), found `f32x4` with element type `f32`
simd_shuffle_generic::<_, f32x8, I8>(x, x);
//~^ ERROR element type `i32` (element of input `i32x4`), found `f32x8` with element type `f32`

simd_shuffle_generic::<_, i32x8, I2>(x, x);
//~^ ERROR expected return type of length 2, found `i32x8` with length 8
simd_shuffle_generic::<_, i32x8, I4>(x, x);
//~^ ERROR expected return type of length 4, found `i32x8` with length 8
simd_shuffle_generic::<_, i32x2, I8>(x, x);
//~^ ERROR expected return type of length 8, found `i32x2` with length 2
}
}
Loading

0 comments on commit 1a01e57

Please sign in to comment.