DonKane58 commited on
Commit
af4e1a3
1 Parent(s): f60ecb3

Update app/tapas.py

Browse files
Files changed (1) hide show
  1. app/tapas.py +99 -0
app/tapas.py CHANGED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering
2
+ import pandas as pd
3
+ import re
4
+
5
+ p = re.compile('\d+(\.\d+)?')
6
+
7
+ def load_model_and_tokenizer():
8
+ """
9
+ Load
10
+ """
11
+ tokenizer = AutoTokenizer.from_pretrained("Meena/table-question-answering-tapas")
12
+ model = AutoModelForTableQuestionAnswering.from_pretrained("Meena/table-question-answering-tapas")
13
+
14
+ # Return tokenizer and model
15
+ return tokenizer, model
16
+
17
+
18
+ def prepare_inputs(table, queries, tokenizer):
19
+ """
20
+ Convert dictionary into data frame and tokenize inputs given queries.
21
+ """
22
+ table = table.astype('str').head(100)
23
+ inputs = tokenizer(table=table, queries=queries, padding='max_length', return_tensors="pt")
24
+ return table, inputs
25
+
26
+
27
+ def generate_predictions(inputs, model, tokenizer):
28
+ """
29
+ Generate predictions for some tokenized input.
30
+ """
31
+ # Generate model results
32
+ outputs = model(**inputs)
33
+
34
+ # Convert logit outputs into predictions for table cells and aggregation operators
35
+ predicted_table_cell_coords, predicted_aggregation_operators = tokenizer.convert_logits_to_predictions(
36
+ inputs,
37
+ outputs.logits.detach(),
38
+ outputs.logits_aggregation.detach()
39
+ )
40
+
41
+ # Return values
42
+ return predicted_table_cell_coords, predicted_aggregation_operators
43
+
44
+ def postprocess_predictions(predicted_aggregation_operators, predicted_table_cell_coords, table):
45
+ """
46
+ Compute the predicted operation and nicely structure the answers.
47
+ """
48
+ # Process predicted aggregation operators
49
+ aggregation_operators = {0: "NONE", 1: "SUM", 2: "AVERAGE", 3:"COUNT"}
50
+ aggregation_predictions_string = [aggregation_operators[x] for x in predicted_aggregation_operators]
51
+ # Process predicted table cell coordinates
52
+ answers = []
53
+ for agg, coordinates in zip(predicted_aggregation_operators, predicted_table_cell_coords):
54
+ if len(coordinates) == 1:
55
+ # 1 cell
56
+ answers.append(table.iat[coordinates[0]])
57
+ else:
58
+ # > 1 cell
59
+ cell_values = []
60
+ for coordinate in coordinates:
61
+ cell_values.append(table.iat[coordinate])
62
+ answers.append(", ".join(cell_values))
63
+
64
+ # Return values
65
+ return aggregation_predictions_string, answers
66
+
67
+
68
+ def show_answers(queries, answers, aggregation_predictions_string):
69
+ """
70
+ Visualize the postprocessed answers.
71
+ """
72
+ agg = {"NONE": lambda x: x, "SUM" : lambda x: sum(x), "AVERAGE": lambda x: (sum(x) / len(x)), "COUNT": lambda x: len(x)}
73
+ results = []
74
+ for query, answer, predicted_agg in zip(queries, answers, aggregation_predictions_string):
75
+ print(query)
76
+ if predicted_agg == "NONE":
77
+ print("Predicted answer: " + answer)
78
+ else:
79
+ if all([not p.match(val) == None for val in answer.split(', ')]):
80
+ # print("Predicted answer: " + predicted_agg + "(" + answer + ") = " + str(agg[predicted_agg](list(map(float, answer.split(','))))))
81
+ result = str(agg[predicted_agg](list(map(float, answer.split(',')))))
82
+ elif predicted_agg == "COUNT":
83
+ # print("Predicted answer: " + predicted_agg + "(" + answer + ") = " + str(agg[predicted_agg](answer.split(','))))
84
+ result = str(agg[predicted_agg](answer.split(',')))
85
+ else:
86
+ result = predicted_agg + " > " + answer
87
+ results.append(result)
88
+ return results
89
+
90
+ def execute_query(query, table):
91
+ """
92
+ Invoke the TAPAS model.
93
+ """
94
+ queries = [query]
95
+ tokenizer, model = load_model_and_tokenizer()
96
+ table, inputs = prepare_inputs(table, queries, tokenizer)
97
+ predicted_table_cell_coords, predicted_aggregation_operators = generate_predictions(inputs, model, tokenizer)
98
+ aggregation_predictions_string, answers = postprocess_predictions(predicted_aggregation_operators, predicted_table_cell_coords, table)
99
+ return show_answers(queries, answers, aggregation_predictions_string)