Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 4 No.36】为 Paddle 优化 tile op 在 GPU 上的计算性能 #52482

Merged
merged 12 commits into from
Apr 10, 2023
Merged
84 changes: 82 additions & 2 deletions paddle/phi/kernels/gpu/tile_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,90 @@
// limitations under the License.

#include "paddle/phi/kernels/tile_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/tile_kernel_impl.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"

namespace phi {

template <typename T, typename Context>
void TileKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& repeat_times,
DenseTensor* out) {
auto x_dims = x.dims();
auto rank = x_dims.size();
auto repeat_times_data = repeat_times.GetData();
int repeat_times_size = repeat_times_data.size();
rank = std::max(rank, repeat_times_size);

if (rank == 0) {
phi::Copy<DeviceContext>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
return;
}

for (size_t i = 0; i < repeat_times_data.size(); ++i) {
PADDLE_ENFORCE_GT(
repeat_times_data[i],
0,
errors::InvalidArgument(
"All elements of the input 'repeat_times' for tile op must "
"be positive integers, but the value received is %d.",
repeat_times_data[i]));
}

auto vec_x_dims = phi::vectorize<int>(x_dims);
if (repeat_times_data.size() < vec_x_dims.size()) {
int diff = vec_x_dims.size() - repeat_times_data.size();
repeat_times_data.insert(repeat_times_data.begin(), diff, 1);
} else {
int diff = repeat_times_data.size() - vec_x_dims.size();
vec_x_dims.insert(vec_x_dims.begin(), diff, 1);
}

PADDLE_ENFORCE_EQ(
repeat_times_data.size(),
vec_x_dims.size(),
errors::InvalidArgument(
"The rank (%d) of the input 'x' and the rank (%d) of the input "
"'repeat_times' for tile op must match after promotion.",
vec_x_dims.size(),
repeat_times_data.size()));

DDim new_x_dims = make_ddim(vec_x_dims);
DDim out_dims(new_x_dims);
DenseTensor new_x = x;
vec_x_dims.insert(vec_x_dims.begin(), 1, 1);
for (size_t i = 0; i < repeat_times_data.size(); ++i) {
out_dims[i] *= repeat_times_data[i];
new_x.Resize(make_ddim(vec_x_dims));
std::vector<const DenseTensor*> ins = {&new_x};
vec_x_dims[i] *= repeat_times_data[i];
if (i != repeat_times_data.size() - 1) {
if (repeat_times_data[i] != 1) {
DenseTensor tmp_out;
tmp_out.Resize(make_ddim(vec_x_dims));
dev_ctx.template Alloc<T>(&tmp_out);
std::vector<DenseTensor*> outs = {&tmp_out};
phi::funcs::BroadcastKernel<ElementwiseType::kUnary, T, T>(
dev_ctx, ins, &outs, i, kps::IdentityFunctor<T>());
tmp_out.Resize(out_dims);
new_x = tmp_out;
}
vec_x_dims[i] *= vec_x_dims[i + 1];
vec_x_dims[i + 1] = 1;
} else {
out->Resize(make_ddim(vec_x_dims));
dev_ctx.template Alloc<T>(out);
std::vector<DenseTensor*> outs = {out};
phi::funcs::BroadcastKernel<ElementwiseType::kUnary, T, T>(
dev_ctx, ins, &outs, i, kps::IdentityFunctor<T>());
out->Resize(out_dims);
}
}
}

} // namespace phi

PD_REGISTER_KERNEL(tile,
GPU,
Expand Down