Spaces:
Runtime error
Runtime error
Baskar2005
commited on
Commit
•
06b19dd
1
Parent(s):
d1a3805
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
%%writefile app.py
|
2 |
+
import os
|
3 |
+
import locale
|
4 |
+
import transformers
|
5 |
+
import torch
|
6 |
+
import json
|
7 |
+
import logging
|
8 |
+
from typing import List
|
9 |
+
from huggingface_hub import login
|
10 |
+
from huggingface_hub import InferenceClient
|
11 |
+
from langchain_openai import AzureChatOpenAI
|
12 |
+
import gradio as gr
|
13 |
+
from langchain.chains import LLMSummarizationCheckerChain
|
14 |
+
import os
|
15 |
+
import re
|
16 |
+
from typing import Set, List, Tuple
|
17 |
+
huggingface_key = os.getenv('HUGGINGFACE_KEY')
|
18 |
+
login(huggingface_key) # Huggingface api token
|
19 |
+
|
20 |
+
# Configure logging
|
21 |
+
logging.basicConfig(filename='factchecking.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
22 |
+
class FactChecking:
|
23 |
+
|
24 |
+
def __init__(self):
|
25 |
+
|
26 |
+
self.llm = AzureChatOpenAI(
|
27 |
+
azure_deployment = "ChatGPT"
|
28 |
+
)
|
29 |
+
|
30 |
+
self.client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
31 |
+
|
32 |
+
def format_prompt(self, question: str) -> str:
|
33 |
+
"""
|
34 |
+
Formats the input question into a specific structure for text generation.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
question (str): The user's question to be formatted.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
str: The formatted prompt including instructions and the question.
|
41 |
+
"""
|
42 |
+
# Combine the instruction template with the user's question
|
43 |
+
prompt = f"[INST] you are the ai assitant your task is answr for the user question[/INST]"
|
44 |
+
prompt1 = f"[INST] {question} [/INST]"
|
45 |
+
return prompt+prompt1
|
46 |
+
|
47 |
+
def mixtral_response(self,prompt, temperature=0.9, max_new_tokens=5000, top_p=0.95, repetition_penalty=1.0):
|
48 |
+
"""
|
49 |
+
Generates a response to the given prompt using text generation parameters.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
prompt (str): The user's question.
|
53 |
+
temperature (float): Controls randomness in response generation.
|
54 |
+
max_new_tokens (int): The maximum number of tokens to generate.
|
55 |
+
top_p (float): Nucleus sampling parameter controlling diversity.
|
56 |
+
repetition_penalty (float): Penalty for repeating tokens.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
str: The generated response to the input prompt.
|
60 |
+
"""
|
61 |
+
|
62 |
+
# Adjust temperature and top_p values within acceptable ranges
|
63 |
+
temperature = float(temperature)
|
64 |
+
if temperature < 1e-2:
|
65 |
+
temperature = 1e-2
|
66 |
+
top_p = float(top_p)
|
67 |
+
|
68 |
+
generate_kwargs = dict(
|
69 |
+
temperature=temperature,
|
70 |
+
max_new_tokens=max_new_tokens,
|
71 |
+
top_p=top_p,
|
72 |
+
repetition_penalty=repetition_penalty,
|
73 |
+
do_sample=True,
|
74 |
+
seed=42,
|
75 |
+
)
|
76 |
+
# Simulating a call to a client's text generation API
|
77 |
+
formatted_prompt =self.format_prompt(prompt)
|
78 |
+
stream =self.client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
79 |
+
output = ""
|
80 |
+
|
81 |
+
for response in stream:
|
82 |
+
output += response.token.text
|
83 |
+
|
84 |
+
return output.replace("</s>","")
|
85 |
+
|
86 |
+
def extract_unique_sentences(self, text: str) -> Set[str]:
|
87 |
+
"""
|
88 |
+
Extracts unique sentences from the given text.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
text (str): The input text.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
Set[str]: A set containing unique sentences.
|
95 |
+
"""
|
96 |
+
try:
|
97 |
+
# Tokenize the text into sentences using regex
|
98 |
+
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
|
99 |
+
logging.info("Sentence extraction completed successfully.")
|
100 |
+
# Return a list of sentences
|
101 |
+
return sentences
|
102 |
+
except Exception as e:
|
103 |
+
logging.error(f"Error occurred in extract_unique_sentences: {e}")
|
104 |
+
return set()
|
105 |
+
|
106 |
+
def find_different_sentences(self, text1: str, text2: str) -> List[Tuple[str, str]]:
|
107 |
+
"""
|
108 |
+
Finds sentences that are different between two texts.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
text1 (str): The first text.
|
112 |
+
text2 (str): The second text.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
List[Tuple[str, str]]: A list of tuples containing sentences and their labels.
|
116 |
+
"""
|
117 |
+
try:
|
118 |
+
sentences_text1 = self.extract_unique_sentences(text1)
|
119 |
+
sentences_text2 = self.extract_unique_sentences(text2)
|
120 |
+
# Initialize labels list
|
121 |
+
labels = []
|
122 |
+
# Iterate over sentences in text1
|
123 |
+
for sentence in sentences_text1:
|
124 |
+
if sentence in sentences_text2:
|
125 |
+
# If sentence is common to both texts, assign 'factual' label
|
126 |
+
labels.append((sentence, 'factual'))
|
127 |
+
else:
|
128 |
+
# If sentence is unique to text1, assign 'hallucinated' label
|
129 |
+
labels.append((sentence, 'hallucinated'))
|
130 |
+
logging.info("Sentence comparison completed successfully.")
|
131 |
+
return labels
|
132 |
+
except Exception as e:
|
133 |
+
logging.error(f"Error occurred in find_different_sentences: {e}")
|
134 |
+
return []
|
135 |
+
|
136 |
+
def extract_words(self, text: str) -> List[str]:
|
137 |
+
"""
|
138 |
+
Extracts words from the input text.
|
139 |
+
|
140 |
+
Parameters:
|
141 |
+
text (str): The input text.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
List[str]: A list containing the extracted words.
|
145 |
+
"""
|
146 |
+
try:
|
147 |
+
# Tokenize the text into words and non-word characters (including spaces) using regex
|
148 |
+
chunks = re.findall(r'\b\w+\b|\W+', text)
|
149 |
+
logging.info("Words extracted successfully.")
|
150 |
+
except Exception as e:
|
151 |
+
logging.error(f"An error occurred while extracting words: {str(e)}")
|
152 |
+
return []
|
153 |
+
else:
|
154 |
+
return chunks
|
155 |
+
|
156 |
+
def label_words(self, text1: str, text2: str) -> List[Tuple[str, str]]:
|
157 |
+
"""
|
158 |
+
Labels words in text1 as 'factual' if they are present in text2, otherwise 'hallucinated'.
|
159 |
+
|
160 |
+
Parameters:
|
161 |
+
text1 (str): The first text.
|
162 |
+
text2 (str): The second text.
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
List[Tuple[str, str]]: A list of tuples containing words from text1 and their labels.
|
166 |
+
"""
|
167 |
+
try:
|
168 |
+
# Extract chunks from both texts
|
169 |
+
chunks_text1 = self.extract_words(text1)
|
170 |
+
chunks_text2 = self.extract_words(text2)
|
171 |
+
# Convert chunks_text2 into a set for faster lookup
|
172 |
+
chunks_set_text2 = set(chunks_text2)
|
173 |
+
# Initialize labels list
|
174 |
+
labels = []
|
175 |
+
# Iterate over chunks in text1
|
176 |
+
for chunk in chunks_text1:
|
177 |
+
# Check if chunk is present in text2
|
178 |
+
if chunk in chunks_set_text2:
|
179 |
+
labels.append((chunk, 'factual'))
|
180 |
+
else:
|
181 |
+
labels.append((chunk, 'hallucinated'))
|
182 |
+
logging.info("Words labeled successfully.")
|
183 |
+
return labels
|
184 |
+
except Exception as e:
|
185 |
+
logging.error(f"An error occurred while labeling words: {str(e)}")
|
186 |
+
return []
|
187 |
+
|
188 |
+
def find_hallucinatted_sentence(self, question: str) -> Tuple[str, List[str]]:
|
189 |
+
"""
|
190 |
+
Finds hallucinated sentences in response to a given question.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
question (str): The input question.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
Tuple[str, List[str]]: A tuple containing the original llama_result and a list of hallucinated sentences.
|
197 |
+
"""
|
198 |
+
try:
|
199 |
+
|
200 |
+
# Generate initial response using contract generator
|
201 |
+
mixtral_response = self.mixtral_response(question)
|
202 |
+
|
203 |
+
# Create checker chain for summarization checking
|
204 |
+
checker_chain = LLMSummarizationCheckerChain.from_llm(self.llm, verbose=True, max_checks=2)
|
205 |
+
|
206 |
+
# Run fact checking on the generated result
|
207 |
+
fact_checking_result = checker_chain.run(mixtral_response)
|
208 |
+
|
209 |
+
# Find different sentences between original result and fact checking result
|
210 |
+
prediction_list = self.find_different_sentences(mixtral_response, fact_checking_result)
|
211 |
+
|
212 |
+
#word prediction list
|
213 |
+
word_prediction_list = self.label_words(mixtral_response, fact_checking_result)
|
214 |
+
|
215 |
+
logging.info("Sentences comparison completed successfully.")
|
216 |
+
# Return the original result and list of hallucinated sentences
|
217 |
+
return mixtral_response,fact_checking_result,prediction_list,word_prediction_list
|
218 |
+
|
219 |
+
except Exception as e:
|
220 |
+
logging.error(f"Error occurred in find_hallucinatted_sentence: {e}")
|
221 |
+
return "", []
|
222 |
+
|
223 |
+
def interface(self):
|
224 |
+
css=""".gradio-container {background: rgb(157,228,255);
|
225 |
+
background: radial-gradient(circle, rgba(157,228,255,1) 0%, rgba(18,115,106,1) 100%);}"""
|
226 |
+
|
227 |
+
with gr.Blocks(css=css) as demo:
|
228 |
+
gr.HTML("""
|
229 |
+
<center><h1 style="color:#fff">Detect Hallucination</h1></center>""")
|
230 |
+
with gr.Row():
|
231 |
+
question = gr.Textbox(label="Question")
|
232 |
+
with gr.Row():
|
233 |
+
button = gr.Button(value="Submit")
|
234 |
+
with gr.Row():
|
235 |
+
with gr.Column(scale=0.50):
|
236 |
+
mixtral_response = gr.Textbox(label="llm answer")
|
237 |
+
with gr.Column(scale=0.50):
|
238 |
+
fact_checking_result = gr.Textbox(label="Corrected Result")
|
239 |
+
with gr.Row():
|
240 |
+
with gr.Column(scale=0.50):
|
241 |
+
highlighted_prediction = gr.HighlightedText(
|
242 |
+
label="Sentence Hallucination detection",
|
243 |
+
combine_adjacent=True,
|
244 |
+
color_map={"hallucinated": "red", "factual": "green"},
|
245 |
+
show_legend=True)
|
246 |
+
with gr.Column(scale=0.50):
|
247 |
+
word_highlighted_prediction = gr.HighlightedText(
|
248 |
+
label="Word Hallucination detection",
|
249 |
+
combine_adjacent=True,
|
250 |
+
color_map={"hallucinated": "red", "factual": "green"},
|
251 |
+
show_legend=True)
|
252 |
+
button.click(self.find_hallucinatted_sentence,question,[mixtral_response,fact_checking_result,highlighted_prediction,word_highlighted_prediction])
|
253 |
+
demo.launch(debug=True)
|
254 |
+
|
255 |
+
|
256 |
+
hallucination_detection = FactChecking()
|
257 |
+
hallucination_detection.interface()
|