Skip to content

Commit

Permalink
Enabling Tapex in table question answering pipeline. (huggingface#1…
Browse files Browse the repository at this point in the history
…6663)

* Enabling `Tapex` in table question answering pipeline.

* Questions are independant for Tapex, making the test respect that.

* Missing extra space.
  • Loading branch information
Narsil authored and elusenji committed Jun 12, 2022
1 parent 5ebf4dc commit 71d2f09
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 36 deletions.
89 changes: 53 additions & 36 deletions src/transformers/pipelines/table_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,10 @@ def __init__(self, args_parser=TableQuestionAnsweringArgumentHandler(), *args, *
else MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
)

self.aggregate = bool(getattr(self.model.config, "aggregation_labels")) and bool(
getattr(self.model.config, "num_aggregation_labels")
self.aggregate = bool(getattr(self.model.config, "aggregation_labels", None)) and bool(
getattr(self.model.config, "num_aggregation_labels", None)
)
self.type = "tapas" if hasattr(self.model.config, "aggregation_labels") else None

def batch_inference(self, **inputs):
return self.model(**inputs)
Expand Down Expand Up @@ -335,7 +336,13 @@ def _sanitize_parameters(self, sequential=None, padding=None, truncation=None, *
forward_params["sequential"] = sequential
return preprocess_params, forward_params, {}

def preprocess(self, pipeline_input, sequential=None, padding=True, truncation="drop_rows_to_fit"):
def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None):
if truncation is None:
if self.type == "tapas":
truncation = "drop_rows_to_fit"
else:
truncation = "do_not_truncate"

table, query = pipeline_input["table"], pipeline_input["query"]
if table.empty:
raise ValueError("table is empty")
Expand All @@ -347,45 +354,55 @@ def preprocess(self, pipeline_input, sequential=None, padding=True, truncation="

def _forward(self, model_inputs, sequential=False):
table = model_inputs.pop("table")
outputs = self.sequential_inference(**model_inputs) if sequential else self.batch_inference(**model_inputs)

if self.type == "tapas":
if sequential:
outputs = self.sequential_inference(**model_inputs)
else:
outputs = self.batch_inference(**model_inputs)
else:
outputs = self.model.generate(**model_inputs)
model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
return model_outputs

def postprocess(self, model_outputs):
inputs = model_outputs["model_inputs"]
table = model_outputs["table"]
outputs = model_outputs["outputs"]
if self.aggregate:
logits, logits_agg = outputs[:2]
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits, logits_agg)
answer_coordinates_batch, agg_predictions = predictions
aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)}

no_agg_label_index = self.model.config.no_aggregation_label_index
aggregators_prefix = {
i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index
}
if self.type == "tapas":
if self.aggregate:
logits, logits_agg = outputs[:2]
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits, logits_agg)
answer_coordinates_batch, agg_predictions = predictions
aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)}

no_agg_label_index = self.model.config.no_aggregation_label_index
aggregators_prefix = {
i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index
}
else:
logits = outputs[0]
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits)
answer_coordinates_batch = predictions[0]
aggregators = {}
aggregators_prefix = {}
answers = []
for index, coordinates in enumerate(answer_coordinates_batch):
cells = [table.iat[coordinate] for coordinate in coordinates]
aggregator = aggregators.get(index, "")
aggregator_prefix = aggregators_prefix.get(index, "")
answer = {
"answer": aggregator_prefix + ", ".join(cells),
"coordinates": coordinates,
"cells": [table.iat[coordinate] for coordinate in coordinates],
}
if aggregator:
answer["aggregator"] = aggregator

answers.append(answer)
if len(answer) == 0:
raise PipelineException("Empty answer")
else:
logits = outputs[0]
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits)
answer_coordinates_batch = predictions[0]
aggregators = {}
aggregators_prefix = {}

answers = []
for index, coordinates in enumerate(answer_coordinates_batch):
cells = [table.iat[coordinate] for coordinate in coordinates]
aggregator = aggregators.get(index, "")
aggregator_prefix = aggregators_prefix.get(index, "")
answer = {
"answer": aggregator_prefix + ", ".join(cells),
"coordinates": coordinates,
"cells": [table.iat[coordinate] for coordinate in coordinates],
}
if aggregator:
answer["aggregator"] = aggregator

answers.append(answer)
if len(answer) == 0:
raise PipelineException("Empty answer")
answers = [{"answer": answer} for answer in self.tokenizer.batch_decode(outputs, skip_special_tokens=True)]

return answers if len(answers) > 1 else answers[0]
28 changes: 28 additions & 0 deletions tests/pipelines/test_pipelines_table_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,3 +632,31 @@ def test_integration_sqa_tf(self):
{"answer": "28 november 1967", "coordinates": [(2, 3)], "cells": ["28 november 1967"]},
]
self.assertListEqual(results, expected_results)

@slow
@require_torch
def test_large_model_pt_tapex(self):
model_id = "microsoft/tapex-large-finetuned-wtq"
table_querier = pipeline(
"table-question-answering",
model=model_id,
)
data = {
"Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
"Age": ["56", "45", "59"],
"Number of movies": ["87", "53", "69"],
"Date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
}
queries = [
"How many movies has George Clooney played in?",
"How old is Mr Clooney ?",
"What's the date of birth of Leonardo ?",
]
results = table_querier(data, queries, sequential=True)

expected_results = [
{"answer": " 69"},
{"answer": " 59"},
{"answer": " 10 june 1996"},
]
self.assertListEqual(results, expected_results)

0 comments on commit 71d2f09

Please sign in to comment.