diff --git a/include/nanobind/nb_lib.h b/include/nanobind/nb_lib.h index 2ad8bfe5..c117e7a1 100644 --- a/include/nanobind/nb_lib.h +++ b/include/nanobind/nb_lib.h @@ -360,7 +360,7 @@ NB_CORE dlpack::tensor *tensor_inc_ref(tensor_handle *) noexcept; NB_CORE void tensor_dec_ref(tensor_handle *) noexcept; /// Wrap a tensor_handle* into a PyCapsule -NB_CORE PyObject *tensor_wrap(tensor_handle *, int framework) noexcept; +NB_CORE PyObject *tensor_wrap(tensor_handle *, int framework, bool writeable) noexcept; // ======================================================================== diff --git a/include/nanobind/tensor.h b/include/nanobind/tensor.h index 6ec655f4..1ecd344d 100644 --- a/include/nanobind/tensor.h +++ b/include/nanobind/tensor.h @@ -75,6 +75,8 @@ struct numpy { }; struct tensorflow { }; struct pytorch { }; struct jax { }; +struct writeable { }; +struct readonly { }; template constexpr dlpack::dtype dtype() { static_assert( @@ -108,6 +110,7 @@ struct tensor_req { size_t *shape = nullptr; bool req_shape = false; bool req_dtype = false; + bool readonly = false; char req_order = '\0'; uint8_t req_device = 0; }; @@ -176,11 +179,21 @@ template struct tensor_arg> { static void apply(tensor_req &tr) { tr.req_device = (uint8_t) T::value; } }; +template <> struct tensor_arg { + static constexpr size_t size = 0; + static constexpr auto name = const_name("readonly"); + static void apply(tensor_req &tr) { tr.readonly = true; } +}; + template struct tensor_info { using scalar_type = void; using shape_type = void; constexpr static auto name = const_name("tensor"); constexpr static tensor_framework framework = tensor_framework::none; + // Relevant for from_cpp + constexpr static bool writeable = false; + // Relevant for from_python + constexpr static bool readonly = false; }; template struct tensor_info : tensor_info { @@ -213,6 +226,14 @@ template struct tensor_info : tensor_info { constexpr static tensor_framework framework = tensor_framework::jax; }; +template struct tensor_info : tensor_info { + constexpr static bool writeable = true; +}; + +template struct tensor_info : tensor_info { + constexpr static bool readonly = true; +}; + NAMESPACE_END(detail) template class tensor { @@ -319,10 +340,16 @@ template struct type_caster> { bool from_python(handle src, uint8_t flags, cleanup_list *) noexcept { constexpr size_t size = (0 + ... + detail::tensor_arg::size); + + static_assert(Value::Info::readonly == false || + Value::Info::framework == tensor_framework::numpy, + "Currently, only numpy arrays can be passed readonly."); + size_t shape[size + 1]; detail::tensor_req req; req.shape = shape; (detail::tensor_arg::apply(req), ...); + value = tensor(tensor_import( src.ptr(), &req, flags & (uint8_t) cast_flags::convert)); return value.is_valid(); @@ -330,7 +357,13 @@ template struct type_caster> { static handle from_cpp(const tensor &tensor, rv_policy, cleanup_list *) noexcept { - return tensor_wrap(tensor.handle(), int(Value::Info::framework)); + + static_assert(Value::Info::writeable == false || + Value::Info::framework == tensor_framework::numpy, + "Currently, writeable arrays are only available with numpy."); + + return tensor_wrap(tensor.handle(), int(Value::Info::framework), + Value::Info::writeable); } }; diff --git a/src/tensor.cpp b/src/tensor.cpp index 74b6b1dd..41839b12 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -145,11 +145,11 @@ void nb_tensor_releasebuffer(PyObject *, Py_buffer *view) { PyMem_Free(view->strides); } -static PyObject *dlpack_from_buffer_protocol(PyObject *o) { +static PyObject *dlpack_from_buffer_protocol(PyObject *o, bool readonly) { scoped_pymalloc view; scoped_pymalloc mt; - if (PyObject_GetBuffer(o, view.get(), PyBUF_RECORDS)) { + if (PyObject_GetBuffer(o, view.get(), readonly ? PyBUF_RECORDS_RO : PyBUF_RECORDS)) { PyErr_Clear(); return nullptr; } @@ -264,7 +264,17 @@ tensor_handle *tensor_import(PyObject *o, const tensor_req *req, // If this is not a capsule, try calling o.__dlpack__() if (!PyCapsule_CheckExact(o)) { - capsule = steal(PyObject_CallMethod(o, "__dlpack__", nullptr)); + if (req->readonly) + { + // We need to go via the buffer protocol for this. Only numpy is supported. + // Others are prevented via a static_assert in from_python. + capsule = steal(dlpack_from_buffer_protocol(o, req->readonly)); + if (!capsule.is_valid()) + return nullptr; + } + else { + capsule = steal(PyObject_CallMethod(o, "__dlpack__", nullptr)); + } if (!capsule.is_valid()) { PyErr_Clear(); @@ -291,7 +301,7 @@ tensor_handle *tensor_import(PyObject *o, const tensor_req *req, // Try creating a tensor via the buffer protocol if (!capsule.is_valid()) - capsule = steal(dlpack_from_buffer_protocol(o)); + capsule = steal(dlpack_from_buffer_protocol(o, req->readonly)); if (!capsule.is_valid()) return nullptr; @@ -548,7 +558,7 @@ static void tensor_capsule_destructor(PyObject *o) { PyErr_Clear(); } -PyObject *tensor_wrap(tensor_handle *th, int framework) noexcept { +PyObject *tensor_wrap(tensor_handle *th, int framework, bool writeable) noexcept { tensor_inc_ref(th); object o = steal(PyCapsule_New(th->tensor, "dltensor", tensor_capsule_destructor)), package; @@ -583,23 +593,41 @@ PyObject *tensor_wrap(tensor_handle *th, int framework) noexcept { if (package.is_valid()) { - try { - o = package.attr("from_dlpack")(o); - } catch (...) { + if (writeable) + { + // Force usage of asarray which goes via the buffer interface + // (see nb_tensor_getbuffer) as dlpack does not support returning + // a writeable array. if ((tensor_framework) framework == tensor_framework::numpy) { try { - // Older numpy versions - o = package.attr("_from_dlpack")(o); + o = package.attr("asarray")(o); } catch (...) { + return nullptr; + } + } + else + // This should be prevented by the static_assert in + // from_cpp of the type_caster. + return nullptr; + } else { + try { + o = package.attr("from_dlpack")(o); + } catch (...) { + if ((tensor_framework) framework == tensor_framework::numpy) { try { - // Yet older numpy versions - o = package.attr("asarray")(o); + // Older numpy versions + o = package.attr("_from_dlpack")(o); } catch (...) { - return nullptr; + try { + // Yet older numpy versions + o = package.attr("asarray")(o); + } catch (...) { + return nullptr; + } } + } else { + return nullptr; } - } else { - return nullptr; } } } diff --git a/tests/test_tensor.cpp b/tests/test_tensor.cpp index 914b5863..b55552c5 100644 --- a/tests/test_tensor.cpp +++ b/tests/test_tensor.cpp @@ -125,4 +125,23 @@ NB_MODULE(test_tensor_ext, m) { return nb::tensor>(f, 2, shape, deleter); }); + + m.def("ret_numpy_writeable", []() { + float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 }; + size_t shape[2] = { 2, 4 }; + + nb::capsule deleter(f, [](void *data) noexcept { + destruct_count++; + delete[] (float *) data; + }); + + return nb::tensor, nb::writeable>(f, 2, shape, + deleter); + }); + + m.def("passthrough", [](nb::tensor<> a) { return a; }); + m.def("accept_numpy_readonly", + []([[maybe_unused]] nb::tensor ro, + [[maybe_unused]] nb::tensor rw) { + }); } diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 89056342..302179b4 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -306,6 +306,7 @@ def test16_return_numpy(): x = t.ret_numpy() assert x.shape == (2, 4) assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]]) + assert not x.flags['WRITEABLE'] del x gc.collect() assert t.destruct_count() - dc == 1 @@ -325,3 +326,36 @@ def test17_return_pytorch(): del x gc.collect() assert t.destruct_count() - dc == 1 + + +@needs_numpy +def test18_return_numpy_writeable(): + gc.collect() + dc = t.destruct_count() + x = t.ret_numpy_writeable() + assert x.shape == (2, 4) + assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]]) + # Check the flags. + assert x.flags['WRITEABLE'] + # Check we can actually write. + x[0, 0] = 1 + del x + gc.collect() + assert t.destruct_count() - dc == 1 + + +@needs_numpy +def test19_pass_numpy_readonly(): + rw = np.zeros((2, 2)) + ro = np.zeros((2, 2)) + ro.setflags(write=False) + + assert rw.flags["WRITEABLE"] + assert not ro.flags["WRITEABLE"] + + # Only first parameter accepts readonly, so this should throw. + with pytest.raises(TypeError): + t.accept_numpy_readonly(ro, ro) + + # This works though. + t.accept_numpy_readonly(ro, rw)