Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce unnecessary zero-init'd allocations #632

Merged
merged 10 commits into from
Apr 14, 2022
60 changes: 37 additions & 23 deletions thinc/backends/_custom_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,20 @@ def compile_mmh(src):
backprop_swish_kernel_float = _get_kernel("backprop_swish<float>")


def _alloc(shape, dtype, *, zeros: bool = True):
if zeros:
return cupy.zeros(shape, dtype)
else:
return cupy.empty(shape, dtype)


def _alloc_like(array, zeros: bool = True):
if zeros:
return cupy.zeros_like(array)
else:
return cupy.empty_like(array)


def clipped_linear(
X,
*,
Expand All @@ -139,7 +153,7 @@ def clipped_linear(

out = X
if not inplace:
out = cupy.zeros_like(X)
out = _alloc_like(X, zeros=False)
if X.dtype == "float32":
clipped_linear_kernel_float(
(num_blocks,),
Expand All @@ -160,7 +174,7 @@ def gelu(X, *, inplace=False, threshold=6.0, threads_per_block=128, num_blocks=1

out = X
if not inplace:
out = cupy.zeros_like(X)
out = _alloc_like(X, zeros=False)
if X.dtype == "float32":
gelu_kernel_float(
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
Expand Down Expand Up @@ -190,7 +204,7 @@ def seq2col(seq, nW, *, lengths=None, threads_per_block=128, num_blocks=128):
lengths = check_seq2col_lengths(lengths, B)
nL = lengths.shape[0]

out = cupy.zeros((B, I * nF), dtype=seq.dtype)
out = _alloc((B, I * nF), dtype=seq.dtype, zeros=True)

if seq.size != 0 and lengths.size != 0:
if seq.dtype == "float32":
Expand All @@ -211,8 +225,8 @@ def maxout(X, *, threads_per_block=128, num_blocks=128):
B, I, P = X.shape

out_shape = (B, I)
best = cupy.zeros(out_shape, dtype=X.dtype)
which = cupy.zeros(out_shape, dtype="i")
best = _alloc(out_shape, dtype=X.dtype, zeros=False)
which = _alloc(out_shape, dtype="i", zeros=False)

if X.dtype == "float32":
maxout_kernel_float(
Expand All @@ -231,7 +245,7 @@ def mish(X, *, inplace=False, threshold=5, threads_per_block=128, num_blocks=128

out = X
if not inplace:
out = cupy.zeros_like(X)
out = _alloc_like(X, zeros=False)

if X.dtype == "float32":
mish_kernel_float(
Expand All @@ -254,7 +268,7 @@ def reduce_sum(X, lengths, *, threads_per_block=128, num_blocks=128):

_check_lengths(lengths, T)

out = cupy.zeros((B, O), dtype=X.dtype)
out = _alloc((B, O), dtype=X.dtype, zeros=True)

if X.dtype == "float32":
reduce_sum_kernel_float(
Expand All @@ -277,7 +291,7 @@ def reduce_mean(X, lengths, *, threads_per_block=128, num_blocks=128):

_check_lengths(lengths, T)

out = cupy.zeros((B, O), dtype=X.dtype)
out = _alloc((B, O), dtype=X.dtype, zeros=True)

if X.dtype == "float32":
reduce_sum_kernel_float(
Expand All @@ -303,8 +317,8 @@ def reduce_max(X, lengths, *, threads_per_block=128, num_blocks=128):
_check_lengths(lengths, T)

out_shape = (B, O)
maxes = cupy.zeros(out_shape, dtype=X.dtype)
which = cupy.zeros(out_shape, dtype="i")
maxes = _alloc(out_shape, dtype=X.dtype, zeros=False)
which = _alloc(out_shape, dtype="i", zeros=False)
svlandeg marked this conversation as resolved.
Show resolved Hide resolved

if X.dtype == "float32":
reduce_max_kernel_float(
Expand All @@ -323,7 +337,7 @@ def swish(X, *, inplace=False, threshold=17.0, threads_per_block=128, num_blocks

out = X
if not inplace:
out = cupy.zeros_like(X)
out = _alloc_like(X, zeros=False)
if X.dtype == "float32":
swish_kernel_float(
(num_blocks,), (threads_per_block,), (out, X, threshold, X.size)
Expand All @@ -345,7 +359,7 @@ def backprop_seq2col(dY, nW, *, lengths=None, threads_per_block=128, num_blocks=
lengths = check_seq2col_lengths(lengths, B)
nL = lengths.shape[0]

out = cupy.zeros((B, I), dtype=dY.dtype)
out = _alloc((B, I), dtype=dY.dtype, zeros=True)

if dY.size != 0 and lengths.size != 0:
if dY.dtype == "float32":
Expand Down Expand Up @@ -377,7 +391,7 @@ def backprop_clipped_linear(

out = dY
if not inplace:
out = cupy.zeros_like(dY)
out = _alloc_like(dY, zeros=False)

if dY.dtype == "float32":
backprop_clipped_linear_kernel_float(
Expand All @@ -403,7 +417,7 @@ def backprop_hard_swish(

out = dY
if not inplace:
out = cupy.zeros_like(dY)
out = _alloc_like(dY, zeros=False)

if dY.dtype == "float32":
backprop_hard_swish_kernel_float(
Expand All @@ -425,7 +439,7 @@ def backprop_hard_swish_mobilenet(

out = dY
if not inplace:
out = cupy.zeros_like(dY)
out = _alloc_like(dY, zeros=False)

if dY.dtype == "float32":
backprop_hard_swish_mobilenet_kernel_float(
Expand Down Expand Up @@ -453,7 +467,7 @@ def backprop_gelu(

out = dY
if not inplace:
out = cupy.zeros_like(dY)
out = _alloc_like(dY, zeros=False)

if dY.dtype == "float32":
backprop_gelu_kernel_float(
Expand All @@ -473,7 +487,7 @@ def backprop_maxout(dY, which, P, *, threads_per_block=128, num_blocks=128):
B = dY.shape[0]
I = dY.shape[1]

out = cupy.zeros((B, I, P), dtype=dY.dtype)
out = _alloc((B, I, P), dtype=dY.dtype, zeros=True)

_check_which_maxout(which, B, I, P)

Expand All @@ -497,7 +511,7 @@ def backprop_mish(

out = dY
if not inplace:
out = cupy.zeros_like(dY)
out = _alloc_like(dY, zeros=False)

if dY.dtype == "float32":
backprop_mish_kernel_float(
Expand All @@ -519,7 +533,7 @@ def backprop_reduce_sum(d_sums, lengths, *, threads_per_block=128, num_blocks=12
O = d_sums.shape[1]
_check_lengths(lengths, T)

out = cupy.zeros((T, O), dtype=d_sums.dtype)
out = _alloc((T, O), dtype=d_sums.dtype, zeros=False)

if d_sums.dtype == "float32":
backprop_reduce_sum_kernel_float(
Expand All @@ -541,7 +555,7 @@ def backprop_reduce_mean(d_means, lengths, *, threads_per_block=128, num_blocks=
O = d_means.shape[1]
_check_lengths(lengths, T)

out = cupy.zeros((T, O), dtype=d_means.dtype)
out = _alloc((T, O), dtype=d_means.dtype, zeros=False)

if d_means.dtype == "float32":
backprop_reduce_mean_kernel_float(
Expand All @@ -565,7 +579,7 @@ def backprop_reduce_max(
O = d_maxes.shape[1]
_check_lengths(lengths, T)

out = cupy.zeros((T, O), dtype=d_maxes.dtype)
out = _alloc((T, O), dtype=d_maxes.dtype, zeros=True)

_check_which_reduce_max(which, (B, O), lengths)

Expand All @@ -590,7 +604,7 @@ def backprop_swish(

out = dY
if not inplace:
out = cupy.zeros_like(dY)
out = _alloc_like(dY, zeros=False)

if dY.dtype == "float32":
backprop_swish_kernel_float(
Expand All @@ -605,7 +619,7 @@ def backprop_swish(


def hash(ids, seed, *, threads_per_block=128, num_blocks=128):
out = cupy.zeros((ids.shape[0], 4), dtype="uint32")
out = _alloc((ids.shape[0], 4), dtype="uint32", zeros=True)

# sizeof(uint32_t) * 4
out_size = 4 * 4
Expand Down
35 changes: 19 additions & 16 deletions thinc/backends/numpy_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,11 @@ class NumpyOps(Ops):
else:
return self.xp.array(data)

def alloc(self, shape: Shape, *, dtype: Optional[DTypes] = "float32") -> ArrayXd:
return self.xp.zeros(shape, dtype=dtype)
def alloc(self, shape: Shape, *, dtype: Optional[DTypes] = "float32", zeros: bool = True) -> ArrayXd:
if zeros:
return self.xp.zeros(shape, dtype=dtype)
else:
return self.xp.empty(shape, dtype=dtype)

def gemm(self, np.ndarray x, np.ndarray y, *, np.ndarray out=None, trans1=False, trans2=False):
if x.ndim != 2:
Expand Down Expand Up @@ -159,14 +162,14 @@ class NumpyOps(Ops):
cdef int P = X.shape[2]

cdef np.ndarray best
cdef np.ndarray which = numpy.empty(shape=(B, O), dtype='int32', order='C')
cdef np.ndarray which = self.alloc(shape=(B, O), dtype='int32', zeros=False)
if reals3d_ft is float3d_t:
best = numpy.empty(shape=(B, O), dtype="float32", order='C')
best = self.alloc(shape=(B, O), dtype="float32", zeros=False)
if len(X) > 0:
cpu_maxout(<float*>best.data, <int*>which.data,
&X[0, 0, 0], B, O, P)
else:
best = numpy.empty(shape=(B, O), dtype="float64", order='C')
best = self.alloc(shape=(B, O), dtype="float64", zeros=False)
if len(X) > 0:
cpu_maxout(<double*>best.data, <int*>which.data,
&X[0, 0, 0], B, O, P)
Expand Down Expand Up @@ -393,12 +396,12 @@ class NumpyOps(Ops):
assert O != 0

cdef np.ndarray maxes
cdef np.ndarray which = numpy.zeros(shape=(B, O), dtype="i")
cdef np.ndarray which = self.alloc(shape=(B, O), dtype="i", zeros=False)
if reals2d_ft is float2d_t:
maxes = numpy.zeros(shape=(B, O), dtype="float32")
maxes = self.alloc(shape=(B, O), dtype="float32", zeros=False)
cpu_reduce_max(<float*>maxes.data, <int*>which.data, &X[0, 0], &lengths[0], B, T, O)
else:
maxes = numpy.zeros(shape=(B, O), dtype="float64")
maxes = self.alloc(shape=(B, O), dtype="float64", zeros=False)
cpu_reduce_max(<double*>maxes.data, <int*>which.data, &X[0, 0], &lengths[0], B, T, O)

return maxes, which
Expand Down Expand Up @@ -472,7 +475,7 @@ class NumpyOps(Ops):
def position_encode(self, int N, int D, int period=10000, out=None):
cdef np.ndarray out_
if out is None:
out_ = self.alloc((N, D))
out_ = self.alloc((N, D), zeros=False)
else:
out_ = out
assert out_.shape[0] == N
Expand Down Expand Up @@ -627,7 +630,7 @@ def lstm_forward_training(
Cid = C[i, d]
Gid = G[i, d]
_lstm_forward_training(
d, N, nO, nI, nT,
d, N, nO, nI, nT,
Gid,
<float*>Yid.data,
<float*>Cid.data,
Expand Down Expand Up @@ -779,7 +782,7 @@ def backprop_lstm(np.ndarray dY, np.ndarray lengths, np.ndarray params, fwd_stat
Wx, Wh, bias = all_layer_params[i][d][0]
dWx, dWh, d_bias = all_layer_grads[i][d][0]
assert Wx.shape[1] == dWx.shape[1] == X.shape[1] == dX.shape[1], (Wx.shape[1], dWx.shape[1], X.shape[1], dX.shape[1])
dYid = dY[d]
dYid = dY[d]
dC.fill(0.)
dG.fill(0.)
Cid = C[i, d]
Expand All @@ -800,7 +803,7 @@ def backprop_lstm(np.ndarray dY, np.ndarray lengths, np.ndarray params, fwd_stat
<float*>dWh.data,
<float*>d_bias.data,
<float*>Cid.data,
<float*>Gid.data,
<float*>Gid.data,
<float*>Yid.data,
<float*>X.data,
<float*>Wx.data,
Expand Down Expand Up @@ -865,7 +868,7 @@ cdef int _lstm_backward_training(
Ct3 = &C[seq_t3*nO]
Gt3 = &G[seq_t3*nO*4]
Ct2 = &C[seq_t2*nO]

batch_size = min(size_t2, size_t3)
cpu_lstm_gates_bwd(dGt3, dCt2,
dYt3, dCt3, Gt3, Ct3, Ct2, batch_size * nO
Expand Down Expand Up @@ -979,7 +982,7 @@ cdef void cpu_lstm_activate_fwd(float* gates, int B, int N) nogil:
"""Apply sigmoid activation in-place to columns 0, 1, 2 and tanh to column 3.
The data is assumed to have the gates in the last dimension.
"""
# This just does the following, but unrolled slightly to give
# This just does the following, but unrolled slightly to give
# a better chance at simd.
#
# gates[g+i+0] = sigmoid(gates[g+i+0])
Expand Down Expand Up @@ -1020,7 +1023,7 @@ cdef void cpu_lstm_activate_fwd(float* gates, int B, int N) nogil:
gates[g+3] = tanhf(gates[g+3])
g += 4


cdef void cpu_lstm_gates_fwd(float* hiddens, float* cells,
const float* gates, const float* prevcells, int B, int N) nogil:
cdef float hf, hi, ho, hc, ct2, ct3
Expand Down Expand Up @@ -1064,7 +1067,7 @@ cdef void cpu_lstm_gates_bwd(
hi = Gt3[i*4+1]
ho = Gt3[i*4+2]
hc = Gt3[i*4+3]

tanh_ct3 = tanhf(ct3)
# 3b: Yt3 = tanhCt3 * ho
d_ho = dyt3 * tanh_ct3
Expand Down
Loading