Skip to content

Commit

Permalink
Add check for callable to 'Language.replace_pipe' to fix #3737 (#3741)
Browse files Browse the repository at this point in the history
  • Loading branch information
BreakBB authored and ines committed May 14, 2019
1 parent 8baff1c commit ed18a6e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
2 changes: 2 additions & 0 deletions spacy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,8 @@ class Errors(object):
E133 = ("The sum of prior probabilities for alias '{alias}' should not exceed 1, "
"but found {sum}.")
E134 = ("Alias '{alias}' defined for unknown entity '{entity}'.")
E135 = ("If you meant to replace a built-in component, use `create_pipe`: "
"`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")


@add_codes
Expand Down
5 changes: 5 additions & 0 deletions spacy/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ def replace_pipe(self, name, component):
"""
if name not in self.pipe_names:
raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names))
if not hasattr(component, "__call__"):
msg = Errors.E003.format(component=repr(component), name=name)
if isinstance(component, basestring_) and component in self.factories:
msg += Errors.E135.format(name=name)
raise ValueError(msg)
self.pipeline[self.pipe_names.index(name)] = (name, component)

def rename_pipe(self, old_name, new_name):
Expand Down
6 changes: 4 additions & 2 deletions spacy/tests/pipeline/test_pipe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ def test_get_pipe(nlp, name):
assert nlp.get_pipe(name) == new_pipe


@pytest.mark.parametrize("name,replacement", [("my_component", lambda doc: doc)])
def test_replace_pipe(nlp, name, replacement):
@pytest.mark.parametrize("name,replacement,not_callable", [("my_component", lambda doc: doc, {})])
def test_replace_pipe(nlp, name, replacement, not_callable):
with pytest.raises(ValueError):
nlp.replace_pipe(name, new_pipe)
nlp.add_pipe(new_pipe, name=name)
with pytest.raises(ValueError):
nlp.replace_pipe(name, not_callable)
nlp.replace_pipe(name, replacement)
assert nlp.get_pipe(name) != new_pipe
assert nlp.get_pipe(name) == replacement
Expand Down

0 comments on commit ed18a6e

Please sign in to comment.