Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve token head verification #5079

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions spacy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,13 @@ class Errors(object):
"make sure the gold EL data refers to valid results of the "
"named entity recognizer in the `nlp` pipeline.")
E189 = ("Each argument to `get_doc` should be of equal length.")
E190 = ("Token head out of range in `Doc.from_array()` for token index "
"'{index}' with value '{value}' (equivalent to relative head "
"index: '{rel_head_index}'). The head indices should be relative "
"to the current token index rather than absolute indices in the "
"array.")
E191 = ("Invalid head: the head token must be from the same doc as the "
"token itself.")


@add_codes
Expand Down
27 changes: 27 additions & 0 deletions spacy/tests/doc/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,30 @@ def test_doc_array_idx(en_vocab):
assert offsets[0] == 0
assert offsets[1] == 3
assert offsets[2] == 11


def test_doc_from_array_heads_in_bounds(en_vocab):
"""Test that Doc.from_array doesn't set heads that are out of bounds."""
words = ["This", "is", "a", "sentence", "."]
doc = Doc(en_vocab, words=words)
for token in doc:
token.head = doc[0]

# correct
arr = doc.to_array(["HEAD"])
doc_from_array = Doc(en_vocab, words=words)
doc_from_array.from_array(["HEAD"], arr)

# head before start
arr = doc.to_array(["HEAD"])
arr[0] = -1
doc_from_array = Doc(en_vocab, words=words)
with pytest.raises(ValueError):
doc_from_array.from_array(["HEAD"], arr)

# head after end
arr = doc.to_array(["HEAD"])
arr[0] = 5
doc_from_array = Doc(en_vocab, words=words)
with pytest.raises(ValueError):
doc_from_array.from_array(["HEAD"], arr)
5 changes: 5 additions & 0 deletions spacy/tests/doc/test_token_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ def test_doc_token_api_head_setter(en_tokenizer):
assert doc[4].left_edge.i == 0
assert doc[2].left_edge.i == 0

# head token must be from the same document
doc2 = get_doc(tokens.vocab, words=[t.text for t in tokens], heads=heads)
with pytest.raises(ValueError):
doc[0].head = doc2[0]


def test_is_sent_start(en_tokenizer):
doc = en_tokenizer("This is a sentence. This is another.")
Expand Down
10 changes: 9 additions & 1 deletion spacy/tokens/doc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ cdef class Doc:

if SENT_START in attrs and HEAD in attrs:
raise ValueError(Errors.E032)
cdef int i, col
cdef int i, col, abs_head_index
cdef attr_id_t attr_id
cdef TokenC* tokens = self.c
cdef int length = len(array)
Expand All @@ -804,6 +804,14 @@ cdef class Doc:
attr_ids[i] = attr_id
if len(array.shape) == 1:
array = array.reshape((array.size, 1))
# Check that all heads are within the document bounds
if HEAD in attrs:
col = attrs.index(HEAD)
for i in range(length):
# cast index to signed int
abs_head_index = numpy.int32(array[i, col]) + i
honnibal marked this conversation as resolved.
Show resolved Hide resolved
if abs_head_index < 0 or abs_head_index >= length:
raise ValueError(Errors.E190.format(index=i, value=array[i, col], rel_head_index=numpy.int32(array[i, col])))
# Do TAG first. This lets subsequent loop override stuff like POS, LEMMA
if TAG in attrs:
col = attrs.index(TAG)
Expand Down
3 changes: 3 additions & 0 deletions spacy/tokens/token.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,9 @@ cdef class Token:
# This function sets the head of self to new_head and updates the
# counters for left/right dependents and left/right corner for the
# new and the old head
# Check that token is from the same document
if self.doc != new_head.doc:
raise ValueError(Errors.E191)
# Do nothing if old head is new head
if self.i + self.c.head == new_head.i:
return
Expand Down