Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch_bernoulli() is failing on device="cuda" with Error : Expected a 'cuda' device type for generator but found 'cpu' #480

Closed
cregouby opened this issue Feb 22, 2021 · 2 comments · Fixed by #906
Milestone

Comments

@cregouby
Copy link
Collaborator

cregouby commented Feb 22, 2021

Hello,

torch_bernoulli() distribution is failing on CUDA device with the following error message

Error in (function (self, generator)  : 
  Expected a 'cuda' device type for generator but found 'cpu'

whereas it is working as expected on CPU device.

Workaround

Even if it is non-sense with regards to performance, a workaround is to compute the bernoulli in CPU before moving the result to GPU :

library(torch)
cuda_is_available()
#> [1] TRUE
# workaround
torch_bernoulli(torch_ones(3,3, device="cpu") * 0.5)$to(device="cuda")
#> torch_tensor
#>  1  0  1
#>  0  0  1
#>  0  1  0
#> [ CUDAFloatType{3,3} ]

Reprex

library(torch)
cuda_is_available()
#> [1] TRUE
# working function on device="cpu"
torch_bernoulli(torch_ones(3,3, device="cpu") * 0.5)
#> torch_tensor
#>  1  0  0
#>  0  1  0
#>  0  1  1
#> [ CPUFloatType{3,3} ]
# failing function on device = "cuda"
torch_bernoulli(torch_ones(3,3, device="cuda")* 0.5)
#> Error in (function (self, generator) : Expected a 'cuda' device type for generator but found 'cpu'
#> Exception raised from check_generator at /pytorch/aten/src/ATen/Utils.h:108 (most recent call first):
#> frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x69 (0x7f1b17435b89 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libc10.so)
#> frame #1: <unknown function> + 0x1e2179e (0x7f1acc08479e in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cuda.so)
#> frame #2: <unknown function> + 0x1e21acb (0x7f1acc084acb in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cuda.so)
#> frame #3: at::native::bernoulli_tensor_kernel(at::Tensor&, at::Tensor const&, c10::optional<at::Generator>) + 0x22 (0x7f1acc084d12 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cuda.so)
#> frame #4: <unknown function> + 0xd8d684 (0x7f1b071ff684 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #5: <unknown function> + 0xd8db21 (0x7f1b071ffb21 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #6: at::native::bernoulli_(at::Tensor&, at::Tensor const&, c10::optional<at::Generator>) + 0x32 (0x7f1b071f4702 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #7: <unknown function> + 0x32747f3 (0x7f1acd4d77f3 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cuda.so)
#> frame #8: <unknown function> + 0x32a9ace (0x7f1acd50cace in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cuda.so)
#> frame #9: <unknown function> + 0x14e2b63 (0x7f1b07954b63 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #10: at::Tensor::bernoulli_(at::Tensor const&, c10::optional<at::Generator>) const + 0xf2 (0x7f1b07ac0b12 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #11: at::native::bernoulli(at::Tensor const&, c10::optional<at::Generator>) + 0x94 (0x7f1b071e4494 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #12: <unknown function> + 0x15b7271 (0x7f1b07a29271 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #13: <unknown function> + 0xb646d1 (0x7f1b06fd66d1 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #14: <unknown function> + 0x14e24a1 (0x7f1b079544a1 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #15: at::bernoulli(at::Tensor const&, c10::optional<at::Generator>) + 0xda (0x7f1b0785571a in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #16: <unknown function> + 0x2a10704 (0x7f1b08e82704 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #17: <unknown function> + 0xb646d1 (0x7f1b06fd66d1 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #18: <unknown function> + 0x14e24a1 (0x7f1b079544a1 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #19: at::bernoulli(at::Tensor const&, c10::optional<at::Generator>) + 0xda (0x7f1b0785571a in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #20: _lantern_bernoulli_tensor_generator + 0x6b (0x7f1b179059db in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/liblantern.so)
#> frame #21: cpp_torch_namespace_bernoulli_self_Tensor(XPtrTorchTensor, XPtrTorchGenerator) + 0x35 (0x7f1b18158365 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/libs/torchpkg.so)
#> frame #22: _torch_cpp_torch_namespace_bernoulli_self_Tensor + 0x97 (0x7f1b17fb3ef7 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/libs/torchpkg.so)
#> frame #23: <unknown function> + 0xf5cbc (0x7f1b2204bcbc in /usr/lib/R/lib/libR.so)
#> frame #24: <unknown function> + 0xf6206 (0x7f1b2204c206 in /usr/lib/R/lib/libR.so)
#> frame #25: <unknown function> + 0x1303c1 (0x7f1b220863c1 in /usr/lib/R/lib/libR.so)
#> frame #26: Rf_eval + 0x88 (0x7f1b220a0fc8 in /usr/lib/R/lib/libR.so)
#> frame #27: <unknown function> + 0x14ce8f (0x7f1b220a2e8f in /usr/lib/R/lib/libR.so)
#> frame #28: Rf_applyClosure + 0x1a2 (0x7f1b220a3d82 in /usr/lib/R/lib/libR.so)
#> frame #29: Rf_eval + 0x2af (0x7f1b220a11ef in /usr/lib/R/lib/libR.so)
#> frame #30: <unknown function> + 0xc16ad (0x7f1b220176ad in /usr/lib/R/lib/libR.so)
#> frame #31: <unknown function> + 0x1303c1 (0x7f1b220863c1 in /usr/lib/R/lib/libR.so)
#> frame #32: Rf_eval + 0x88 (0x7f1b220a0fc8 in /usr/lib/R/lib/libR.so)
#> frame #33: <unknown function> + 0x14ce8f (0x7f1b220a2e8f in /usr/lib/R/lib/libR.so)
#> frame #34: Rf_applyClosure + 0x1a2 (0x7f1b220a3d82 in /usr/lib/R/lib/libR.so)
#> frame #35: <unknown function> + 0x13a6fe (0x7f1b220906fe in /usr/lib/R/lib/libR.so)
#> frame #36: Rf_eval + 0x88 (0x7f1b220a0fc8 in /usr/lib/R/lib/libR.so)
#> frame #37: <unknown function> + 0x14ce8f (0x7f1b220a2e8f in /usr/lib/R/lib/libR.so)
#> frame #38: Rf_applyClosure + 0x1a2 (0x7f1b220a3d82 in /usr/lib/R/lib/libR.so)
#> frame #39: <unknown function> + 0x13a6fe (0x7f1b220906fe in /usr/lib/R/lib/libR.so)
#> frame #40: Rf_eval + 0x88 (0x7f1b220a0fc8 in /usr/lib/R/lib/libR.so)
#> frame #41: <unknown function> + 0x14ce8f (0x7f1b220a2e8f in /usr/lib/R/lib/libR.so)
#> frame #42: Rf_applyClosure + 0x1a2 (0x7f1b220a3d82 in /usr/lib/R/lib/libR.so)
#> frame #43: <unknown function> + 0x13a6fe (0x7f1b220906fe in /usr/lib/R/lib/libR.so)
#> frame #44: Rf_eval + 0x88 (0x7f1b220a0fc8 in /usr/lib/R/lib/libR.so)
#> frame #45: <unknown function> + 0x14ce8f (0x7f1b220a2e8f in /usr/lib/R/lib/libR.so)
#> frame #46: Rf_applyClosure + 0x1a2 (0x7f1b220a3d82 in /usr/lib/R/lib/libR.so)
#> frame #47: Rf_eval + 0x2af (0x7f1b220a11ef in /usr/lib/R/lib/libR.so)
#> frame #48: <unknown function> + 0x151172 (0x7f1b220a7172 in /usr/lib/R/lib/libR.so)
#> frame #49: <unknown function> + 0x1303c1 (0x7f1b220863c1 in /usr/lib/R/lib/libR.so)
#> frame #50: Rf_eval + 0x88 (0x7f1b220a0fc8 in /usr/lib/R/lib/libR.so)
#> frame #51: <unknown function> + 0x14ce8f (0x7f1b220a2e8f in /usr/lib/R/lib/libR.so)
#> frame #52: Rf_applyClosure + 0x1a2 (0x7f1b220a3d82 in /usr/lib/R/lib/libR.so)
#> frame #53: <unknown function> + 0x13a6fe (0x7f1b220906fe in /usr/lib/R/lib/libR.so)
#> frame #54: Rf_eval + 0x88 (0x7f1b220a0fc8 in /usr/lib/R/lib/libR.so)
#> frame #55: <unknown function> + 0x14ba5c (0x7f1b220a1a5c in /usr/lib/R/lib/libR.so)
#> frame #56: Rf_eval + 0x39f (0x7f1b220a12df in /usr/lib/R/lib/libR.so)
#> frame #57: <unknown function> + 0x151bf0 (0x7f1b220a7bf0 in /usr/lib/R/lib/libR.so)
#> frame #58: <unknown function> + 0x18f35f (0x7f1b220e535f in /usr/lib/R/lib/libR.so)
#> frame #59: <unknown function> + 0x1301b1 (0x7f1b220861b1 in /usr/lib/R/lib/libR.so)
#> frame #60: Rf_eval + 0x88 (0x7f1b220a0fc8 in /usr/lib/R/lib/libR.so)
#> frame #61: <unknown function> + 0x14ce8f (0x7f1b220a2e8f in /usr/lib/R/lib/libR.so)
#> frame #62: Rf_applyClosure + 0x1a2 (0x7f1b220a3d82 in /usr/lib/R/lib/libR.so)
#> frame #63: <unknown function> + 0x13a6fe (0x7f1b220906fe in /usr/lib/R/lib/libR.so)

Created on 2021-02-22 by the reprex package (v1.0.0)

@dfalbel
Copy link
Member

dfalbel commented Feb 24, 2021

Thanks @cregouby . This is indeed a bug. I'll work on a fix soon

@dfalbel dfalbel added this to the torch v0.9.0 milestone Oct 14, 2022
@cregouby
Copy link
Collaborator Author

Many thanks for this ! I'll be able to remove the workaround to it !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants