Skip to content

Commit

Permalink
Reimplement parser rehearsal function (#10878)
Browse files Browse the repository at this point in the history
* Reimplement parser rehearsal function

Before the parser refactor, rehearsal was driven by a loop in the
`rehearse` method itself. For each parsing step, the loops would:

1. Get the predictions of the teacher.
2. Get the predictions and backprop function of the student.
3. Compute the loss and backprop into the student.
4. Move the teacher and student forward with the predictions of
   the student.

In the refactored parser, we cannot perform search stepwise rehearsal
anymore, since the model now predicts all parsing steps at once.
Therefore, rehearsal is performed in the following steps:

1. Get the predictions of all parsing steps from the student, along
   with its backprop function.
2. Get the predictions from the teacher, but use the predictions of
   the student to advance the parser while doing so.
3. Compute the loss and backprop into the student.

To support the second step a new method, `advance_with_actions` is
added to `GreedyBatch`, which performs the provided parsing steps.

* tb_framework: wrap upper_W and upper_b in Linear

Thinc's Optimizer cannot handle resizing of existing parameters. Until
it does, we work around this by wrapping the weights/biases of the upper
layer of the parser model in Linear. When the upper layer is resized, we
copy over the existing parameters into a new Linear instance. This does
not trigger an error in Optimizer, because it sees the resized layer as
a new set of parameters.

* Add test for TransitionSystem.apply_actions

* Better FIXME marker

Co-authored-by: Madeesh Kannan <[email protected]>

* Fixes from Madeesh

* Apply suggestions from Sofie

Co-authored-by: Sofie Van Landeghem <[email protected]>

* Remove useless assignment

Co-authored-by: Madeesh Kannan <[email protected]>
Co-authored-by: Sofie Van Landeghem <[email protected]>
  • Loading branch information
3 people authored Jun 8, 2022
1 parent aad3897 commit 63e90dd
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 67 deletions.
146 changes: 87 additions & 59 deletions spacy/ml/tb_framework.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# cython: infer_types=True, cdivision=True, boundscheck=False
from typing import List, Tuple, Any, Optional, cast
from typing import List, Tuple, Any, Optional, TypeVar, cast
from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free, realloc
from libcpp.vector cimport vector
Expand All @@ -10,12 +10,14 @@ from thinc.api import uniform_init, glorot_uniform_init, zero_init
from thinc.api import NumpyOps
from thinc.backends.linalg cimport Vec, VecVec
from thinc.backends.cblas cimport CBlas
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
from thinc.types import Floats1d, Floats2d, Floats3d, Floats4d
from thinc.types import Ints1d, Ints2d

from ..errors import Errors
from ..pipeline._parser_internals import _beam_utils
from ..pipeline._parser_internals.batch import GreedyBatch
from ..pipeline._parser_internals.transition_system cimport c_transition_batch, TransitionSystem
from ..pipeline._parser_internals.transition_system cimport c_transition_batch, c_apply_actions
from ..pipeline._parser_internals.transition_system cimport TransitionSystem
from ..pipeline._parser_internals.stateclass cimport StateC, StateClass
from ..tokens.doc import Doc
from ..util import registry
Expand Down Expand Up @@ -43,18 +45,28 @@ def TransitionModel(
tok2vec_projected = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width)) # type: ignore
tok2vec_projected.set_dim("nO", hidden_width)

# FIXME: we use `upper` as a container for the upper layer's
# weights and biases. Thinc optimizers cannot handle resizing
# of parameters. So, when the parser model is resized, we
# construct a new `upper` layer, which has a different key in
# the optimizer. Once the optimizer supports parameter resizing,
# we can replace the `upper` layer by `upper_W` and `upper_b`
# parameters in this model.
upper = Linear(nO=None, nI=hidden_width, init_W=zero_init)

