Skip to content

Commit

Permalink
Add f16 inline ASM support for 32-bit ARM
Browse files Browse the repository at this point in the history
  • Loading branch information
beetrees committed Jun 16, 2024
1 parent 12b33d3 commit 0a6a421
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 6 deletions.
58 changes: 58 additions & 0 deletions compiler/rustc_codegen_llvm/src/asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,26 @@ fn llvm_fixup_input<'ll, 'tcx>(
value
}
}
(
InlineAsmRegClass::Arm(
ArmInlineAsmRegClass::dreg
| ArmInlineAsmRegClass::dreg_low8
| ArmInlineAsmRegClass::dreg_low16,
),
Abi::Vector { element, count: 4 },
) if element.primitive() == Primitive::Float(Float::F16) => {
bx.bitcast(value, bx.type_f64())
}
(
InlineAsmRegClass::Arm(
ArmInlineAsmRegClass::qreg
| ArmInlineAsmRegClass::qreg_low4
| ArmInlineAsmRegClass::qreg_low8,
),
Abi::Vector { element, count: 8 },
) if element.primitive() == Primitive::Float(Float::F16) => {
bx.bitcast(value, bx.type_vector(bx.type_i16(), 8))
}
(InlineAsmRegClass::Mips(MipsInlineAsmRegClass::reg), Abi::Scalar(s)) => {
match s.primitive() {
// MIPS only supports register-length arithmetics.
Expand Down Expand Up @@ -1130,6 +1150,26 @@ fn llvm_fixup_output<'ll, 'tcx>(
value
}
}
(
InlineAsmRegClass::Arm(
ArmInlineAsmRegClass::dreg
| ArmInlineAsmRegClass::dreg_low8
| ArmInlineAsmRegClass::dreg_low16,
),
Abi::Vector { element, count: 4 },
) if element.primitive() == Primitive::Float(Float::F16) => {
bx.bitcast(value, bx.type_vector(bx.type_f16(), 4))
}
(
InlineAsmRegClass::Arm(
ArmInlineAsmRegClass::qreg
| ArmInlineAsmRegClass::qreg_low4
| ArmInlineAsmRegClass::qreg_low8,
),
Abi::Vector { element, count: 8 },
) if element.primitive() == Primitive::Float(Float::F16) => {
bx.bitcast(value, bx.type_vector(bx.type_f16(), 8))
}
(InlineAsmRegClass::Mips(MipsInlineAsmRegClass::reg), Abi::Scalar(s)) => {
match s.primitive() {
// MIPS only supports register-length arithmetics.
Expand Down Expand Up @@ -1233,6 +1273,24 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
layout.llvm_type(cx)
}
}
(
InlineAsmRegClass::Arm(
ArmInlineAsmRegClass::dreg
| ArmInlineAsmRegClass::dreg_low8
| ArmInlineAsmRegClass::dreg_low16,
),
Abi::Vector { element, count: 4 },
) if element.primitive() == Primitive::Float(Float::F16) => cx.type_f64(),
(
InlineAsmRegClass::Arm(
ArmInlineAsmRegClass::qreg
| ArmInlineAsmRegClass::qreg_low4
| ArmInlineAsmRegClass::qreg_low8,
),
Abi::Vector { element, count: 8 },
) if element.primitive() == Primitive::Float(Float::F16) => {
cx.type_vector(cx.type_i16(), 8)
}
(InlineAsmRegClass::Mips(MipsInlineAsmRegClass::reg), Abi::Scalar(s)) => {
match s.primitive() {
// MIPS only supports register-length arithmetics.
Expand Down
10 changes: 5 additions & 5 deletions compiler/rustc_target/src/asm/arm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ impl ArmInlineAsmRegClass {
_arch: InlineAsmArch,
) -> &'static [(InlineAsmType, Option<Symbol>)] {
match self {
Self::reg => types! { _: I8, I16, I32, F32; },
Self::sreg | Self::sreg_low16 => types! { vfp2: I32, F32; },
Self::reg => types! { _: I8, I16, I32, F16, F32; },
Self::sreg | Self::sreg_low16 => types! { vfp2: I32, F16, F32; },
Self::dreg_low16 | Self::dreg_low8 => types! {
vfp2: I64, F64, VecI8(8), VecI16(4), VecI32(2), VecI64(1), VecF32(2);
vfp2: I64, F64, VecI8(8), VecI16(4), VecI32(2), VecI64(1), VecF16(4), VecF32(2);
},
Self::dreg => types! {
d32: I64, F64, VecI8(8), VecI16(4), VecI32(2), VecI64(1), VecF32(2);
d32: I64, F64, VecI8(8), VecI16(4), VecI32(2), VecI64(1), VecF16(4), VecF32(2);
},
Self::qreg | Self::qreg_low8 | Self::qreg_low4 => types! {
neon: VecI8(16), VecI16(8), VecI32(4), VecI64(2), VecF32(4);
neon: VecI8(16), VecI16(8), VecI32(4), VecI64(2), VecF16(8), VecF32(4);
},
}
}
Expand Down
87 changes: 86 additions & 1 deletion tests/assembly/asm/arm-types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//@ compile-flags: -C opt-level=0
//@ needs-llvm-components: arm

#![feature(no_core, lang_items, rustc_attrs, repr_simd)]
#![feature(no_core, lang_items, rustc_attrs, repr_simd, f16)]
#![crate_type = "rlib"]
#![no_core]
#![allow(asm_sub_register, non_camel_case_types)]
Expand Down Expand Up @@ -38,6 +38,8 @@ pub struct i32x2(i32, i32);
#[repr(simd)]
pub struct i64x1(i64);
#[repr(simd)]
pub struct f16x4(f16, f16, f16, f16);
#[repr(simd)]
pub struct f32x2(f32, f32);
#[repr(simd)]
pub struct i8x16(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8);
Expand All @@ -48,11 +50,14 @@ pub struct i32x4(i32, i32, i32, i32);
#[repr(simd)]
pub struct i64x2(i64, i64);
#[repr(simd)]
pub struct f16x8(f16, f16, f16, f16, f16, f16, f16, f16);
#[repr(simd)]
pub struct f32x4(f32, f32, f32, f32);

impl Copy for i8 {}
impl Copy for i16 {}
impl Copy for i32 {}
impl Copy for f16 {}
impl Copy for f32 {}
impl Copy for i64 {}
impl Copy for f64 {}
Expand All @@ -61,11 +66,13 @@ impl Copy for i8x8 {}
impl Copy for i16x4 {}
impl Copy for i32x2 {}
impl Copy for i64x1 {}
impl Copy for f16x4 {}
impl Copy for f32x2 {}
impl Copy for i8x16 {}
impl Copy for i16x8 {}
impl Copy for i32x4 {}
impl Copy for i64x2 {}
impl Copy for f16x8 {}
impl Copy for f32x4 {}

extern "C" {
Expand Down Expand Up @@ -152,6 +159,12 @@ check!(reg_i16 i16 reg "mov");
// CHECK: @NO_APP
check!(reg_i32 i32 reg "mov");

// CHECK-LABEL: reg_f16:
// CHECK: @APP
// CHECK: mov {{[a-z0-9]+}}, {{[a-z0-9]+}}
// CHECK: @NO_APP
check!(reg_f16 f16 reg "mov");

// CHECK-LABEL: reg_f32:
// CHECK: @APP
// CHECK: mov {{[a-z0-9]+}}, {{[a-z0-9]+}}
Expand All @@ -170,6 +183,12 @@ check!(reg_ptr ptr reg "mov");
// CHECK: @NO_APP
check!(sreg_i32 i32 sreg "vmov.f32");

// CHECK-LABEL: sreg_f16:
// CHECK: @APP
// CHECK: vmov.f32 s{{[0-9]+}}, s{{[0-9]+}}
// CHECK: @NO_APP
check!(sreg_f16 f16 sreg "vmov.f32");

// CHECK-LABEL: sreg_f32:
// CHECK: @APP
// CHECK: vmov.f32 s{{[0-9]+}}, s{{[0-9]+}}
Expand All @@ -188,6 +207,12 @@ check!(sreg_ptr ptr sreg "vmov.f32");
// CHECK: @NO_APP
check!(sreg_low16_i32 i32 sreg_low16 "vmov.f32");

// CHECK-LABEL: sreg_low16_f16:
// CHECK: @APP
// CHECK: vmov.f32 s{{[0-9]+}}, s{{[0-9]+}}
// CHECK: @NO_APP
check!(sreg_low16_f16 f16 sreg_low16 "vmov.f32");

// CHECK-LABEL: sreg_low16_f32:
// CHECK: @APP
// CHECK: vmov.f32 s{{[0-9]+}}, s{{[0-9]+}}
Expand Down Expand Up @@ -230,6 +255,12 @@ check!(dreg_i32x2 i32x2 dreg "vmov.f64");
// CHECK: @NO_APP
check!(dreg_i64x1 i64x1 dreg "vmov.f64");

// CHECK-LABEL: dreg_f16x4:
// CHECK: @APP
// CHECK: vmov.f64 d{{[0-9]+}}, d{{[0-9]+}}
// CHECK: @NO_APP
check!(dreg_f16x4 f16x4 dreg "vmov.f64");

// CHECK-LABEL: dreg_f32x2:
// CHECK: @APP
// CHECK: vmov.f64 d{{[0-9]+}}, d{{[0-9]+}}
Expand Down Expand Up @@ -272,6 +303,12 @@ check!(dreg_low16_i32x2 i32x2 dreg_low16 "vmov.f64");
// CHECK: @NO_APP
check!(dreg_low16_i64x1 i64x1 dreg_low16 "vmov.f64");

// CHECK-LABEL: dreg_low16_f16x4:
// CHECK: @APP
// CHECK: vmov.f64 d{{[0-9]+}}, d{{[0-9]+}}
// CHECK: @NO_APP
check!(dreg_low16_f16x4 f16x4 dreg_low16 "vmov.f64");

// CHECK-LABEL: dreg_low16_f32x2:
// CHECK: @APP
// CHECK: vmov.f64 d{{[0-9]+}}, d{{[0-9]+}}
Expand Down Expand Up @@ -314,6 +351,12 @@ check!(dreg_low8_i32x2 i32x2 dreg_low8 "vmov.f64");
// CHECK: @NO_APP
check!(dreg_low8_i64x1 i64x1 dreg_low8 "vmov.f64");

// CHECK-LABEL: dreg_low8_f16x4:
// CHECK: @APP
// CHECK: vmov.f64 d{{[0-9]+}}, d{{[0-9]+}}
// CHECK: @NO_APP
check!(dreg_low8_f16x4 f16x4 dreg_low8 "vmov.f64");

// CHECK-LABEL: dreg_low8_f32x2:
// CHECK: @APP
// CHECK: vmov.f64 d{{[0-9]+}}, d{{[0-9]+}}
Expand Down Expand Up @@ -344,6 +387,12 @@ check!(qreg_i32x4 i32x4 qreg "vmov");
// CHECK: @NO_APP
check!(qreg_i64x2 i64x2 qreg "vmov");

// CHECK-LABEL: qreg_f16x8:
// CHECK: @APP
// CHECK: vorr q{{[0-9]+}}, q{{[0-9]+}}, q{{[0-9]+}}
// CHECK: @NO_APP
check!(qreg_f16x8 f16x8 qreg "vmov");

// CHECK-LABEL: qreg_f32x4:
// CHECK: @APP
// CHECK: vorr q{{[0-9]+}}, q{{[0-9]+}}, q{{[0-9]+}}
Expand Down Expand Up @@ -374,6 +423,12 @@ check!(qreg_low8_i32x4 i32x4 qreg_low8 "vmov");
// CHECK: @NO_APP
check!(qreg_low8_i64x2 i64x2 qreg_low8 "vmov");

// CHECK-LABEL: qreg_low8_f16x8:
// CHECK: @APP
// CHECK: vorr q{{[0-9]+}}, q{{[0-9]+}}, q{{[0-9]+}}
// CHECK: @NO_APP
check!(qreg_low8_f16x8 f16x8 qreg_low8 "vmov");

// CHECK-LABEL: qreg_low8_f32x4:
// CHECK: @APP
// CHECK: vorr q{{[0-9]+}}, q{{[0-9]+}}, q{{[0-9]+}}
Expand Down Expand Up @@ -404,6 +459,12 @@ check!(qreg_low4_i32x4 i32x4 qreg_low4 "vmov");
// CHECK: @NO_APP
check!(qreg_low4_i64x2 i64x2 qreg_low4 "vmov");

// CHECK-LABEL: qreg_low4_f16x8:
// CHECK: @APP
// CHECK: vorr q{{[0-9]+}}, q{{[0-9]+}}, q{{[0-9]+}}
// CHECK: @NO_APP
check!(qreg_low4_f16x8 f16x8 qreg_low4 "vmov");

// CHECK-LABEL: qreg_low4_f32x4:
// CHECK: @APP
// CHECK: vorr q{{[0-9]+}}, q{{[0-9]+}}, q{{[0-9]+}}
Expand All @@ -428,6 +489,12 @@ check_reg!(r0_i16 i16 "r0" "mov");
// CHECK: @NO_APP
check_reg!(r0_i32 i32 "r0" "mov");

// CHECK-LABEL: r0_f16:
// CHECK: @APP
// CHECK: mov r0, r0
// CHECK: @NO_APP
check_reg!(r0_f16 f16 "r0" "mov");

// CHECK-LABEL: r0_f32:
// CHECK: @APP
// CHECK: mov r0, r0
Expand All @@ -446,6 +513,12 @@ check_reg!(r0_ptr ptr "r0" "mov");
// CHECK: @NO_APP
check_reg!(s0_i32 i32 "s0" "vmov.f32");

// CHECK-LABEL: s0_f16:
// CHECK: @APP
// CHECK: vmov.f32 s0, s0
// CHECK: @NO_APP
check_reg!(s0_f16 f16 "s0" "vmov.f32");

// CHECK-LABEL: s0_f32:
// CHECK: @APP
// CHECK: vmov.f32 s0, s0
Expand Down Expand Up @@ -494,6 +567,12 @@ check_reg!(d0_i32x2 i32x2 "d0" "vmov.f64");
// CHECK: @NO_APP
check_reg!(d0_i64x1 i64x1 "d0" "vmov.f64");

// CHECK-LABEL: d0_f16x4:
// CHECK: @APP
// CHECK: vmov.f64 d0, d0
// CHECK: @NO_APP
check_reg!(d0_f16x4 f16x4 "d0" "vmov.f64");

// CHECK-LABEL: d0_f32x2:
// CHECK: @APP
// CHECK: vmov.f64 d0, d0
Expand Down Expand Up @@ -524,6 +603,12 @@ check_reg!(q0_i32x4 i32x4 "q0" "vmov");
// CHECK: @NO_APP
check_reg!(q0_i64x2 i64x2 "q0" "vmov");

// CHECK-LABEL: q0_f16x8:
// CHECK: @APP
// CHECK: vorr q0, q0, q0
// CHECK: @NO_APP
check_reg!(q0_f16x8 f16x8 "q0" "vmov");

// CHECK-LABEL: q0_f32x4:
// CHECK: @APP
// CHECK: vorr q0, q0, q0
Expand Down

0 comments on commit 0a6a421

Please sign in to comment.