-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabling Tapex
in table question answering pipeline.
#16663
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,9 +105,10 @@ def __init__(self, args_parser=TableQuestionAnsweringArgumentHandler(), *args, * | |
else MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING | ||
) | ||
|
||
self.aggregate = bool(getattr(self.model.config, "aggregation_labels")) and bool( | ||
getattr(self.model.config, "num_aggregation_labels") | ||
self.aggregate = bool(getattr(self.model.config, "aggregation_labels", None)) and bool( | ||
getattr(self.model.config, "num_aggregation_labels", None) | ||
) | ||
self.type = "tapas" if hasattr(self.model.config, "aggregation_labels") else None | ||
|
||
def batch_inference(self, **inputs): | ||
return self.model(**inputs) | ||
|
@@ -335,7 +336,13 @@ def _sanitize_parameters(self, sequential=None, padding=None, truncation=None, * | |
forward_params["sequential"] = sequential | ||
return preprocess_params, forward_params, {} | ||
|
||
def preprocess(self, pipeline_input, sequential=None, padding=True, truncation="drop_rows_to_fit"): | ||
def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None): | ||
if truncation is None: | ||
if self.type == "tapas": | ||
truncation = "drop_rows_to_fit" | ||
else: | ||
truncation = "do_not_truncate" | ||
Comment on lines
+340
to
+344
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TAPEX also supports the "drop_rows_to_fit" truncation strategy. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/tapex-large-finetuned-wtq")
tokenizer.encode(table, truncation="drop_rows_to_fit")
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for reporting, investigating. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update: you can continue with this, TAPEX doesn't support |
||
|
||
table, query = pipeline_input["table"], pipeline_input["query"] | ||
if table.empty: | ||
raise ValueError("table is empty") | ||
|
@@ -347,45 +354,55 @@ def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=" | |
|
||
def _forward(self, model_inputs, sequential=False): | ||
table = model_inputs.pop("table") | ||
outputs = self.sequential_inference(**model_inputs) if sequential else self.batch_inference(**model_inputs) | ||
|
||
if self.type == "tapas": | ||
if sequential: | ||
outputs = self.sequential_inference(**model_inputs) | ||
else: | ||
outputs = self.batch_inference(**model_inputs) | ||
else: | ||
outputs = self.model.generate(**model_inputs) | ||
model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs} | ||
return model_outputs | ||
|
||
def postprocess(self, model_outputs): | ||
inputs = model_outputs["model_inputs"] | ||
table = model_outputs["table"] | ||
outputs = model_outputs["outputs"] | ||
if self.aggregate: | ||
logits, logits_agg = outputs[:2] | ||
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits, logits_agg) | ||
answer_coordinates_batch, agg_predictions = predictions | ||
aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)} | ||
|
||
no_agg_label_index = self.model.config.no_aggregation_label_index | ||
aggregators_prefix = { | ||
i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index | ||
} | ||
if self.type == "tapas": | ||
if self.aggregate: | ||
logits, logits_agg = outputs[:2] | ||
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits, logits_agg) | ||
answer_coordinates_batch, agg_predictions = predictions | ||
aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)} | ||
|
||
no_agg_label_index = self.model.config.no_aggregation_label_index | ||
aggregators_prefix = { | ||
i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index | ||
} | ||
else: | ||
logits = outputs[0] | ||
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits) | ||
answer_coordinates_batch = predictions[0] | ||
aggregators = {} | ||
aggregators_prefix = {} | ||
answers = [] | ||
for index, coordinates in enumerate(answer_coordinates_batch): | ||
cells = [table.iat[coordinate] for coordinate in coordinates] | ||
aggregator = aggregators.get(index, "") | ||
aggregator_prefix = aggregators_prefix.get(index, "") | ||
answer = { | ||
"answer": aggregator_prefix + ", ".join(cells), | ||
"coordinates": coordinates, | ||
"cells": [table.iat[coordinate] for coordinate in coordinates], | ||
} | ||
if aggregator: | ||
answer["aggregator"] = aggregator | ||
|
||
answers.append(answer) | ||
if len(answer) == 0: | ||
raise PipelineException("Empty answer") | ||
else: | ||
logits = outputs[0] | ||
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits) | ||
answer_coordinates_batch = predictions[0] | ||
aggregators = {} | ||
aggregators_prefix = {} | ||
|
||
answers = [] | ||
for index, coordinates in enumerate(answer_coordinates_batch): | ||
cells = [table.iat[coordinate] for coordinate in coordinates] | ||
aggregator = aggregators.get(index, "") | ||
aggregator_prefix = aggregators_prefix.get(index, "") | ||
answer = { | ||
"answer": aggregator_prefix + ", ".join(cells), | ||
"coordinates": coordinates, | ||
"cells": [table.iat[coordinate] for coordinate in coordinates], | ||
} | ||
if aggregator: | ||
answer["aggregator"] = aggregator | ||
|
||
answers.append(answer) | ||
if len(answer) == 0: | ||
raise PipelineException("Empty answer") | ||
answers = [{"answer": answer} for answer in self.tokenizer.batch_decode(outputs, skip_special_tokens=True)] | ||
|
||
return answers if len(answers) > 1 else answers[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if we can leverage
model.config.is_encoder_decoder
here (which is True for TAPEX but False for TAPAS)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually I would agree that reusing more "general" concepts is better.
The thing I want to avoid is treating
not encoder-decoder
astapas
. currently the code for Tapas is super highly specific, so I'd rather isolatetapas
than isolatetapex
if you want.Ultimately both are OK tbh.