File size: 3,923 Bytes
5af1bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from huggingface_hub import from_pretrained_keras
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from tensorflow import keras
import gradio as gr


def make_bert_preprocessing_model(sentence_features, seq_length=128):
    """Returns Model mapping string features to BERT inputs.



      Args:

        sentence_features: A list with the names of string-valued features.

        seq_length: An integer that defines the sequence length of BERT inputs.



      Returns:

        A Keras Model that can be called on a list or dict of string Tensors

        (with the order or names, resp., given by sentence_features) and

        returns a dict of tensors for input to BERT.

    """
    
    input_segments = [
        tf.keras.layers.Input(shape=(), dtype=tf.string, name=ft)
        for ft in sentence_features
    ]
    
    # tokenize the text to word pieces
    bert_preprocess = hub.load(bert_preprocess_path)
    tokenizer = hub.KerasLayer(bert_preprocess.tokenize,
                              name="tokenizer")
    
    segments = [tokenizer(s) for s in input_segments]
    
    truncated_segments = segments
    
    packer = hub.KerasLayer(bert_preprocess.bert_pack_inputs,
                           arguments=dict(seq_length=seq_length),
                           name="packer")
    model_inputs = packer(truncated_segments)
    return keras.Model(input_segments, model_inputs)


def preprocess_image(image_path, resize):
    extension = tf.strings.split(image_path)[-1]
    
    image = tf.io.read_file(image_path)
    if extension == b"jpg":
        image = tf.image.decode_jpeg(image, 3)
    else:
        image = tf.image.decode_png(image, 3)
        
    image = tf.image.resize(image, resize)
    return image

def preprocess_text(text_1, text_2):
    
    text_1 = tf.convert_to_tensor([text_1])
    text_2 = tf.convert_to_tensor([text_2])
    
    output = bert_preprocess_model([text_1, text_2])
    
    output = {feature: tf.squeeze(output[feature]) for feature in bert_input_features}
    
    return output

def preprocess_text_and_image(sample, resize):
    
    image_1 = preprocess_image(sample['image_1_path'], resize)
    image_2 = preprocess_image(sample['image_2_path'], resize)
    
    text = preprocess_text(sample['text_1'], sample['text_2'])
    
    return {"image_1": image_1, "image_2": image_2, "text": text}


def classify_info(image_1, text_1, image_2, text_2):

    sample = dict()
    sample['image_1_path'] = image_1
    sample['image_2_path'] = image_2
    sample['text_1'] = text_1
    sample['text_2'] = text_2

    dataframe = pd.DataFrame(sample, index=[0])

    ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), [0]))
    ds = ds.map(lambda x, y: (preprocess_text_and_image(x, resize), y)).cache()
    batch_size = 1
    auto = tf.data.AUTOTUNE
    ds = ds.batch(batch_size).prefetch(auto)
    output = model.predict(ds)

    label = np.argmax(output)
    return labels[label]


model = from_pretrained_keras("keras-io/multimodal-entailment")
resize = (128, 128)
bert_input_features = ["input_word_ids", "input_type_ids", "input_mask"]
bert_model_path = ("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1")
bert_preprocess_path = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
bert_preprocess_model = make_bert_preprocessing_model(['text_1', 'text_2'])

labels = {0: "Contradictory", 1: "Implies", 2: "No Entailment"}

resize = (128, 128)
image_1 = gr.inputs.Image(type="filepath")
image_2 = gr.inputs.Image(type="filepath")

text_1 = gr.inputs.Textbox(lines=5)
text_2 = gr.inputs.Textbox(lines=5)

label = gr.outputs.Label()

iface = gr.Interface(classify_info, 
	inputs=[image_1, text_1, image_2, text_2],outputs=label)

iface.launch()