File size: 2,294 Bytes
8e4e15e
 
 
 
 
cee3943
 
8e4e15e
 
dfaf13a
8e4e15e
 
dfaf13a
8e4e15e
dfaf13a
8e4e15e
 
dfaf13a
8e4e15e
 
 
 
 
 
 
 
 
 
 
 
 
cee3943
8e4e15e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cee3943
 
 
 
 
 
 
 
 
 
 
 
 
 
8e4e15e
cee3943
 
fd0862d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import tensorflow as tf
import numpy as np
import pickle
import gradio as gr
from tensorflow.keras.preprocessing import sequence
import random
import time

# Load the encoder model
enc_model = tf.keras.models.load_model('./encoder_model.h5')

# Load the decoder model
dec_model = tf.keras.models.load_model('./decoder_model.h5')

with open('./tokenizer.pkl', 'rb') as f:
    tokenizer = pickle.load(f)

with open('./tokenizer_params (1).pkl', 'rb') as f:
    tokenizer_params = pickle.load(f)

maxlen_questions = tokenizer_params["maxlen_questions"]
maxlen_answers = tokenizer_params["maxlen_answers"]

def str_to_tokens(sentence: str):
    words = sentence.lower().split()
    tokens_list = list()

    for word in words:
        tokens_list.append(tokenizer.word_index[word])
    return sequence.pad_sequences([tokens_list], maxlen=maxlen_questions, padding='post')

def chatbot_response(question, chat_history):
    states_values = enc_model.predict(str_to_tokens(question))
    empty_target_seq = np.zeros((1, 1))
    empty_target_seq[0, 0] = tokenizer.word_index['start']
    stop_condition = False
    decoded_translation = ''

    while not stop_condition:
        dec_outputs, h, c = dec_model.predict([empty_target_seq] + states_values)
        sampled_word_index = np.argmax(dec_outputs[0, -1, :])
        sampled_word = None

        for word, index in tokenizer.word_index.items():
            if sampled_word_index == index:
                decoded_translation += f' {word}'
                sampled_word = word

        if sampled_word == 'end' or len(decoded_translation.split()) > maxlen_answers:
            stop_condition = True

        empty_target_seq = np.zeros((1, 1))
        empty_target_seq[0, 0] = sampled_word_index
        states_values = [h, c]
    
    decoded_translation = decoded_translation.split(' end')[0]
    bot_message = decoded_translation
    
    chat_history.append((question, bot_message))
    time.sleep(2)
    return "", chat_history

# Gradio Blocks Interface
with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.ClearButton([msg, chatbot])

    def respond(message, chat_history):
        return chatbot_response(message, chat_history)

    msg.submit(respond, [msg, chatbot], [msg, chatbot])

demo.launch()