return Model(
name="parser_model",
forward=forward,
init=init,
layers=[tok2vec_projected],
refs={"tok2vec": tok2vec_projected},
layers=[tok2vec_projected, upper],
refs={
"tok2vec": tok2vec_projected,
"upper": upper,
},
params={
"lower_W": None, # Floats2d W for the hidden layer
"lower_b": None, # Floats1d bias for the hidden layer
"lower_pad": None, # Floats1d padding for the hidden layer
"upper_W": None, # Floats2d W for the output layer
"upper_b": None, # Floats1d bias for the output layer
},
dims={
"nO": None, # Output size
Expand All @@ -74,29 +86,30 @@ def TransitionModel(

def resize_output(model: Model, new_nO: int) -> Model:
old_nO = model.maybe_get_dim("nO")
upper = model.get_ref("upper")
if old_nO is None:
model.set_dim("nO", new_nO)
upper.set_dim("nO", new_nO)
upper.initialize()
return model
elif new_nO <= old_nO:
return model
elif model.has_param("upper_W"):
elif upper.has_param("W"):
nH = model.get_dim("nH")
new_W = model.ops.alloc2f(new_nO, nH)
new_b = model.ops.alloc1f(new_nO)
old_W = model.get_param("upper_W")
old_b = model.get_param("upper_b")
new_upper = Linear(nO=new_nO, nI=nH, init_W=zero_init)
new_upper.initialize()
new_W = new_upper.get_param("W")
new_b = new_upper.get_param("b")
old_W = upper.get_param("W")
old_b = upper.get_param("b")
new_W[:old_nO] = old_W # type: ignore
new_b[:old_nO] = old_b # type: ignore
for i in range(old_nO, new_nO):
model.attrs["unseen_classes"].add(i)
model.set_param("upper_W", new_W)
model.set_param("upper_b", new_b)
model.layers[-1] = new_upper
model.set_ref("upper", new_upper)
# TODO: Avoid this private intrusion
model._dims["nO"] = new_nO
if model.has_grad("upper_W"):
model.set_grad("upper_W", model.get_param("upper_W") * 0)
if model.has_grad("upper_b"):
model.set_grad("upper_b", model.get_param("upper_b") * 0)
return model


Expand All @@ -113,9 +126,7 @@ def init(
inferred_nO = _infer_nO(Y)
if inferred_nO is not None:
current_nO = model.maybe_get_dim("nO")
if current_nO is None:
model.set_dim("nO", inferred_nO)
elif current_nO != inferred_nO:
if current_nO is None or current_nO != inferred_nO:
model.attrs["resize_output"](model, inferred_nO)
nO = model.get_dim("nO")
nP = model.get_dim("nP")
Expand All @@ -127,72 +138,83 @@ def init(
Wl = ops.alloc2f(nH * nP, nF * nI)
bl = ops.alloc1f(nH * nP)
padl = ops.alloc1f(nI)
Wu = ops.alloc2f(nO, nH)
bu = ops.alloc1f(nO)
Wu = zero_init(ops, Wu.shape)
# Wl = zero_init(ops, Wl.shape)
Wl = glorot_uniform_init(ops, Wl.shape)
padl = uniform_init(ops, padl.shape) # type: ignore
# TODO: Experiment with whether better to initialize upper_W
model.set_param("lower_W", Wl)
model.set_param("lower_b", bl)
model.set_param("lower_pad", padl)
model.set_param("upper_W", Wu)
model.set_param("upper_b", bu)
# model = _lsuv_init(model)
return model

def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: bool):
InWithoutActions = Tuple[List[Doc], TransitionSystem]
InWithActions = Tuple[List[Doc], TransitionSystem, List[Ints1d]]
InT = TypeVar("InT", InWithoutActions, InWithActions)

def forward(model, docs_moves: InT, is_train: bool):
if len(docs_moves) == 2:
docs, moves = docs_moves
actions = None
else:
docs, moves, actions = docs_moves

beam_width = model.attrs["beam_width"]
lower_pad = model.get_param("lower_pad")
tok2vec = model.get_ref("tok2vec")

docs, moves = docs_moves
states = moves.init_batch(docs)
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
tokvecs = model.ops.xp.vstack((tokvecs, lower_pad))
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)
seen_mask = _get_seen_mask(model)

# Fixme: support actions in forward_cpu
if beam_width == 1 and not is_train and isinstance(model.ops, NumpyOps):
return _forward_greedy_cpu(model, moves, states, feats, seen_mask)
return _forward_greedy_cpu(model, moves, states, feats, seen_mask, actions=actions)
else:
return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train)
return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train, actions=actions)


def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[StateClass], np.ndarray feats,
np.ndarray[np.npy_bool, ndim=1] seen_mask):
cdef vector[StateC *] c_states
np.ndarray[np.npy_bool, ndim=1] seen_mask, actions: Optional[List[Ints1d]]=None):
cdef vector[StateC*] c_states
cdef StateClass state
for state in states:
if not state.is_final():
c_states.push_back(state.c)
weights = get_c_weights(model, <float *> feats.data, seen_mask)
weights = get_c_weights(model, <float*>feats.data, seen_mask)
# Precomputed features have rows for each token, plus one for padding.
cdef int n_tokens = feats.shape[0] - 1
sizes = get_c_sizes(model, c_states.size(), n_tokens)
cdef CBlas cblas = model.ops.cblas()
scores = _parseC(cblas, moves, &c_states[0], weights, sizes)
scores = _parseC(cblas, moves, &c_states[0], weights, sizes, actions=actions)

