harshvardhan96's picture
Update app.py
fd0862d
raw
history blame
2.29 kB
import tensorflow as tf
import numpy as np
import pickle
import gradio as gr
from tensorflow.keras.preprocessing import sequence
import random
import time
# Load the encoder model
enc_model = tf.keras.models.load_model('./encoder_model.h5')
# Load the decoder model
dec_model = tf.keras.models.load_model('./decoder_model.h5')
with open('./tokenizer.pkl', 'rb') as f:
tokenizer = pickle.load(f)
with open('./tokenizer_params (1).pkl', 'rb') as f:
tokenizer_params = pickle.load(f)
maxlen_questions = tokenizer_params["maxlen_questions"]
maxlen_answers = tokenizer_params["maxlen_answers"]
def str_to_tokens(sentence: str):
words = sentence.lower().split()
tokens_list = list()
for word in words:
tokens_list.append(tokenizer.word_index[word])
return sequence.pad_sequences([tokens_list], maxlen=maxlen_questions, padding='post')
def chatbot_response(question, chat_history):
states_values = enc_model.predict(str_to_tokens(question))
empty_target_seq = np.zeros((1, 1))
empty_target_seq[0, 0] = tokenizer.word_index['start']
stop_condition = False
decoded_translation = ''
while not stop_condition:
dec_outputs, h, c = dec_model.predict([empty_target_seq] + states_values)
sampled_word_index = np.argmax(dec_outputs[0, -1, :])
sampled_word = None
for word, index in tokenizer.word_index.items():
if sampled_word_index == index:
decoded_translation += f' {word}'
sampled_word = word
if sampled_word == 'end' or len(decoded_translation.split()) > maxlen_answers:
stop_condition = True
empty_target_seq = np.zeros((1, 1))
empty_target_seq[0, 0] = sampled_word_index
states_values = [h, c]
decoded_translation = decoded_translation.split(' end')[0]
bot_message = decoded_translation
chat_history.append((question, bot_message))
time.sleep(2)
return "", chat_history
# Gradio Blocks Interface
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.ClearButton([msg, chatbot])
def respond(message, chat_history):
return chatbot_response(message, chat_history)
msg.submit(respond, [msg, chatbot], [msg, chatbot])
demo.launch()