Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: force buffer protocol in tensor creation for numpy to avoid read/write constraints with dlpack. #56

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/nanobind/nb_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ NB_CORE dlpack::tensor *tensor_inc_ref(tensor_handle *) noexcept;
NB_CORE void tensor_dec_ref(tensor_handle *) noexcept;

/// Wrap a tensor_handle* into a PyCapsule
NB_CORE PyObject *tensor_wrap(tensor_handle *, int framework) noexcept;
NB_CORE PyObject *tensor_wrap(tensor_handle *, int framework, bool writeable) noexcept;

// ========================================================================

Expand Down
35 changes: 34 additions & 1 deletion include/nanobind/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ struct numpy { };
struct tensorflow { };
struct pytorch { };
struct jax { };
struct writeable { };
struct readonly { };

template <typename T> constexpr dlpack::dtype dtype() {
static_assert(
Expand Down Expand Up @@ -108,6 +110,7 @@ struct tensor_req {
size_t *shape = nullptr;
bool req_shape = false;
bool req_dtype = false;
bool readonly = false;
char req_order = '\0';
uint8_t req_device = 0;
};
Expand Down Expand Up @@ -176,11 +179,21 @@ template <typename T> struct tensor_arg<T, enable_if_t<T::is_device>> {
static void apply(tensor_req &tr) { tr.req_device = (uint8_t) T::value; }
};

template <> struct tensor_arg<readonly> {
static constexpr size_t size = 0;
static constexpr auto name = const_name("readonly");
static void apply(tensor_req &tr) { tr.readonly = true; }
};

template <typename... Ts> struct tensor_info {
using scalar_type = void;
using shape_type = void;
constexpr static auto name = const_name("tensor");
constexpr static tensor_framework framework = tensor_framework::none;
// Relevant for from_cpp
constexpr static bool writeable = false;
// Relevant for from_python
constexpr static bool readonly = false;
};

template <typename T, typename... Ts> struct tensor_info<T, Ts...> : tensor_info<Ts...> {
Expand Down Expand Up @@ -213,6 +226,14 @@ template <typename... Ts> struct tensor_info<jax, Ts...> : tensor_info<Ts...> {
constexpr static tensor_framework framework = tensor_framework::jax;
};

template <typename... Ts> struct tensor_info<writeable, Ts...> : tensor_info<Ts...> {
constexpr static bool writeable = true;
};

template <typename... Ts> struct tensor_info<readonly, Ts...> : tensor_info<Ts...> {
constexpr static bool readonly = true;
};

NAMESPACE_END(detail)

template <typename... Args> class tensor {
Expand Down Expand Up @@ -319,18 +340,30 @@ template <typename... Args> struct type_caster<tensor<Args...>> {

bool from_python(handle src, uint8_t flags, cleanup_list *) noexcept {
constexpr size_t size = (0 + ... + detail::tensor_arg<Args>::size);

static_assert(Value::Info::readonly == false ||
Value::Info::framework == tensor_framework::numpy,
"Currently, only numpy arrays can be passed readonly.");

size_t shape[size + 1];
detail::tensor_req req;
req.shape = shape;
(detail::tensor_arg<Args>::apply(req), ...);

value = tensor<Args...>(tensor_import(
src.ptr(), &req, flags & (uint8_t) cast_flags::convert));
return value.is_valid();
}

static handle from_cpp(const tensor<Args...> &tensor, rv_policy,
cleanup_list *) noexcept {
return tensor_wrap(tensor.handle(), int(Value::Info::framework));

static_assert(Value::Info::writeable == false ||
Value::Info::framework == tensor_framework::numpy,
"Currently, writeable arrays are only available with numpy.");

return tensor_wrap(tensor.handle(), int(Value::Info::framework),
Value::Info::writeable);
}
};

Expand Down
58 changes: 43 additions & 15 deletions src/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ void nb_tensor_releasebuffer(PyObject *, Py_buffer *view) {
PyMem_Free(view->strides);
}

static PyObject *dlpack_from_buffer_protocol(PyObject *o) {
static PyObject *dlpack_from_buffer_protocol(PyObject *o, bool readonly) {
scoped_pymalloc<Py_buffer> view;
scoped_pymalloc<managed_tensor> mt;

if (PyObject_GetBuffer(o, view.get(), PyBUF_RECORDS)) {
if (PyObject_GetBuffer(o, view.get(), readonly ? PyBUF_RECORDS_RO : PyBUF_RECORDS)) {
PyErr_Clear();
return nullptr;
}
Expand Down Expand Up @@ -264,7 +264,17 @@ tensor_handle *tensor_import(PyObject *o, const tensor_req *req,

// If this is not a capsule, try calling o.__dlpack__()
if (!PyCapsule_CheckExact(o)) {
capsule = steal(PyObject_CallMethod(o, "__dlpack__", nullptr));
if (req->readonly)
{
// We need to go via the buffer protocol for this. Only numpy is supported.
// Others are prevented via a static_assert in from_python.
capsule = steal(dlpack_from_buffer_protocol(o, req->readonly));
if (!capsule.is_valid())
return nullptr;
}
else {
capsule = steal(PyObject_CallMethod(o, "__dlpack__", nullptr));
}

if (!capsule.is_valid()) {
PyErr_Clear();
Expand All @@ -291,7 +301,7 @@ tensor_handle *tensor_import(PyObject *o, const tensor_req *req,

// Try creating a tensor via the buffer protocol
if (!capsule.is_valid())
capsule = steal(dlpack_from_buffer_protocol(o));
capsule = steal(dlpack_from_buffer_protocol(o, req->readonly));

if (!capsule.is_valid())
return nullptr;
Expand Down Expand Up @@ -548,7 +558,7 @@ static void tensor_capsule_destructor(PyObject *o) {
PyErr_Clear();
}

PyObject *tensor_wrap(tensor_handle *th, int framework) noexcept {
PyObject *tensor_wrap(tensor_handle *th, int framework, bool writeable) noexcept {
tensor_inc_ref(th);
object o = steal(PyCapsule_New(th->tensor, "dltensor", tensor_capsule_destructor)),
package;
Expand Down Expand Up @@ -583,23 +593,41 @@ PyObject *tensor_wrap(tensor_handle *th, int framework) noexcept {


if (package.is_valid()) {
try {
o = package.attr("from_dlpack")(o);
} catch (...) {
if (writeable)
{
// Force usage of asarray which goes via the buffer interface
// (see nb_tensor_getbuffer) as dlpack does not support returning
// a writeable array.
if ((tensor_framework) framework == tensor_framework::numpy) {
try {
// Older numpy versions
o = package.attr("_from_dlpack")(o);
o = package.attr("asarray")(o);
} catch (...) {
return nullptr;
}
}
else
// This should be prevented by the static_assert in
// from_cpp of the type_caster.
return nullptr;
} else {
try {
o = package.attr("from_dlpack")(o);
} catch (...) {
if ((tensor_framework) framework == tensor_framework::numpy) {
try {
// Yet older numpy versions
o = package.attr("asarray")(o);
// Older numpy versions
o = package.attr("_from_dlpack")(o);
} catch (...) {
return nullptr;
try {
// Yet older numpy versions
o = package.attr("asarray")(o);
} catch (...) {
return nullptr;
}
}
} else {
return nullptr;
}
} else {
return nullptr;
}
}
}
Expand Down
19 changes: 19 additions & 0 deletions tests/test_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,23 @@ NB_MODULE(test_tensor_ext, m) {
return nb::tensor<nb::pytorch, float, nb::shape<2, 4>>(f, 2, shape,
deleter);
});

m.def("ret_numpy_writeable", []() {
float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
size_t shape[2] = { 2, 4 };

nb::capsule deleter(f, [](void *data) noexcept {
destruct_count++;
delete[] (float *) data;
});

return nb::tensor<nb::numpy, float, nb::shape<2, 4>, nb::writeable>(f, 2, shape,
deleter);
});

m.def("passthrough", [](nb::tensor<> a) { return a; });
m.def("accept_numpy_readonly",
[]([[maybe_unused]] nb::tensor<nb::numpy, nb::readonly> ro,
[[maybe_unused]] nb::tensor<nb::numpy, nb::writeable> rw) {
});
}
34 changes: 34 additions & 0 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def test16_return_numpy():
x = t.ret_numpy()
assert x.shape == (2, 4)
assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]])
assert not x.flags['WRITEABLE']
del x
gc.collect()
assert t.destruct_count() - dc == 1
Expand All @@ -325,3 +326,36 @@ def test17_return_pytorch():
del x
gc.collect()
assert t.destruct_count() - dc == 1


@needs_numpy
def test18_return_numpy_writeable():
gc.collect()
dc = t.destruct_count()
x = t.ret_numpy_writeable()
assert x.shape == (2, 4)
assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]])
# Check the flags.
assert x.flags['WRITEABLE']
# Check we can actually write.
x[0, 0] = 1
del x
gc.collect()
assert t.destruct_count() - dc == 1


@needs_numpy
def test19_pass_numpy_readonly():
rw = np.zeros((2, 2))
ro = np.zeros((2, 2))
ro.setflags(write=False)

assert rw.flags["WRITEABLE"]
assert not ro.flags["WRITEABLE"]

# Only first parameter accepts readonly, so this should throw.
with pytest.raises(TypeError):
t.accept_numpy_readonly(ro, ro)

# This works though.
t.accept_numpy_readonly(ro, rw)