Baskar2005 commited on
Commit
06b19dd
1 Parent(s): d1a3805

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +257 -0
app.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ %%writefile app.py
2
+ import os
3
+ import locale
4
+ import transformers
5
+ import torch
6
+ import json
7
+ import logging
8
+ from typing import List
9
+ from huggingface_hub import login
10
+ from huggingface_hub import InferenceClient
11
+ from langchain_openai import AzureChatOpenAI
12
+ import gradio as gr
13
+ from langchain.chains import LLMSummarizationCheckerChain
14
+ import os
15
+ import re
16
+ from typing import Set, List, Tuple
17
+ huggingface_key = os.getenv('HUGGINGFACE_KEY')
18
+ login(huggingface_key) # Huggingface api token
19
+
20
+ # Configure logging
21
+ logging.basicConfig(filename='factchecking.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
22
+ class FactChecking:
23
+
24
+ def __init__(self):
25
+
26
+ self.llm = AzureChatOpenAI(
27
+ azure_deployment = "ChatGPT"
28
+ )
29
+
30
+ self.client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
31
+
32
+ def format_prompt(self, question: str) -> str:
33
+ """
34
+ Formats the input question into a specific structure for text generation.
35
+
36
+ Args:
37
+ question (str): The user's question to be formatted.
38
+
39
+ Returns:
40
+ str: The formatted prompt including instructions and the question.
41
+ """
42
+ # Combine the instruction template with the user's question
43
+ prompt = f"[INST] you are the ai assitant your task is answr for the user question[/INST]"
44
+ prompt1 = f"[INST] {question} [/INST]"
45
+ return prompt+prompt1
46
+
47
+ def mixtral_response(self,prompt, temperature=0.9, max_new_tokens=5000, top_p=0.95, repetition_penalty=1.0):
48
+ """
49
+ Generates a response to the given prompt using text generation parameters.
50
+
51
+ Args:
52
+ prompt (str): The user's question.
53
+ temperature (float): Controls randomness in response generation.
54
+ max_new_tokens (int): The maximum number of tokens to generate.
55
+ top_p (float): Nucleus sampling parameter controlling diversity.
56
+ repetition_penalty (float): Penalty for repeating tokens.
57
+
58
+ Returns:
59
+ str: The generated response to the input prompt.
60
+ """
61
+
62
+ # Adjust temperature and top_p values within acceptable ranges
63
+ temperature = float(temperature)
64
+ if temperature < 1e-2:
65
+ temperature = 1e-2
66
+ top_p = float(top_p)
67
+
68
+ generate_kwargs = dict(
69
+ temperature=temperature,
70
+ max_new_tokens=max_new_tokens,
71
+ top_p=top_p,
72
+ repetition_penalty=repetition_penalty,
73
+ do_sample=True,
74
+ seed=42,
75
+ )
76
+ # Simulating a call to a client's text generation API
77
+ formatted_prompt =self.format_prompt(prompt)
78
+ stream =self.client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
79
+ output = ""
80
+
81
+ for response in stream:
82
+ output += response.token.text
83
+
84
+ return output.replace("</s>","")
85
+
86
+ def extract_unique_sentences(self, text: str) -> Set[str]:
87
+ """
88
+ Extracts unique sentences from the given text.
89
+
90
+ Args:
91
+ text (str): The input text.
92
+
93
+ Returns:
94
+ Set[str]: A set containing unique sentences.
95
+ """
96
+ try:
97
+ # Tokenize the text into sentences using regex
98
+ sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
99
+ logging.info("Sentence extraction completed successfully.")
100
+ # Return a list of sentences
101
+ return sentences
102
+ except Exception as e:
103
+ logging.error(f"Error occurred in extract_unique_sentences: {e}")
104
+ return set()
105
+
106
+ def find_different_sentences(self, text1: str, text2: str) -> List[Tuple[str, str]]:
107
+ """
108
+ Finds sentences that are different between two texts.
109
+
110
+ Args:
111
+ text1 (str): The first text.
112
+ text2 (str): The second text.
113
+
114
+ Returns:
115
+ List[Tuple[str, str]]: A list of tuples containing sentences and their labels.
116
+ """
117
+ try:
118
+ sentences_text1 = self.extract_unique_sentences(text1)
119
+ sentences_text2 = self.extract_unique_sentences(text2)
120
+ # Initialize labels list
121
+ labels = []
122
+ # Iterate over sentences in text1
123
+ for sentence in sentences_text1:
124
+ if sentence in sentences_text2:
125
+ # If sentence is common to both texts, assign 'factual' label
126
+ labels.append((sentence, 'factual'))
127
+ else:
128
+ # If sentence is unique to text1, assign 'hallucinated' label
129
+ labels.append((sentence, 'hallucinated'))
130
+ logging.info("Sentence comparison completed successfully.")
131
+ return labels
132
+ except Exception as e:
133
+ logging.error(f"Error occurred in find_different_sentences: {e}")
134
+ return []
135
+
136
+ def extract_words(self, text: str) -> List[str]:
137
+ """
138
+ Extracts words from the input text.
139
+
140
+ Parameters:
141
+ text (str): The input text.
142
+
143
+ Returns:
144
+ List[str]: A list containing the extracted words.
145
+ """
146
+ try:
147
+ # Tokenize the text into words and non-word characters (including spaces) using regex
148
+ chunks = re.findall(r'\b\w+\b|\W+', text)
149
+ logging.info("Words extracted successfully.")
150
+ except Exception as e:
151
+ logging.error(f"An error occurred while extracting words: {str(e)}")
152
+ return []
153
+ else:
154
+ return chunks
155
+
156
+ def label_words(self, text1: str, text2: str) -> List[Tuple[str, str]]:
157
+ """
158
+ Labels words in text1 as 'factual' if they are present in text2, otherwise 'hallucinated'.
159
+
160
+ Parameters:
161
+ text1 (str): The first text.
162
+ text2 (str): The second text.
163
+
164
+ Returns:
165
+ List[Tuple[str, str]]: A list of tuples containing words from text1 and their labels.
166
+ """
167
+ try:
168
+ # Extract chunks from both texts
169
+ chunks_text1 = self.extract_words(text1)
170
+ chunks_text2 = self.extract_words(text2)
171
+ # Convert chunks_text2 into a set for faster lookup
172
+ chunks_set_text2 = set(chunks_text2)
173
+ # Initialize labels list
174
+ labels = []
175
+ # Iterate over chunks in text1
176
+ for chunk in chunks_text1:
177
+ # Check if chunk is present in text2
178
+ if chunk in chunks_set_text2:
179
+ labels.append((chunk, 'factual'))
180
+ else:
181
+ labels.append((chunk, 'hallucinated'))
182
+ logging.info("Words labeled successfully.")
183
+ return labels
184
+ except Exception as e:
185
+ logging.error(f"An error occurred while labeling words: {str(e)}")
186
+ return []
187
+
188
+ def find_hallucinatted_sentence(self, question: str) -> Tuple[str, List[str]]:
189
+ """
190
+ Finds hallucinated sentences in response to a given question.
191
+
192
+ Args:
193
+ question (str): The input question.
194
+
195
+ Returns:
196
+ Tuple[str, List[str]]: A tuple containing the original llama_result and a list of hallucinated sentences.
197
+ """
198
+ try:
199
+
200
+ # Generate initial response using contract generator
201
+ mixtral_response = self.mixtral_response(question)
202
+
203
+ # Create checker chain for summarization checking
204
+ checker_chain = LLMSummarizationCheckerChain.from_llm(self.llm, verbose=True, max_checks=2)
205
+
206
+ # Run fact checking on the generated result
207
+ fact_checking_result = checker_chain.run(mixtral_response)
208
+
209
+ # Find different sentences between original result and fact checking result
210
+ prediction_list = self.find_different_sentences(mixtral_response, fact_checking_result)
211
+
212
+ #word prediction list
213
+ word_prediction_list = self.label_words(mixtral_response, fact_checking_result)
214
+
215
+ logging.info("Sentences comparison completed successfully.")
216
+ # Return the original result and list of hallucinated sentences
217
+ return mixtral_response,fact_checking_result,prediction_list,word_prediction_list
218
+
219
+ except Exception as e:
220
+ logging.error(f"Error occurred in find_hallucinatted_sentence: {e}")
221
+ return "", []
222
+
223
+ def interface(self):
224
+ css=""".gradio-container {background: rgb(157,228,255);
225
+ background: radial-gradient(circle, rgba(157,228,255,1) 0%, rgba(18,115,106,1) 100%);}"""
226
+
227
+ with gr.Blocks(css=css) as demo:
228
+ gr.HTML("""
229
+ <center><h1 style="color:#fff">Detect Hallucination</h1></center>""")
230
+ with gr.Row():
231
+ question = gr.Textbox(label="Question")
232
+ with gr.Row():
233
+ button = gr.Button(value="Submit")
234
+ with gr.Row():
235
+ with gr.Column(scale=0.50):
236
+ mixtral_response = gr.Textbox(label="llm answer")
237
+ with gr.Column(scale=0.50):
238
+ fact_checking_result = gr.Textbox(label="Corrected Result")
239
+ with gr.Row():
240
+ with gr.Column(scale=0.50):
241
+ highlighted_prediction = gr.HighlightedText(
242
+ label="Sentence Hallucination detection",
243
+ combine_adjacent=True,
244
+ color_map={"hallucinated": "red", "factual": "green"},
245
+ show_legend=True)
246
+ with gr.Column(scale=0.50):
247
+ word_highlighted_prediction = gr.HighlightedText(
248
+ label="Word Hallucination detection",
249
+ combine_adjacent=True,
250
+ color_map={"hallucinated": "red", "factual": "green"},
251
+ show_legend=True)
252
+ button.click(self.find_hallucinatted_sentence,question,[mixtral_response,fact_checking_result,highlighted_prediction,word_highlighted_prediction])
253
+ demo.launch(debug=True)
254
+
255
+
256
+ hallucination_detection = FactChecking()
257
+ hallucination_detection.interface()