Skip to content

Commit

Permalink
Restore C CPU inference in the refactored parser (#10747)
Browse files Browse the repository at this point in the history
* Bring back the C parsing model

The C parsing model is used for CPU inference and is still faster for
CPU inference than the forward pass of the Thinc model.

* Use C sgemm provided by the Ops implementation

* Make tb_framework module Cython, merge in C forward implementation

* TransitionModel: raise in backprop returned from forward_cpu

* Re-enable greedy parse test

* Return transition scores when forward_cpu is used

* Apply suggestions from code review

Import `Model` from `thinc.api`

Co-authored-by: Sofie Van Landeghem <[email protected]>

* Use relative imports in tb_framework

* Don't assume a default for beam_width

* We don't have a direct dependency on BLIS anymore

* Rename forwards to _forward_{fallback,greedy_cpu}

* Require thinc >=8.1.0,<8.2.0

* tb_framework: clean up imports

* Fix return type of _get_seen_mask

* Move up _forward_greedy_cpu

* Style fixes.

* Lower thinc lowerbound to 8.1.0.dev0

* Formatting fix

Co-authored-by: Adriane Boyd <[email protected]>

Co-authored-by: Sofie Van Landeghem <[email protected]>
Co-authored-by: Adriane Boyd <[email protected]>
  • Loading branch information
3 people authored May 30, 2022
1 parent 65c770c commit aad3897
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 31 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"spacy.vocab",
"spacy.attrs",
"spacy.kb",
"spacy.ml.tb_framework",
"spacy.morphology",
"spacy.pipeline._edit_tree_internals.edit_trees",
"spacy.pipeline.morphologizer",
Expand Down
1 change: 1 addition & 0 deletions spacy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,7 @@ class Errors(metaclass=ErrorsWithCodes):
E1035 = ("Token index {i} out of bounds ({length})")
E1036 = ("Cannot index into NoneNode")
E1037 = ("Invalid attribute value '{attr}'.")
E1038 = ("Backprop is not supported when is_train is not set.")


# Deprecated model shortcuts, only used in errors and warnings
Expand Down
28 changes: 28 additions & 0 deletions spacy/ml/tb_framework.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from libc.stdint cimport int8_t


cdef struct SizesC:
int states
int classes
int hiddens
int pieces
int feats
int embed_width
int tokens


cdef struct WeightsC:
const float* feat_weights
const float* feat_bias
const float* hidden_bias
const float* hidden_weights
const int8_t* seen_mask


cdef struct ActivationsC:
int* token_ids
float* unmaxed
float* hiddens
int* is_valid
int _curr_size
int _max_size
244 changes: 214 additions & 30 deletions spacy/ml/tb_framework.py → spacy/ml/tb_framework.pyx
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
# cython: infer_types=True, cdivision=True, boundscheck=False
from typing import List, Tuple, Any, Optional, cast
from thinc.api import Ops, Model, normal_init, chain, list2array, Linear
from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free, realloc
from libcpp.vector cimport vector
import numpy
cimport numpy as np
from thinc.api import Model, normal_init, chain, list2array, Linear
from thinc.api import uniform_init, glorot_uniform_init, zero_init
from thinc.api import NumpyOps
from thinc.backends.linalg cimport Vec, VecVec
from thinc.backends.cblas cimport CBlas
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
import numpy

from ..errors import Errors
from ..pipeline._parser_internals import _beam_utils
from ..pipeline._parser_internals.batch import GreedyBatch
from ..pipeline._parser_internals.transition_system cimport c_transition_batch, TransitionSystem
from ..pipeline._parser_internals.stateclass cimport StateC, StateClass
from ..tokens.doc import Doc
from ..util import registry


TransitionSystem = Any # TODO
State = Any # TODO


Expand Down Expand Up @@ -131,29 +142,82 @@ def init(
# model = _lsuv_init(model)
return model


def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: bool):
nF = model.get_dim("nF")
tok2vec = model.get_ref("tok2vec")
beam_width = model.attrs["beam_width"]
lower_pad = model.get_param("lower_pad")
lower_W = model.get_param("lower_W")
tok2vec = model.get_ref("tok2vec")

docs, moves = docs_moves
states = moves.init_batch(docs)
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
tokvecs = model.ops.xp.vstack((tokvecs, lower_pad))
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)
seen_mask = _get_seen_mask(model)

