Skip to content

Commit

Permalink
Reduce the amount of copying required to evaluated vector constants
Browse files Browse the repository at this point in the history
  • Loading branch information
tannergooding committed Jun 10, 2024
1 parent 1cca48e commit d057ab9
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 153 deletions.
17 changes: 9 additions & 8 deletions src/coreclr/jit/simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ TBase EvaluateUnaryScalar(genTreeOps oper, TBase arg0)
}

template <typename TSimd, typename TBase>
void EvaluateUnarySimd(genTreeOps oper, bool scalar, TSimd* result, TSimd arg0)
void EvaluateUnarySimd(genTreeOps oper, bool scalar, TSimd* result, const TSimd& arg0)
{
uint32_t count = sizeof(TSimd) / sizeof(TBase);

Expand Down Expand Up @@ -445,7 +445,7 @@ void EvaluateUnarySimd(genTreeOps oper, bool scalar, TSimd* result, TSimd arg0)
}

template <typename TSimd>
void EvaluateUnarySimd(genTreeOps oper, bool scalar, var_types baseType, TSimd* result, TSimd arg0)
void EvaluateUnarySimd(genTreeOps oper, bool scalar, var_types baseType, TSimd* result, const TSimd& arg0)
{
switch (baseType)
{
Expand Down Expand Up @@ -725,7 +725,7 @@ TBase EvaluateBinaryScalar(genTreeOps oper, TBase arg0, TBase arg1)
}

template <typename TSimd, typename TBase>
void EvaluateBinarySimd(genTreeOps oper, bool scalar, TSimd* result, TSimd arg0, TSimd arg1)
void EvaluateBinarySimd(genTreeOps oper, bool scalar, TSimd* result, const TSimd& arg0, const TSimd& arg1)
{
uint32_t count = sizeof(TSimd) / sizeof(TBase);

Expand Down Expand Up @@ -758,7 +758,8 @@ void EvaluateBinarySimd(genTreeOps oper, bool scalar, TSimd* result, TSimd arg0,
}

template <typename TSimd>
void EvaluateBinarySimd(genTreeOps oper, bool scalar, var_types baseType, TSimd* result, TSimd arg0, TSimd arg1)
void EvaluateBinarySimd(
genTreeOps oper, bool scalar, var_types baseType, TSimd* result, const TSimd& arg0, const TSimd& arg1)
{
switch (baseType)
{
Expand Down Expand Up @@ -830,7 +831,7 @@ void EvaluateBinarySimd(genTreeOps oper, bool scalar, var_types baseType, TSimd*
}

template <typename TSimd>
double EvaluateGetElementFloating(var_types simdBaseType, TSimd arg0, int32_t arg1)
double EvaluateGetElementFloating(var_types simdBaseType, const TSimd& arg0, int32_t arg1)
{
switch (simdBaseType)
{
Expand All @@ -852,7 +853,7 @@ double EvaluateGetElementFloating(var_types simdBaseType, TSimd arg0, int32_t ar
}

template <typename TSimd>
int64_t EvaluateGetElementIntegral(var_types simdBaseType, TSimd arg0, int32_t arg1)
int64_t EvaluateGetElementIntegral(var_types simdBaseType, const TSimd& arg0, int32_t arg1)
{
switch (simdBaseType)
{
Expand Down Expand Up @@ -904,7 +905,7 @@ int64_t EvaluateGetElementIntegral(var_types simdBaseType, TSimd arg0, int32_t a
}

template <typename TSimd>
void EvaluateWithElementFloating(var_types simdBaseType, TSimd* result, TSimd arg0, int32_t arg1, double arg2)
void EvaluateWithElementFloating(var_types simdBaseType, TSimd* result, const TSimd& arg0, int32_t arg1, double arg2)
{
*result = arg0;

Expand All @@ -930,7 +931,7 @@ void EvaluateWithElementFloating(var_types simdBaseType, TSimd* result, TSimd ar
}

template <typename TSimd>
void EvaluateWithElementIntegral(var_types simdBaseType, TSimd* result, TSimd arg0, int32_t arg1, int64_t arg2)
void EvaluateWithElementIntegral(var_types simdBaseType, TSimd* result, const TSimd& arg0, int32_t arg1, int64_t arg2)
{
*result = arg0;

Expand Down
183 changes: 56 additions & 127 deletions src/coreclr/jit/valuenum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1856,33 +1856,33 @@ ValueNum ValueNumStore::VNForByrefCon(target_size_t cnsVal)
}

#if defined(FEATURE_SIMD)
ValueNum ValueNumStore::VNForSimd8Con(simd8_t cnsVal)
ValueNum ValueNumStore::VNForSimd8Con(const simd8_t& cnsVal)
{
return VnForConst(cnsVal, GetSimd8CnsMap(), TYP_SIMD8);
}

ValueNum ValueNumStore::VNForSimd12Con(simd12_t cnsVal)
ValueNum ValueNumStore::VNForSimd12Con(const simd12_t& cnsVal)
{
return VnForConst(cnsVal, GetSimd12CnsMap(), TYP_SIMD12);
}

ValueNum ValueNumStore::VNForSimd16Con(simd16_t cnsVal)
ValueNum ValueNumStore::VNForSimd16Con(const simd16_t& cnsVal)
{
return VnForConst(cnsVal, GetSimd16CnsMap(), TYP_SIMD16);
}

#if defined(TARGET_XARCH)
ValueNum ValueNumStore::VNForSimd32Con(simd32_t cnsVal)
ValueNum ValueNumStore::VNForSimd32Con(const simd32_t& cnsVal)
{
return VnForConst(cnsVal, GetSimd32CnsMap(), TYP_SIMD32);
}

ValueNum ValueNumStore::VNForSimd64Con(simd64_t cnsVal)
ValueNum ValueNumStore::VNForSimd64Con(const simd64_t& cnsVal)
{
return VnForConst(cnsVal, GetSimd64CnsMap(), TYP_SIMD64);
}

ValueNum ValueNumStore::VNForSimdMaskCon(simdmask_t cnsVal)
ValueNum ValueNumStore::VNForSimdMaskCon(const simdmask_t& cnsVal)
{
return VnForConst(cnsVal, GetSimdMaskCnsMap(), TYP_MASK);
}
Expand Down Expand Up @@ -2217,70 +2217,59 @@ ValueNum ValueNumStore::VNAllBitsForType(var_types typ)
}

#ifdef FEATURE_SIMD
ValueNum ValueNumStore::VNOneForSimdType(var_types simdType, var_types simdBaseType)
template <typename TSimd>
TSimd BroadcastConstantToSimd(ValueNumStore* vns, var_types baseType, ValueNum argVN)
{
assert(varTypeIsSIMD(simdType));
assert(vns->IsVNConstant(argVN));
assert(!varTypeIsSIMD(vns->TypeOfVN(argVN)));

simd_t simdVal = {};
int simdSize = genTypeSize(simdType);
TSimd result = {};

switch (simdBaseType)
switch (baseType)
{
case TYP_BYTE:
case TYP_UBYTE:
case TYP_FLOAT:
{
for (int i = 0; i < simdSize; i++)
{
simdVal.u8[i] = 1;
}
float arg = vns->GetConstantSingle(argVN);
BroadcastConstantToSimd<TSimd, float>(&result, arg);
break;
}

case TYP_SHORT:
case TYP_USHORT:
case TYP_DOUBLE:
{
for (int i = 0; i < (simdSize / 2); i++)
{
simdVal.u16[i] = 1;
}
double arg = vns->GetConstantDouble(argVN);
BroadcastConstantToSimd<TSimd, double>(&result, arg);
break;
}

case TYP_INT:
case TYP_UINT:
case TYP_BYTE:
case TYP_UBYTE:
{
for (int i = 0; i < (simdSize / 4); i++)
{
simdVal.u32[i] = 1;
}
uint8_t arg = static_cast<uint8_t>(vns->GetConstantInt32(argVN));
BroadcastConstantToSimd<TSimd, uint8_t>(&result, arg);
break;
}

case TYP_LONG:
case TYP_ULONG:
case TYP_SHORT:
case TYP_USHORT:
{
for (int i = 0; i < (simdSize / 8); i++)
{
simdVal.u64[i] = 1;
}
uint16_t arg = static_cast<uint16_t>(vns->GetConstantInt32(argVN));
BroadcastConstantToSimd<TSimd, uint16_t>(&result, arg);
break;
}

case TYP_FLOAT:
case TYP_INT:
case TYP_UINT:
{
for (int i = 0; i < (simdSize / 4); i++)
{
simdVal.f32[i] = 1.0f;
}
uint32_t arg = static_cast<uint32_t>(vns->GetConstantInt32(argVN));
BroadcastConstantToSimd<TSimd, uint32_t>(&result, arg);
break;
}

case TYP_DOUBLE:
case TYP_LONG:
case TYP_ULONG:
{
for (int i = 0; i < (simdSize / 8); i++)
{
simdVal.f64[i] = 1.0;
}
uint64_t arg = static_cast<uint64_t>(vns->GetConstantInt64(argVN));
BroadcastConstantToSimd<TSimd, uint64_t>(&result, arg);
break;
}

Expand All @@ -2290,42 +2279,46 @@ ValueNum ValueNumStore::VNOneForSimdType(var_types simdType, var_types simdBaseT
}
}

return result;
}

ValueNum ValueNumStore::VNOneForSimdType(var_types simdType, var_types simdBaseType)
{
assert(varTypeIsSIMD(simdType));

ValueNum oneVN = VNOneForType(simdBaseType);

switch (simdType)
{
case TYP_SIMD8:
{
simd8_t simd8Val;
memcpy(&simd8Val, &simdVal, sizeof(simd8_t));
return VNForSimd8Con(simd8Val);
simd8_t result = BroadcastConstantToSimd<simd8_t>(this, simdBaseType, oneVN);
return VNForSimd8Con(result);
}

case TYP_SIMD12:
{
simd12_t simd12Val;
memcpy(&simd12Val, &simdVal, sizeof(simd12_t));
return VNForSimd12Con(simd12Val);
simd12_t result = BroadcastConstantToSimd<simd12_t>(this, simdBaseType, oneVN);
return VNForSimd12Con(result);
}

case TYP_SIMD16:
{
simd16_t simd16Val;
memcpy(&simd16Val, &simdVal, sizeof(simd16_t));
return VNForSimd16Con(simd16Val);
simd16_t result = BroadcastConstantToSimd<simd16_t>(this, simdBaseType, oneVN);
return VNForSimd16Con(result);
}

#if defined(TARGET_XARCH)
case TYP_SIMD32:
{
simd32_t simd32Val;
memcpy(&simd32Val, &simdVal, sizeof(simd32_t));
return VNForSimd32Con(simd32Val);
simd32_t result = BroadcastConstantToSimd<simd32_t>(this, simdBaseType, oneVN);
return VNForSimd32Con(result);
}

case TYP_SIMD64:
{
simd64_t simd64Val;
memcpy(&simd64Val, &simdVal, sizeof(simd64_t));
return VNForSimd64Con(simd64Val);
simd64_t result = BroadcastConstantToSimd<simd64_t>(this, simdBaseType, oneVN);
return VNForSimd64Con(result);
}

case TYP_MASK:
Expand Down Expand Up @@ -6870,71 +6863,6 @@ void ValueNumStore::SetVNIsCheckedBound(ValueNum vn)
}

#ifdef FEATURE_HW_INTRINSICS
template <typename TSimd>
TSimd BroadcastConstantToSimd(ValueNumStore* vns, var_types baseType, ValueNum argVN)
{
assert(vns->IsVNConstant(argVN));
assert(!varTypeIsSIMD(vns->TypeOfVN(argVN)));

TSimd result = {};

switch (baseType)
{
case TYP_FLOAT:
{
float arg = vns->GetConstantSingle(argVN);
BroadcastConstantToSimd<TSimd, float>(&result, arg);
break;
}

case TYP_DOUBLE:
{
double arg = vns->GetConstantDouble(argVN);
BroadcastConstantToSimd<TSimd, double>(&result, arg);
break;
}

case TYP_BYTE:
case TYP_UBYTE:
{
uint8_t arg = static_cast<uint8_t>(vns->GetConstantInt32(argVN));
BroadcastConstantToSimd<TSimd, uint8_t>(&result, arg);
break;
}

case TYP_SHORT:
case TYP_USHORT:
{
uint16_t arg = static_cast<uint16_t>(vns->GetConstantInt32(argVN));
BroadcastConstantToSimd<TSimd, uint16_t>(&result, arg);
break;
}

case TYP_INT:
case TYP_UINT:
{
uint32_t arg = static_cast<uint32_t>(vns->GetConstantInt32(argVN));
BroadcastConstantToSimd<TSimd, uint32_t>(&result, arg);
break;
}

case TYP_LONG:
case TYP_ULONG:
{
uint64_t arg = static_cast<uint64_t>(vns->GetConstantInt64(argVN));
BroadcastConstantToSimd<TSimd, uint64_t>(&result, arg);
break;
}

default:
{
unreached();
}
}

return result;
}

simd8_t GetConstantSimd8(ValueNumStore* vns, var_types baseType, ValueNum argVN)
{
assert(vns->IsVNConstant(argVN));
Expand Down Expand Up @@ -7126,7 +7054,7 @@ ValueNum EvaluateBinarySimd(ValueNumStore* vns,
}

template <typename TSimd>
ValueNum EvaluateSimdGetElement(ValueNumStore* vns, var_types baseType, TSimd arg0, int32_t arg1)
ValueNum EvaluateSimdGetElement(ValueNumStore* vns, var_types baseType, const TSimd& arg0, int32_t arg1)
{
switch (baseType)
{
Expand Down Expand Up @@ -7617,7 +7545,8 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree,
if (TypeOfVN(arg1VN) == TYP_SIMD16)
{
if ((ni != NI_AVX2_ShiftLeftLogicalVariable) && (ni != NI_AVX2_ShiftRightArithmeticVariable) &&
(ni != NI_AVX512F_VL_ShiftRightArithmeticVariable) && (ni != NI_AVX2_ShiftRightLogicalVariable))
(ni != NI_AVX512F_VL_ShiftRightArithmeticVariable) &&
(ni != NI_AVX10v1_ShiftRightArithmeticVariable) && (ni != NI_AVX2_ShiftRightLogicalVariable))
{
// The xarch shift instructions support taking the shift amount as
// a simd16, in which case they take the shift amount from the lower
Expand Down
Loading

0 comments on commit d057ab9

Please sign in to comment.