diff --git a/spacy/tests/doc/test_span.py b/spacy/tests/doc/test_span.py index 01bb93c504b..917f22e9c8a 100644 --- a/spacy/tests/doc/test_span.py +++ b/spacy/tests/doc/test_span.py @@ -279,3 +279,12 @@ def test_filter_spans(doc): assert len(filtered[1]) == 5 assert filtered[0].start == 1 and filtered[0].end == 4 assert filtered[1].start == 5 and filtered[1].end == 10 + + +def test_span_eq_hash(doc, doc_not_parsed): + assert doc[0:2] == doc[0:2] + assert doc[0:2] != doc[1:3] + assert doc[0:2] != doc_not_parsed[0:2] + assert hash(doc[0:2]) == hash(doc[0:2]) + assert hash(doc[0:2]) != hash(doc[1:3]) + assert hash(doc[0:2]) != hash(doc_not_parsed[0:2]) diff --git a/spacy/tests/util.py b/spacy/tests/util.py index 175480fe723..9ee5b89f8bc 100644 --- a/spacy/tests/util.py +++ b/spacy/tests/util.py @@ -95,7 +95,11 @@ def assert_docs_equal(doc1, doc2): assert [t.ent_type for t in doc1] == [t.ent_type for t in doc2] assert [t.ent_iob for t in doc1] == [t.ent_iob for t in doc2] - assert [ent for ent in doc1.ents] == [ent for ent in doc2.ents] + for ent1, ent2 in zip(doc1.ents, doc2.ents): + assert ent1.start == ent2.start + assert ent1.end == ent2.end + assert ent1.label == ent2.label + assert ent1.kb_id == ent2.kb_id def assert_packed_msg_equal(b1, b2): diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx index 24857790bd0..35c70f236b3 100644 --- a/spacy/tokens/span.pyx +++ b/spacy/tokens/span.pyx @@ -127,22 +127,27 @@ cdef class Span: return False else: return True - # Eq + # < if op == 0: return self.start_char < other.start_char + # <= elif op == 1: return self.start_char <= other.start_char + # == elif op == 2: - return self.start_char == other.start_char and self.end_char == other.end_char + return (self.doc, self.start_char, self.end_char, self.label, self.kb_id) == (other.doc, other.start_char, other.end_char, other.label, other.kb_id) + # != elif op == 3: - return self.start_char != other.start_char or self.end_char != other.end_char + return (self.doc, self.start_char, self.end_char, self.label, self.kb_id) != (other.doc, other.start_char, other.end_char, other.label, other.kb_id) + # > elif op == 4: return self.start_char > other.start_char + # >= elif op == 5: return self.start_char >= other.start_char def __hash__(self): - return hash((self.doc, self.label, self.start_char, self.end_char)) + return hash((self.doc, self.start_char, self.end_char, self.label, self.kb_id)) def __len__(self): """Get the number of tokens in the span.