Skip to content

Commit

Permalink
Add resume logic to spacy pretrain (#3652)
Browse files Browse the repository at this point in the history
* Added ability to resume training

* Add to readmee

* Remove duplicate entry
  • Loading branch information
tokestermw authored and honnibal committed Jun 12, 2019
1 parent eb3e426 commit 9c064e6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
15 changes: 15 additions & 0 deletions spacy/cli/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .._ml import Tok2Vec, flatten, chain, create_default_optimizer
from .._ml import masked_language_model
from .. import util
from .train import _load_pretrained_tok2vec


@plac.annotations(
Expand All @@ -36,6 +37,12 @@
seed=("Seed for random number generators", "option", "s", int),
n_iter=("Number of iterations to pretrain", "option", "i", int),
n_save_every=("Save model every X batches.", "option", "se", int),
init_tok2vec=(
"Path to pretrained weights for the token-to-vector parts of the models. See 'spacy pretrain'. Experimental.",
"option",
"t2v",
Path,
),
)
def pretrain(
texts_loc,
Expand All @@ -53,6 +60,7 @@ def pretrain(
min_length=5,
seed=0,
n_save_every=None,
init_tok2vec=None,
):
"""
Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
Expand All @@ -70,6 +78,9 @@ def pretrain(
errors around this need some improvement.
"""
config = dict(locals())
for key in config:
if isinstance(config[key], Path):
config[key] = str(config[key])
msg = Printer()
util.fix_random_seed(seed)

Expand Down Expand Up @@ -112,6 +123,10 @@ def pretrain(
subword_features=True, # Set to False for Chinese etc
),
)
# Load in pre-trained weights
if init_tok2vec is not None:
components = _load_pretrained_tok2vec(nlp, init_tok2vec)
msg.text("Loaded pretrained tok2vec for: {}".format(components))
optimizer = create_default_optimizer(model.ops)
tracker = ProgressTracker(frequency=10000)
msg.divider("Pre-training tok2vec layer")
Expand Down
1 change: 1 addition & 0 deletions website/docs/api/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ $ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir] [--width]
| `--n-iter`, `-i` | option | Number of iterations to pretrain. |
| `--use-vectors`, `-uv` | flag | Whether to use the static vectors as input features. |
| `--n-save_every`, `-se` | option | Save model every X batches. |
| `--init-tok2vec`, `-t2v` <Tag variant="new">2.1</Tag> | option | Path to pretrained weights for the token-to-vector parts of the models. See `spacy pretrain`. Experimental.|
| **CREATES** | weights | The pre-trained weights that can be used to initialize `spacy train`. |
### JSONL format for raw text {#pretrain-jsonl}
Expand Down

0 comments on commit 9c064e6

Please sign in to comment.