diff --git a/R/codegen-utils.R b/R/codegen-utils.R index 82c5c898e6..68a1392f30 100644 --- a/R/codegen-utils.R +++ b/R/codegen-utils.R @@ -13,17 +13,7 @@ as_1_based_tensor_list <- function(x) { } as_1_based_tensor <- function(x) { - with_no_grad({ - if (!any(x$shape == 0)) { - e <- torch_min(torch_abs(x))$to(dtype = torch_int()) - if (e$item() == 0) { - runtime_error("Indices/Index start at 1 and got a 0.") - } - } - - out <- x - (x > 0)$to(dtype = x$dtype) - }) - out + to_index_tensor(x) } clean_chars <- c("'", "\"", "%", "#", ":", ">", "<", ",", " ", "*", "&") diff --git a/R/gen-method.R b/R/gen-method.R index 52f4512e84..fd76c9451c 100644 --- a/R/gen-method.R +++ b/R/gen-method.R @@ -4582,7 +4582,7 @@ call_c_function( return_types = return_types, fun_type = 'method' )}) -Tensor$set("public", "movedim", function(source, destination) { args <- mget(x = c("source", "destination")) +Tensor$set("private", "_movedim", function(source, destination) { args <- mget(x = c("source", "destination")) args <- c(list(self = self), args) expected_types <- list(self = "Tensor", source = c("IntArrayRef", "int64_t"), destination = c("IntArrayRef", "int64_t")) diff --git a/R/tensor.R b/R/tensor.R index e86757cec3..88bdb09e25 100644 --- a/R/tensor.R +++ b/R/tensor.R @@ -253,6 +253,9 @@ Tensor <- R7Class( }, bincount = function(weights = list(), minlength = 0L) { to_index_tensor(self)$private$`_bincount`(weights = weights, minlength = minlength) + }, + movedim = function(source, destination) { + private$`_movedim`(as_1_based_dim(source), as_1_based_dim(destination)) } ), active = list( diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index d07ff46781..76a5672592 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -5755,13 +5755,13 @@ BEGIN_RCPP END_RCPP } // cpp_torch_method_tile_self_Tensor_dims_IntArrayRef -XPtrTorchTensor cpp_torch_method_tile_self_Tensor_dims_IntArrayRef(XPtrTorchTensor self, XPtrTorchIndexIntArrayRef dims); +XPtrTorchTensor cpp_torch_method_tile_self_Tensor_dims_IntArrayRef(XPtrTorchTensor self, XPtrTorchIntArrayRef dims); RcppExport SEXP _torch_cpp_torch_method_tile_self_Tensor_dims_IntArrayRef(SEXP selfSEXP, SEXP dimsSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< XPtrTorchTensor >::type self(selfSEXP); - Rcpp::traits::input_parameter< XPtrTorchIndexIntArrayRef >::type dims(dimsSEXP); + Rcpp::traits::input_parameter< XPtrTorchIntArrayRef >::type dims(dimsSEXP); rcpp_result_gen = Rcpp::wrap(cpp_torch_method_tile_self_Tensor_dims_IntArrayRef(self, dims)); return rcpp_result_gen; END_RCPP @@ -22005,13 +22005,13 @@ BEGIN_RCPP END_RCPP } // cpp_torch_namespace_tile_self_Tensor_dims_IntArrayRef -XPtrTorchTensor cpp_torch_namespace_tile_self_Tensor_dims_IntArrayRef(XPtrTorchTensor self, XPtrTorchIndexIntArrayRef dims); +XPtrTorchTensor cpp_torch_namespace_tile_self_Tensor_dims_IntArrayRef(XPtrTorchTensor self, XPtrTorchIntArrayRef dims); RcppExport SEXP _torch_cpp_torch_namespace_tile_self_Tensor_dims_IntArrayRef(SEXP selfSEXP, SEXP dimsSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< XPtrTorchTensor >::type self(selfSEXP); - Rcpp::traits::input_parameter< XPtrTorchIndexIntArrayRef >::type dims(dimsSEXP); + Rcpp::traits::input_parameter< XPtrTorchIntArrayRef >::type dims(dimsSEXP); rcpp_result_gen = Rcpp::wrap(cpp_torch_namespace_tile_self_Tensor_dims_IntArrayRef(self, dims)); return rcpp_result_gen; END_RCPP diff --git a/src/gen-namespace.cpp b/src/gen-namespace.cpp index fbc781bc25..e51f7e0389 100644 --- a/src/gen-namespace.cpp +++ b/src/gen-namespace.cpp @@ -2443,7 +2443,7 @@ return XPtrTorchTensor(r_out); } // [[Rcpp::export]] -XPtrTorchTensor cpp_torch_method_tile_self_Tensor_dims_IntArrayRef (XPtrTorchTensor self, XPtrTorchIndexIntArrayRef dims) { +XPtrTorchTensor cpp_torch_method_tile_self_Tensor_dims_IntArrayRef (XPtrTorchTensor self, XPtrTorchIntArrayRef dims) { auto r_out = lantern_Tensor_tile_tensor_intarrayref(self.get(), dims.get()); return XPtrTorchTensor(r_out); } @@ -9991,7 +9991,7 @@ return XPtrTorchTensor(r_out); } // [[Rcpp::export]] -XPtrTorchTensor cpp_torch_namespace_tile_self_Tensor_dims_IntArrayRef (XPtrTorchTensor self, XPtrTorchIndexIntArrayRef dims) { +XPtrTorchTensor cpp_torch_namespace_tile_self_Tensor_dims_IntArrayRef (XPtrTorchTensor self, XPtrTorchIntArrayRef dims) { auto r_out = lantern_tile_tensor_intarrayref(self.get(), dims.get()); return XPtrTorchTensor(r_out); } diff --git a/tests/testthat/test-gen-namespace.R b/tests/testthat/test-gen-namespace.R index 67207ef11f..289ac2c0cf 100644 --- a/tests/testthat/test-gen-namespace.R +++ b/tests/testthat/test-gen-namespace.R @@ -368,3 +368,9 @@ test_that("zeros_out", { expect_tensor(torch_zeros_out(x, c(2))) expect_equal_to_tensor(x, torch_tensor(c(0, 0))) }) + +test_that("tile works correctly", { + x <- torch_tensor(c(1, 2, 3)) + expect_true(length(x$tile(2)) == 6) + expect_true(length(torch_tile(x, 2)) == 6) +}) diff --git a/tests/testthat/test-wrapers.R b/tests/testthat/test-wrapers.R index 1d8eb25d56..d7fb85cdc7 100644 --- a/tests/testthat/test-wrapers.R +++ b/tests/testthat/test-wrapers.R @@ -133,6 +133,8 @@ test_that("movedim", { x <- torch_randn(3, 2, 1) expect_tensor_shape(torch_movedim(x, 1, 2), c(2, 3, 1)) expect_tensor_shape(torch_movedim(x, c(1, 2), c(2, 3)), c(1, 3, 2)) + expect_tensor_shape(x$movedim(1, 2), c(2, 3, 1)) + expect_tensor_shape(x$movedim(c(1, 2), c(2, 3)), c(1, 3, 2)) }) test_that("norm", { diff --git a/tools/torchgen/R/cpp.R b/tools/torchgen/R/cpp.R index 34323a776e..c6a18787fe 100644 --- a/tools/torchgen/R/cpp.R +++ b/tools/torchgen/R/cpp.R @@ -105,15 +105,21 @@ cpp_function_name <- function(method, type) { make_cpp_function_name(method$name, arg_types, type) } +indexing_special_cases <- function(argument) { + !(argument$decl_name %in% c("tile")) +} + cpp_parameter_type <- function(argument) { - if (argument$name %in% c("index", "indices", "dims") && + if (indexing_special_cases(argument) && + argument$name %in% c("index", "indices", "dims") && argument$dynamic_type == "Tensor") { return("XPtrTorchIndexTensor") } - if (argument$name %in% c("dims", "dims_self", "dims_other", "dim") && + if (indexing_special_cases(argument) && + argument$name %in% c("dims", "dims_self", "dims_other", "dim") && argument$dynamic_type == "IntArrayRef") { if (argument$type %in% c("c10::optional", "OptionalIntArrayRef")) { @@ -123,21 +129,23 @@ cpp_parameter_type <- function(argument) { } } - if (argument$name %in% c("dim", "dim0", "dim1", "dim2", "start_dim", "end_dim", "index") && + if (indexing_special_cases(argument) && + argument$name %in% c("dim", "dim0", "dim1", "dim2", "start_dim", "end_dim", "index") && argument$dynamic_type == "int64_t") { - if (argument$type == "c10::optional") return("XPtrTorchoptional_index_int64_t") else return("XPtrTorchindex_int64_t") } - if (argument$name == "indices" && + if (indexing_special_cases(argument) && + argument$name == "indices" && argument$dynamic_type == "TensorList") { return("XPtrTorchIndexTensorList") } - if (argument$name == "indices" && + if (indexing_special_cases(argument) && + argument$name == "indices" && argument$dynamic_type == "const c10::List> &") { return("XPtrTorchOptionalIndexTensorList") } @@ -333,8 +341,11 @@ cpp_parameter <- function(argument) { } cpp_signature <- function(decl) { - - res <- purrr::map_chr(decl$arguments, cpp_parameter) %>% + name <- decl$name + res <- purrr::map_chr(decl$arguments, function(x) { + x$decl_name <- name #expose de declaration name + cpp_parameter(x) + }) %>% glue::glue_collapse(sep = ", ") if(length(res) == 0) diff --git a/tools/torchgen/R/r.R b/tools/torchgen/R/r.R index be0fa0e3fd..0c82720291 100644 --- a/tools/torchgen/R/r.R +++ b/tools/torchgen/R/r.R @@ -426,7 +426,8 @@ internal_methods <- c("_backward", "retain_grad", "size", "to", "stride", "copy_", "topk", "scatter_", "scatter", "rename", "rename_", "narrow", "narrow_copy", "is_leaf", "max", "min", "argsort", "argmax", "argmin", "norm", "split", - "nonzero", "nonzero_numpy", "view", "sort", "bincount") + "nonzero", "nonzero_numpy", "view", "sort", "bincount", + "movedim") r_method_env <- function(decls) { if (decls[[1]]$name %in% internal_methods)