Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gensim.models.BaseKeyedVectors.add_entity method for fill KeyedVectors in manual way. Fix #1942 #1957

69 changes: 69 additions & 0 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,75 @@ def get_vector(self, entity):
else:
raise KeyError("'%s' not in vocabulary" % entity)

def add_entity(self, entity, weights, replace=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe remove this method (add_entities looks enough, wdyt @gojomo?)

"""Add entity vector in a manual way.
If `entity` is already in the vocabulary, old vector is keeped unless `replace` flag is True.

Parameters
----------
entity : str
Entity specified by string tag.
weights : np.array
1D numpy array with shape (`vector_size`,)
replace: bool, optional
Boolean flag indicating whether to replace old vector if entity is already in the vocabulary.
Default, False, means that old vector is keeped.
"""
self.add_entities([entity], weights.reshape(1, -1), replace=replace)

def add_entities(self, entities, weights, replace=False):
"""Add entities and theirs vectors in a manual way.
If some entity is already in the vocabulary, old vector is keeped unless `replace` flag is True.

Parameters
----------
entities : list of str
Entities specified by string tags.
weights: list of np.array or np.array
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{list of numpy.ndarray, numpy.ndarray}

List of 1D np.array vectors or 2D np.array of vectors.
replace: bool, optional
Boolean flag indicating whether to replace vectors for entities which are already in the vocabulary.
Default, False, means that old vectors for those entities are keeped.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to duplicate "default" value for trivial case in docstring, maybe better to write something like

Flag indicating whether to replace vectors for entities which are already in the vocabulary, if True - replace vectors, otherwise - keep old vectors.

"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: multiline docstring should ends with empty line, i.e.

"""
...
last text

"""

if isinstance(weights, list):
weights = np.array(weights)

in_vocab_mask = np.zeros(len(entities), dtype=np.bool)
in_vocab_idxs = []
out_vocab_entities = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method might be simpler without separate in_vocab_idxs and out_vocab_entities – just driving those ops from the mask, using options like where() or nonzero().


for idx, entity in zip(range(len(entities)), entities):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enumerate() more idiomatic.

if entity in self.vocab:
in_vocab_mask[idx] = True
in_vocab_idxs.append(self.vocab[entity].index)
else:
out_vocab_entities.append(entity)

# add new entities to the vocab
for entity in out_vocab_entities:
entity_id = len(self.vocab)
self.vocab[entity] = Vocab(index=entity_id, count=1)
self.index2entity.append(entity)

# add vectors for new entities
if len(self.vectors) == 0:
self.vectors = weights[~in_vocab_mask]
else:
self.vectors = vstack((self.vectors, weights[~in_vocab_mask]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might this line work even in the case where len(self.vectors)==0, making the check/branch unnecessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not obvious how to do that, because when empty KeyedVectors object is created, self.vectors = [] is true. In that case, we can't use vstack(([], weights[~in_vocab_mask])) and ValueError: all the input array dimensions except for the concatenation axis must match exactly is raised.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for an empty KeyedVectors to have a self.vectors that is already a proper-dimensioned (0, vector_size) empty ndarray? (Not sure myself, but would simplify things in later places like this.)


# change vectors for in_vocab entities if `replace` flag is specified
if replace:
self.vectors[in_vocab_idxs] = weights[in_vocab_mask]

def __setitem__(self, entities, weights):
"""Idiomatic way to call `add_entities` with `replace=True`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to write full docstring

"""
if not isinstance(entities, list):
entities = [entities]
weights = weights.reshape(1, -1)

self.add_entities(entities, weights, replace=True)

def __getitem__(self, entities):
"""
Accept a single entity (string tag) or list of entities as input.
Expand Down
72 changes: 72 additions & 0 deletions gensim/test/test_keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,78 @@ def test_wv_property(self):
"""Test that the deprecated `wv` property returns `self`. To be removed in v4.0.0."""
self.assertTrue(self.vectors is self.vectors.wv)

def test_add_entity(self):
"""Test that adding entity in a manual way works correctly."""
entities = ['___some_entity{}_not_present_in_keyed_vectors___'.format(i) for i in range(5)]
vectors = [np.random.randn(self.vectors.vector_size) for _ in range(5)]

# Test `add_entity` on already filled kv.
for ent, vector in zip(entities, vectors):
self.vectors.add_entity(ent, vector)

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(self.vectors[ent], vector))

# Test `add_entity` on empty kv.
kv = EuclideanKeyedVectors(self.vectors.vector_size)
for ent, vector in zip(entities, vectors):
kv.add_entity(ent, vector)

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(kv[ent], vector))

def test_add_entities(self):
"""Test that adding a bulk of entities in a manual way works correctly."""
entities = ['___some_entity{}_not_present_in_keyed_vectors___'.format(i) for i in range(5)]
vectors = [np.random.randn(self.vectors.vector_size) for _ in range(5)]

# Test `add_entities` on already filled kv.
vocab_size = len(self.vectors.vocab)
self.vectors.add_entities(entities, vectors, replace=False)
self.assertEqual(vocab_size + len(entities), len(self.vectors.vocab))

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(self.vectors[ent], vector))

# Test `add_entities` on empty kv.
kv = EuclideanKeyedVectors(self.vectors.vector_size)
kv[entities] = vectors
self.assertEqual(len(kv.vocab), len(entities))

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(kv[ent], vector))

def test_set_item(self):
"""Test that __setitem__ works correctly."""
vocab_size = len(self.vectors.vocab)

# Add new entity.
entity = '___some_new_entity___'
vector = np.random.randn(self.vectors.vector_size)
self.vectors[entity] = vector

self.assertEqual(len(self.vectors.vocab), vocab_size + 1)
self.assertTrue(np.allclose(self.vectors[entity], vector))

# Replace vector for entity in vocab.
vocab_size = len(self.vectors.vocab)
vector = np.random.randn(self.vectors.vector_size)
self.vectors['war'] = vector

self.assertEqual(len(self.vectors.vocab), vocab_size)
self.assertTrue(np.allclose(self.vectors['war'], vector))

# __setitem__ on several entities.
vocab_size = len(self.vectors.vocab)
entities = ['war', '___some_new_entity1___', '___some_new_entity2___', 'terrorism', 'conflict']
vectors = [np.random.randn(self.vectors.vector_size) for _ in range(len(entities))]

self.vectors[entities] = vectors

self.assertEqual(len(self.vectors.vocab), vocab_size + 2)
for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(self.vectors[ent], vector))


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
Expand Down