Skip to content

Commit

Permalink
MSMARCO support: monoBERT (#14)
Browse files Browse the repository at this point in the history
* add monobert for marco

* temp change python3.7 to 3.6 for colab compatibility

* fix evaluation options

* fix issues

* add missing options in evaluate_passage_ranker

* working monobert

* update transformers, clean code

* update tokenizers

* add dataclasses if < 3.7

* cleanup todos

* update to newer transformers along with syntax, clean up settings

* model-name-or-path as str type

* fix tokenizer loading for t5
  • Loading branch information
ronakice authored Apr 30, 2020
1 parent 34345c8 commit 55e4961
Show file tree
Hide file tree
Showing 11 changed files with 360 additions and 15 deletions.
1 change: 1 addition & 0 deletions pygaggle/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .kaggle import *
from .relevance import *
from .msmarco import *
137 changes: 137 additions & 0 deletions pygaggle/data/msmarco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
from collections import OrderedDict, defaultdict
from typing import List, Set, DefaultDict
import json
import logging
from itertools import permutations

from pydantic import BaseModel
import scipy.special as sp
import numpy as np

from .relevance import RelevanceExample, MsMarcoPassageLoader
from pygaggle.model.tokenize import SpacySenticizer
from pygaggle.rerank.base import Query, Text
from pygaggle.data.unicode import convert_to_unicode


__all__ = ['MsMarcoExample', 'MsMarcoDataset']


class MsMarcoExample(BaseModel):
qid: str
text: str
candidates: List[str]
relevant_candidates: Set[str]

class MsMarcoDataset(BaseModel):
examples: List[MsMarcoExample]

@classmethod
def load_qrels(cls, path: str) -> DefaultDict[str, Set[str]]:
qrels = defaultdict(set)
with open(path) as f:
for i, line in enumerate(f):
qid, _, doc_id, relevance = line.rstrip().split('\t')
if int(relevance) >= 1:
qrels[qid].add(doc_id)
return qrels

@classmethod
def load_run(cls, path: str):
'''Returns OrderedDict[str, List[str]]'''
run = OrderedDict()
with open(path) as f:
for i, line in enumerate(f):
qid, doc_title, rank = line.split('\t')
if qid not in run:
run[qid] = []
run[qid].append((doc_title, int(rank)))
sorted_run = OrderedDict()
for qid, doc_titles_ranks in run.items():
sorted(doc_titles_ranks, key=lambda x: x[1])
doc_titles = [doc_titles for doc_titles, _ in doc_titles_ranks]
sorted_run[qid] = doc_titles
return sorted_run

@classmethod
def load_queries(cls,
path: str,
qrels: DefaultDict[str, Set[str]],
run) -> List[MsMarcoExample]:
queries = []
with open(path) as f:
for i, line in enumerate(f):
qid, query = line.rstrip().split('\t')
candidates = run[qid]
queries.append(MsMarcoExample(qid = qid,
text = query,
candidates = run[qid],
relevant_candidates = qrels[qid]))
return queries

@classmethod
def from_folder(cls,
folder: str,
split: str = 'dev',
is_duo: bool = False) -> 'MsMarcoDataset':
run_mono = "mono." if is_duo else ""
query_path = os.path.join(folder, f"queries.{split}.small.tsv")
qrels_path = os.path.join(folder, f"qrels.{split}.small.tsv")
run_path = os.path.join(folder, f"run.{run_mono}{split}.small.tsv")
return cls(examples = cls.load_queries(query_path,
cls.load_qrels(qrels_path),
cls.load_run(run_path)))


def query_passage_tuples(self, is_duo: bool = False):
return (((ex.qid, ex.text, ex.relevant_candidates), perm_pas) for ex in self.examples
for perm_pas in permutations(ex.candidates, r=1+int(is_duo)))


def to_relevance_examples(self,
index_path: str,
is_duo: bool = False) -> List[RelevanceExample]:
loader = MsMarcoPassageLoader(index_path)
example_map = {}
for (qid, text, rel_cands), cands in self.query_passage_tuples():
if qid not in example_map:
example_map[qid] = [convert_to_unicode(text), [], [], []]
example_map[qid][1].append([cand for cand in cands][0])
try:
passages = [loader.load_passage(cand) for cand in cands]
example_map[qid][2].append([convert_to_unicode(passage.all_text) for passage in passages][0])
except ValueError as e:
logging.warning(f'Skipping {passages}')
continue
example_map[qid][3].append(cands[0] in rel_cands)
mean_stats = defaultdict(list)
for ex in self.examples:
int_rels = np.array(list(map(int, example_map[ex.qid][3])))
p = int_rels.sum()/(len(ex.candidates) - 1) if is_duo else int_rels.sum()
mean_stats['Random P@1'].append(np.mean(int_rels))
n = len(ex.candidates) - p
N = len(ex.candidates)
if len(ex.candidates) <= 1000:
mean_stats['Random R@1000'].append(1 if 1 in int_rels else 0)
numer = np.array([sp.comb(n, i) / (N - i) for i in range(0, n + 1) if i!=N]) * p
if n == N:
numer = np.append(numer, 0)
denom = np.array([sp.comb(N, i) for i in range(0, n + 1)])
rr = 1 / np.arange(1, n + 2)
rmrr = np.sum(numer * rr / denom)
mean_stats['Random MRR'].append(rmrr)
rmrr10 = np.sum(numer[:10] * rr[:10] / denom[:10])
mean_stats['Random MRR@10'].append(rmrr10)
ex_index = len(ex.candidates)
for rel_cand in ex.relevant_candidates:
if rel_cand in ex.candidates:
ex_index = min(ex.candidates.index(rel_cand), ex_index)
mean_stats['Existing MRR'].append(1 / (ex_index + 1) if ex_index < len(ex.candidates) else 0)
mean_stats['Existing MRR@10'].append(1 / (ex_index + 1) if ex_index < 10 else 0)
for k, v in mean_stats.items():
logging.info(f'{k}: {np.mean(v)}')
return [RelevanceExample(Query(text=query_text, id=qid),
list(map(lambda s: Text(s[1], dict(docid=s[0])), zip(cands, cands_text))),
rel_cands) \
for qid, (query_text, cands, cands_text, rel_cands) in example_map.items()]
21 changes: 21 additions & 0 deletions pygaggle/data/relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ def all_text(self):
return '\n'.join((self.abstract, self.body_text, self.ref_entries))


@dataclass
class MsMarcoPassage:
para_text: str

@property
def all_text(self):
return self.para_text


class Cord19DocumentLoader:
double_space_pattern = re.compile(r'\s\s+')

Expand All @@ -50,3 +59,15 @@ def unfold(entries):
return Cord19Document(unfold(article['abstract']),
unfold(article['body_text']),
unfold(ref_entries))


class MsMarcoPassageLoader:
def __init__(self, index_path: str):
self.searcher = pysearch.SimpleSearcher(index_path)

def load_passage(self, id: str) -> MsMarcoPassage:
try:
passage = self.searcher.doc(id).lucene_document().get('raw')
except AttributeError:
raise ValueError('passage unretrievable')
return MsMarcoPassage(passage)
8 changes: 8 additions & 0 deletions pygaggle/data/unicode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
5 changes: 4 additions & 1 deletion pygaggle/model/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ def greedy_decode(model: PreTrainedModel,
past = model.get_encoder()(input_ids, attention_mask=attention_mask)
next_token_logits = None
for _ in range(length):
model_inputs = model.prepare_inputs_for_generation(decode_ids, past=past, attention_mask=attention_mask)
model_inputs = model.prepare_inputs_for_generation(decode_ids,
past=past,
attention_mask=attention_mask,
use_cache=True)
outputs = model(**model_inputs) # (batch_size, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size, vocab_size)
decode_ids = torch.cat([decode_ids, next_token_logits.max(1)[1].unsqueeze(-1)], dim=-1)
Expand Down
17 changes: 17 additions & 0 deletions pygaggle/model/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ class RecallAt3Metric(TopkMixin, RecallAccumulator):
top_k = 3


@register_metric('recall@50')
class RecallAt50Metric(TopkMixin, RecallAccumulator):
top_k = 50


@register_metric('recall@1000')
class RecallAt1000Metric(TopkMixin, RecallAccumulator):
top_k = 1000


@register_metric('mrr')
class MrrMetric(MeanAccumulator):
def accumulate(self, scores: List[float], gold: RelevanceExample):
Expand All @@ -108,6 +118,13 @@ def accumulate(self, scores: List[float], gold: RelevanceExample):
self.scores.append(rr)


@register_metric('mrr@10')
class MrrAt10Metric(MeanAccumulator):
def accumulate(self, scores: List[float], gold: RelevanceExample):
scores = sorted(list(enumerate(scores)), key=lambda x: x[1], reverse=True)
rr = next((1 / (rank_idx + 1) for rank_idx, (idx, _) in enumerate(scores) if (gold.labels[idx] and rank_idx < 10)), 0)
self.scores.append(rr)

class ThresholdedRecallMetric(DynamicThresholdingMixin, RecallAccumulator):
threshold = 0.5

Expand Down
5 changes: 4 additions & 1 deletion pygaggle/rerank/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ class Query:
----------
text : str
The query text.
id : Optional[str]
The query id.
"""
def __init__(self, text: str):
def __init__(self, text: str, id: Optional[str] = None):
self.text = text
self.id = id


class Text:
Expand Down
4 changes: 2 additions & 2 deletions pygaggle/run/evaluate_kaggle_highlighter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from pygaggle.rerank.similarity import CosineSimilarityMatrixProvider
from pygaggle.model import SimpleBatchTokenizer, CachedT5ModelLoader, T5BatchTokenizer, RerankerEvaluator, metric_names
from pygaggle.data import LitReviewDataset
from pygaggle.settings import Settings
from pygaggle.settings import Cord19Settings


SETTINGS = Settings()
SETTINGS = Cord19Settings()
METHOD_CHOICES = ('transformer', 'bm25', 't5', 'seq_class_transformer', 'qa_transformer', 'random')


Expand Down
Loading

0 comments on commit 55e4961

Please sign in to comment.