-
Notifications
You must be signed in to change notification settings - Fork 100
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add evaluate_document_ranker * remove .DS_Store file and update gitignore * add options for segment size and stride * move segment method and aggregrate method to class and make aggregate method configurable * set max_length=512 for T5Tokenizer * fix bug in segmentation.py * remove treccar from task choice and change class name from Passage to Document * change sum aggregate method to mean and make aggregate method configurable via options * add aggregate_method argument for evaluate_by_segment * add comment for self.doc_end_indexes * remove try-except for construct_seq_class_transformer * removed unused import * maintain docid for each segment * make SegmentProcess independent from documents text * Create experiments-msmarco-document.md * Update experiments-msmarco-document.md * Update experiments-msmarco-document.md * add tqdm for loading passage * add "Preprocessing:" and change query_passage_tuples to return list * add info for process * add operation info for document * fix typo * Update pre-built index * fix fh&sh scores * add info of if writting to run * remove sanity check using MSMARCO evaluation script Co-authored-by: Ronak <[email protected]> * fix typo Co-authored-by: Ronak <[email protected]> * remove evaluating by msmarco_eval.py Co-authored-by: Ronak <[email protected]> * fix typo Co-authored-by: Ronak <[email protected]> Co-authored-by: richard3983 <[email protected]> Co-authored-by: Ronak <[email protected]>
- Loading branch information
1 parent
13e099b
commit f621265
Showing
9 changed files
with
396 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# ide | ||
.idea/ | ||
.vscode | ||
.DS_Store | ||
|
||
# python | ||
*.pyc | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# PyGaggle: Neural Ranking Baselines on [MS MARCO Document Retrieval](https:/microsoft/TREC-2019-Deep-Learning) | ||
|
||
This page contains instructions for running various neural reranking baselines on the MS MARCO *document* ranking task. | ||
Note that there is also a separate [MS MARCO *passage* ranking task](https:/castorini/pygaggle/blob/master/docs/experiments-msmarco-passage.md). | ||
|
||
Prior to running this, we suggest looking at our first-stage [BM25 ranking instructions](https:/castorini/anserini/blob/master/docs/experiments-msmarco-doc.md). | ||
We rerank the BM25 run files that contain ~1000 documents per query using monoT5. | ||
monoT5 is a pointwise reranker. This means that each document is scored independently using T5. | ||
|
||
Since it can take many hours to run these models on all of the 5193 queries from the MS MARCO dev set, we will instead use a subset of 50 queries randomly sampled from the dev set. | ||
|
||
Note 1: Run the following instructions at root of this repo. | ||
Note 2: Make sure that you have access to a GPU | ||
Note 3: Installation must have been done from source and make sure the [anserini-eval](https:/castorini/anserini-eval) submodule is pulled. | ||
To do this, first clone the repository recursively. | ||
|
||
``` | ||
git clone --recursive https:/castorini/pygaggle.git | ||
``` | ||
|
||
Then install PyGaggle using: | ||
|
||
``` | ||
pip install pygaggle/ | ||
``` | ||
|
||
## Models | ||
|
||
+ monoT5-base: Document Ranking with a Pretrained Sequence-to-Sequence Model [(Nogueira et al., 2020)](https://arxiv.org/pdf/2003.06713.pdf) | ||
|
||
## Data Prep | ||
|
||
We're first going to download the queries, qrels and run files corresponding to the MS MARCO set considered. The run file is generated by following the BM25 ranking instructions. We'll store all these files in the `data` directory. | ||
|
||
``` | ||
wget https://www.dropbox.com/s/8lvdkgzjjctxhzy/msmarco_doc_ans_small.zip -P data | ||
``` | ||
|
||
To confirm, `msmarco_doc_ans_small.zip` should have MD5 checksum of `aeed5902c23611e21eaa156d908c4748`. | ||
|
||
Next, we extract the contents into `data`. | ||
|
||
``` | ||
unzip data/msmarco_doc_ans_small.zip -d data | ||
``` | ||
|
||
`msmarco_doc_ans_small` contains two disjoint sets, `fh` and `sh`, and each set has 25 queries. | ||
|
||
Let's download the pre-built MS MARCO index : | ||
|
||
``` | ||
wget https://www.dropbox.com/s/awukuo8c0tkl9sc/index-msmarco-doc-20200527-a1ecfa.tar.gz | ||
``` | ||
|
||
`index-msmarco-doc-20200527-a1ecfa.tar.gz ` should have MD5 checksum of `72b1a0f9a9094a86d15c6f4babf8967a`. | ||
|
||
Then, we can extract it into into `indexes`: | ||
|
||
``` | ||
tar xvfz index-msmarco-doc-20200527-a1ecfa.tar.gz -C indexes | ||
rm index-msmarco-doc-20200527-a1ecfa.tar.gz | ||
``` | ||
|
||
Now, we can begin with re-ranking the set. | ||
|
||
## Re-Ranking with monoT5 | ||
|
||
Let us now re-rank the first half: | ||
|
||
``` | ||
python -um pygaggle.run.evaluate_document_ranker --split dev \ | ||
--method t5 \ | ||
--model castorini/monot5-base-msmarco \ | ||
--dataset data/msmarco_doc_ans_small/fh \ | ||
--model-type t5-base \ | ||
--task msmarco \ | ||
--index-dir indexes/index-msmarco-doc-20200527-a1ecfa \ | ||
--batch-size 32 \ | ||
--output-file runs/run.monot5.doc_fh.dev.tsv | ||
``` | ||
|
||
The following output will be visible after it has finished: | ||
|
||
``` | ||
precision@1 0.16 | ||
recall@3 0.44 | ||
recall@50 0.84 | ||
recall@1000 0.88 | ||
mrr 0.33663 | ||
mrr@10 0.33171 | ||
``` | ||
|
||
It takes about 5 hours to re-rank this subset on MS MARCO using a P100. | ||
It is worth noting again that you might need to modify the batch size to best fit the GPU at hand. | ||
|
||
Upon completion, the re-ranked run file `runs/run.monot5.doc_fh.dev.tsv` will be available in the `runs` directory. | ||
|
||
We can use the official MS MARCO evaluation script to verify the MRR@10: | ||
|
||
``` | ||
python eval/msmarco_eval.py data/msmarco_doc_ans_small/fh/qrels.dev.small.tsv runs/run.monot5.doc_fh.dev.tsv | ||
``` | ||
|
||
You should see the same result. | ||
|
||
We can modify the argument for `--dataset` to `data/msmarco_doc_ans_small/sh` to re-rank the second half of the dataset, and don't forget to change output file name. | ||
|
||
The results are as follows: | ||
|
||
``` | ||
precision@1 0.24 | ||
recall@3 0.32 | ||
recall@50 0.76 | ||
recall@1000 0.88 | ||
mrr 0.31052 | ||
mrr@10 0.29133 | ||
``` | ||
|
||
|
||
|
||
|
||
If you were able to replicate these results, please submit a PR adding to the replication log! | ||
|
||
|
||
## Replication Log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import spacy | ||
import numpy as np | ||
from pygaggle.rerank.base import Text | ||
from typing import List | ||
from dataclasses import dataclass | ||
from copy import deepcopy | ||
|
||
|
||
@dataclass | ||
class SegmentGroup: | ||
""" | ||
'segments' stores the List of document segments | ||
'doc_end_indexes' stores the index of the last segment of each | ||
document when 'segment()' converting a 'List[Text]' of documents into | ||
'List[Text]' of segments. It will be used to split and group segments' | ||
scores and feed the aggregated scores back to documents in 'aggregate()' | ||
""" | ||
segments: List[Text] | ||
doc_end_indexes: List[int] | ||
|
||
|
||
class SegmentProcessor: | ||
|
||
def __init__(self, max_characters=10000000): | ||
self.nlp = spacy.blank("en") | ||
self.nlp.add_pipe(self.nlp.create_pipe("sentencizer")) | ||
self.max_characters = max_characters | ||
self.aggregate_methods = { | ||
"max": self._max_aggregate, | ||
"mean": self._mean_aggregate | ||
} | ||
|
||
def segment(self, documents: List[Text], seg_size: int, stride: int) -> SegmentGroup: | ||
segmented_doc, doc_end_indexes, end_idx = [], [0], 0 | ||
|
||
for document in documents: | ||
doc = self.nlp(document.text[:self.max_characters]) | ||
sentences = [sent.string.strip() for sent in doc.sents] | ||
for i in range(0, len(sentences), stride): | ||
segment_text = ' '.join(sentences[i:i + seg_size]) | ||
segmented_doc.append(Text(segment_text, dict(docid=document.raw["docid"]))) | ||
if i + seg_size >= len(sentences): | ||
end_idx += i/stride + 1 | ||
doc_end_indexes.append(int(end_idx)) | ||
break | ||
return SegmentGroup(segmented_doc, doc_end_indexes) | ||
|
||
def aggregate(self, documents: List[Text], segments_group: SegmentGroup, method: str = "max") -> List[Text]: | ||
docs = deepcopy(documents) | ||
for i in range(len(docs)): | ||
doc_start_idx = segments_group.doc_end_indexes[i] | ||
doc_end_idx = segments_group.doc_end_indexes[i+1] | ||
target_scores = [seg.score for seg in segments_group.segments[doc_start_idx: doc_end_idx]] | ||
docs[i].score = self.aggregate_methods[method](target_scores) | ||
return docs | ||
|
||
@staticmethod | ||
def _max_aggregate(scores): | ||
return max(scores) | ||
|
||
@staticmethod | ||
def _mean_aggregate(scores): | ||
return np.mean(scores) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.