Skip to content

Commit

Permalink
[SYCL] Fix tanhf to support lstm
Browse files Browse the repository at this point in the history
  • Loading branch information
DongBaiYue committed Oct 28, 2022
1 parent 2349d35 commit c32e75b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/runtime/sycl/sycl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ class SYCLWrappedFunc {
void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
VLOG(0) << "enter sycl wrapped func operator()";
ICHECK(w_->devices.size() != 0) << "No SYCL device";

// get kernel
void (*kernel_func)(sycl::queue &Q, sycl::range<3> k0_dimGrid, sycl::range<3> k0_dimBlock, void** void_args) = (void (*)(sycl::queue &Q, sycl::range<3> k0_dimGrid, sycl::range<3> k0_dimBlock, void** void_args))dlsym(so_handler_, func_name_.c_str());
ICHECK(kernel_func != NULL) << "ERROR:"<<dlerror()<<":dlsym\n";

// get thread dimension
ThreadWorkLoad wl = launch_param_config_.Extract(args);
sycl::range<3> k0_dimGrid(wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
sycl::range<3> k0_dimBlock(wl.block_dim(0), wl.block_dim(1), wl.block_dim(2));
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_sycl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,8 @@ void CodeGenSYCL::VisitExpr_(const CallNode* op, std::ostream& os) {
func = StringImm("exp");
}else if(func->value == "powf"){
func = StringImm("pow");
}else if(func->value == "tanhf"){
func = StringImm("tanh");
}
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os);
} else {
Expand Down

0 comments on commit c32e75b

Please sign in to comment.