Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor from buffer #1061

Merged
merged 1 commit into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ export(optimizer)
export(sampler)
export(slc)
export(tensor_dataset)
export(tensor_from_buffer)
export(torch_abs)
export(torch_absolute)
export(torch_acos)
Expand Down
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -14597,6 +14597,10 @@ cpp_torch_tensor_print <- function(x, n) {
invisible(.Call(`_torch_cpp_torch_tensor_print`, x, n))
}

cpp_tensor_from_buffer <- function(data, shape, options) {
.Call(`_torch_cpp_tensor_from_buffer`, data, shape, options)
}

cpp_torch_tensor_dtype <- function(x) {
.Call(`_torch_cpp_torch_tensor_dtype`, x)
}
Expand Down
14 changes: 14 additions & 0 deletions R/tensor.R
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,17 @@ tensor_to_complex <- function(x) {
torch_tensor(Im(x), dtype = torch_double())
)
}

#' Creates a tensor from a buffer of memory
#'
#' It creates a tensor without taking ownership of the memory it points to.
#' You must call `clone` if you want to copy the memory over a new tensor.
#'
#' @param buffer An R atomic object containing the data in a contiguous array.
#' @param shape The shape of the resulting tensor.
#' @param dtype A torch data type for the tresulting tensor.
#'
#' @export
tensor_from_buffer <- function(buffer, shape, dtype = "float") {
cpp_tensor_from_buffer(buffer, shape, list(dtype=dtype))
}
19 changes: 19 additions & 0 deletions man/tensor_from_buffer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48356,6 +48356,19 @@ BEGIN_RCPP
return R_NilValue;
END_RCPP
}
// cpp_tensor_from_buffer
torch::Tensor cpp_tensor_from_buffer(const SEXP& data, std::vector<int64_t> shape, XPtrTorchTensorOptions options);
RcppExport SEXP _torch_cpp_tensor_from_buffer(SEXP dataSEXP, SEXP shapeSEXP, SEXP optionsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const SEXP& >::type data(dataSEXP);
Rcpp::traits::input_parameter< std::vector<int64_t> >::type shape(shapeSEXP);
Rcpp::traits::input_parameter< XPtrTorchTensorOptions >::type options(optionsSEXP);
rcpp_result_gen = Rcpp::wrap(cpp_tensor_from_buffer(data, shape, options));
return rcpp_result_gen;
END_RCPP
}
// cpp_torch_tensor_dtype
Rcpp::XPtr<XPtrTorchDtype> cpp_torch_tensor_dtype(torch::Tensor x);
RcppExport SEXP _torch_cpp_torch_tensor_dtype(SEXP xSEXP) {
Expand Down Expand Up @@ -52515,6 +52528,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_torch_cpp_Tensor_has_storage", (DL_FUNC) &_torch_cpp_Tensor_has_storage, 1},
{"_torch_cpp_Storage_data_ptr", (DL_FUNC) &_torch_cpp_Storage_data_ptr, 1},
{"_torch_cpp_torch_tensor_print", (DL_FUNC) &_torch_cpp_torch_tensor_print, 2},
{"_torch_cpp_tensor_from_buffer", (DL_FUNC) &_torch_cpp_tensor_from_buffer, 3},
{"_torch_cpp_torch_tensor_dtype", (DL_FUNC) &_torch_cpp_torch_tensor_dtype, 1},
{"_torch_torch_tensor_cpp", (DL_FUNC) &_torch_torch_tensor_cpp, 5},
{"_torch_cpp_as_array", (DL_FUNC) &_torch_cpp_as_array, 1},
Expand Down
18 changes: 14 additions & 4 deletions src/lantern/src/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,20 @@
void *_lantern_from_blob(void *data, int64_t *sizes, size_t sizes_size,
int64_t *strides, size_t strides_size, void *options) {
LANTERN_FUNCTION_START
return make_raw::Tensor(
torch::from_blob(data, std::vector<int64_t>(sizes, sizes + sizes_size),
std::vector<int64_t>(strides, strides + strides_size),
from_raw::TensorOptions(options)));
if (strides_size == 0) {
return make_raw::Tensor(torch::from_blob(
data,
std::vector<int64_t>(sizes, sizes + sizes_size),
from_raw::TensorOptions(options)
));
} else {
return make_raw::Tensor(torch::from_blob(
data,
std::vector<int64_t>(sizes, sizes + sizes_size),
std::vector<int64_t>(strides, strides + strides_size),
from_raw::TensorOptions(options)
));
}
LANTERN_FUNCTION_END
}

Expand Down
13 changes: 13 additions & 0 deletions src/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ void cpp_torch_tensor_print(torch::Tensor x, int n) {
Rcpp::Rcout << result;
};

// [[Rcpp::export]]
torch::Tensor cpp_tensor_from_buffer(const SEXP& data, std::vector<int64_t> shape, XPtrTorchTensorOptions options) {
return lantern_from_blob(
DATAPTR(data),
&shape[0],
shape.size(),
// we use the default strides
nullptr,
0,
options.get()
);
}

// [[Rcpp::export]]
Rcpp::XPtr<XPtrTorchDtype> cpp_torch_tensor_dtype(torch::Tensor x) {
XPtrTorchDtype out = lantern_Tensor_dtype(x.get());
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/test-tensor.R
Original file line number Diff line number Diff line change
Expand Up @@ -524,4 +524,12 @@ test_that("can convert to half using the method `half()`", {

x <- torch_tensor(1, dtype="half")
expect_equal(as.numeric(x), 1)
})

test_that("can create tensor from a buffer", {
x <- runif(10)
y <- tensor_from_buffer(x, shape = 10, dtype = "float64")
expect_equal(as.numeric(y), x)
y$add_(1)
expect_equal(as.numeric(y), x)
})