From f6212655972e421fc4398b3956153aa1222c1b87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xueguang=20Ma=20=E9=A9=AC=E9=9B=AA=E5=85=89?= Date: Thu, 11 Jun 2020 04:03:29 +0800 Subject: [PATCH] Add evaluate_document_ranker (#49) * add evaluate_document_ranker * remove .DS_Store file and update gitignore * add options for segment size and stride * move segment method and aggregrate method to class and make aggregate method configurable * set max_length=512 for T5Tokenizer * fix bug in segmentation.py * remove treccar from task choice and change class name from Passage to Document * change sum aggregate method to mean and make aggregate method configurable via options * add aggregate_method argument for evaluate_by_segment * add comment for self.doc_end_indexes * remove try-except for construct_seq_class_transformer * removed unused import * maintain docid for each segment * make SegmentProcess independent from documents text * Create experiments-msmarco-document.md * Update experiments-msmarco-document.md * Update experiments-msmarco-document.md * add tqdm for loading passage * add "Preprocessing:" and change query_passage_tuples to return list * add info for process * add operation info for document * fix typo * Update pre-built index * fix fh&sh scores * add info of if writting to run * remove sanity check using MSMARCO evaluation script Co-authored-by: Ronak * fix typo Co-authored-by: Ronak * remove evaluating by msmarco_eval.py Co-authored-by: Ronak * fix typo Co-authored-by: Ronak Co-authored-by: richard3983 Co-authored-by: Ronak --- .gitignore | 1 + docs/experiments-msmarco-document.md | 125 ++++++++++++++++ pygaggle/data/msmarco.py | 7 +- pygaggle/data/segmentation.py | 63 ++++++++ pygaggle/model/evaluate.py | 21 +++ pygaggle/model/tokenize.py | 1 + pygaggle/model/writer.py | 2 +- pygaggle/run/evaluate_document_ranker.py | 177 +++++++++++++++++++++++ pygaggle/run/evaluate_passage_ranker.py | 3 + 9 files changed, 396 insertions(+), 4 deletions(-) create mode 100644 docs/experiments-msmarco-document.md create mode 100644 pygaggle/data/segmentation.py create mode 100644 pygaggle/run/evaluate_document_ranker.py diff --git a/.gitignore b/.gitignore index 12bda015..39ee120c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # ide .idea/ .vscode +.DS_Store # python *.pyc diff --git a/docs/experiments-msmarco-document.md b/docs/experiments-msmarco-document.md new file mode 100644 index 00000000..63903992 --- /dev/null +++ b/docs/experiments-msmarco-document.md @@ -0,0 +1,125 @@ +# PyGaggle: Neural Ranking Baselines on [MS MARCO Document Retrieval](https://github.com/microsoft/TREC-2019-Deep-Learning) + +This page contains instructions for running various neural reranking baselines on the MS MARCO *document* ranking task. +Note that there is also a separate [MS MARCO *passage* ranking task](https://github.com/castorini/pygaggle/blob/master/docs/experiments-msmarco-passage.md). + +Prior to running this, we suggest looking at our first-stage [BM25 ranking instructions](https://github.com/castorini/anserini/blob/master/docs/experiments-msmarco-doc.md). +We rerank the BM25 run files that contain ~1000 documents per query using monoT5. +monoT5 is a pointwise reranker. This means that each document is scored independently using T5. + +Since it can take many hours to run these models on all of the 5193 queries from the MS MARCO dev set, we will instead use a subset of 50 queries randomly sampled from the dev set. + +Note 1: Run the following instructions at root of this repo. +Note 2: Make sure that you have access to a GPU +Note 3: Installation must have been done from source and make sure the [anserini-eval](https://github.com/castorini/anserini-eval) submodule is pulled. +To do this, first clone the repository recursively. + +``` +git clone --recursive https://github.com/castorini/pygaggle.git +``` + +Then install PyGaggle using: + +``` +pip install pygaggle/ +``` + +## Models + ++ monoT5-base: Document Ranking with a Pretrained Sequence-to-Sequence Model [(Nogueira et al., 2020)](https://arxiv.org/pdf/2003.06713.pdf) + +## Data Prep + +We're first going to download the queries, qrels and run files corresponding to the MS MARCO set considered. The run file is generated by following the BM25 ranking instructions. We'll store all these files in the `data` directory. + +``` +wget https://www.dropbox.com/s/8lvdkgzjjctxhzy/msmarco_doc_ans_small.zip -P data +``` + +To confirm, `msmarco_doc_ans_small.zip` should have MD5 checksum of `aeed5902c23611e21eaa156d908c4748`. + +Next, we extract the contents into `data`. + +``` +unzip data/msmarco_doc_ans_small.zip -d data +``` + +`msmarco_doc_ans_small` contains two disjoint sets, `fh` and `sh`, and each set has 25 queries. + +Let's download the pre-built MS MARCO index : + +``` +wget https://www.dropbox.com/s/awukuo8c0tkl9sc/index-msmarco-doc-20200527-a1ecfa.tar.gz +``` + +`index-msmarco-doc-20200527-a1ecfa.tar.gz ` should have MD5 checksum of `72b1a0f9a9094a86d15c6f4babf8967a`. + +Then, we can extract it into into `indexes`: + +``` +tar xvfz index-msmarco-doc-20200527-a1ecfa.tar.gz -C indexes +rm index-msmarco-doc-20200527-a1ecfa.tar.gz +``` + +Now, we can begin with re-ranking the set. + +## Re-Ranking with monoT5 + +Let us now re-rank the first half: + +``` +python -um pygaggle.run.evaluate_document_ranker --split dev \ + --method t5 \ + --model castorini/monot5-base-msmarco \ + --dataset data/msmarco_doc_ans_small/fh \ + --model-type t5-base \ + --task msmarco \ + --index-dir indexes/index-msmarco-doc-20200527-a1ecfa \ + --batch-size 32 \ + --output-file runs/run.monot5.doc_fh.dev.tsv +``` + +The following output will be visible after it has finished: + +``` +precision@1 0.16 +recall@3 0.44 +recall@50 0.84 +recall@1000 0.88 +mrr 0.33663 +mrr@10 0.33171 +``` + +It takes about 5 hours to re-rank this subset on MS MARCO using a P100. +It is worth noting again that you might need to modify the batch size to best fit the GPU at hand. + +Upon completion, the re-ranked run file `runs/run.monot5.doc_fh.dev.tsv` will be available in the `runs` directory. + +We can use the official MS MARCO evaluation script to verify the MRR@10: + +``` +python eval/msmarco_eval.py data/msmarco_doc_ans_small/fh/qrels.dev.small.tsv runs/run.monot5.doc_fh.dev.tsv +``` + +You should see the same result. + +We can modify the argument for `--dataset` to `data/msmarco_doc_ans_small/sh` to re-rank the second half of the dataset, and don't forget to change output file name. + +The results are as follows: + +``` +precision@1 0.24 +recall@3 0.32 +recall@50 0.76 +recall@1000 0.88 +mrr 0.31052 +mrr@10 0.29133 +``` + + + + +If you were able to replicate these results, please submit a PR adding to the replication log! + + +## Replication Log diff --git a/pygaggle/data/msmarco.py b/pygaggle/data/msmarco.py index 95c4bf6e..fdbe1a5c 100644 --- a/pygaggle/data/msmarco.py +++ b/pygaggle/data/msmarco.py @@ -7,6 +7,7 @@ from pydantic import BaseModel import scipy.special as sp import numpy as np +from tqdm import tqdm from .relevance import RelevanceExample, MsMarcoPassageLoader from pygaggle.rerank.base import Query, Text @@ -82,16 +83,16 @@ def from_folder(cls, cls.load_run(run_path))) def query_passage_tuples(self, is_duo: bool = False): - return (((ex.qid, ex.text, ex.relevant_candidates), perm_pas) + return [((ex.qid, ex.text, ex.relevant_candidates), perm_pas) for ex in self.examples - for perm_pas in permutations(ex.candidates, r=1+int(is_duo))) + for perm_pas in permutations(ex.candidates, r=1+int(is_duo))] def to_relevance_examples(self, index_path: str, is_duo: bool = False) -> List[RelevanceExample]: loader = MsMarcoPassageLoader(index_path) example_map = {} - for (qid, text, rel_cands), cands in self.query_passage_tuples(): + for (qid, text, rel_cands), cands in tqdm(self.query_passage_tuples()): if qid not in example_map: example_map[qid] = [convert_to_unicode(text), [], [], []] example_map[qid][1].append([cand for cand in cands][0]) diff --git a/pygaggle/data/segmentation.py b/pygaggle/data/segmentation.py new file mode 100644 index 00000000..29241efe --- /dev/null +++ b/pygaggle/data/segmentation.py @@ -0,0 +1,63 @@ +import spacy +import numpy as np +from pygaggle.rerank.base import Text +from typing import List +from dataclasses import dataclass +from copy import deepcopy + + +@dataclass +class SegmentGroup: + """ + 'segments' stores the List of document segments + 'doc_end_indexes' stores the index of the last segment of each + document when 'segment()' converting a 'List[Text]' of documents into + 'List[Text]' of segments. It will be used to split and group segments' + scores and feed the aggregated scores back to documents in 'aggregate()' + """ + segments: List[Text] + doc_end_indexes: List[int] + + +class SegmentProcessor: + + def __init__(self, max_characters=10000000): + self.nlp = spacy.blank("en") + self.nlp.add_pipe(self.nlp.create_pipe("sentencizer")) + self.max_characters = max_characters + self.aggregate_methods = { + "max": self._max_aggregate, + "mean": self._mean_aggregate + } + + def segment(self, documents: List[Text], seg_size: int, stride: int) -> SegmentGroup: + segmented_doc, doc_end_indexes, end_idx = [], [0], 0 + + for document in documents: + doc = self.nlp(document.text[:self.max_characters]) + sentences = [sent.string.strip() for sent in doc.sents] + for i in range(0, len(sentences), stride): + segment_text = ' '.join(sentences[i:i + seg_size]) + segmented_doc.append(Text(segment_text, dict(docid=document.raw["docid"]))) + if i + seg_size >= len(sentences): + end_idx += i/stride + 1 + doc_end_indexes.append(int(end_idx)) + break + return SegmentGroup(segmented_doc, doc_end_indexes) + + def aggregate(self, documents: List[Text], segments_group: SegmentGroup, method: str = "max") -> List[Text]: + docs = deepcopy(documents) + for i in range(len(docs)): + doc_start_idx = segments_group.doc_end_indexes[i] + doc_end_idx = segments_group.doc_end_indexes[i+1] + target_scores = [seg.score for seg in segments_group.segments[doc_start_idx: doc_end_idx]] + docs[i].score = self.aggregate_methods[method](target_scores) + return docs + + @staticmethod + def _max_aggregate(scores): + return max(scores) + + @staticmethod + def _mean_aggregate(scores): + return np.mean(scores) diff --git a/pygaggle/model/evaluate.py b/pygaggle/model/evaluate.py index 72515082..bb136838 100644 --- a/pygaggle/model/evaluate.py +++ b/pygaggle/model/evaluate.py @@ -10,6 +10,8 @@ from pygaggle.rerank.base import Reranker from pygaggle.model.writer import Writer +from pygaggle.data.segmentation import SegmentProcessor + __all__ = ['RerankerEvaluator', 'metric_names'] METRIC_MAP = OrderedDict() @@ -163,3 +165,22 @@ def evaluate(self, for metric in metrics: metric.accumulate(scores, example) return metrics + + def evaluate_by_segments(self, + examples: List[RelevanceExample], + seg_size: int, + stride: int, + aggregate_method: str) -> List[MetricAccumulator]: + metrics = [cls() for cls in self.metrics] + segment_processor = SegmentProcessor() + for example in tqdm(examples, disable=not self.use_tqdm): + segment_group = segment_processor.segment(example.documents, seg_size, stride) + segment_group.segments = self.reranker.rerank(example.query, segment_group.segments) + doc_scores = [x.score for x in segment_processor.aggregate(example.documents, + segment_group, + aggregate_method)] + if self.writer is not None: + self.writer.write(doc_scores, example) + for metric in metrics: + metric.accumulate(doc_scores, example) + return metrics diff --git a/pygaggle/model/tokenize.py b/pygaggle/model/tokenize.py index 940aeffd..6ceb3dc2 100644 --- a/pygaggle/model/tokenize.py +++ b/pygaggle/model/tokenize.py @@ -112,6 +112,7 @@ def __init__(self, *args, **kwargs): kwargs['return_attention_mask'] = True kwargs['pad_to_max_length'] = True kwargs['return_tensors'] = 'pt' + kwargs['max_length'] = 512 super().__init__(*args, **kwargs) diff --git a/pygaggle/model/writer.py b/pygaggle/model/writer.py index 5251d94c..58f70dc7 100644 --- a/pygaggle/model/writer.py +++ b/pygaggle/model/writer.py @@ -10,7 +10,7 @@ class Writer: def __init__(self, path: Optional[Path] = None, overwrite: bool = True): self.to_output = str(path) not in [".", None] - print(self.to_output) + print(f'Writing run: {self.to_output}') if self.to_output: self.f = open(path, "w" if overwrite else "w+") diff --git a/pygaggle/run/evaluate_document_ranker.py b/pygaggle/run/evaluate_document_ranker.py new file mode 100644 index 00000000..aa934a43 --- /dev/null +++ b/pygaggle/run/evaluate_document_ranker.py @@ -0,0 +1,177 @@ +from typing import Optional, List +from pathlib import Path +import logging + +from pydantic import BaseModel, validator +from transformers import (AutoModel, + AutoTokenizer, + AutoModelForSequenceClassification, + T5ForConditionalGeneration) +import torch + +from .args import ArgumentParserBuilder, opt +from pygaggle.rerank.base import Reranker +from pygaggle.rerank.bm25 import Bm25Reranker +from pygaggle.rerank.transformer import ( + UnsupervisedTransformerReranker, + T5Reranker, + SequenceClassificationTransformerReranker +) +from pygaggle.rerank.random import RandomReranker +from pygaggle.rerank.similarity import CosineSimilarityMatrixProvider +from pygaggle.model import (SimpleBatchTokenizer, + T5BatchTokenizer, + RerankerEvaluator, + metric_names, + MsMarcoWriter) +from pygaggle.data import MsMarcoDataset +from pygaggle.settings import MsMarcoSettings + + +SETTINGS = MsMarcoSettings() +METHOD_CHOICES = ('transformer', 'bm25', 't5', 'seq_class_transformer', + 'random') + + +class DocumentRankingEvaluationOptions(BaseModel): + task: str + dataset: Path + index_dir: Path + method: str + model: str + split: str + batch_size: int + seg_size: int + seg_stride: int + aggregate_method: str + device: str + is_duo: bool + from_tf: bool + metrics: List[str] + model_type: Optional[str] + tokenizer_name: Optional[str] + + @validator('task') + def task_exists(cls, v: str): + assert v in ['msmarco'] + + @validator('dataset') + def dataset_exists(cls, v: Path): + assert v.exists(), 'data directory must exist' + return v + + @validator('index_dir') + def index_dir_exists(cls, v: Path): + assert v.exists(), 'index directory must exist' + return v + + @validator('model') + def model_sane(cls, v: str, values, **kwargs): + method = values['method'] + if method == 'transformer' and v is None: + raise ValueError('transformer name or path must be specified') + return v + + @validator('tokenizer_name') + def tokenizer_sane(cls, v: str, values, **kwargs): + if v is None: + return values['model'] + return v + + +def construct_t5(options: DocumentRankingEvaluationOptions) -> Reranker: + device = torch.device(options.device) + model = T5ForConditionalGeneration.from_pretrained(options.model, + from_tf=options.from_tf).to(device).eval() + tokenizer = AutoTokenizer.from_pretrained(options.model_type) + tokenizer = T5BatchTokenizer(tokenizer, options.batch_size) + return T5Reranker(model, tokenizer) + + +def construct_transformer(options: + DocumentRankingEvaluationOptions) -> Reranker: + device = torch.device(options.device) + model = AutoModel.from_pretrained(options.model, + from_tf=options.from_tf).to(device).eval() + tokenizer = SimpleBatchTokenizer(AutoTokenizer.from_pretrained( + options.tokenizer_name), + options.batch_size) + provider = CosineSimilarityMatrixProvider() + return UnsupervisedTransformerReranker(model, tokenizer, provider) + + +def construct_seq_class_transformer(options: DocumentRankingEvaluationOptions + ) -> Reranker: + model = AutoModelForSequenceClassification.from_pretrained(options.model, from_tf=options.from_tf) + device = torch.device(options.device) + model = model.to(device).eval() + tokenizer = AutoTokenizer.from_pretrained(options.tokenizer_name) + return SequenceClassificationTransformerReranker(model, tokenizer) + + +def construct_bm25(options: DocumentRankingEvaluationOptions) -> Reranker: + return Bm25Reranker(index_path=str(options.index_dir)) + + +def main(): + apb = ArgumentParserBuilder() + apb.add_opts(opt('--task', + type=str, + default='msmarco'), + opt('--dataset', type=Path, required=True), + opt('--index-dir', type=Path, required=True), + opt('--method', + required=True, + type=str, + choices=METHOD_CHOICES), + opt('--model', + required=True, + type=str, + help='Path to pre-trained model or huggingface model name'), + opt('--output-file', type=Path, default='.'), + opt('--overwrite-output', action='store_true'), + opt('--split', + type=str, + default='dev', + choices=('dev', 'eval')), + opt('--batch-size', '-bsz', type=int, default=96), + opt('--device', type=str, default='cuda:0'), + opt('--is-duo', action='store_true'), + opt('--from-tf', action='store_true'), + opt('--metrics', + type=str, + nargs='+', + default=metric_names(), + choices=metric_names()), + opt('--model-type', type=str), + opt('--tokenizer-name', type=str), + opt('--seg-size', type=int, default=10), + opt('--seg-stride', type=int, default=5), + opt('--aggregate-method', type=str, default="max")) + args = apb.parser.parse_args() + options = DocumentRankingEvaluationOptions(**vars(args)) + logging.info("Preprocessing Queries & Docs:") + ds = MsMarcoDataset.from_folder(str(options.dataset), split=options.split, + is_duo=options.is_duo) + examples = ds.to_relevance_examples(str(options.index_dir), + is_duo=options.is_duo) + logging.info("Loading Ranker & Tokenizer:") + construct_map = dict(transformer=construct_transformer, + bm25=construct_bm25, + t5=construct_t5, + seq_class_transformer=construct_seq_class_transformer, + random=lambda _: RandomReranker()) + reranker = construct_map[options.method](options) + writer = MsMarcoWriter(args.output_file, args.overwrite_output) + evaluator = RerankerEvaluator(reranker, options.metrics, writer=writer) + width = max(map(len, args.metrics)) + 1 + logging.info("Reranking:") + for metric in evaluator.evaluate_by_segments(examples, + options.seg_size, + options.seg_stride, + options.aggregate_method): + logging.info(f'{metric.name:<{width}}{metric.value:.5}') + + +if __name__ == '__main__': + main() diff --git a/pygaggle/run/evaluate_passage_ranker.py b/pygaggle/run/evaluate_passage_ranker.py index d424a4e9..d40010f4 100644 --- a/pygaggle/run/evaluate_passage_ranker.py +++ b/pygaggle/run/evaluate_passage_ranker.py @@ -157,10 +157,12 @@ def main(): opt('--tokenizer-name', type=str)) args = apb.parser.parse_args() options = PassageRankingEvaluationOptions(**vars(args)) + logging.info("Preprocessing Queries & Passages:") ds = MsMarcoDataset.from_folder(str(options.dataset), split=options.split, is_duo=options.is_duo) examples = ds.to_relevance_examples(str(options.index_dir), is_duo=options.is_duo) + logging.info("Loading Ranker & Tokenizer:") construct_map = dict(transformer=construct_transformer, bm25=construct_bm25, t5=construct_t5, @@ -170,6 +172,7 @@ def main(): writer = MsMarcoWriter(args.output_file, args.overwrite_output) evaluator = RerankerEvaluator(reranker, options.metrics, writer=writer) width = max(map(len, args.metrics)) + 1 + logging.info("Reranking:") for metric in evaluator.evaluate(examples): logging.info(f'{metric.name:<{width}}{metric.value:.5}')