Josephgflowers commited on
Commit
c6cdb52
1 Parent(s): c1e5d75

Upload tinyllama_agent_cinder_txtai-rag.py

Browse files
Files changed (1) hide show
  1. tinyllama_agent_cinder_txtai-rag.py +158 -0
tinyllama_agent_cinder_txtai-rag.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import requests
3
+ import os
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchsummary import summary
8
+ from accelerate import dispatch_model, infer_auto_device_map
9
+ from txtai import Embeddings
10
+ from txtai.pipeline import LLM
11
+ #pip3 install git+https://github.com/neuml/txtai#egg=txtai[pipeline-llm]
12
+
13
+
14
+ # Wikipedia Embeddings Database
15
+ embeddings = Embeddings()
16
+ embeddings.load(provider="huggingface-hub", container="neuml/txtai-wikipedia")
17
+
18
+ #os.environ['OMP_NUM_THREADS'] = '6'
19
+
20
+ #
21
+ #DuckDuckGo
22
+ #
23
+ def query_duckduckgo(query):
24
+ """Query DuckDuckGo API for a given search term and return the results."""
25
+ url = "https://api.duckduckgo.com/"
26
+ params = {
27
+ 'q': query,
28
+ 'format': 'json',
29
+ 'pretty': '1',
30
+ 'no_html': '1'
31
+ }
32
+
33
+ try:
34
+ response = requests.get(url, params=params)
35
+ response.raise_for_status() # Raises an HTTPError for bad responses
36
+ return response.json()
37
+ except requests.RequestException as e:
38
+ print(f"An error occurred: {e}")
39
+ return None
40
+
41
+ def handle_query(user_input):
42
+ """Process user input and display the answer from DuckDuckGo."""
43
+ result = query_duckduckgo(user_input)
44
+ if result and 'AbstractText' in result and result['AbstractText']:
45
+ print(result['AbstractText'])
46
+ else:
47
+ print("DuckDuck Go failed. Going to Wiki.")
48
+ result ="\n".join([x["text"] for x in embeddings.search(user_input)])
49
+ print("Restults from Wiki: \n",result)
50
+
51
+
52
+
53
+
54
+ # Load model and tokenizer
55
+ model_path = "Josephgflowers/TinyLlama-Cinder-Agent-Rag"#
56
+ # Define the device (CPU or GPU)
57
+ #device = torch.device("cuda")
58
+ device = torch.device("cpu")
59
+ model = AutoModelForCausalLM.from_pretrained(model_path,ignore_mismatched_sizes=True).to(device)
60
+
61
+ print(model)
62
+ total_params = sum(p.numel() for p in model.parameters())
63
+ print("Total number of parameters: ", total_params)
64
+
65
+ sequence_length = 2048 # or whatever your specific sequence length is
66
+ #embedding_size = 2048 # as per your model's definition
67
+
68
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
69
+ stop_token =2 #3556 </ #2 #128247
70
+ #'</s>' 2
71
+
72
+
73
+
74
+ def chat_with_model(prompt_text, stop_token, model, tokenizer):
75
+ # Encode the prompt text
76
+ encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt").to(device)
77
+
78
+ # Generate response
79
+ output_sequences = model.generate(
80
+ input_ids=encoded_prompt,
81
+ #max_length=len(encoded_prompt[0]) + 256,
82
+ max_new_tokens=256,
83
+ temperature=0.1,
84
+ repetition_penalty=1.2,
85
+ top_k=20,
86
+ top_p=0.9,
87
+ do_sample=True,
88
+ num_return_sequences=1,
89
+ eos_token_id=stop_token
90
+ )
91
+
92
+ # Decode the generated sequence
93
+ generated_sequence = output_sequences[0].tolist()
94
+ text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
95
+ response_text = text[len(prompt_text):].strip() # Extract only the response text
96
+ #response_text = response_text.replace("<s>","").replace("</s>","")
97
+ return response_text
98
+
99
+ # Initialize conversation history
100
+
101
+ conversation_history = ''#'<s>\n<|system|>\nYou are a helpful assistant.</s>\n'#'<s>\n<|system|>\nYou are a
102
+
103
+ # Get user's preference for input mode and character name
104
+ input_mode = 'text' ##input("Enter 'text' for text input or 'speech' for speech input: ").lower()
105
+ character_name = '<|user|>' # input("Enter your character name (USER, JONAH, JOSEPH, KIMBERLY, etc.): ")
106
+
107
+
108
+ #
109
+ #handle_query(user_input)
110
+ # Chat loop
111
+ num_chat = 1
112
+ while num_chat <= 20:
113
+ question = input(f"{character_name}: ")
114
+ user_input = question # Get text input from user
115
+ #context = "\n".join([x["text"] for x in embeddings.search(question)])
116
+ context= handle_query(user_input)
117
+ #print('History: '+ conversation_history)
118
+ prompt_text = f"""
119
+ <s>
120
+ <|system|>
121
+ You will be given documentation as context to answer a users question. You are an expert at summarization. Pay close attention to the key concepts. Use only information from the Context in your answer.
122
+ </s>
123
+ <|data|>
124
+ Context:
125
+ {context}
126
+ -Use only the above context to answer the question.
127
+ </s>
128
+ <|user|>
129
+ Here is information on "{question}". Extract only the above information into topic, category, keywords, and summary formatted in JSON. Think through the most critical information to provide then respond with the JSON object of topic, category, keywords, and summary.
130
+ </s>
131
+ <|assistant|>
132
+
133
+ """
134
+ #topic, category, keywords, and summary formatted in JSON. Think through the most critical information to provide then respond with the JSON object of topic, category, keywords, and summary
135
+ #Here is information on "{question}". Extract only the above information into topic, category, keywords, and summary formatted in JSON. Think through the most critical information to provide then respond with the JSON object of topic, category, keywords, and summary
136
+
137
+ #Use only the documentation provided to answer this question: {question}
138
+
139
+
140
+ response_text = chat_with_model(prompt_text, stop_token, model, tokenizer)
141
+ response_text = response_text.replace('<s>','')
142
+ #print('Response: '+ context)
143
+
144
+ # Extract assistant's response from the response_text
145
+ response_text = response_text.split('</s>\n', 1)[0] # Extract the first message from the assistant
146
+
147
+ print(f"\n______________________________________________\n\nAssistant: {response_text}")
148
+
149
+ # Update conversation history
150
+ conversation_history += f"{prompt_text}{response_text}</s>\n"
151
+ if len(conversation_history) > 2048:
152
+ conversation_history = conversation_history[1024:]
153
+ else:
154
+ conversation_history = conversation_history
155
+
156
+ num_chat += 1
157
+
158
+