def backprop(dY):
raise ValueError(Errors.E1038)

return (states, scores), backprop

cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
WeightsC weights, SizesC sizes):
WeightsC weights, SizesC sizes, actions: Optional[List[Ints1d]]=None):
cdef int i, j
cdef vector[StateC *] unfinished
cdef ActivationsC activations = alloc_activations(sizes)
cdef np.ndarray step_scores
cdef np.ndarray step_actions

scores = []
while sizes.states >= 1:
step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f")
step_actions = actions[0] if actions is not None else None
with nogil:
predict_states(cblas, &activations, <float *> step_scores.data, states, &weights, sizes)
# Validate actions, argmax, take action.
c_transition_batch(moves, states, <const float *> step_scores.data, sizes.classes,
sizes.states)
predict_states(cblas, &activations, <float*>step_scores.data, states, &weights, sizes)
if actions is None:
# Validate actions, argmax, take action.
c_transition_batch(moves, states, <const float*>step_scores.data, sizes.classes,
sizes.states)
else:
c_apply_actions(moves, states, <const int*>step_actions.data, sizes.states)
for i in range(sizes.states):
if not states[i].is_final():
unfinished.push_back(states[i])
Expand All @@ -201,15 +223,17 @@ cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states,
sizes.states = unfinished.size()
scores.append(step_scores)
unfinished.clear()
actions = actions[1:] if actions is not None else None
free_activations(&activations)

return scores

def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool):

def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool,
actions: Optional[List[Ints1d]]=None):
nF = model.get_dim("nF")
upper = model.get_ref("upper")
lower_b = model.get_param("lower_b")
upper_W = model.get_param("upper_W")
upper_b = model.get_param("upper_b")
nH = model.get_dim("nH")
nP = model.get_dim("nP")

