diff --git a/NAMESPACE b/NAMESPACE index 4a849c6b1c..243d2c948e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -69,6 +69,7 @@ S3method(as.matrix,torch_tensor) S3method(as.numeric,torch_tensor) S3method(as_array,torch_tensor) S3method(as_iterator,dataloader) +S3method(as_iterator,iterable_dataset) S3method(as_iterator,utils_sampler) S3method(asin,torch_tensor) S3method(atan,torch_tensor) @@ -83,6 +84,7 @@ S3method(expm1,torch_tensor) S3method(floor,torch_tensor) S3method(length,dataloader) S3method(length,dataset) +S3method(length,iterable_dataset) S3method(length,nn_module_list) S3method(length,nn_sequential) S3method(length,torch_tensor) @@ -100,6 +102,7 @@ S3method(nn_prune_head,nn_sequential) S3method(print,R7) S3method(print,cuda_memory_stats) S3method(print,dataset_generator) +S3method(print,iterable_dataset_generator) S3method(print,nn_module) S3method(print,script_function) S3method(print,script_method) @@ -178,6 +181,7 @@ export(is_torch_layout) export(is_torch_memory_format) export(is_torch_qscheme) export(is_undefined_tensor) +export(iterable_dataset) export(jit_compile) export(jit_load) export(jit_ops) @@ -526,9 +530,11 @@ export(torch_ceil) export(torch_celu) export(torch_celu_) export(torch_cfloat) +export(torch_cfloat128) export(torch_cfloat32) export(torch_cfloat64) export(torch_chain_matmul) +export(torch_chalf) export(torch_channel_shuffle) export(torch_channels_last_format) export(torch_cholesky) diff --git a/R/utils-data-dataloader.R b/R/utils-data-dataloader.R index ee43b4af99..e0578418dd 100644 --- a/R/utils-data-dataloader.R +++ b/R/utils-data-dataloader.R @@ -152,11 +152,15 @@ DataLoader <- R6::R6Class( if (is_map_dataset(dataset)) { self$.dataset_kind <- "map" + } else if (is_iterable_dataset(dataset)) { + self$.dataset_kind <- "iterable" + } else { + cli::cli_abort("Unknown dataset type with class {.cls {class(dataset)}}") } if (is.null(sampler)) { if (self$.dataset_kind == "iterable") { - # TODO + sampler <- InfiniteSampler() } else { if (shuffle) { sampler <- RandomSampler(dataset, generator = generator) @@ -200,13 +204,26 @@ DataLoader <- R6::R6Class( } MultiProcessingDataLoaderIter$new(self) + } else if (self$.dataset_kind == "iterable") { + if (self$num_workers == 0) { + return(SingleProcessDataLoaderIter$new(self)) + } + cli::cli_abort("Multi-process dataloader not implemented yet for Iterable datasets.") } else { not_implemented_error() } }, .length = function() { if (self$.dataset_kind == "iterable") { - not_implemented_error() + l <- length(self$dataset) + + if (is.na(l)) return(l) + + if (self$drop_last) { + return(l %/% self$batch_size) + } else { + return(as.integer(ceiling(l / self$batch_size))) + } } else { length(self$.index_sampler) } @@ -283,6 +300,13 @@ SingleProcessDataLoaderIter <- R6::R6Class( self$.collate_fn, self$.drop_last ) + } else if (self$.dataset_kind == "iterable") { + self$.dataset_fetcher <- IterableDatasetFetcher$new( + self$.dataset, + self$.auto_collation, + self$.collate_fn, + self$.drop_last + ) } else { not_implemented_error() } @@ -294,7 +318,9 @@ SingleProcessDataLoaderIter <- R6::R6Class( return(coro::exhausted()) } + # data can be exhausted in iterable datasets data <- self$.dataset_fetcher$fetch(index) + if (self$.pin_memory) { # TODO } diff --git a/R/utils-data-fetcher.R b/R/utils-data-fetcher.R index be37ba6296..109ccca74c 100644 --- a/R/utils-data-fetcher.R +++ b/R/utils-data-fetcher.R @@ -36,6 +36,12 @@ IterableDatasetFetcher <- R6::R6Class( d <- self$dataset_iter() if (is_exhausted(d)) { + if (self$drop_last || i == 1) { + return(coro::exhausted()) + } + + # we drop the null values in that list. + data <- data[seq_len(i-1L)] break } @@ -44,6 +50,7 @@ IterableDatasetFetcher <- R6::R6Class( } else { data <- self$dataset_iter() } + self$collate_fn(data) } ) diff --git a/R/utils-data-sampler.R b/R/utils-data-sampler.R index aae488c2fe..ff355eef23 100644 --- a/R/utils-data-sampler.R +++ b/R/utils-data-sampler.R @@ -159,6 +159,19 @@ BatchSampler <- sampler( } ) +InfiniteSampler <- sampler( + "infinite_sampler", + initialize = function() {}, + .iter = function() { + function() { + TRUE + } + }, + .length = function() { + Inf + } +) + #' @export as_iterator.utils_sampler <- function(x) { it <- x$.iter() diff --git a/R/utils-data.R b/R/utils-data.R index d850ec85f6..6270950202 100644 --- a/R/utils-data.R +++ b/R/utils-data.R @@ -36,10 +36,32 @@ Dataset <- R6::R6Class( ) ) +IterableDataset <- R6::R6Class( + classname = "iterable_dataset", + lock_objects = FALSE, + public = list( + .iter = function() { + not_implemented_error() + }, + .length = function() { + NA_integer_ + } + ) +) + is_map_dataset <- function(x) { inherits(x, "dataset") } +is_iterable_dataset <- function(x) { + inherits(x, "iterable_dataset") +} + +#' @export +as_iterator.iterable_dataset <- function(x) { + x$.iter() +} + get_init <- function(x) { if (!is.null(x$public_methods$initialize)) { return(x$public_methods$initialize) @@ -107,12 +129,58 @@ dataset <- function(name = NULL, inherit = Dataset, ..., ) } + +#' Creates an iterable dataset +#' +#' @inheritParams dataset +#' @examples +#' ids <- iterable_dataset( +#' name = "hello", +#' initialize = function(n = 5) { +#' self$n <- n +#' self$i <- 0 +#' }, +#' .iter = function() { +#' i <- 0 +#' function() { +#' i <<- i + 1 +#' if (i > self$n) { +#' coro::exhausted() +#' } else { +#' i +#' } +#' } +#' } +#' ) +#' coro::collect(ids()$.iter()) +#' @export +iterable_dataset <- function(name, inherit = IterableDataset, ..., + private = NULL, active = NULL, + parent_env = parent.frame()) { + create_class( + name = name, + inherit = inherit, + ..., + private = private, + active = active, + parent_env = parent_env, + attr_name = "Dataset", + constructor_class = "iterable_dataset_generator" + ) +} + #' @export print.dataset_generator <- function(x, ...) { cli::cat_line("") print(attr(x, "Dataset")) } +#' @export +print.iterable_dataset_generator <- function(x, ...) { + cli::cat_line("") + print(attr(x, "IterableDataset")) +} + #' @export `[.dataset` <- function(x, y) { y <- as.integer(y) @@ -136,6 +204,11 @@ length.dataset <- function(x) { x$.length() } +#' @export +length.iterable_dataset <- function(x) { + x$.length() +} + #' Dataset wrapping tensors. #' #' Each sample will be retrieved by indexing tensors along the first dimension. diff --git a/man/iterable_dataset.Rd b/man/iterable_dataset.Rd new file mode 100644 index 0000000000..5fe61a0599 --- /dev/null +++ b/man/iterable_dataset.Rd @@ -0,0 +1,57 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils-data.R +\name{iterable_dataset} +\alias{iterable_dataset} +\title{Creates an iterable dataset} +\usage{ +iterable_dataset( + name, + inherit = IterableDataset, + ..., + private = NULL, + active = NULL, + parent_env = parent.frame() +) +} +\arguments{ +\item{name}{a name for the dataset. It it's also used as the class +for it.} + +\item{inherit}{you can optionally inherit from a dataset when creating a +new dataset.} + +\item{...}{public methods for the dataset class} + +\item{private}{passed to \code{\link[R6:R6Class]{R6::R6Class()}}.} + +\item{active}{passed to \code{\link[R6:R6Class]{R6::R6Class()}}.} + +\item{parent_env}{An environment to use as the parent of newly-created +objects.} +} +\description{ +Creates an iterable dataset +} +\examples{ +if (torch_is_installed()) { +ids <- iterable_dataset( + name = "hello", + initialize = function(n = 5) { + self$n <- n + self$i <- 0 + }, + .iter = function() { + i <- 0 + function() { + i <<- i + 1 + if (i > self$n) { + coro::exhausted() + } else { + i + } + } + } +) +coro::collect(ids()$.iter()) +} +} diff --git a/man/torch_dtype.Rd b/man/torch_dtype.Rd index cdd808efa2..99a76b2a19 100644 --- a/man/torch_dtype.Rd +++ b/man/torch_dtype.Rd @@ -6,10 +6,12 @@ \alias{torch_float} \alias{torch_float64} \alias{torch_double} -\alias{torch_cfloat} \alias{torch_cfloat32} -\alias{torch_cdouble} +\alias{torch_chalf} +\alias{torch_cfloat} \alias{torch_cfloat64} +\alias{torch_cdouble} +\alias{torch_cfloat128} \alias{torch_float16} \alias{torch_half} \alias{torch_uint8} @@ -34,13 +36,17 @@ torch_float64() torch_double() +torch_cfloat32() + +torch_chalf() + torch_cfloat() -torch_cfloat32() +torch_cfloat64() torch_cdouble() -torch_cfloat64() +torch_cfloat128() torch_float16() diff --git a/src/utils.cpp b/src/utils.cpp index 8d95093ce5..850452ec98 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -65,9 +65,9 @@ std::thread::id main_thread_id() noexcept { // [[Rcpp::export]] Rcpp::List transpose2(Rcpp::List x) { auto templ = Rcpp::as(x[0]); - auto num_elements = templ.length(); + const auto num_elements = templ.length(); - auto size = x.length(); + const auto size = x.length(); std::vector out; for (auto i = 0; i < num_elements; i++) { @@ -75,6 +75,9 @@ Rcpp::List transpose2(Rcpp::List x) { } for (size_t j = 0; j < size; j++) { + if (Rf_isNull(x[j])) { + Rcpp::stop("NULL is not allowed. Expected a list."); + } auto el = Rcpp::as(x[j]); for (auto i = 0; i < num_elements; i++) { out[i][j] = el[i]; diff --git a/tests/testthat/test-utils-data-dataloader.R b/tests/testthat/test-utils-data-dataloader.R index ce066ae498..fe86ecb104 100644 --- a/tests/testthat/test-utils-data-dataloader.R +++ b/tests/testthat/test-utils-data-dataloader.R @@ -514,3 +514,111 @@ test_that("collate works with bool data", { expect_true(out$dtype == torch_bool()) }) + +test_that("can use dataloaders on iterable datasets", { + ids <- iterable_dataset( + "ids", + initialize = function(n = 320) { + self$n <- n + }, + .iter = function() { + i <- 0 + function() { + i <<- i + 1 + if (i <= self$n) { + i + } else { + coro::exhausted() + } + } + } + ) + + dl <- dataloader(ids(), batch_size = 32) + data <- coro::collect(dl) + + expect_equal(length(data), 10) + expect_equal(data[[10]]$shape, 32) + + dl <- dataloader(ids(33), batch_size = 32) + data <- coro::collect(dl) + + expect_equal(length(data), 2) + expect_equal(data[[2]]$shape, 1) + + dl <- dataloader(ids(33), batch_size = 32, drop_last = TRUE) + data <- coro::collect(dl) + + expect_equal(length(data), 1) + + # length can be NA for iterable datasets + expect_true(is.na(length(dl))) +}) + +test_that("correctly reports length for iterable datasets that provide length", { + + ids <- iterable_dataset( + "ids", + initialize = function(n = 320) { + self$n <- n + }, + .iter = function() { + i <- 0 + function() { + i <<- i + 1 + if (i <= self$n) { + i + } else { + coro::exhausted() + } + } + }, + .length = function() { + self$n + } + ) + + expect_equal(length(ids()), 320) + + dl <- dataloader(ids(), batch_size = 32) + expect_equal(length(dl), 10) + + dl <- dataloader(ids(33), batch_size = 32) + expect_equal(length(dl), 2) + + dl <- dataloader(ids(33), batch_size = 32, drop_last = TRUE) + expect_equal(length(dl), 1) + +}) + +test_that("a case that errors in luz", { + + get_iterable_ds <- iterable_dataset( + "iterable_ds", + initialize = function(len = 100, x_size = 10, y_size = 1, fixed_values = FALSE) { + self$len <- len + self$x <- torch::torch_randn(size = c(len, x_size)) + self$y <- torch::torch_randn(size = c(len, y_size)) + }, + .iter = function() { + i <- 0 + function() { + i <<- i + 1 + + if (i > self$len) { + return(coro::exhausted()) + } + + list( + x = self$x[i,..], + y = self$y[i,..] + ) + } + } + ) + + ds <- get_iterable_ds() + dl <- dataloader(ds, batch_size = 32) + expect_equal(length(coro::collect(dl)), 4) + +})