if beam_width == 1 and not is_train and isinstance(model.ops, NumpyOps):
return _forward_greedy_cpu(model, moves, states, feats, seen_mask)
else:
return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train)

def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[StateClass], np.ndarray feats,
np.ndarray[np.npy_bool, ndim=1] seen_mask):
cdef vector[StateC *] c_states
cdef StateClass state
for state in states:
if not state.is_final():
c_states.push_back(state.c)
weights = get_c_weights(model, <float *> feats.data, seen_mask)
# Precomputed features have rows for each token, plus one for padding.
cdef int n_tokens = feats.shape[0] - 1
sizes = get_c_sizes(model, c_states.size(), n_tokens)
cdef CBlas cblas = model.ops.cblas()
scores = _parseC(cblas, moves, &c_states[0], weights, sizes)

def backprop(dY):
raise ValueError(Errors.E1038)

return (states, scores), backprop

cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
WeightsC weights, SizesC sizes):
cdef int i, j
cdef vector[StateC *] unfinished
cdef ActivationsC activations = alloc_activations(sizes)
cdef np.ndarray step_scores

scores = []
while sizes.states >= 1:
step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f")
with nogil:
predict_states(cblas, &activations, <float *> step_scores.data, states, &weights, sizes)
# Validate actions, argmax, take action.
c_transition_batch(moves, states, <const float *> step_scores.data, sizes.classes,
sizes.states)
for i in range(sizes.states):
if not states[i].is_final():
unfinished.push_back(states[i])
for i in range(unfinished.size()):
states[i] = unfinished[i]
sizes.states = unfinished.size()
scores.append(step_scores)
unfinished.clear()
free_activations(&activations)

return scores

def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool):
nF = model.get_dim("nF")
lower_b = model.get_param("lower_b")
upper_W = model.get_param("upper_W")
upper_b = model.get_param("upper_b")
nH = model.get_dim("nH")
nP = model.get_dim("nP")
nO = model.get_dim("nO")
nI = model.get_dim("nI")

beam_width = model.attrs["beam_width"]
beam_density = model.attrs["beam_density"]

ops = model.ops
docs, moves = docs_moves
states = moves.init_batch(docs)
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
tokvecs = model.ops.xp.vstack((tokvecs, lower_pad))
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)

all_ids = []
all_which = []
all_statevecs = []
Expand All @@ -164,8 +228,7 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
batch = _beam_utils.BeamBatch(
moves, states, None, width=beam_width, density=beam_density
)
seen_mask = _get_seen_mask(model)
arange = model.ops.xp.arange(nF)
arange = ops.xp.arange(nF)
while not batch.is_done:
ids = numpy.zeros((len(batch.get_unfinished_states()), nF), dtype="i")
for i, state in enumerate(batch.get_unfinished_states()):
Expand All @@ -174,16 +237,16 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
# to create the state vectors.
preacts2f = feats[ids, arange].sum(axis=1) # type: ignore
preacts2f += lower_b
preacts = model.ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
preacts = ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
assert preacts.shape[0] == len(batch.get_unfinished_states()), preacts.shape
statevecs, which = ops.maxout(preacts)
# Multiply the state-vector by the scores weights and add the bias,
# to get the logits.
scores = model.ops.gemm(statevecs, upper_W, trans2=True)
scores = ops.gemm(statevecs, upper_W, trans2=True)
scores += upper_b
scores[:, seen_mask] = model.ops.xp.nanmin(scores)
scores[:, seen_mask] = ops.xp.nanmin(scores)
# Transition the states, filtering out any that are finished.
cpu_scores = model.ops.to_numpy(scores)
cpu_scores = ops.to_numpy(scores)
batch.advance(cpu_scores)
all_scores.append(scores)
if is_train:
Expand All @@ -193,10 +256,9 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
all_which.append(which)

