Skip to content

Commit

Permalink
Do not emit undefined lshr/ashr for Neon shifts (#1238)
Browse files Browse the repository at this point in the history
  • Loading branch information
JamieCunliffe authored Oct 22, 2021
1 parent 05e32d4 commit 299b5e2
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 21 deletions.
14 changes: 10 additions & 4 deletions crates/core_arch/src/aarch64/neon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2772,7 +2772,8 @@ pub unsafe fn vshld_n_u64<const N: i32>(a: u64) -> u64 {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshrd_n_s64<const N: i32>(a: i64) -> i64 {
static_assert!(N : i32 where N >= 1 && N <= 64);
a >> N
let n: i32 = if N == 64 { 63 } else { N };
a >> n
}

/// Unsigned shift right
Expand All @@ -2782,7 +2783,12 @@ pub unsafe fn vshrd_n_s64<const N: i32>(a: i64) -> i64 {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshrd_n_u64<const N: i32>(a: u64) -> u64 {
static_assert!(N : i32 where N >= 1 && N <= 64);
a >> N
let n: i32 = if N == 64 {
return 0;
} else {
N
};
a >> n
}

/// Signed shift right and accumulate
Expand All @@ -2792,7 +2798,7 @@ pub unsafe fn vshrd_n_u64<const N: i32>(a: u64) -> u64 {
#[rustc_legacy_const_generics(2)]
pub unsafe fn vsrad_n_s64<const N: i32>(a: i64, b: i64) -> i64 {
static_assert!(N : i32 where N >= 1 && N <= 64);
a + (b >> N)
a + vshrd_n_s64::<N>(b)
}

/// Unsigned shift right and accumulate
Expand All @@ -2802,7 +2808,7 @@ pub unsafe fn vsrad_n_s64<const N: i32>(a: i64, b: i64) -> i64 {
#[rustc_legacy_const_generics(2)]
pub unsafe fn vsrad_n_u64<const N: i32>(a: u64, b: u64) -> u64 {
static_assert!(N : i32 where N >= 1 && N <= 64);
a + (b >> N)
a + vshrd_n_u64::<N>(b)
}

/// Shift Left and Insert (immediate)
Expand Down
48 changes: 32 additions & 16 deletions crates/core_arch/src/arm_shared/neon/generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21987,7 +21987,8 @@ pub unsafe fn vshll_n_u32<const N: i32>(a: uint32x2_t) -> uint64x2_t {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshr_n_s8<const N: i32>(a: int8x8_t) -> int8x8_t {
static_assert!(N : i32 where N >= 1 && N <= 8);
simd_shr(a, vdup_n_s8(N.try_into().unwrap()))
let n: i32 = if N == 8 { 7 } else { N };
simd_shr(a, vdup_n_s8(n.try_into().unwrap()))
}

/// Shift right
Expand All @@ -21999,7 +22000,8 @@ pub unsafe fn vshr_n_s8<const N: i32>(a: int8x8_t) -> int8x8_t {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshrq_n_s8<const N: i32>(a: int8x16_t) -> int8x16_t {
static_assert!(N : i32 where N >= 1 && N <= 8);
simd_shr(a, vdupq_n_s8(N.try_into().unwrap()))
let n: i32 = if N == 8 { 7 } else { N };
simd_shr(a, vdupq_n_s8(n.try_into().unwrap()))
}

/// Shift right
Expand All @@ -22011,7 +22013,8 @@ pub unsafe fn vshrq_n_s8<const N: i32>(a: int8x16_t) -> int8x16_t {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshr_n_s16<const N: i32>(a: int16x4_t) -> int16x4_t {
static_assert!(N : i32 where N >= 1 && N <= 16);
simd_shr(a, vdup_n_s16(N.try_into().unwrap()))
let n: i32 = if N == 16 { 15 } else { N };
simd_shr(a, vdup_n_s16(n.try_into().unwrap()))
}

/// Shift right
Expand All @@ -22023,7 +22026,8 @@ pub unsafe fn vshr_n_s16<const N: i32>(a: int16x4_t) -> int16x4_t {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshrq_n_s16<const N: i32>(a: int16x8_t) -> int16x8_t {
static_assert!(N : i32 where N >= 1 && N <= 16);
simd_shr(a, vdupq_n_s16(N.try_into().unwrap()))
let n: i32 = if N == 16 { 15 } else { N };
simd_shr(a, vdupq_n_s16(n.try_into().unwrap()))
}

/// Shift right
Expand All @@ -22035,7 +22039,8 @@ pub unsafe fn vshrq_n_s16<const N: i32>(a: int16x8_t) -> int16x8_t {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshr_n_s32<const N: i32>(a: int32x2_t) -> int32x2_t {
static_assert!(N : i32 where N >= 1 && N <= 32);
simd_shr(a, vdup_n_s32(N.try_into().unwrap()))
let n: i32 = if N == 32 { 31 } else { N };
simd_shr(a, vdup_n_s32(n.try_into().unwrap()))
}

/// Shift right
Expand All @@ -22047,7 +22052,8 @@ pub unsafe fn vshr_n_s32<const N: i32>(a: int32x2_t) -> int32x2_t {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshrq_n_s32<const N: i32>(a: int32x4_t) -> int32x4_t {
static_assert!(N : i32 where N >= 1 && N <= 32);
simd_shr(a, vdupq_n_s32(N.try_into().unwrap()))
let n: i32 = if N == 32 { 31 } else { N };
simd_shr(a, vdupq_n_s32(n.try_into().unwrap()))
}

/// Shift right
Expand All @@ -22059,7 +22065,8 @@ pub unsafe fn vshrq_n_s32<const N: i32>(a: int32x4_t) -> int32x4_t {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshr_n_s64<const N: i32>(a: int64x1_t) -> int64x1_t {
static_assert!(N : i32 where N >= 1 && N <= 64);
simd_shr(a, vdup_n_s64(N.try_into().unwrap()))
let n: i32 = if N == 64 { 63 } else { N };
simd_shr(a, vdup_n_s64(n.try_into().unwrap()))
}

/// Shift right
Expand All @@ -22071,7 +22078,8 @@ pub unsafe fn vshr_n_s64<const N: i32>(a: int64x1_t) -> int64x1_t {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshrq_n_s64<const N: i32>(a: int64x2_t) -> int64x2_t {
static_assert!(N : i32 where N >= 1 && N <= 64);
simd_shr(a, vdupq_n_s64(N.try_into().unwrap()))
let n: i32 = if N == 64 { 63 } else { N };
simd_shr(a, vdupq_n_s64(n.try_into().unwrap()))
}

/// Shift right
Expand All @@ -22083,7 +22091,8 @@ pub unsafe fn vshrq_n_s64<const N: i32>(a: int64x2_t) -> int64x2_t {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshr_n_u8<const N: i32>(a: uint8x8_t) -> uint8x8_t {
static_assert!(N : i32 where N >= 1 && N <= 8);
simd_shr(a, vdup_n_u8(N.try_into().unwrap()))
let n: i32 = if N == 8 { return vdup_n_u8(0); } else { N };
simd_shr(a, vdup_n_u8(n.try_into().unwrap()))
}

/// Shift right
Expand All @@ -22095,7 +22104,8 @@ pub unsafe fn vshr_n_u8<const N: i32>(a: uint8x8_t) -> uint8x8_t {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshrq_n_u8<const N: i32>(a: uint8x16_t) -> uint8x16_t {
static_assert!(N : i32 where N >= 1 && N <= 8);
simd_shr(a, vdupq_n_u8(N.try_into().unwrap()))
let n: i32 = if N == 8 { return vdupq_n_u8(0); } else { N };
simd_shr(a, vdupq_n_u8(n.try_into().unwrap()))
}

/// Shift right
Expand All @@ -22107,7 +22117,8 @@ pub unsafe fn vshrq_n_u8<const N: i32>(a: uint8x16_t) -> uint8x16_t {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshr_n_u16<const N: i32>(a: uint16x4_t) -> uint16x4_t {
static_assert!(N : i32 where N >= 1 && N <= 16);
simd_shr(a, vdup_n_u16(N.try_into().unwrap()))
let n: i32 = if N == 16 { return vdup_n_u16(0); } else { N };
simd_shr(a, vdup_n_u16(n.try_into().unwrap()))
}

/// Shift right
Expand All @@ -22119,7 +22130,8 @@ pub unsafe fn vshr_n_u16<const N: i32>(a: uint16x4_t) -> uint16x4_t {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshrq_n_u16<const N: i32>(a: uint16x8_t) -> uint16x8_t {
static_assert!(N : i32 where N >= 1 && N <= 16);
simd_shr(a, vdupq_n_u16(N.try_into().unwrap()))
let n: i32 = if N == 16 { return vdupq_n_u16(0); } else { N };
simd_shr(a, vdupq_n_u16(n.try_into().unwrap()))
}

/// Shift right
Expand All @@ -22131,7 +22143,8 @@ pub unsafe fn vshrq_n_u16<const N: i32>(a: uint16x8_t) -> uint16x8_t {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshr_n_u32<const N: i32>(a: uint32x2_t) -> uint32x2_t {
static_assert!(N : i32 where N >= 1 && N <= 32);
simd_shr(a, vdup_n_u32(N.try_into().unwrap()))
let n: i32 = if N == 32 { return vdup_n_u32(0); } else { N };
simd_shr(a, vdup_n_u32(n.try_into().unwrap()))
}

/// Shift right
Expand All @@ -22143,7 +22156,8 @@ pub unsafe fn vshr_n_u32<const N: i32>(a: uint32x2_t) -> uint32x2_t {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshrq_n_u32<const N: i32>(a: uint32x4_t) -> uint32x4_t {
static_assert!(N : i32 where N >= 1 && N <= 32);
simd_shr(a, vdupq_n_u32(N.try_into().unwrap()))
let n: i32 = if N == 32 { return vdupq_n_u32(0); } else { N };
simd_shr(a, vdupq_n_u32(n.try_into().unwrap()))
}

/// Shift right
Expand All @@ -22155,7 +22169,8 @@ pub unsafe fn vshrq_n_u32<const N: i32>(a: uint32x4_t) -> uint32x4_t {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshr_n_u64<const N: i32>(a: uint64x1_t) -> uint64x1_t {
static_assert!(N : i32 where N >= 1 && N <= 64);
simd_shr(a, vdup_n_u64(N.try_into().unwrap()))
let n: i32 = if N == 64 { return vdup_n_u64(0); } else { N };
simd_shr(a, vdup_n_u64(n.try_into().unwrap()))
}

/// Shift right
Expand All @@ -22167,7 +22182,8 @@ pub unsafe fn vshr_n_u64<const N: i32>(a: uint64x1_t) -> uint64x1_t {
#[rustc_legacy_const_generics(1)]
pub unsafe fn vshrq_n_u64<const N: i32>(a: uint64x2_t) -> uint64x2_t {
static_assert!(N : i32 where N >= 1 && N <= 64);
simd_shr(a, vdupq_n_u64(N.try_into().unwrap()))
let n: i32 = if N == 64 { return vdupq_n_u64(0); } else { N };
simd_shr(a, vdupq_n_u64(n.try_into().unwrap()))
}

/// Shift right narrow
Expand Down
3 changes: 2 additions & 1 deletion crates/stdarch-gen/neon.spec
Original file line number Diff line number Diff line change
Expand Up @@ -6785,7 +6785,8 @@ name = vshr
n-suffix
constn = N
multi_fn = static_assert-N-1-bits
multi_fn = simd_shr, a, {vdup-nself-noext, N.try_into().unwrap()}
multi_fn = fix_right_shift_imm-N-bits
multi_fn = simd_shr, a, {vdup-nself-noext, n.try_into().unwrap()}
a = 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64
n = 2
validate 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
Expand Down
22 changes: 22 additions & 0 deletions crates/stdarch-gen/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2664,6 +2664,28 @@ fn get_call(
);
}
}
if fn_name.starts_with("fix_right_shift_imm") {
let fn_format: Vec<_> = fn_name.split('-').map(|v| v.to_string()).collect();
let lim = if fn_format[2] == "bits" {
type_bits(in_t[1]).to_string()
} else {
fn_format[2].clone()
};
let fixed = if in_t[1].starts_with('u') {
format!("return vdup{nself}(0);", nself = type_to_n_suffix(in_t[1]))
} else {
(lim.parse::<i32>().unwrap() - 1).to_string()
};

return format!(
r#"let {name}: i32 = if {const_name} == {upper} {{ {fixed} }} else {{ N }};"#,
name = fn_format[1].to_lowercase(),
const_name = fn_format[1],
upper = lim,
fixed = fixed,
);
}

if fn_name.starts_with("matchn") {
let fn_format: Vec<_> = fn_name.split('-').map(|v| v.to_string()).collect();
let len = match &*fn_format[1] {
Expand Down

0 comments on commit 299b5e2

Please sign in to comment.