Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

intrinsics::simd: add missing functions, avoid UB-triggering fast-math #121223

Merged
merged 3 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
simd_reduce(fx, v, None, ret, &|fx, _ty, a, b| fx.bcx.ins().bxor(a, b));
}

sym::simd_reduce_min | sym::simd_reduce_min_nanless => {
sym::simd_reduce_min => {
intrinsic_args!(fx, args => (v); intrinsic);

if !v.layout().ty.is_simd() {
Expand All @@ -762,7 +762,7 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
});
}

sym::simd_reduce_max | sym::simd_reduce_max_nanless => {
sym::simd_reduce_max => {
intrinsic_args!(fx, args => (v); intrinsic);

if !v.layout().ty.is_simd() {
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_codegen_gcc/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1752,7 +1752,7 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {
self.vector_reduce(src, |a, b, context| context.new_binary_op(None, op, a.get_type(), a, b))
}

pub fn vector_reduce_fadd_fast(&mut self, _acc: RValue<'gcc>, _src: RValue<'gcc>) -> RValue<'gcc> {
pub fn vector_reduce_fadd_reassoc(&mut self, _acc: RValue<'gcc>, _src: RValue<'gcc>) -> RValue<'gcc> {
unimplemented!();
}

Expand All @@ -1772,7 +1772,7 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {
unimplemented!();
}

pub fn vector_reduce_fmul_fast(&mut self, _acc: RValue<'gcc>, _src: RValue<'gcc>) -> RValue<'gcc> {
pub fn vector_reduce_fmul_reassoc(&mut self, _acc: RValue<'gcc>, _src: RValue<'gcc>) -> RValue<'gcc> {
unimplemented!();
}

Expand Down
7 changes: 2 additions & 5 deletions compiler/rustc_codegen_gcc/src/intrinsic/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -989,14 +989,14 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(

arith_red!(
simd_reduce_add_unordered: BinaryOp::Plus,
vector_reduce_fadd_fast,
vector_reduce_fadd_reassoc,
false,
add,
0.0 // TODO: Use this argument.
);
arith_red!(
simd_reduce_mul_unordered: BinaryOp::Mult,
vector_reduce_fmul_fast,
vector_reduce_fmul_reassoc,
false,
mul,
1.0
Expand Down Expand Up @@ -1041,9 +1041,6 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(

minmax_red!(simd_reduce_min: vector_reduce_min, vector_reduce_fmin);
minmax_red!(simd_reduce_max: vector_reduce_max, vector_reduce_fmax);
// TODO(sadlerap): revisit these intrinsics to generate more optimal reductions
minmax_red!(simd_reduce_min_nanless: vector_reduce_min, vector_reduce_fmin);
minmax_red!(simd_reduce_max_nanless: vector_reduce_max, vector_reduce_fmax);

macro_rules! bitwise_red {
($name:ident : $op:expr, $boolean:expr) => {
Expand Down
24 changes: 4 additions & 20 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1367,17 +1367,17 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
pub fn vector_reduce_fmul(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
unsafe { llvm::LLVMRustBuildVectorReduceFMul(self.llbuilder, acc, src) }
}
pub fn vector_reduce_fadd_algebraic(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
pub fn vector_reduce_fadd_reassoc(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
unsafe {
let instr = llvm::LLVMRustBuildVectorReduceFAdd(self.llbuilder, acc, src);
llvm::LLVMRustSetAlgebraicMath(instr);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@saethlin you changed these to "algebraic", but I think they should be just "reassoc" to match what clang does. (And even that is questionable, see rust-lang/stdarch#1533. But that's a separate issue.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I believe the "algebraic" implementation is technically wrong. But it was an improvement to land because it doesn't explode so dramatically :p

llvm::LLVMRustSetAllowReassoc(instr);
instr
}
}
pub fn vector_reduce_fmul_algebraic(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
pub fn vector_reduce_fmul_reassoc(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
unsafe {
let instr = llvm::LLVMRustBuildVectorReduceFMul(self.llbuilder, acc, src);
llvm::LLVMRustSetAlgebraicMath(instr);
llvm::LLVMRustSetAllowReassoc(instr);
instr
}
}
Expand Down Expand Up @@ -1406,22 +1406,6 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
llvm::LLVMRustBuildVectorReduceFMax(self.llbuilder, src, /*NoNaNs:*/ false)
}
}
pub fn vector_reduce_fmin_fast(&mut self, src: &'ll Value) -> &'ll Value {
unsafe {
let instr =
llvm::LLVMRustBuildVectorReduceFMin(self.llbuilder, src, /*NoNaNs:*/ true);
llvm::LLVMRustSetFastMath(instr);
instr
}
}
pub fn vector_reduce_fmax_fast(&mut self, src: &'ll Value) -> &'ll Value {
unsafe {
let instr =
llvm::LLVMRustBuildVectorReduceFMax(self.llbuilder, src, /*NoNaNs:*/ true);
llvm::LLVMRustSetFastMath(instr);
instr
}
}
pub fn vector_reduce_min(&mut self, src: &'ll Value, is_signed: bool) -> &'ll Value {
unsafe { llvm::LLVMRustBuildVectorReduceMin(self.llbuilder, src, is_signed) }
}
Expand Down
7 changes: 2 additions & 5 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1880,14 +1880,14 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
arith_red!(simd_reduce_mul_ordered: vector_reduce_mul, vector_reduce_fmul, true, mul, 1.0);
arith_red!(
simd_reduce_add_unordered: vector_reduce_add,
vector_reduce_fadd_algebraic,
vector_reduce_fadd_reassoc,
false,
add,
0.0
);
arith_red!(
simd_reduce_mul_unordered: vector_reduce_mul,
vector_reduce_fmul_algebraic,
vector_reduce_fmul_reassoc,
false,
mul,
1.0
Expand Down Expand Up @@ -1920,9 +1920,6 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
minmax_red!(simd_reduce_min: vector_reduce_min, vector_reduce_fmin);
minmax_red!(simd_reduce_max: vector_reduce_max, vector_reduce_fmax);

minmax_red!(simd_reduce_min_nanless: vector_reduce_min, vector_reduce_fmin_fast);
minmax_red!(simd_reduce_max_nanless: vector_reduce_max, vector_reduce_fmax_fast);

macro_rules! bitwise_red {
($name:ident : $red:ident, $boolean:expr) => {
if name == sym::$name {
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1619,6 +1619,7 @@ extern "C" {

pub fn LLVMRustSetFastMath(Instr: &Value);
pub fn LLVMRustSetAlgebraicMath(Instr: &Value);
pub fn LLVMRustSetAllowReassoc(Instr: &Value);

// Miscellaneous instructions
pub fn LLVMRustGetInstrProfIncrementIntrinsic(M: &Module) -> &Value;
Expand Down
4 changes: 1 addition & 3 deletions compiler/rustc_hir_analysis/src/check/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,9 +606,7 @@ pub fn check_platform_intrinsic_type(
| sym::simd_reduce_or
| sym::simd_reduce_xor
| sym::simd_reduce_min
| sym::simd_reduce_max
| sym::simd_reduce_min_nanless
| sym::simd_reduce_max_nanless => (2, 0, vec![param(0)], param(1)),
| sym::simd_reduce_max => (2, 0, vec![param(0)], param(1)),
sym::simd_shuffle => (3, 0, vec![param(0), param(0), param(1)], param(2)),
sym::simd_shuffle_generic => (2, 1, vec![param(0), param(0)], param(1)),
_ => {
Expand Down
14 changes: 14 additions & 0 deletions compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,20 @@ extern "C" void LLVMRustSetAlgebraicMath(LLVMValueRef V) {
}
}

// Enable the reassoc fast-math flag, allowing transformations that pretend
// floating-point addition and multiplication are associative.
//
// Note that this does NOT enable any flags which can cause a floating-point operation on
// well-defined inputs to return poison, and therefore this function can be used to build
// safe Rust intrinsics (such as fadd_algebraic).
//
// https://llvm.org/docs/LangRef.html#fast-math-flags
extern "C" void LLVMRustSetAllowReassoc(LLVMValueRef V) {
if (auto I = dyn_cast<Instruction>(unwrap<Value>(V))) {
I->setHasAllowReassoc(true);
}
}

extern "C" LLVMValueRef
LLVMRustBuildAtomicLoad(LLVMBuilderRef B, LLVMTypeRef Ty, LLVMValueRef Source,
const char *Name, LLVMAtomicOrdering Order) {
Expand Down
2 changes: 0 additions & 2 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1553,9 +1553,7 @@ symbols! {
simd_reduce_and,
simd_reduce_any,
simd_reduce_max,
simd_reduce_max_nanless,
simd_reduce_min,
simd_reduce_min_nanless,
simd_reduce_mul_ordered,
simd_reduce_mul_unordered,
simd_reduce_or,
Expand Down
69 changes: 69 additions & 0 deletions library/core/src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,24 @@
//! In this module, a "vector" is any `repr(simd)` type.

extern "platform-intrinsic" {
/// Insert an element into a vector, returning the updated vector.
///
/// `T` must be a vector with element type `U`.
///
/// # Safety
///
/// `idx` must be in-bounds of the vector.
pub fn simd_insert<T, U>(x: T, idx: u32, val: U) -> T;

/// Extract an element from a vector.
///
/// `T` must be a vector with element type `U`.
///
/// # Safety
///
/// `idx` must be in-bounds of the vector.
pub fn simd_extract<T, U>(x: T, idx: u32) -> U;

/// Add two simd vectors elementwise.
///
/// `T` must be a vector of integer or floating point primitive types.
Expand Down Expand Up @@ -317,6 +335,14 @@ extern "platform-intrinsic" {
/// Starting with the value `y`, add the elements of `x` and accumulate.
pub fn simd_reduce_add_ordered<T, U>(x: T, y: U) -> U;

/// Add elements within a vector in arbitrary order. May also be re-associated with
/// unordered additions on the inputs/outputs.
///
/// `T` must be a vector of integer or floating-point primitive types.
///
/// `U` must be the element type of `T`.
pub fn simd_reduce_add_unordered<T, U>(x: T) -> U;

/// Multiply elements within a vector from left to right.
///
/// `T` must be a vector of integer or floating-point primitive types.
Expand All @@ -326,6 +352,14 @@ extern "platform-intrinsic" {
/// Starting with the value `y`, multiply the elements of `x` and accumulate.
pub fn simd_reduce_mul_ordered<T, U>(x: T, y: U) -> U;

/// Add elements within a vector in arbitrary order. May also be re-associated with
/// unordered additions on the inputs/outputs.
///
/// `T` must be a vector of integer or floating-point primitive types.
///
/// `U` must be the element type of `T`.
pub fn simd_reduce_mul_unordered<T, U>(x: T) -> U;

/// Check if all mask values are true.
///
/// `T` must be a vector of integer primitive types.
Expand Down Expand Up @@ -518,4 +552,39 @@ extern "platform-intrinsic" {
///
/// `T` must be a vector of floats.
pub fn simd_fma<T>(x: T, y: T, z: T) -> T;

// Computes the sine of each element.
///
/// `T` must be a vector of floats.
pub fn simd_fsin<T>(a: T) -> T;

// Computes the cosine of each element.
///
/// `T` must be a vector of floats.
pub fn simd_fcos<T>(a: T) -> T;

// Computes the exponential function of each element.
///
/// `T` must be a vector of floats.
pub fn simd_fexp<T>(a: T) -> T;

// Computes 2 raised to the power of each element.
///
/// `T` must be a vector of floats.
pub fn simd_fexp2<T>(a: T) -> T;

// Computes the base 10 logarithm of each element.
///
/// `T` must be a vector of floats.
pub fn simd_flog10<T>(a: T) -> T;

// Computes the base 2 logarithm of each element.
///
/// `T` must be a vector of floats.
pub fn simd_flog2<T>(a: T) -> T;

// Computes the natural logarithm of each element.
///
/// `T` must be a vector of floats.
pub fn simd_flog<T>(a: T) -> T;
}
2 changes: 1 addition & 1 deletion tests/codegen/simd/issue-120720-reduce-nan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::arch::x86_64::*;
#[no_mangle]
#[target_feature(enable = "avx512f")] // Function-level target feature mismatches inhibit inlining
pub unsafe fn demo() -> bool {
// CHECK: %0 = tail call reassoc nsz arcp contract double @llvm.vector.reduce.fadd.v8f64(
// CHECK: %0 = tail call reassoc double @llvm.vector.reduce.fadd.v8f64(
// CHECK: %_0.i = fcmp uno double %0, 0.000000e+00
// CHECK: ret i1 %_0.i
let res = unsafe {
Expand Down
6 changes: 0 additions & 6 deletions tests/ui/simd/intrinsic/generic-reduction-pass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ extern "platform-intrinsic" {
fn simd_reduce_mul_ordered<T, U>(x: T, acc: U) -> U;
fn simd_reduce_min<T, U>(x: T) -> U;
fn simd_reduce_max<T, U>(x: T) -> U;
fn simd_reduce_min_nanless<T, U>(x: T) -> U;
fn simd_reduce_max_nanless<T, U>(x: T) -> U;
fn simd_reduce_and<T, U>(x: T) -> U;
fn simd_reduce_or<T, U>(x: T) -> U;
fn simd_reduce_xor<T, U>(x: T) -> U;
Expand Down Expand Up @@ -127,10 +125,6 @@ fn main() {
assert_eq!(r, -2_f32);
let r: f32 = simd_reduce_max(x);
assert_eq!(r, 4_f32);
let r: f32 = simd_reduce_min_nanless(x);
assert_eq!(r, -2_f32);
let r: f32 = simd_reduce_max_nanless(x);
assert_eq!(r, 4_f32);
}

unsafe {
Expand Down
Loading