From d41a45b008a9818501e9f5382a967b36120a2561 Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Mon, 14 Aug 2023 15:42:15 -0300 Subject: [PATCH] Raise error when comparing a device with something that's not a dtype. --- R/dtype.R | 3 +++ tests/testthat/test-dtype.R | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/R/dtype.R b/R/dtype.R index f88f74a09e..3871a0af72 100644 --- a/R/dtype.R +++ b/R/dtype.R @@ -147,6 +147,9 @@ torch_qint32 <- function() torch_dtype$new(cpp_torch_qint32()) #' @export `==.torch_dtype` <- function(e1, e2) { + if (!is_torch_dtype(e1) || !is_torch_dtype(e2)) { + runtime_error("One of the objects is not a dtype. Comparison is not possible.") + } cpp_dtype_to_string(e1$ptr) == cpp_dtype_to_string(e2$ptr) } diff --git a/tests/testthat/test-dtype.R b/tests/testthat/test-dtype.R index 635d77f2a6..30af649bce 100644 --- a/tests/testthat/test-dtype.R +++ b/tests/testthat/test-dtype.R @@ -66,3 +66,12 @@ test_that("can set select devices using strings", { } }) + +test_that("error when comparing dtypes", { + + expect_error( + NULL == torch_float64(), + "not a dtype" + ) + +}) \ No newline at end of file