Skip to content

Commit

Permalink
Merge pull request #423 from rust-lang/bitmask-again-again-again
Browse files Browse the repository at this point in the history
Implement special swizzles for masks and remove `{to,from}_bitmask_vector`
  • Loading branch information
calebzulawski authored Jun 5, 2024
2 parents 7cd6f95 + bd92b7c commit 8c31005
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 158 deletions.
42 changes: 0 additions & 42 deletions crates/core_simd/src/masks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,48 +308,6 @@ where
Self(mask_impl::Mask::from_bitmask_integer(bitmask))
}

/// Create a bitmask vector from a mask.
///
/// Each bit is set if the corresponding element in the mask is `true`.
/// The remaining bits are unset.
///
/// The bits are packed into the first N bits of the vector:
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::mask32x8;
/// let mask = mask32x8::from_array([true, false, true, false, false, false, true, false]);
/// assert_eq!(mask.to_bitmask_vector()[0], 0b01000101);
/// ```
#[inline]
#[must_use = "method returns a new integer and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<u8, N> {
self.0.to_bitmask_vector()
}

/// Create a mask from a bitmask vector.
///
/// For each bit, if it is set, the corresponding element in the mask is set to `true`.
///
/// The bits are packed into the first N bits of the vector:
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::{mask32x8, u8x8};
/// let bitmask = u8x8::from_array([0b01000101, 0, 0, 0, 0, 0, 0, 0]);
/// assert_eq!(
/// mask32x8::from_bitmask_vector(bitmask),
/// mask32x8::from_array([true, false, true, false, false, false, true, false]),
/// );
/// ```
#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
Self(mask_impl::Mask::from_bitmask_vector(bitmask))
}

/// Find the index of the first set element.
///
/// ```
Expand Down
17 changes: 0 additions & 17 deletions crates/core_simd/src/masks/bitmask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,6 @@ where
unsafe { Self(core::intrinsics::simd::simd_bitmask(value), PhantomData) }
}

#[inline]
#[must_use = "method returns a new vector and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<u8, N> {
let mut bitmask = Simd::splat(0);
bitmask.as_mut_array()[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref());
bitmask
}

#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
let len = bytes.as_ref().len();
bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]);
Self(bytes, PhantomData)
}

#[inline]
pub fn to_bitmask_integer(self) -> u64 {
let mut bitmask = [0u8; 8];
Expand Down
56 changes: 0 additions & 56 deletions crates/core_simd/src/masks/full_masks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,62 +140,6 @@ where
unsafe { Mask(core::intrinsics::simd::simd_cast(self.0)) }
}

#[inline]
#[must_use = "method returns a new vector and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<u8, N> {
let mut bitmask = Simd::splat(0);

// Safety: Bytes is the right size array
unsafe {
// Compute the bitmask
let mut bytes: <LaneCount<N> as SupportedLaneCount>::BitMask =
core::intrinsics::simd::simd_bitmask(self.0);

// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
for x in bytes.as_mut() {
*x = x.reverse_bits()
}
if N % 8 > 0 {
bytes.as_mut()[N / 8] >>= 8 - N % 8;
}
}

bitmask.as_mut_array()[..bytes.as_ref().len()].copy_from_slice(bytes.as_ref());
}

bitmask
}

#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();

// Safety: Bytes is the right size array
unsafe {
let len = bytes.as_ref().len();
bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]);

// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
for x in bytes.as_mut() {
*x = x.reverse_bits();
}
if N % 8 > 0 {
bytes.as_mut()[N / 8] >>= 8 - N % 8;
}
}

// Compute the regular mask
Self::from_int_unchecked(core::intrinsics::simd::simd_select_bitmask(
bytes,
Self::splat(true).to_int(),
Self::splat(false).to_int(),
))
}
}

#[inline]
unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U
where
Expand Down
180 changes: 179 additions & 1 deletion crates/core_simd/src/swizzle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ where
///
/// ```
/// # #![feature(portable_simd)]
/// # use core::simd::Simd;
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::Simd;
/// let a = Simd::from_array([0, 4, 1, 5]);
/// let b = Simd::from_array([2, 6, 3, 7]);
/// let (x, y) = a.deinterleave(b);
Expand Down Expand Up @@ -383,4 +385,180 @@ where
}
Resize::<N>::concat_swizzle(self, Simd::splat(value))
}

/// Extract a vector from another vector.
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::u32x4;
/// let x = u32x4::from_array([0, 1, 2, 3]);
/// assert_eq!(x.extract::<1, 2>().to_array(), [1, 2]);
/// ```
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn extract<const START: usize, const LEN: usize>(self) -> Simd<T, LEN>
where
LaneCount<LEN>: SupportedLaneCount,
{
struct Extract<const N: usize, const START: usize>;
impl<const N: usize, const START: usize, const LEN: usize> Swizzle<LEN> for Extract<N, START> {
const INDEX: [usize; LEN] = const {
assert!(START + LEN <= N, "index out of bounds");
let mut index = [0; LEN];
let mut i = 0;
while i < LEN {
index[i] = START + i;
i += 1;
}
index
};
}
Extract::<N, START>::swizzle(self)
}
}

