Josephgflowers
commited on
Commit
•
c6cdb52
1
Parent(s):
c1e5d75
Upload tinyllama_agent_cinder_txtai-rag.py
Browse files
tinyllama_agent_cinder_txtai-rag.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import requests
|
3 |
+
import os
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchsummary import summary
|
8 |
+
from accelerate import dispatch_model, infer_auto_device_map
|
9 |
+
from txtai import Embeddings
|
10 |
+
from txtai.pipeline import LLM
|
11 |
+
#pip3 install git+https://github.com/neuml/txtai#egg=txtai[pipeline-llm]
|
12 |
+
|
13 |
+
|
14 |
+
# Wikipedia Embeddings Database
|
15 |
+
embeddings = Embeddings()
|
16 |
+
embeddings.load(provider="huggingface-hub", container="neuml/txtai-wikipedia")
|
17 |
+
|
18 |
+
#os.environ['OMP_NUM_THREADS'] = '6'
|
19 |
+
|
20 |
+
#
|
21 |
+
#DuckDuckGo
|
22 |
+
#
|
23 |
+
def query_duckduckgo(query):
|
24 |
+
"""Query DuckDuckGo API for a given search term and return the results."""
|
25 |
+
url = "https://api.duckduckgo.com/"
|
26 |
+
params = {
|
27 |
+
'q': query,
|
28 |
+
'format': 'json',
|
29 |
+
'pretty': '1',
|
30 |
+
'no_html': '1'
|
31 |
+
}
|
32 |
+
|
33 |
+
try:
|
34 |
+
response = requests.get(url, params=params)
|
35 |
+
response.raise_for_status() # Raises an HTTPError for bad responses
|
36 |
+
return response.json()
|
37 |
+
except requests.RequestException as e:
|
38 |
+
print(f"An error occurred: {e}")
|
39 |
+
return None
|
40 |
+
|
41 |
+
def handle_query(user_input):
|
42 |
+
"""Process user input and display the answer from DuckDuckGo."""
|
43 |
+
result = query_duckduckgo(user_input)
|
44 |
+
if result and 'AbstractText' in result and result['AbstractText']:
|
45 |
+
print(result['AbstractText'])
|
46 |
+
else:
|
47 |
+
print("DuckDuck Go failed. Going to Wiki.")
|
48 |
+
result ="\n".join([x["text"] for x in embeddings.search(user_input)])
|
49 |
+
print("Restults from Wiki: \n",result)
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
# Load model and tokenizer
|
55 |
+
model_path = "Josephgflowers/TinyLlama-Cinder-Agent-Rag"#
|
56 |
+
# Define the device (CPU or GPU)
|
57 |
+
#device = torch.device("cuda")
|
58 |
+
device = torch.device("cpu")
|
59 |
+
model = AutoModelForCausalLM.from_pretrained(model_path,ignore_mismatched_sizes=True).to(device)
|
60 |
+
|
61 |
+
print(model)
|
62 |
+
total_params = sum(p.numel() for p in model.parameters())
|
63 |
+
print("Total number of parameters: ", total_params)
|
64 |
+
|
65 |
+
sequence_length = 2048 # or whatever your specific sequence length is
|
66 |
+
#embedding_size = 2048 # as per your model's definition
|
67 |
+
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
69 |
+
stop_token =2 #3556 </ #2 #128247
|
70 |
+
#'</s>' 2
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
def chat_with_model(prompt_text, stop_token, model, tokenizer):
|
75 |
+
# Encode the prompt text
|
76 |
+
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt").to(device)
|
77 |
+
|
78 |
+
# Generate response
|
79 |
+
output_sequences = model.generate(
|
80 |
+
input_ids=encoded_prompt,
|
81 |
+
#max_length=len(encoded_prompt[0]) + 256,
|
82 |
+
max_new_tokens=256,
|
83 |
+
temperature=0.1,
|
84 |
+
repetition_penalty=1.2,
|
85 |
+
top_k=20,
|
86 |
+
top_p=0.9,
|
87 |
+
do_sample=True,
|
88 |
+
num_return_sequences=1,
|
89 |
+
eos_token_id=stop_token
|
90 |
+
)
|
91 |
+
|
92 |
+
# Decode the generated sequence
|
93 |
+
generated_sequence = output_sequences[0].tolist()
|
94 |
+
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
95 |
+
response_text = text[len(prompt_text):].strip() # Extract only the response text
|
96 |
+
#response_text = response_text.replace("<s>","").replace("</s>","")
|
97 |
+
return response_text
|
98 |
+
|
99 |
+
# Initialize conversation history
|
100 |
+
|
101 |
+
conversation_history = ''#'<s>\n<|system|>\nYou are a helpful assistant.</s>\n'#'<s>\n<|system|>\nYou are a
|
102 |
+
|
103 |
+
# Get user's preference for input mode and character name
|
104 |
+
input_mode = 'text' ##input("Enter 'text' for text input or 'speech' for speech input: ").lower()
|
105 |
+
character_name = '<|user|>' # input("Enter your character name (USER, JONAH, JOSEPH, KIMBERLY, etc.): ")
|
106 |
+
|
107 |
+
|
108 |
+
#
|
109 |
+
#handle_query(user_input)
|
110 |
+
# Chat loop
|
111 |
+
num_chat = 1
|
112 |
+
while num_chat <= 20:
|
113 |
+
question = input(f"{character_name}: ")
|
114 |
+
user_input = question # Get text input from user
|
115 |
+
#context = "\n".join([x["text"] for x in embeddings.search(question)])
|
116 |
+
context= handle_query(user_input)
|
117 |
+
#print('History: '+ conversation_history)
|
118 |
+
prompt_text = f"""
|
119 |
+
<s>
|
120 |
+
<|system|>
|
121 |
+
You will be given documentation as context to answer a users question. You are an expert at summarization. Pay close attention to the key concepts. Use only information from the Context in your answer.
|
122 |
+
</s>
|
123 |
+
<|data|>
|
124 |
+
Context:
|
125 |
+
{context}
|
126 |
+
-Use only the above context to answer the question.
|
127 |
+
</s>
|
128 |
+
<|user|>
|
129 |
+
Here is information on "{question}". Extract only the above information into topic, category, keywords, and summary formatted in JSON. Think through the most critical information to provide then respond with the JSON object of topic, category, keywords, and summary.
|
130 |
+
</s>
|
131 |
+
<|assistant|>
|
132 |
+
|
133 |
+
"""
|
134 |
+
#topic, category, keywords, and summary formatted in JSON. Think through the most critical information to provide then respond with the JSON object of topic, category, keywords, and summary
|
135 |
+
#Here is information on "{question}". Extract only the above information into topic, category, keywords, and summary formatted in JSON. Think through the most critical information to provide then respond with the JSON object of topic, category, keywords, and summary
|
136 |
+
|
137 |
+
#Use only the documentation provided to answer this question: {question}
|
138 |
+
|
139 |
+
|
140 |
+
response_text = chat_with_model(prompt_text, stop_token, model, tokenizer)
|
141 |
+
response_text = response_text.replace('<s>','')
|
142 |
+
#print('Response: '+ context)
|
143 |
+
|
144 |
+
# Extract assistant's response from the response_text
|
145 |
+
response_text = response_text.split('</s>\n', 1)[0] # Extract the first message from the assistant
|
146 |
+
|
147 |
+
print(f"\n______________________________________________\n\nAssistant: {response_text}")
|
148 |
+
|
149 |
+
# Update conversation history
|
150 |
+
conversation_history += f"{prompt_text}{response_text}</s>\n"
|
151 |
+
if len(conversation_history) > 2048:
|
152 |
+
conversation_history = conversation_history[1024:]
|
153 |
+
else:
|
154 |
+
conversation_history = conversation_history
|
155 |
+
|
156 |
+
num_chat += 1
|
157 |
+
|
158 |
+
|