Skip to content

Commit

Permalink
dynamic dropout refactor (#334)
Browse files Browse the repository at this point in the history
  • Loading branch information
ce1adon authored Jul 15, 2020
1 parent 542d1e6 commit d1933c3
Showing 1 changed file with 11 additions and 26 deletions.
37 changes: 11 additions & 26 deletions src/ocl/dropoutocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,23 +271,16 @@ void DropoutDescriptor::DropoutForward(const Handle& handle,
std::string kernel_name = "DropoutForward";

std::string network_config =
"fwd-" + std::string(xDesc.GetType() == miopenHalf ? "fp16-" : "fp32-") + "dim" +
std::to_string(in_len[0]) + "x" + std::to_string(in_len[1]) + "x" +
std::to_string(in_len[2]) + "x" + std::to_string(in_len[3]) + "x" +
std::to_string(in_len[4]) + "-xstr" + std::to_string(in_str[0]) + "x" +
std::to_string(in_str[1]) + "x" + std::to_string(in_str[2]) + "x" +
std::to_string(in_str[3]) + "x" + std::to_string(in_str[4]) + "-ystr" +
std::to_string(out_str[0]) + "x" + std::to_string(out_str[1]) + "x" +
std::to_string(out_str[2]) + "x" + std::to_string(out_str[3]) + "x" +
std::to_string(out_str[4]) + "-dropout" + std::to_string(dropout) + "-seed" +
"fwd-" + std::string(xDesc.GetType() == miopenHalf ? "fp16-" : "fp32-") + "-seed" +
std::to_string(seed) + "-rng" + std::to_string(rng_mode) + "-rsvsp" +
std::to_string(static_cast<int>(use_rsvsp)) + "-mask" +
std::to_string(static_cast<int>(use_mask)) + "-evo" +
std::to_string(static_cast<int>(state_evo)) + "-blk" + std::to_string(RD_BLCK) + "-wg" +
std::to_string(wk_grp_num) + "-noise" + std::to_string(noise_shape.GetLengths()[0]);
std::to_string(wk_grp_num) /* + "-noise" + std::to_string(noise_shape.GetLengths()[0])*/;

for(int i = 1; i < noise_shape.GetSize(); i++)
network_config += "x" + std::to_string(noise_shape.GetLengths()[i]);
// TODO: Add noise shape
// for(int i = 1; i < noise_shape.GetSize(); i++)
// network_config += "x" + std::to_string(noise_shape.GetLengths()[i]);

auto&& kernels = handle.GetKernels(kernel_name, network_config);

Expand Down Expand Up @@ -482,23 +475,15 @@ void DropoutDescriptor::DropoutBackward(const Handle& handle,
std::string kernel_name = "DropoutBackward";

std::string network_config =
"bwd-" + std::string(dyDesc.GetType() == miopenHalf ? "fp16-" : "fp32-") + "dim" +
std::to_string(in_len[0]) + "x" + std::to_string(in_len[1]) + "x" +
std::to_string(in_len[2]) + "x" + std::to_string(in_len[3]) + "x" +
std::to_string(in_len[4]) + "-xstr" + std::to_string(in_str[0]) + "x" +
std::to_string(in_str[1]) + "x" + std::to_string(in_str[2]) + "x" +
std::to_string(in_str[3]) + "x" + std::to_string(in_str[4]) + "-ystr" +
std::to_string(out_str[0]) + "x" + std::to_string(out_str[1]) + "x" +
std::to_string(out_str[2]) + "x" + std::to_string(out_str[3]) + "x" +
std::to_string(out_str[4]) + "-dropout" + std::to_string(dropout) + "-seed" +
"bwd-" + std::string(dyDesc.GetType() == miopenHalf ? "fp16-" : "fp32-") + "-seed" +
std::to_string(seed) + "-rng" + std::to_string(rng_mode) + "-prng" +
std::to_string(static_cast<int>(use_prng)) + "-mask" +
std::to_string(static_cast<int>(use_mask)) + "-evo" +
std::to_string(static_cast<int>(use_prng)) + "-evo" +
std::to_string(static_cast<int>(state_evo)) + "-blk" + std::to_string(RD_BLCK) + "-wg" +
std::to_string(wk_grp_num) + "-noise" + std::to_string(noise_shape.GetLengths()[0]);
std::to_string(wk_grp_num) /* + "-noise" + std::to_string(noise_shape.GetLengths()[0]) */;

for(int i = 1; i < noise_shape.GetSize(); i++)
network_config += "x" + std::to_string(noise_shape.GetLengths()[i]);
// TODO: Add noise shape
// for(int i = 1; i < noise_shape.GetSize(); i++)
// network_config += "x" + std::to_string(noise_shape.GetLengths()[i]);

auto&& kernels = handle.GetKernels(kernel_name, network_config);

Expand Down

0 comments on commit d1933c3

Please sign in to comment.