impl<T, const N: usize> Mask<T, N>
where
T: MaskElement,
LaneCount<N>: SupportedLaneCount,
{
/// Reverse the order of the elements in the mask.
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn reverse(self) -> Self {
// Safety: swizzles are safe for masks
unsafe { Self::from_int_unchecked(self.to_int().reverse()) }
}

/// Rotates the mask such that the first `OFFSET` elements of the slice move to the end
/// while the last `self.len() - OFFSET` elements move to the front. After calling `rotate_elements_left`,
/// the element previously at index `OFFSET` will become the first element in the slice.
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn rotate_elements_left<const OFFSET: usize>(self) -> Self {
// Safety: swizzles are safe for masks
unsafe { Self::from_int_unchecked(self.to_int().rotate_elements_left::<OFFSET>()) }
}

/// Rotates the mask such that the first `self.len() - OFFSET` elements of the mask move to
/// the end while the last `OFFSET` elements move to the front. After calling `rotate_elements_right`,
/// the element previously at index `self.len() - OFFSET` will become the first element in the slice.
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn rotate_elements_right<const OFFSET: usize>(self) -> Self {
// Safety: swizzles are safe for masks
unsafe { Self::from_int_unchecked(self.to_int().rotate_elements_right::<OFFSET>()) }
}

/// Interleave two masks.
///
/// The resulting masks contain elements taken alternatively from `self` and `other`, first
/// filling the first result, and then the second.
///
/// The reverse of this operation is [`Mask::deinterleave`].
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::mask32x4;
/// let a = mask32x4::from_array([false, true, false, true]);
/// let b = mask32x4::from_array([false, false, true, true]);
/// let (x, y) = a.interleave(b);
/// assert_eq!(x.to_array(), [false, false, true, false]);
/// assert_eq!(y.to_array(), [false, true, true, true]);
/// ```
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn interleave(self, other: Self) -> (Self, Self) {
let (lo, hi) = self.to_int().interleave(other.to_int());
// Safety: swizzles are safe for masks
unsafe { (Self::from_int_unchecked(lo), Self::from_int_unchecked(hi)) }
}

/// Deinterleave two masks.
///
/// The first result takes every other element of `self` and then `other`, starting with
/// the first element.
///
/// The second result takes every other element of `self` and then `other`, starting with
/// the second element.
///
/// The reverse of this operation is [`Mask::interleave`].
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::mask32x4;
/// let a = mask32x4::from_array([false, true, false, true]);
/// let b = mask32x4::from_array([false, false, true, true]);
/// let (x, y) = a.deinterleave(b);
/// assert_eq!(x.to_array(), [false, false, false, true]);
/// assert_eq!(y.to_array(), [true, true, false, true]);
/// ```
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn deinterleave(self, other: Self) -> (Self, Self) {
let (even, odd) = self.to_int().deinterleave(other.to_int());
// Safety: swizzles are safe for masks
unsafe {
(
Self::from_int_unchecked(even),
Self::from_int_unchecked(odd),
)
}
}

/// Resize a mask.
///
/// If `M` > `N`, extends the length of a mask, setting the new elements to `value`.
/// If `M` < `N`, truncates the mask to the first `M` elements.
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::mask32x4;
/// let x = mask32x4::from_array([false, true, true, false]);
/// assert_eq!(x.resize::<8>(true).to_array(), [false, true, true, false, true, true, true, true]);
/// assert_eq!(x.resize::<2>(true).to_array(), [false, true]);
/// ```
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn resize<const M: usize>(self, value: bool) -> Mask<T, M>
where
LaneCount<M>: SupportedLaneCount,
{
// Safety: swizzles are safe for masks
unsafe {
Mask::<T, M>::from_int_unchecked(self.to_int().resize::<M>(if value {
T::TRUE
} else {
T::FALSE
}))
}
}

/// Extract a vector from another vector.
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::mask32x4;
/// let x = mask32x4::from_array([false, true, true, false]);
/// assert_eq!(x.extract::<1, 2>().to_array(), [true, true]);
/// ```
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn extract<const START: usize, const LEN: usize>(self) -> Mask<T, LEN>
where
LaneCount<LEN>: SupportedLaneCount,
{
// Safety: swizzles are safe for masks
unsafe { Mask::<T, LEN>::from_int_unchecked(self.to_int().extract::<START, LEN>()) }
}
}
7 changes: 7 additions & 0 deletions crates/core_simd/src/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ where
///
/// When the element is disabled, that memory location is not accessed and the corresponding
/// value from `or` is passed through.
///
/// # Safety
/// Enabled loads must not exceed the length of `slice`.
#[must_use]
#[inline]
pub unsafe fn load_select_unchecked(
Expand All @@ -459,6 +462,9 @@ where
///
/// When the element is disabled, that memory location is not accessed and the corresponding
/// value from `or` is passed through.
///
/// # Safety
/// Enabled `ptr` elements must be safe to read as if by `std::ptr::read`.
#[must_use]
#[inline]
pub unsafe fn load_select_ptr(
Expand Down Expand Up @@ -1214,6 +1220,7 @@ fn lane_indices<const N: usize>() -> Simd<usize, N>
where
LaneCount<N>: SupportedLaneCount,
{
#![allow(clippy::needless_range_loop)]
let mut index = [0; N];
for i in 0..N {
index[i] = i;
Expand Down
Loading

0 comments on commit 8c31005

Please sign in to comment.