Update app.py
Browse files
app.py
CHANGED
@@ -53,7 +53,7 @@ def load_and_store_models(model_names):
|
|
53 |
generated_text = model.generate(tokenizer.encode(sample_text, return_tensors="pt"), max_length=50)
|
54 |
decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)
|
55 |
store_to_redis_table(name, decoded_text)
|
56 |
-
redis_client.hset("models", name, decoded_text)
|
57 |
except Exception as e:
|
58 |
logger.error(f"Error loading model {name}: {e}")
|
59 |
|
@@ -104,45 +104,45 @@ async def index():
|
|
104 |
<script>
|
105 |
const userInput = document.getElementById('user-input');
|
106 |
|
107 |
-
userInput.addEventListener('keyup', function(event) {
|
108 |
-
if (event.key === 'Enter') {
|
109 |
event.preventDefault();
|
110 |
sendMessage();
|
111 |
-
}
|
112 |
-
});
|
113 |
|
114 |
-
function sendMessage() {
|
115 |
const userMessage = userInput.value.trim();
|
116 |
if (userMessage === '') return;
|
117 |
|
118 |
appendMessage('user', userMessage);
|
119 |
userInput.value = '';
|
120 |
|
121 |
-
fetch(`/autocomplete?q
|
122 |
.then(response => response.json())
|
123 |
-
.then(data => {
|
124 |
-
fetch(`/get_response?q
|
125 |
.then(response => response.json())
|
126 |
-
.then(data => {
|
127 |
const botMessage = data.response;
|
128 |
appendMessage('bot', botMessage);
|
129 |
-
})
|
130 |
-
.catch(error => {
|
131 |
console.error('Error:', error);
|
132 |
-
});
|
133 |
-
})
|
134 |
-
.catch(error => {
|
135 |
console.error('Error:', error);
|
136 |
-
});
|
137 |
-
}
|
138 |
|
139 |
-
function appendMessage(sender, message) {
|
140 |
const chatBox = document.getElementById('chat-box');
|
141 |
const messageElement = document.createElement('div');
|
142 |
-
messageElement.className =
|
143 |
messageElement.innerText = message;
|
144 |
chatBox.appendChild(messageElement);
|
145 |
-
}
|
146 |
</script>
|
147 |
</body>
|
148 |
</html>
|
@@ -164,7 +164,7 @@ async def autocomplete(q: str = Query(..., title='query'), background_tasks: Bac
|
|
164 |
message_history.append(('user', q))
|
165 |
|
166 |
background_tasks.add_task(generate_responses, q)
|
167 |
-
return {"status": "Processing request, please wait..."}
|
168 |
|
169 |
@app.get('/get_response')
|
170 |
async def get_response(q: str = Query(..., title='query')):
|
@@ -189,7 +189,7 @@ def generate_responses(q):
|
|
189 |
similarities = calculate_similarity(q, generated_responses)
|
190 |
most_coherent_response = generated_responses[np.argmax(similarities)]
|
191 |
store_to_redis_table(q, "\n".join(generated_responses))
|
192 |
-
redis_client.hset("responses", q, most_coherent_response)
|
193 |
else:
|
194 |
logger.warning("No valid responses generated.")
|
195 |
except Exception as e:
|
@@ -197,12 +197,12 @@ def generate_responses(q):
|
|
197 |
|
198 |
if __name__ == '__main__':
|
199 |
gpt2_models = [
|
200 |
-
"gpt2",
|
201 |
-
"gpt2-medium",
|
202 |
-
"gpt2-large",
|
203 |
"gpt2-xl"
|
204 |
]
|
205 |
-
|
206 |
programming_models = [
|
207 |
"google/bert2bert_L-24_uncased",
|
208 |
"microsoft/CodeGPT-small-java",
|
@@ -212,4 +212,4 @@ if __name__ == '__main__':
|
|
212 |
|
213 |
load_and_store_models(gpt2_models + programming_models)
|
214 |
|
215 |
-
uvicorn.run(app=app, host='0.0.0.0', port=int(os.getenv("PORT",
|
|
|
53 |
generated_text = model.generate(tokenizer.encode(sample_text, return_tensors="pt"), max_length=50)
|
54 |
decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)
|
55 |
store_to_redis_table(name, decoded_text)
|
56 |
+
redis_client.hset("models", name, decoded_text)
|
57 |
except Exception as e:
|
58 |
logger.error(f"Error loading model {name}: {e}")
|
59 |
|
|
|
104 |
<script>
|
105 |
const userInput = document.getElementById('user-input');
|
106 |
|
107 |
+
userInput.addEventListener('keyup', function(event) {{
|
108 |
+
if (event.key === 'Enter') {{
|
109 |
event.preventDefault();
|
110 |
sendMessage();
|
111 |
+
}}
|
112 |
+
}});
|
113 |
|
114 |
+
function sendMessage() {{
|
115 |
const userMessage = userInput.value.trim();
|
116 |
if (userMessage === '') return;
|
117 |
|
118 |
appendMessage('user', userMessage);
|
119 |
userInput.value = '';
|
120 |
|
121 |
+
fetch(`/autocomplete?q=` + encodeURIComponent(userMessage))
|
122 |
.then(response => response.json())
|
123 |
+
.then(data => {{
|
124 |
+
fetch(`/get_response?q=` + encodeURIComponent(userMessage))
|
125 |
.then(response => response.json())
|
126 |
+
.then(data => {{
|
127 |
const botMessage = data.response;
|
128 |
appendMessage('bot', botMessage);
|
129 |
+
}})
|
130 |
+
.catch(error => {{
|
131 |
console.error('Error:', error);
|
132 |
+
}});
|
133 |
+
}})
|
134 |
+
.catch(error => {{
|
135 |
console.error('Error:', error);
|
136 |
+
}});
|
137 |
+
}}
|
138 |
|
139 |
+
function appendMessage(sender, message) {{
|
140 |
const chatBox = document.getElementById('chat-box');
|
141 |
const messageElement = document.createElement('div');
|
142 |
+
messageElement.className = sender + '-message';
|
143 |
messageElement.innerText = message;
|
144 |
chatBox.appendChild(messageElement);
|
145 |
+
}}
|
146 |
</script>
|
147 |
</body>
|
148 |
</html>
|
|
|
164 |
message_history.append(('user', q))
|
165 |
|
166 |
background_tasks.add_task(generate_responses, q)
|
167 |
+
return {"status": "Processing request, please wait..."}
|
168 |
|
169 |
@app.get('/get_response')
|
170 |
async def get_response(q: str = Query(..., title='query')):
|
|
|
189 |
similarities = calculate_similarity(q, generated_responses)
|
190 |
most_coherent_response = generated_responses[np.argmax(similarities)]
|
191 |
store_to_redis_table(q, "\n".join(generated_responses))
|
192 |
+
redis_client.hset("responses", q, most_coherent_response)
|
193 |
else:
|
194 |
logger.warning("No valid responses generated.")
|
195 |
except Exception as e:
|
|
|
197 |
|
198 |
if __name__ == '__main__':
|
199 |
gpt2_models = [
|
200 |
+
"gpt2",
|
201 |
+
"gpt2-medium",
|
202 |
+
"gpt2-large",
|
203 |
"gpt2-xl"
|
204 |
]
|
205 |
+
|
206 |
programming_models = [
|
207 |
"google/bert2bert_L-24_uncased",
|
208 |
"microsoft/CodeGPT-small-java",
|
|
|
212 |
|
213 |
load_and_store_models(gpt2_models + programming_models)
|
214 |
|
215 |
+
uvicorn.run(app=app, host='0.0.0.0', port=int(os.getenv("PORT", 8001)))
|