Spaces:
Runtime error
Runtime error
import re | |
import logging | |
import gradio as gr | |
from huggingface_hub import login | |
from typing import Set, List, Tuple | |
from huggingface_hub import InferenceClient | |
from langchain_openai import AzureChatOpenAI | |
from langchain.chains import LLMSummarizationCheckerChain | |
huggingface_key = os.getenv('HUGGINGFACE_KEY') | |
login(huggingface_key) # Huggingface api token | |
# Configure logging | |
logging.basicConfig(filename='factchecking.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
class FactChecking: | |
def __init__(self): | |
self.llm = AzureChatOpenAI( | |
azure_deployment = "ChatGPT" | |
) | |
self.client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
def format_prompt(self, question: str) -> str: | |
""" | |
Formats the input question into a specific structure for text generation. | |
Args: | |
question (str): The user's question to be formatted. | |
Returns: | |
str: The formatted prompt including instructions and the question. | |
""" | |
# Combine the instruction template with the user's question | |
prompt = f"[INST] you are the ai assitant your task is answr for the user question[/INST]" | |
prompt1 = f"[INST] {question} [/INST]" | |
return prompt+prompt1 | |
def mixtral_response(self,prompt, temperature=0.9, max_new_tokens=5000, top_p=0.95, repetition_penalty=1.0): | |
""" | |
Generates a response to the given prompt using text generation parameters. | |
Args: | |
prompt (str): The user's question. | |
temperature (float): Controls randomness in response generation. | |
max_new_tokens (int): The maximum number of tokens to generate. | |
top_p (float): Nucleus sampling parameter controlling diversity. | |
repetition_penalty (float): Penalty for repeating tokens. | |
Returns: | |
str: The generated response to the input prompt. | |
""" | |
# Adjust temperature and top_p values within acceptable ranges | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
# Simulating a call to a client's text generation API | |
formatted_prompt =self.format_prompt(prompt) | |
stream =self.client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
return output.replace("</s>","") | |
def extract_unique_sentences(self, text: str) -> Set[str]: | |
""" | |
Extracts unique sentences from the given text. | |
Args: | |
text (str): The input text. | |
Returns: | |
Set[str]: A set containing unique sentences. | |
""" | |
try: | |
# Tokenize the text into sentences using regex | |
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text) | |
logging.info("Sentence extraction completed successfully.") | |
# Return a list of sentences | |
return sentences | |
except Exception as e: | |
logging.error(f"Error occurred in extract_unique_sentences: {e}") | |
return set() | |
def find_different_sentences(self, text1: str, text2: str) -> List[Tuple[str, str]]: | |
""" | |
Finds sentences that are different between two texts. | |
Args: | |
text1 (str): The first text. | |
text2 (str): The second text. | |
Returns: | |
List[Tuple[str, str]]: A list of tuples containing sentences and their labels. | |
""" | |
try: | |
sentences_text1 = self.extract_unique_sentences(text1) | |
sentences_text2 = self.extract_unique_sentences(text2) | |
# Initialize labels list | |
labels = [] | |
# Iterate over sentences in text1 | |
for sentence in sentences_text1: | |
if sentence in sentences_text2: | |
# If sentence is common to both texts, assign 'factual' label | |
labels.append((sentence, 'factual')) | |
else: | |
# If sentence is unique to text1, assign 'hallucinated' label | |
labels.append((sentence, 'hallucinated')) | |
logging.info("Sentence comparison completed successfully.") | |
return labels | |
except Exception as e: | |
logging.error(f"Error occurred in find_different_sentences: {e}") | |
return [] | |
def extract_words(self, text: str) -> List[str]: | |
""" | |
Extracts words from the input text. | |
Parameters: | |
text (str): The input text. | |
Returns: | |
List[str]: A list containing the extracted words. | |
""" | |
try: | |
# Tokenize the text into words and non-word characters (including spaces) using regex | |
chunks = re.findall(r'\b\w+\b|\W+', text) | |
logging.info("Words extracted successfully.") | |
except Exception as e: | |
logging.error(f"An error occurred while extracting words: {str(e)}") | |
return [] | |
else: | |
return chunks | |
def label_words(self, text1: str, text2: str) -> List[Tuple[str, str]]: | |
""" | |
Labels words in text1 as 'factual' if they are present in text2, otherwise 'hallucinated'. | |
Parameters: | |
text1 (str): The first text. | |
text2 (str): The second text. | |
Returns: | |
List[Tuple[str, str]]: A list of tuples containing words from text1 and their labels. | |
""" | |
try: | |
# Extract chunks from both texts | |
chunks_text1 = self.extract_words(text1) | |
chunks_text2 = self.extract_words(text2) | |
# Convert chunks_text2 into a set for faster lookup | |
chunks_set_text2 = set(chunks_text2) | |
# Initialize labels list | |
labels = [] | |
# Iterate over chunks in text1 | |
for chunk in chunks_text1: | |
# Check if chunk is present in text2 | |
if chunk in chunks_set_text2: | |
labels.append((chunk, 'factual')) | |
else: | |
labels.append((chunk, 'hallucinated')) | |
logging.info("Words labeled successfully.") | |
return labels | |
except Exception as e: | |
logging.error(f"An error occurred while labeling words: {str(e)}") | |
return [] | |
def find_hallucinatted_sentence(self, question: str) -> Tuple[str, List[str]]: | |
""" | |
Finds hallucinated sentences in response to a given question. | |
Args: | |
question (str): The input question. | |
Returns: | |
Tuple[str, List[str]]: A tuple containing the original llama_result and a list of hallucinated sentences. | |
""" | |
try: | |
# Generate initial response using contract generator | |
mixtral_response = self.mixtral_response(question) | |
# Create checker chain for summarization checking | |
checker_chain = LLMSummarizationCheckerChain.from_llm(self.llm, verbose=True, max_checks=2) | |
# Run fact checking on the generated result | |
fact_checking_result = checker_chain.run(mixtral_response) | |
# Find different sentences between original result and fact checking result | |
prediction_list = self.find_different_sentences(mixtral_response, fact_checking_result) | |
#word prediction list | |
word_prediction_list = self.label_words(mixtral_response, fact_checking_result) | |
logging.info("Sentences comparison completed successfully.") | |
# Return the original result and list of hallucinated sentences | |
return mixtral_response,fact_checking_result,prediction_list,word_prediction_list | |
except Exception as e: | |
logging.error(f"Error occurred in find_hallucinatted_sentence: {e}") | |
return "", [] | |
def interface(self): | |
css=""".gradio-container {background: rgb(157,228,255); | |
background: radial-gradient(circle, rgba(157,228,255,1) 0%, rgba(18,115,106,1) 100%);}""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(""" | |
<center><h1 style="color:#fff">Detect Hallucination</h1></center>""") | |
with gr.Row(): | |
question = gr.Textbox(label="Question") | |
with gr.Row(): | |
button = gr.Button(value="Submit") | |
with gr.Row(): | |
with gr.Column(scale=0.50): | |
mixtral_response = gr.Textbox(label="llm answer") | |
with gr.Column(scale=0.50): | |
fact_checking_result = gr.Textbox(label="Corrected Result") | |
with gr.Row(): | |
with gr.Column(scale=0.50): | |
highlighted_prediction = gr.HighlightedText( | |
label="Sentence Hallucination detection", | |
combine_adjacent=True, | |
color_map={"hallucinated": "red", "factual": "green"}, | |
show_legend=True) | |
with gr.Column(scale=0.50): | |
word_highlighted_prediction = gr.HighlightedText( | |
label="Word Hallucination detection", | |
combine_adjacent=True, | |
color_map={"hallucinated": "red", "factual": "green"}, | |
show_legend=True) | |
button.click(self.find_hallucinatted_sentence,question,[mixtral_response,fact_checking_result,highlighted_prediction,word_highlighted_prediction]) | |
demo.launch(debug=True) | |
hallucination_detection = FactChecking() | |
hallucination_detection.interface() |