From e3096fdfaaabafcba2705eb95e4cb3ef62598dc1 Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Mon, 2 Oct 2023 15:39:29 -0300 Subject: [PATCH] move length implementation to R7 --- NAMESPACE | 2 +- R/R7.R | 10 ++++++++++ R/operators.R | 5 ----- R/tensor.R | 3 +++ tests/testthat/test-utils-data.R | 17 +++++++++++++++++ 5 files changed, 31 insertions(+), 6 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index ea8f93626a..3c0a266408 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -82,12 +82,12 @@ S3method(enumerate,dataloader) S3method(exp,torch_tensor) S3method(expm1,torch_tensor) S3method(floor,torch_tensor) +S3method(length,R7) S3method(length,dataloader) S3method(length,dataset) S3method(length,iterable_dataset) S3method(length,nn_module_list) S3method(length,nn_sequential) -S3method(length,torch_tensor) S3method(length,utils_sampler) S3method(log,torch_tensor) S3method(log10,torch_tensor) diff --git a/R/R7.R b/R/R7.R index 8d57b9c4f4..cec9070f4e 100644 --- a/R/R7.R +++ b/R/R7.R @@ -126,3 +126,13 @@ extract_method <- function(self, name, call = TRUE) { print.R7 <- function(x, ...) { x$print(...) } + +#' @export +length.R7 <- function(x) { + tryCatch( + x$length(), + error = function(err) { + cli::cli_abort("{.val length} is not support for objects with class {.cls {class(x)}}") + } + ) +} \ No newline at end of file diff --git a/R/operators.R b/R/operators.R index 68e0db7ba6..da31d2a71c 100644 --- a/R/operators.R +++ b/R/operators.R @@ -180,11 +180,6 @@ dim.torch_tensor <- function(x) { cpp_tensor_dim(x$ptr) } -#' @export -length.torch_tensor <- function(x) { - prod(dim(x)) -} - #' @export as.numeric.torch_tensor <- function(x, ...) { as.numeric(as_array(x)) diff --git a/R/tensor.R b/R/tensor.R index 3b7d60c8e6..dc35438491 100644 --- a/R/tensor.R +++ b/R/tensor.R @@ -45,6 +45,9 @@ Tensor <- R7Class( dim = function() { cpp_tensor_ndim(self) }, + length = function() { + prod(dim(self)) + }, size = function(dim) { x <- cpp_tensor_dim(self$ptr) diff --git a/tests/testthat/test-utils-data.R b/tests/testthat/test-utils-data.R index 301b07ef2a..7e5647198c 100644 --- a/tests/testthat/test-utils-data.R +++ b/tests/testthat/test-utils-data.R @@ -172,3 +172,20 @@ test_that("can get a single element using `[[`", { ds <- tensor_dataset(torch_rand(11,3), torch_rand(11,1)) expect_equal(dim(ds[[1]][[1]]), 3) }) + +test_that("can have a dataset named torch_tensor", { + + ds <- dataset("torch_tensor", + initialize = function() { + }, + .getitem = function(id) { + torch::torch_tensor(1) + }, + .length = function() 1L + ) + + expect_no_error({ + a <- ds() + }) + +})