Skip to content

Commit

Permalink
Fix v4 branch to build against Thinc v9 (explosion#11921)
Browse files Browse the repository at this point in the history
* Move `thinc.extra.search` to `spacy.pipeline._parser_internals`

Backport of:
explosion#11317

Co-authored-by: Madeesh Kannan <[email protected]>

* Replace references to `thinc.backends.linalg` with `CBlas`

Backport of:
explosion#11292

Co-authored-by: Madeesh Kannan <[email protected]>

* Use cross entropy from `thinc.legacy`

* Require thinc>=9.0.0.dev0,<9.1.0

Co-authored-by: Madeesh Kannan <[email protected]>
  • Loading branch information
2 people authored and jikanter committed May 10, 2024
1 parent dffe8ae commit f35dc96
Show file tree
Hide file tree
Showing 19 changed files with 606 additions and 50 deletions.
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ requires = [
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0",
"thinc>=8.2.2,<8.3.0",
"numpy>=1.15.0; python_version < '3.9'",
"numpy>=1.25.0; python_version >= '3.9'",
"thinc>=9.0.0.dev0,<9.1.0",
"numpy>=1.15.0",
]
build-backend = "setuptools.build_meta"

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ spacy-legacy>=3.0.11,<3.1.0
spacy-loggers>=1.0.0,<2.0.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=8.2.2,<8.3.0
thinc>=9.0.0.dev0,<9.1.0
ml_datasets>=0.2.0,<0.3.0
murmurhash>=0.28.0,<1.1.0
wasabi>=0.9.1,<1.2.0
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ install_requires =
murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=8.2.2,<8.3.0
wasabi>=0.9.1,<1.2.0
thinc>=9.0.0.dev0,<9.1.0
wasabi>=0.9.1,<1.1.0
srsly>=2.4.3,<3.0.0
catalogue>=2.0.6,<2.1.0
weasel>=0.1.0,<0.5.0
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"spacy.pipeline._parser_internals.arc_eager",
"spacy.pipeline._parser_internals.ner",
"spacy.pipeline._parser_internals.nonproj",
"spacy.pipeline._parser_internals.search",
"spacy.pipeline._parser_internals._state",
"spacy.pipeline._parser_internals.stateclass",
"spacy.pipeline._parser_internals.transition_system",
Expand All @@ -66,6 +67,7 @@
"spacy.matcher.dependencymatcher",
"spacy.symbols",
"spacy.vectors",
"spacy.tests.parser._search",
]
COMPILE_OPTIONS = {
"msvc": ["/Ox", "/EHsc"],
Expand Down
26 changes: 17 additions & 9 deletions spacy/ml/parser_model.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
cimport numpy as np
from libc.math cimport exp
from libc.stdlib cimport calloc, free, realloc
from libc.string cimport memcpy, memset
from thinc.backends.cblas cimport saxpy, sgemm
from thinc.backends.linalg cimport Vec, VecVec

Expand Down Expand Up @@ -116,14 +115,10 @@ cdef void predict_states(
n.hiddens * n.pieces
)
for i in range(n.states):
VecVec.add_i(
&A.unmaxed[i*n.hiddens*n.pieces],
W.feat_bias, 1.,
n.hiddens * n.pieces
)
saxpy(cblas)(n.hiddens * n.pieces, 1., W.feat_bias, 1, &A.unmaxed[i*n.hiddens*n.pieces], 1)
for j in range(n.hiddens):
index = i * n.hiddens * n.pieces + j * n.pieces
which = Vec.arg_max(&A.unmaxed[index], n.pieces)
which = _arg_max(&A.unmaxed[index], n.pieces)
A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which]
memset(A.scores, 0, n.states * n.classes * sizeof(float))
if W.hidden_weights == NULL:
Expand All @@ -138,7 +133,7 @@ cdef void predict_states(
)
# Add bias
for i in range(n.states):
VecVec.add_i(&A.scores[i*n.classes], W.hidden_bias, 1., n.classes)
saxpy(cblas)(n.classes, 1., W.hidden_bias, 1, &A.scores[i*n.classes], 1)
# Set unseen classes to minimum value
i = 0
min_ = A.scores[0]
Expand Down Expand Up @@ -187,7 +182,8 @@ cdef void cpu_log_loss(
"""Do multi-label log loss"""
cdef double max_, gmax, Z, gZ
best = arg_max_if_gold(scores, costs, is_valid, O)
guess = Vec.arg_max(scores, O)
guess = _arg_max(scores, O)

if best == -1 or guess == -1:
# These shouldn't happen, but if they do, we want to make sure we don't
# cause an OOB access.
Expand Down Expand Up @@ -529,3 +525,15 @@ cdef class precompute_hiddens:
return d_best.reshape((d_best.shape + (1,)))

return state_vector, backprop_relu

cdef inline int _arg_max(const float* scores, const int n_classes) nogil:
if n_classes == 2:
return 0 if scores[0] > scores[1] else 1
cdef int i
cdef int best = 0
cdef float mode = scores[0]
for i in range(1, n_classes):
if scores[i] > mode:
mode = scores[i]
best = i
return best
3 changes: 1 addition & 2 deletions spacy/pipeline/_parser_internals/_beam_utils.pxd
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from ...typedefs cimport class_t, hash_t


# These are passed as callbacks to thinc.search.Beam
# These are passed as callbacks to .search.Beam
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1

cdef int check_final_state(void* _state, void* extra_args) except -1
12 changes: 4 additions & 8 deletions spacy/pipeline/_parser_internals/_beam_utils.pyx
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
# cython: infer_types=True
import numpy

from thinc.extra.search cimport Beam

from thinc.extra.search import MaxViolation

from thinc.extra.search cimport MaxViolation
from cpython.ref cimport PyObject, Py_XDECREF

from ...typedefs cimport class_t
from .transition_system cimport Transition, TransitionSystem

from ...errors import Errors

from .search cimport Beam, MaxViolation
from .search import MaxViolation
from .stateclass cimport StateC, StateClass


# These are passed as callbacks to thinc.search.Beam
# These are passed as callbacks to .search.Beam
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <StateC*>_dest
src = <StateC*>_src
Expand Down
3 changes: 1 addition & 2 deletions spacy/pipeline/_parser_internals/arc_eager.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ from ._state cimport ArcC, StateC
from .stateclass cimport StateClass

from ...errors import Errors

from thinc.extra.search cimport Beam
from .search cimport Beam


cdef weight_t MIN_SCORE = -90000
Expand Down
4 changes: 2 additions & 2 deletions spacy/pipeline/_parser_internals/ner.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ from libc.stdint cimport int32_t

from collections import Counter

from thinc.extra.search cimport Beam

from ...tokens.doc cimport Doc

from ...tokens.span import Span
Expand All @@ -23,6 +21,8 @@ from ...typedefs cimport attr_t, weight_t
from ...training import split_bilu_label

from ...training.example cimport Example
from .search cimport Beam
from .stateclass cimport StateClass
from ._state cimport StateC
from .stateclass cimport StateClass
from .transition_system cimport Transition, do_func_t
Expand Down
89 changes: 89 additions & 0 deletions spacy/pipeline/_parser_internals/search.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from cymem.cymem cimport Pool

from libc.stdint cimport uint32_t
from libc.stdint cimport uint64_t
from libcpp.pair cimport pair
from libcpp.queue cimport priority_queue
from libcpp.vector cimport vector

from ...typedefs cimport class_t, weight_t, hash_t

ctypedef pair[weight_t, size_t] Entry
ctypedef priority_queue[Entry] Queue


ctypedef int (*trans_func_t)(void* dest, void* src, class_t clas, void* x) except -1

ctypedef void* (*init_func_t)(Pool mem, int n, void* extra_args) except NULL

ctypedef int (*del_func_t)(Pool mem, void* state, void* extra_args) except -1

ctypedef int (*finish_func_t)(void* state, void* extra_args) except -1

ctypedef hash_t (*hash_func_t)(void* state, void* x) except 0


cdef struct _State:
void* content
class_t* hist
weight_t score
weight_t loss
int i
int t
bint is_done


cdef class Beam:
cdef Pool mem
cdef class_t nr_class
cdef class_t width
cdef class_t size
cdef public weight_t min_density
cdef int t
cdef readonly bint is_done
cdef list histories
cdef list _parent_histories
cdef weight_t** scores
cdef int** is_valid
cdef weight_t** costs
cdef _State* _parents
cdef _State* _states
cdef del_func_t del_func

cdef int _fill(self, Queue* q, weight_t** scores, int** is_valid) except -1

cdef inline void* at(self, int i) nogil:
return self._states[i].content

cdef int initialize(self, init_func_t init_func, del_func_t del_func, int n, void* extra_args) except -1
cdef int advance(self, trans_func_t transition_func, hash_func_t hash_func,
void* extra_args) except -1
cdef int check_done(self, finish_func_t finish_func, void* extra_args) except -1


cdef inline void set_cell(self, int i, int j, weight_t score, int is_valid, weight_t cost) nogil:
self.scores[i][j] = score
self.is_valid[i][j] = is_valid
self.costs[i][j] = cost

cdef int set_row(self, int i, const weight_t* scores, const int* is_valid,
const weight_t* costs) except -1
cdef int set_table(self, weight_t** scores, int** is_valid, weight_t** costs) except -1


cdef class MaxViolation:
cdef Pool mem
cdef weight_t cost
cdef weight_t delta
cdef readonly weight_t p_score
cdef readonly weight_t g_score
cdef readonly double Z
cdef readonly double gZ
cdef class_t n
cdef readonly list p_hist
cdef readonly list g_hist
cdef readonly list p_probs
cdef readonly list g_probs

cpdef int check(self, Beam pred, Beam gold) except -1
cpdef int check_crf(self, Beam pred, Beam gold) except -1
Loading

0 comments on commit f35dc96

Please sign in to comment.