Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor of installation #17

Merged
merged 12 commits into from
Feb 3, 2023
Merged
27 changes: 17 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ It also comes along with a number of disadvantages:
- no context sensitivity due to the implementation of the "max-prior method" for entitiy disambiguation (an improved
method for this is in progress)


## Installation

To install the package, run:
```bash
pip install spacy-entity-linker
```

Afterwards, the knowledge base (Wikidata) must be downloaded. This can be either be done by manually calling

```bash
python -m spacy_entity_linker "download_knowledge_base"
```

or when you first access the entity linker through spacy.
This will download and extract a ~1.3GB file that contains a preprocessed version of Wikidata.

## Use

```python
Expand Down Expand Up @@ -138,16 +155,6 @@ The current implementation supports only Sqlite. This is advantageous for develo
any special setup and configuration. However, for more performance critical usecases, a different database with
in-memory access (e.g. Redis) should be used. This may be implemented in the future.

## Installation

To install the package run: <code>pip install spacy-entity-linker</code>

Afterwards, the knowledge base (Wikidata) must be downloaded. This can be done by calling

<code>python -m spacy_entity_linker "download_knowledge_base"</code>

This will download and extract a ~500mb file that contains a preprocessed version of Wikidata

## Data
the knowledge base was derived from this dataset: https://www.kaggle.com/kenshoresearch/kensho-derived-wikimedia-data

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def open_file(fname):
zip_safe=True,
install_requires=[
'spacy>=3.0.0',
'numpy>=1.0.0'
'numpy>=1.0.0',
'tqdm'
],
entry_points={
'spacy_factories': 'entityLinker = spacy_entity_linker.EntityLinker:EntityLinker'
Expand Down
35 changes: 24 additions & 11 deletions spacy_entity_linker/DatabaseConnection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import sqlite3
import os

from .__main__ import download_knowledge_base

MAX_DEPTH_CHAIN = 10
P_INSTANCE_OF = 31
P_SUBCLASS = 279
Expand All @@ -11,7 +13,7 @@
entity_cache = {}
chain_cache = {}

DB_DEFAULT_PATH = os.path.abspath(__file__ + '/../../data_spacy_entity_linker/wikidb_filtered.db')
DB_DEFAULT_PATH = os.path.abspath(os.path.join(__file__, "../../data_spacy_entity_linker/wikidb_filtered.db"))

wikidata_instance = None

Expand Down Expand Up @@ -49,6 +51,10 @@ def _add_to_cache(self, cache_type, key, value):
self.cache[cache_type][key] = value

def init_database_connection(self, path=DB_DEFAULT_PATH):
# check if the database exists
if not os.path.exists(DB_DEFAULT_PATH):
# Automatically download the knowledge base if it isn't already
download_knowledge_base()
self.conn = sqlite3.connect(path)

def clear_cache(self):
Expand All @@ -61,9 +67,9 @@ def get_entities_from_alias(self, alias):
if self._is_cached("entity", alias):
return self._get_cached_value("entity", alias).copy()

query_alias = """SELECT j.item_id,j.en_label, j.en_description,j.views,j.inlinks,a.en_alias from aliases as a
LEFT JOIN joined as j ON a.item_id = j.item_id
WHERE a.en_alias_lowercase = ? and j.item_id NOT NULL"""
query_alias = """SELECT j.item_id,j.en_label, j.en_description,j.views,j.inlinks,a.en_alias
FROM aliases as a LEFT JOIN joined as j ON a.item_id = j.item_id
WHERE a.en_alias_lowercase = ? AND j.item_id NOT NULL"""

c.execute(query_alias, [alias.lower()])
fetched_rows = c.fetchall()
Expand Down Expand Up @@ -92,7 +98,7 @@ def get_entity_name(self, item_id):
res = c.fetchone()

if res and len(res):
if res[0] == None:
if res[0] is None:
self._add_to_cache("name", item_id, 'no label')
else:
self._add_to_cache("name", item_id, res[0])
Expand Down Expand Up @@ -148,10 +154,14 @@ def get_recursive_edges(self, item_id):
self._append_chain_elements(self, item_id, 0, chain, edges)
return edges

def _append_chain_elements(self, item_id, level=0, chain=[], edges=[], max_depth=10, property=P_INSTANCE_OF):
properties = property
if type(property) != list:
properties = [property]
def _append_chain_elements(self, item_id, level=0, chain=None, edges=None, max_depth=10, prop=P_INSTANCE_OF):
if chain is None:
chain = []
if edges is None:
edges = []
properties = prop
if type(prop) != list:
properties = [prop]

if self._is_cached("chain", (item_id, max_depth)):
chain += self._get_cached_value("chain", (item_id, max_depth)).copy()
Expand All @@ -176,9 +186,12 @@ def _append_chain_elements(self, item_id, level=0, chain=[], edges=[], max_depth
if not (target_item[0] in chain_ids):
chain += [(target_item[0], level + 1)]
edges.append((item_id, target_item[0], target_item[1]))
self._append_chain_elements(target_item[0], level=level + 1, chain=chain, edges=edges,
self._append_chain_elements(target_item[0],
level=level + 1,
chain=chain,
edges=edges,
max_depth=max_depth,
property=property)
prop=prop)

self._add_to_cache("chain", (item_id, max_depth), chain)

Expand Down
60 changes: 40 additions & 20 deletions spacy_entity_linker/__main__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,48 @@
import sys
import tarfile
import urllib.request
import tqdm
import os


class DownloadProgressBar(tqdm.tqdm):
"""
Code taken from https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
"""
def update_to(self, chunk_id=1, max_chunk_size=1, total_size=None):
if total_size is not None:
self.total = total_size
self.update(chunk_id * max_chunk_size - self.n)


def download_knowledge_base(
file_url="https://huggingface.co/MartinoMensio/spaCy-entity-linker/resolve/main/knowledge_base.tar.gz"
):
OUTPUT_TAR_FILE = os.path.abspath(
os.path.dirname(__file__)) + '/../data_spacy_entity_linker/wikidb_filtered.tar.gz'
OUTPUT_DB_PATH = os.path.abspath(os.path.dirname(__file__)) + '/../data_spacy_entity_linker'
if not os.path.exists(OUTPUT_DB_PATH):
os.makedirs(OUTPUT_DB_PATH)
with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc='Downloading knowledge base') as dpb:
urllib.request.urlretrieve(file_url, filename=OUTPUT_TAR_FILE, reporthook=dpb.update_to)

tar = tarfile.open(OUTPUT_TAR_FILE)
tar.extractall(OUTPUT_DB_PATH)
tar.close()

os.remove(OUTPUT_TAR_FILE)


if __name__ == "__main__":
import sys
import urllib
import urllib.request
import tarfile
import os

if len(sys.argv) < 2:
print("No arguments given")
print("No arguments given.")
pass

command = sys.argv.pop(1)

if command == "download_knowledge_base":
FILE_URL = "https://huggingface.co/MartinoMensio/spaCy-entity-linker/resolve/main/knowledge_base.tar.gz"

OUTPUT_TAR_FILE = os.path.abspath(
os.path.dirname(__file__)) + '/../data_spacy_entity_linker/wikidb_filtered.tar.gz'
OUTPUT_DB_PATH = os.path.abspath(os.path.dirname(__file__)) + '/../data_spacy_entity_linker'
if not os.path.exists(OUTPUT_DB_PATH):
os.makedirs(OUTPUT_DB_PATH)
urllib.request.urlretrieve(FILE_URL, OUTPUT_TAR_FILE)

tar = tarfile.open(OUTPUT_TAR_FILE)
tar.extractall(OUTPUT_DB_PATH)
tar.close()

os.remove(OUTPUT_TAR_FILE)
download_knowledge_base()
else:
raise ValueError("Unrecognized command given. If you are trying to install the knowledge base, run "
"'python -m spacy_entity_linker \"download_knowledge_base\"'.")
15 changes: 15 additions & 0 deletions tests/test_EntityLinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,18 @@ def test_initialization(self):
sent._.linkedEntities.pretty_print()

self.nlp.remove_pipe("entityLinker")

def test_empty_root(self):
# test empty lists of roots (#9)
self.nlp.add_pipe("entityLinker", last=True)

doc = self.nlp(
'I was right."\n\n "To that extent."\n\n "But that was all."\n\n "No, no, m')
for sent in doc.sents:
sent._.linkedEntities.pretty_print()
# empty document
doc = self.nlp('\n\n')
for sent in doc.sents:
sent._.linkedEntities.pretty_print()

self.nlp.remove_pipe("entityLinker")