Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling Tapex in table question answering pipeline. #16663

Merged
merged 3 commits into from
Apr 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 53 additions & 36 deletions src/transformers/pipelines/table_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,10 @@ def __init__(self, args_parser=TableQuestionAnsweringArgumentHandler(), *args, *
else MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
)

self.aggregate = bool(getattr(self.model.config, "aggregation_labels")) and bool(
getattr(self.model.config, "num_aggregation_labels")
self.aggregate = bool(getattr(self.model.config, "aggregation_labels", None)) and bool(
getattr(self.model.config, "num_aggregation_labels", None)
)
self.type = "tapas" if hasattr(self.model.config, "aggregation_labels") else None
Copy link
Contributor

@NielsRogge NielsRogge Apr 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we can leverage model.config.is_encoder_decoder here (which is True for TAPEX but False for TAPAS)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually I would agree that reusing more "general" concepts is better.

The thing I want to avoid is treating not encoder-decoder as tapas. currently the code for Tapas is super highly specific, so I'd rather isolate tapas than isolate tapex if you want.

Ultimately both are OK tbh.


def batch_inference(self, **inputs):
return self.model(**inputs)
Expand Down Expand Up @@ -335,7 +336,13 @@ def _sanitize_parameters(self, sequential=None, padding=None, truncation=None, *
forward_params["sequential"] = sequential
return preprocess_params, forward_params, {}

def preprocess(self, pipeline_input, sequential=None, padding=True, truncation="drop_rows_to_fit"):
def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None):
if truncation is None:
if self.type == "tapas":
truncation = "drop_rows_to_fit"
else:
truncation = "do_not_truncate"
Comment on lines +340 to +344
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TAPEX also supports the "drop_rows_to_fit" truncation strategy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tokenizer said it didn't and gave me the regular list.

Is the tokenizer that gets resolved a BartTokenizer maybe ? (I have seen BartTokenizer in some random error log, I discarded it, but maybe ti's a True bug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("microsoft/tapex-large-finetuned-wtq")
tokenizer.encode(table, truncation="drop_rows_to_fit")
ValueError: drop_rows_to_fit is not a valid TruncationStrategy, please select one of ['only_first', 'only_second', 'longest_first', 'do_not_truncate']

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reporting, investigating.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: you can continue with this, TAPEX doesn't support drop_rows_to_fit in a similar manner as TAPAS does.


table, query = pipeline_input["table"], pipeline_input["query"]
if table.empty:
raise ValueError("table is empty")
Expand All @@ -347,45 +354,55 @@ def preprocess(self, pipeline_input, sequential=None, padding=True, truncation="

def _forward(self, model_inputs, sequential=False):
table = model_inputs.pop("table")
outputs = self.sequential_inference(**model_inputs) if sequential else self.batch_inference(**model_inputs)

if self.type == "tapas":
if sequential:
outputs = self.sequential_inference(**model_inputs)
else:
outputs = self.batch_inference(**model_inputs)
else:
outputs = self.model.generate(**model_inputs)
model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
return model_outputs

def postprocess(self, model_outputs):
inputs = model_outputs["model_inputs"]
table = model_outputs["table"]
outputs = model_outputs["outputs"]
if self.aggregate:
logits, logits_agg = outputs[:2]
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits, logits_agg)
answer_coordinates_batch, agg_predictions = predictions
aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)}

no_agg_label_index = self.model.config.no_aggregation_label_index
aggregators_prefix = {
i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index
}
if self.type == "tapas":
if self.aggregate:
logits, logits_agg = outputs[:2]
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits, logits_agg)
answer_coordinates_batch, agg_predictions = predictions
aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)}

no_agg_label_index = self.model.config.no_aggregation_label_index
aggregators_prefix = {
i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index
}
else:
logits = outputs[0]
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits)
answer_coordinates_batch = predictions[0]
aggregators = {}
aggregators_prefix = {}
answers = []
for index, coordinates in enumerate(answer_coordinates_batch):
cells = [table.iat[coordinate] for coordinate in coordinates]
aggregator = aggregators.get(index, "")
aggregator_prefix = aggregators_prefix.get(index, "")
answer = {
"answer": aggregator_prefix + ", ".join(cells),
"coordinates": coordinates,
"cells": [table.iat[coordinate] for coordinate in coordinates],
}
if aggregator:
answer["aggregator"] = aggregator

answers.append(answer)
if len(answer) == 0:
raise PipelineException("Empty answer")
else:
logits = outputs[0]
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits)
answer_coordinates_batch = predictions[0]
aggregators = {}
aggregators_prefix = {}

answers = []
for index, coordinates in enumerate(answer_coordinates_batch):
cells = [table.iat[coordinate] for coordinate in coordinates]
aggregator = aggregators.get(index, "")
aggregator_prefix = aggregators_prefix.get(index, "")
answer = {
"answer": aggregator_prefix + ", ".join(cells),
"coordinates": coordinates,
"cells": [table.iat[coordinate] for coordinate in coordinates],
}
if aggregator:
answer["aggregator"] = aggregator

answers.append(answer)
if len(answer) == 0:
raise PipelineException("Empty answer")
answers = [{"answer": answer} for answer in self.tokenizer.batch_decode(outputs, skip_special_tokens=True)]

return answers if len(answers) > 1 else answers[0]
28 changes: 28 additions & 0 deletions tests/pipelines/test_pipelines_table_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,3 +632,31 @@ def test_integration_sqa_tf(self):
{"answer": "28 november 1967", "coordinates": [(2, 3)], "cells": ["28 november 1967"]},
]
self.assertListEqual(results, expected_results)

@slow
@require_torch
def test_large_model_pt_tapex(self):
model_id = "microsoft/tapex-large-finetuned-wtq"
table_querier = pipeline(
"table-question-answering",
model=model_id,
)
data = {
"Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
"Age": ["56", "45", "59"],
"Number of movies": ["87", "53", "69"],
"Date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
}
queries = [
"How many movies has George Clooney played in?",
"How old is Mr Clooney ?",
"What's the date of birth of Leonardo ?",
]
results = table_querier(data, queries, sequential=True)

expected_results = [
{"answer": " 69"},
{"answer": " 59"},
{"answer": " 10 june 1996"},
]
self.assertListEqual(results, expected_results)