Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change option from model-name-or-path to simpler model, fix flake8 len 120 #40

Merged
merged 2 commits into from
May 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions docs/experiments-msmarco-passage.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,23 @@ First, lets evaluate using monoBERT!

```
python -um pygaggle.run.evaluate_passage_ranker --split dev \
--method seq_class_transformer \
--model-name-or-path castorini/monobert-large-msmarco \
--dataset data/msmarco_ans_small/ \
--index-dir indexes/index-msmarco-passage-20191117-0ed488 \
--task msmarco \
--output-file runs/run.monobert.ans_small.dev.tsv
--method seq_class_transformer \
--model castorini/monobert-large-msmarco \
--dataset data/msmarco_ans_small/ \
--index-dir indexes/index-msmarco-passage-20191117-0ed488 \
--task msmarco \
--output-file runs/run.monobert.ans_small.dev.tsv
```

Upon completion, the following output will be visible:

```
precision@1 0.2761904761904762
recall@3 0.42698412698412697
recall@50 0.8174603174603176
recall@1000 0.8476190476190476
mrr 0.41089693612003686
mrr@10 0.4026795162509449
precision@1 0.2761904761904762
recall@3 0.42698412698412697
recall@50 0.8174603174603176
recall@1000 0.8476190476190476
mrr 0.41089693612003686
mrr@10 0.4026795162509449
```

It takes about ~52 minutes to re-rank this subset on MS MARCO using a P100.
Expand All @@ -106,7 +106,7 @@ We use the monoT5-base variant as it is the easiest to run without access to lar
```
python -um pygaggle.run.evaluate_passage_ranker --split dev \
--method t5 \
--model-name-or-path castorini/monot5-base-msmarco \
--model castorini/monot5-base-msmarco \
--dataset data/msmarco_ans_small \
--model-type t5-base \
--task msmarco \
Expand All @@ -118,12 +118,12 @@ python -um pygaggle.run.evaluate_passage_ranker --split dev \
The following output will be visible after it has finished:

```
precision@1 0.26666666666666666
recall@3 0.4603174603174603
recall@50 0.8063492063492063
recall@1000 0.8476190476190476
mrr 0.3973368360121561
mrr@10 0.39044217687074834
precision@1 0.26666666666666666
recall@3 0.4603174603174603
recall@50 0.8063492063492063
recall@1000 0.8476190476190476
mrr 0.3973368360121561
mrr@10 0.39044217687074834
```

It takes about ~13 minutes to re-rank this subset on MS MARCO using a P100.
Expand Down
23 changes: 13 additions & 10 deletions pygaggle/run/evaluate_passage_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class PassageRankingEvaluationOptions(BaseModel):
dataset: Path
index_dir: Path
method: str
model_name_or_path: str
model: str
split: str
batch_size: int
device: str
Expand All @@ -63,8 +63,8 @@ def index_dir_exists(cls, v: Path):
assert v.exists(), 'index directory must exist'
return v

@validator('model_name_or_path')
def model_name_sane(cls, v: Optional[str], values, **kwargs):
@validator('model')
def model_sane(cls, v: str, values, **kwargs):
method = values['method']
if method == 'transformer' and v is None:
raise ValueError('transformer name or path must be specified')
Expand All @@ -73,13 +73,13 @@ def model_name_sane(cls, v: Optional[str], values, **kwargs):
@validator('tokenizer_name')
def tokenizer_sane(cls, v: str, values, **kwargs):
if v is None:
return values['model_name_or_path']
return values['model']
return v


def construct_t5(options: PassageRankingEvaluationOptions) -> Reranker:
device = torch.device(options.device)
model = T5ForConditionalGeneration.from_pretrained(options.model_name_or_path,
model = T5ForConditionalGeneration.from_pretrained(options.model,
from_tf=options.from_tf).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(options.model_type)
tokenizer = T5BatchTokenizer(tokenizer, options.batch_size)
Expand All @@ -89,7 +89,7 @@ def construct_t5(options: PassageRankingEvaluationOptions) -> Reranker:
def construct_transformer(options:
PassageRankingEvaluationOptions) -> Reranker:
device = torch.device(options.device)
model = AutoModel.from_pretrained(options.model_name_or_path,
model = AutoModel.from_pretrained(options.model,
from_tf=options.from_tf).to(device).eval()
tokenizer = SimpleBatchTokenizer(AutoTokenizer.from_pretrained(
options.tokenizer_name),
Expand All @@ -102,15 +102,15 @@ def construct_seq_class_transformer(options: PassageRankingEvaluationOptions
) -> Reranker:
try:
model = AutoModelForSequenceClassification.from_pretrained(
options.model_name_or_path, from_tf=options.from_tf)
options.model, from_tf=options.from_tf)
except AttributeError:
# Hotfix for BioBERT MS MARCO. Refactor.
BertForSequenceClassification.bias = torch.nn.Parameter(
torch.zeros(2))
BertForSequenceClassification.weight = torch.nn.Parameter(
torch.zeros(2, 768))
model = BertForSequenceClassification.from_pretrained(
options.model_name_or_path, from_tf=options.from_tf)
options.model, from_tf=options.from_tf)
model.classifier.weight = BertForSequenceClassification.weight
model.classifier.bias = BertForSequenceClassification.bias
device = torch.device(options.device)
Expand All @@ -134,7 +134,10 @@ def main():
required=True,
type=str,
choices=METHOD_CHOICES),
opt('--model-name-or-path', type=str),
opt('--model',
required=True,
type=str,
help='Path to pre-trained model or huggingface model name'),
opt('--output-file', type=Path, default='.'),
opt('--overwrite-output', action='store_true'),
opt('--split',
Expand All @@ -150,7 +153,7 @@ def main():
nargs='+',
default=metric_names(),
choices=metric_names()),
opt('--model-type', type=str, default='bert-base'),
opt('--model-type', type=str),
opt('--tokenizer-name', type=str))
args = apb.parser.parse_args()
options = PassageRankingEvaluationOptions(**vars(args))
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[flake8]
max-line-length = 100
max-line-length = 120