Expand Down Expand Up @@ -240,14 +264,17 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC
preacts = ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
assert preacts.shape[0] == len(batch.get_unfinished_states()), preacts.shape
statevecs, which = ops.maxout(preacts)
# Multiply the state-vector by the scores weights and add the bias,
# to get the logits.
scores = ops.gemm(statevecs, upper_W, trans2=True)
scores += upper_b
# We don't use upper's backprop, since we want to backprop for
# all states at once, rather than a single state.
scores = upper.predict(statevecs)
scores[:, seen_mask] = ops.xp.nanmin(scores)
# Transition the states, filtering out any that are finished.
cpu_scores = ops.to_numpy(scores)
batch.advance(cpu_scores)
if actions is None:
batch.advance(cpu_scores)
else:
batch.advance_with_actions(actions[0])
actions = actions[1:]
all_scores.append(scores)
if is_train:
# Remember intermediate results for the backprop.
Expand All @@ -270,10 +297,11 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC
d_scores *= seen_mask == False
# Calculate the gradients for the parameters of the upper layer.
# The weight gemm is (nS, nO) @ (nS, nH).T
model.inc_grad("upper_b", d_scores.sum(axis=0))
model.inc_grad("upper_W", ops.gemm(d_scores, statevecs, trans1=True))
upper.inc_grad("b", d_scores.sum(axis=0))
upper.inc_grad("W", ops.gemm(d_scores, statevecs, trans1=True))
# Now calculate d_statevecs, by backproping through the upper linear layer.
# This gemm is (nS, nO) @ (nO, nH)
upper_W = upper.get_param("W")
d_statevecs = ops.gemm(d_scores, upper_W)
# Backprop through the maxout activation
d_preacts = ops.backprop_maxout(d_statevecs, which, nP)
Expand All @@ -295,11 +323,10 @@ def _forward_reference(
"""Slow reference implementation, without the precomputation"""
nF = model.get_dim("nF")
tok2vec = model.get_ref("tok2vec")
upper = model.get_ref("upper")
lower_pad = model.get_param("lower_pad")
lower_W = model.get_param("lower_W")
lower_b = model.get_param("lower_b")
upper_W = model.get_param("upper_W")
upper_b = model.get_param("upper_b")
nH = model.get_dim("nH")
nP = model.get_dim("nP")
nO = model.get_dim("nO")
Expand Down Expand Up @@ -330,10 +357,9 @@ def _forward_reference(
preacts2f += lower_b
preacts = model.ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
statevecs, which = ops.maxout(preacts)
# Multiply the state-vector by the scores weights and add the bias,
# to get the logits.
scores = model.ops.gemm(statevecs, upper_W, trans2=True)
scores += upper_b
# We don't use upper's backprop, since we want to backprop for
# all states at once, rather than a single state.
scores = upper.predict(statevecs)
scores[:, seen_mask] = model.ops.xp.nanmin(scores)
# Transition the states, filtering out any that are finished.
next_states = moves.transition_states(next_states, scores)
Expand Down Expand Up @@ -366,10 +392,11 @@ def _forward_reference(
assert d_scores.shape == (nS, nO), d_scores.shape
# Calculate the gradients for the parameters of the upper layer.
# The weight gemm is (nS, nO) @ (nS, nH).T
model.inc_grad("upper_b", d_scores.sum(axis=0))
model.inc_grad("upper_W", model.ops.gemm(d_scores, statevecs, trans1=True))
upper.inc_grad("b", d_scores.sum(axis=0))
upper.inc_grad("W", model.ops.gemm(d_scores, statevecs, trans1=True))
# Now calculate d_statevecs, by backproping through the upper linear layer.
# This gemm is (nS, nO) @ (nO, nH)
upper_W = upper.get_param("W")
d_statevecs = model.ops.gemm(d_scores, upper_W)
# Backprop through the maxout activation
d_preacts = model.ops.backprop_maxout(d_statevecs, which, nP)
Expand Down Expand Up @@ -514,9 +541,10 @@ def _lsuv_init(model: Model):


cdef WeightsC get_c_weights(model, const float* feats, np.ndarray[np.npy_bool, ndim=1] seen_mask) except *:
upper = model.get_ref("upper")
cdef np.ndarray lower_b = model.get_param("lower_b")
cdef np.ndarray upper_W = model.get_param("upper_W")
cdef np.ndarray upper_b = model.get_param("upper_b")
cdef np.ndarray upper_W = upper.get_param("W")
cdef np.ndarray upper_b = upper.get_param("b")

cdef WeightsC output
output.feat_weights = feats
Expand Down
3 changes: 3 additions & 0 deletions spacy/pipeline/_parser_internals/batch.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class GreedyBatch(Batch):
def advance(self, scores):
self._next_states = self._moves.transition_states(self._next_states, scores)

def advance_with_actions(self, actions):
self._next_states = self._moves.apply_transitions(self._next_states, actions)

def get_states(self):
return self._states

Expand Down
4 changes: 4 additions & 0 deletions spacy/pipeline/_parser_internals/transition_system.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,9 @@ cdef class TransitionSystem:
cdef int set_costs(self, int* is_valid, weight_t* costs,
const StateC* state, gold) except -1


cdef void c_apply_actions(TransitionSystem moves, StateC** states, const int* actions,
int batch_size) nogil

cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores,
int nr_class, int batch_size) nogil
23 changes: 23 additions & 0 deletions spacy/pipeline/_parser_internals/transition_system.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,17 @@ cdef class TransitionSystem:
action.do(state.c, action.label)
state.c.history.push_back(action.clas)

def apply_actions(self, states, const int[::1] actions):
assert len(states) == actions.shape[0]
cdef StateClass state
cdef vector[StateC*] c_states
c_states.resize(len(states))
cdef int i
for (i, state) in enumerate(states):
c_states[i] = state.c
c_apply_actions(self, &c_states[0], &actions[0], actions.shape[0])
return [state for state in states if not state.c.is_final()]

def transition_states(self, states, float[:, ::1] scores):
assert len(states) == scores.shape[0]
cdef StateClass state
Expand Down Expand Up @@ -279,6 +290,18 @@ cdef class TransitionSystem:
return self


cdef void c_apply_actions(TransitionSystem moves, StateC** states, const int* actions,
int batch_size) nogil:
cdef int i
cdef Transition action
cdef StateC* state
for i in range(batch_size):
state = states[i]
action = moves.c[actions[i]]
action.do(state, action.label)
state.history.push_back(action.clas)


cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores,
int nr_class, int batch_size) nogil:
is_valid = <int*>calloc(moves.n_moves, sizeof(int))
Expand Down
Loading

0 comments on commit 63e90dd

Please sign in to comment.