Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync Span __eq__ and __hash__ #5005

Merged
merged 2 commits into from
Feb 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions spacy/tests/doc/test_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,12 @@ def test_filter_spans(doc):
assert len(filtered[1]) == 5
assert filtered[0].start == 1 and filtered[0].end == 4
assert filtered[1].start == 5 and filtered[1].end == 10


def test_span_eq_hash(doc, doc_not_parsed):
assert doc[0:2] == doc[0:2]
assert doc[0:2] != doc[1:3]
assert doc[0:2] != doc_not_parsed[0:2]
assert hash(doc[0:2]) == hash(doc[0:2])
assert hash(doc[0:2]) != hash(doc[1:3])
assert hash(doc[0:2]) != hash(doc_not_parsed[0:2])
6 changes: 5 additions & 1 deletion spacy/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ def assert_docs_equal(doc1, doc2):

assert [t.ent_type for t in doc1] == [t.ent_type for t in doc2]
assert [t.ent_iob for t in doc1] == [t.ent_iob for t in doc2]
assert [ent for ent in doc1.ents] == [ent for ent in doc2.ents]
for ent1, ent2 in zip(doc1.ents, doc2.ents):
assert ent1.start == ent2.start
assert ent1.end == ent2.end
assert ent1.label == ent2.label
assert ent1.kb_id == ent2.kb_id


def assert_packed_msg_equal(b1, b2):
Expand Down
13 changes: 9 additions & 4 deletions spacy/tokens/span.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,27 @@ cdef class Span:
return False
else:
return True
# Eq
# <
if op == 0:
return self.start_char < other.start_char
# <=
elif op == 1:
return self.start_char <= other.start_char
# ==
elif op == 2:
return self.start_char == other.start_char and self.end_char == other.end_char
return (self.doc, self.start_char, self.end_char, self.label, self.kb_id) == (other.doc, other.start_char, other.end_char, other.label, other.kb_id)
# !=
elif op == 3:
return self.start_char != other.start_char or self.end_char != other.end_char
return (self.doc, self.start_char, self.end_char, self.label, self.kb_id) != (other.doc, other.start_char, other.end_char, other.label, other.kb_id)
# >
elif op == 4:
return self.start_char > other.start_char
# >=
elif op == 5:
return self.start_char >= other.start_char

def __hash__(self):
return hash((self.doc, self.label, self.start_char, self.end_char))
return hash((self.doc, self.start_char, self.end_char, self.label, self.kb_id))

def __len__(self):
"""Get the number of tokens in the span.
Expand Down