Skip to content

Commit

Permalink
Cover all shapes for add / cast.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Aug 4, 2024
1 parent 1b568f4 commit 6c6ade0
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 16 deletions.
4 changes: 0 additions & 4 deletions lib/nnc/cmd/blas/mps/ccv_nnc_add_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,6 @@ static int _ccv_nnc_add_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint,
use_mfa = false;
fallback_reason = "Broadcast semantics unsupported.";
}
if (length % 4 != 0) {
use_mfa = false;
fallback_reason = "Length cannot divide by 4.";
}
}
if (use_mfa) {
if (!CCV_IS_TENSOR_CONTIGUOUS(a) || !CCV_IS_TENSOR_CONTIGUOUS(b) || !CCV_IS_TENSOR_CONTIGUOUS(c)) {
Expand Down
2 changes: 1 addition & 1 deletion lib/nnc/cmd/util/mps/ccv_nnc_util_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ static int _ccv_nnc_datatype_conversion(const ccv_nnc_cmd_t cmd, const ccv_nnc_h
assert(output_size <= input_size);
int i;
@autoreleasepool {
bool use_mfa = false;
bool use_mfa = true;
const char *fallback_reason = NULL;
ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();

Expand Down
61 changes: 55 additions & 6 deletions lib/nnc/mfa/ccv_nnc_mfa_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,26 @@ mfa::add::pipeline::pipeline(mfa::context* context, mfa::add::hash hash) {

auto* pool = NS::AutoreleasePool::alloc()->init();

std::string shader = R"(
std::string shader;
// In this case, we can igore the boundary check.
if (hash.length % (4 * 256) == 0) {
shader = R"(
#include <metal_stdlib>
using namespace metal;
kernel void add(
device const real4 *src0 [[buffer(0)]],
device const real4 *src1 [[buffer(1)]],
device real4 *dst [[buffer(2)]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint idx = tpig.x;
dst[idx] = src0[idx] + src1[idx];
}
)";
} else if (hash.length % 4 == 0) {
shader = R"(
#include <metal_stdlib>
using namespace metal;
Expand All @@ -90,21 +109,51 @@ kernel void add(
dst[idx] = src0[idx] + src1[idx];
}
)";
} else {
shader = R"(
#include <metal_stdlib>
using namespace metal;
kernel void add(
device const real *src0 [[buffer(0)]],
device const real *src1 [[buffer(1)]],
device real *dst [[buffer(2)]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint idx = tpig.x;
if (idx >= count)
return;
dst[idx] = src0[idx] + src1[idx];
}
)";
}

std::string defines = "";
if (hash.data_type == MTL::DataTypeFloat) {
defines += std::string("typedef float4 real4;");
defines += "\n";
defines += std::string("typedef float real;");
defines += "\n";
} else {
defines += std::string("typedef half4 real4;");
defines += "\n";
defines += std::string("typedef half real;");
defines += "\n";
}

defines += "constant uint count = ";
CCV_NNC_MFA_PRECONDITION(hash.length % 4 == 0)
const unsigned int count = hash.length / 4;
defines += std::to_string(count) + ";";
defines += "\n";
unsigned int count;
if (hash.length % 4 == 0) {
count = hash.length / 4;
} else {
count = hash.length;
}
// Only boundary check needs this const in the shader.
if (hash.length % (4 * 256) != 0) {
defines += "constant uint count = ";
defines += std::to_string(count) + ";";
defines += "\n";
}
this->group_size = MTL::Size(256, 1, 1);
const int num_blocks = (count + 255) / 256;
this->grid_size = MTL::Size(num_blocks, 1, 1);
Expand Down
62 changes: 57 additions & 5 deletions lib/nnc/mfa/ccv_nnc_mfa_cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,42 @@ mfa::cast::pipeline::pipeline(mfa::context* context, mfa::cast::hash hash) {

auto* pool = NS::AutoreleasePool::alloc()->init();

std::string shader = R"(
std::string shader;
// In this case, we can igore the boundary check.
if (hash.length % (4 * 256) == 0) {
shader = R"(
#include <metal_stdlib>
using namespace metal;
kernel void cast(
device original_real4 *src [[buffer(0)]],
device real4 *destination [[buffer(1)]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint idx = tpig.x;
destination[idx] = (real4)(src[idx]);
}
)";
} else if (hash.length % 4 == 0) {
shader = R"(
#include <metal_stdlib>
using namespace metal;
kernel void cast(
device original_real4 *src [[buffer(0)]],
device real4 *destination [[buffer(1)]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint idx = tpig.x;
if (idx >= count)
return;
destination[idx] = (real4)(src[idx]);
}
)";
} else {
shader = R"(
#include <metal_stdlib>
using namespace metal;
Expand All @@ -95,29 +130,46 @@ kernel void cast(
destination[idx] = (real)(src[idx]);
}
)";
}

std::string defines = "";
if (hash.data_type == MTL::DataTypeFloat) {
defines += std::string("typedef float real;");
defines += "\n";
defines += std::string("typedef float4 real4;");
defines += "\n";
} else {
defines += std::string("typedef half real;");
defines += "\n";
defines += std::string("typedef half4 real4;");
defines += "\n";
}

if (hash.original_data_type == MTL::DataTypeFloat) {
defines += std::string("typedef float original_real;");
defines += "\n";
defines += std::string("typedef float4 original_real4;");
defines += "\n";
} else {
defines += std::string("typedef half original_real;");
defines += "\n";
defines += std::string("typedef half4 original_real4;");
defines += "\n";
}

defines += "constant uint count = ";
defines += std::to_string(hash.length) + ";";
defines += "\n";
unsigned int count;
if (hash.length % 4 == 0) {
count = hash.length / 4;
} else {
count = hash.length;
}
if (hash.length % (4 * 256) != 0) {
defines += "constant uint count = ";
defines += std::to_string(count) + ";";
defines += "\n";
}
this->group_size = MTL::Size(256, 1, 1);
const int num_blocks = (hash.length + 255) / 256;
const int num_blocks = (count + 255) / 256;
this->grid_size = MTL::Size(num_blocks, 1, 1);

auto constants = NS::TransferPtr(MTL::FunctionConstantValues::alloc()->init());
Expand Down

0 comments on commit 6c6ade0

Please sign in to comment.