def backprop_parser(d_states_d_scores):
d_tokvecs = model.ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1])
ids = model.ops.xp.vstack(all_ids)
ids = ops.xp.vstack(all_ids)
which = ops.xp.vstack(all_which)
statevecs = model.ops.xp.vstack(all_statevecs)
statevecs = ops.xp.vstack(all_statevecs)
_, d_scores = d_states_d_scores
if model.attrs.get("unseen_classes"):
# If we have a negative gradient (i.e. the probability should
Expand All @@ -209,18 +271,18 @@ def backprop_parser(d_states_d_scores):
# Calculate the gradients for the parameters of the upper layer.
# The weight gemm is (nS, nO) @ (nS, nH).T
model.inc_grad("upper_b", d_scores.sum(axis=0))
model.inc_grad("upper_W", model.ops.gemm(d_scores, statevecs, trans1=True))
model.inc_grad("upper_W", ops.gemm(d_scores, statevecs, trans1=True))
# Now calculate d_statevecs, by backproping through the upper linear layer.
# This gemm is (nS, nO) @ (nO, nH)
d_statevecs = model.ops.gemm(d_scores, upper_W)
d_statevecs = ops.gemm(d_scores, upper_W)
# Backprop through the maxout activation
d_preacts = model.ops.backprop_maxout(d_statevecs, which, nP)
d_preacts2f = model.ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP)
d_preacts = ops.backprop_maxout(d_statevecs, which, nP)
d_preacts2f = ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP)
model.inc_grad("lower_b", d_preacts2f.sum(axis=0))
# We don't need to backprop the summation, because we pass back the IDs instead
d_state_features = backprop_feats((d_preacts2f, ids))
d_tokvecs = model.ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1])
model.ops.scatter_add(d_tokvecs, ids, d_state_features)
d_tokvecs = ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1])
ops.scatter_add(d_tokvecs, ids, d_state_features)
model.inc_grad("lower_pad", d_tokvecs[-1])
return (backprop_tok2vec(d_tokvecs[:-1]), None)

Expand Down Expand Up @@ -328,7 +390,7 @@ def backprop_parser(d_states_d_scores):
return (states, all_scores), backprop_parser


def _get_seen_mask(model: Model) -> Floats1d:
def _get_seen_mask(model: Model) -> numpy.array[bool, 1]:
mask = model.ops.xp.zeros(model.get_dim("nO"), dtype="bool")
for class_ in model.attrs.get("unseen_classes", set()):
mask[class_] = True
Expand Down Expand Up @@ -449,3 +511,125 @@ def predict(ids, tokvecs):
else:
break
return model


cdef WeightsC get_c_weights(model, const float* feats, np.ndarray[np.npy_bool, ndim=1] seen_mask) except *:
cdef np.ndarray lower_b = model.get_param("lower_b")
cdef np.ndarray upper_W = model.get_param("upper_W")
cdef np.ndarray upper_b = model.get_param("upper_b")

cdef WeightsC output
output.feat_weights = feats
output.feat_bias = <const float*>lower_b.data
output.hidden_weights = <const float *> upper_W.data
output.hidden_bias = <const float *> upper_b.data
output.seen_mask = <const int8_t*> seen_mask.data

return output


cdef SizesC get_c_sizes(model, int batch_size, int tokens) except *:
cdef SizesC output
output.states = batch_size
output.classes = model.get_dim("nO")
output.hiddens = model.get_dim("nH")
output.pieces = model.get_dim("nP")
output.feats = model.get_dim("nF")
output.embed_width = model.get_dim("nI")
output.tokens = tokens
return output


cdef ActivationsC alloc_activations(SizesC n) nogil:
cdef ActivationsC A
memset(&A, 0, sizeof(A))
resize_activations(&A, n)
return A


cdef void free_activations(const ActivationsC* A) nogil:
free(A.token_ids)
free(A.unmaxed)
free(A.hiddens)
free(A.is_valid)


cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
if n.states <= A._max_size:
A._curr_size = n.states
return
if A._max_size == 0:
A.token_ids = <int*>calloc(n.states * n.feats, sizeof(A.token_ids[0]))
A.unmaxed = <float*>calloc(n.states * n.hiddens * n.pieces, sizeof(A.unmaxed[0]))
A.hiddens = <float*>calloc(n.states * n.hiddens, sizeof(A.hiddens[0]))
A.is_valid = <int*>calloc(n.states * n.classes, sizeof(A.is_valid[0]))
A._max_size = n.states
else:
A.token_ids = <int*>realloc(A.token_ids,
n.states * n.feats * sizeof(A.token_ids[0]))
A.unmaxed = <float*>realloc(A.unmaxed,
n.states * n.hiddens * n.pieces * sizeof(A.unmaxed[0]))
A.hiddens = <float*>realloc(A.hiddens,
n.states * n.hiddens * sizeof(A.hiddens[0]))
A.is_valid = <int*>realloc(A.is_valid,
n.states * n.classes * sizeof(A.is_valid[0]))
A._max_size = n.states
A._curr_size = n.states


cdef void predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** states, const WeightsC* W, SizesC n) nogil:
resize_activations(A, n)
for i in range(n.states):
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
sum_state_features(cblas, A.unmaxed, W.feat_weights, A.token_ids, n)
for i in range(n.states):
VecVec.add_i(&A.unmaxed[i*n.hiddens*n.pieces],
W.feat_bias, 1., n.hiddens * n.pieces)
for j in range(n.hiddens):
index = i * n.hiddens * n.pieces + j * n.pieces
which = Vec.arg_max(&A.unmaxed[index], n.pieces)
A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which]
if W.hidden_weights == NULL:
memcpy(scores, A.hiddens, n.states * n.classes * sizeof(float))
else:
# Compute hidden-to-output
cblas.sgemm()(False, True, n.states, n.classes, n.hiddens,
1.0, <const float *>A.hiddens, n.hiddens,
<const float *>W.hidden_weights, n.hiddens,
0.0, scores, n.classes)
# Add bias
for i in range(n.states):
VecVec.add_i(&scores[i*n.classes], W.hidden_bias, 1., n.classes)
# Set unseen classes to minimum value
i = 0
min_ = scores[0]
for i in range(1, n.states * n.classes):
if scores[i] < min_:
min_ = scores[i]
for i in range(n.states):
for j in range(n.classes):
if W.seen_mask[j]:
scores[i*n.classes+j] = min_


cdef void sum_state_features(CBlas cblas, float* output,
const float* cached, const int* token_ids, SizesC n) nogil:
cdef int idx, b, f, i
cdef const float* feature
cdef int B = n.states
cdef int O = n.hiddens * n.pieces
cdef int F = n.feats
cdef int T = n.tokens
padding = cached + (T * F * O)
cdef int id_stride = F*O
cdef float one = 1.
for b in range(B):
for f in range(F):
if token_ids[f] < 0:
feature = &padding[f*O]
else:
idx = token_ids[f] * id_stride + f*O
feature = &cached[idx]
cblas.saxpy()(O, one, <const float*>feature, 1, &output[b*O], 1)
token_ids += F
3 changes: 3 additions & 0 deletions spacy/pipeline/_parser_internals/transition_system.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,6 @@ cdef class TransitionSystem:

cdef int set_costs(self, int* is_valid, weight_t* costs,
const StateC* state, gold) except -1

cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores,
int nr_class, int batch_size) nogil
1 change: 0 additions & 1 deletion spacy/tests/parser/test_add_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def test_ner_labels_added_implicitly_on_beam_parse():
assert "D" in ner.labels


@pytest.mark.skip(reason="greedy_parse is deprecated")
def test_ner_labels_added_implicitly_on_greedy_parse():
nlp = Language()
ner = nlp.add_pipe("beam_ner")
Expand Down

0 comments on commit aad3897

Please sign in to comment.