diff --git a/paddle/phi/kernels/cpu/stack_kernel.cc b/paddle/phi/kernels/cpu/stack_kernel.cc index a9c428c68047d..cd43c32dfecc8 100644 --- a/paddle/phi/kernels/cpu/stack_kernel.cc +++ b/paddle/phi/kernels/cpu/stack_kernel.cc @@ -25,6 +25,15 @@ void StackKernel(const Context& dev_ctx, int axis, DenseTensor* out) { if (axis < 0) axis += (x[0]->dims().size() + 1); + + auto x_dims = x[0]->dims(); + for (int i = 0; i < x_dims.size(); i++) { + PADDLE_ENFORCE_GT(x_dims[i], + 0, + phi::errors::InvalidArgument( + "The dims of Input(X) should be greater than 0")); + } + int n = static_cast(x.size()); T* y_data = dev_ctx.template Alloc(out); std::vector x_datas(n); diff --git a/paddle/phi/kernels/funcs/stack_and_unstack.h b/paddle/phi/kernels/funcs/stack_and_unstack.h index 0b2b5443383a9..d82cbd523f8fb 100644 --- a/paddle/phi/kernels/funcs/stack_and_unstack.h +++ b/paddle/phi/kernels/funcs/stack_and_unstack.h @@ -77,11 +77,12 @@ void StackRawKernel(const Context& ctx, // Split x dim from axis to matrix of shape [x_row, x_col], and the output // tensor's shape is [x_row, out_col]. - int64_t x_row = 1; + int64_t x_row = 1, x_row_bak = 1; for (int i = 0; i < axis; ++i) { x_row *= x[0]->dims()[i]; } - int64_t x_col = x[0]->numel() / x_row; + x_row_bak = x_row == 0 ? 1 : x_row; + int64_t x_col = x[0]->numel() / x_row_bak; int64_t out_col = x_col * num; if (out->numel() < std::numeric_limits::max()) { diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_beta.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_beta.py index cdff723a11b9d..a130d76e1e0d7 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_beta.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_beta.py @@ -113,6 +113,12 @@ def test_sample_shape(self): == case.get('expect') ) + def test_errors(self): + with self.assertRaises(ValueError): + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0]), dtype='int32') + paddle.distribution.Beta(alpha=x, beta=x) + if __name__ == '__main__': unittest.main()