Spaces:
Runtime error
Runtime error
import plotly.graph_objs as go | |
from sklearn.cluster import KMeans | |
from sklearn.decomposition import PCA | |
import plotly.express as px | |
import numpy as np | |
import os | |
import pprint | |
import codecs | |
import chardet | |
import gradio as gr | |
from langchain.llms import HuggingFacePipeline | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate | |
from langchain.chains.conversation.memory import ConversationalBufferWindowMemory | |
from EdgeGPT import Chatbot | |
def get_content(input_file): | |
# Read the input file in binary mode | |
with open(input_file, 'rb') as f: | |
raw_data = f.read() | |
# Detect the encoding of the file | |
result = chardet.detect(raw_data) | |
encoding = result['encoding'] | |
# Decode the contents using the detected encoding | |
with codecs.open(input_file, 'r', encoding=encoding) as f: | |
raw_text = f.read() | |
# Return the content of the input file | |
return raw_text | |
def split_text(input_file, chunk_size=1000, chunk_overlap=0): | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, | |
chunk_overlap=chunk_overlap, | |
length_function=len, | |
) | |
basename = os.path.basename(input_file) | |
basename = os.path.splitext(basename)[0] | |
raw_text = get_content(input_file=input_file) | |
texts = text_splitter.split_text(text=raw_text) | |
metadatas = [{"source": f"{basename}[{i}]"} for i in range(len(texts))] | |
docs = text_splitter.create_documents(texts=texts, metadatas=metadatas) | |
return texts, metadatas, docs | |
def create_docs(input_file): | |
# Create a text splitter object with a separator character | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=0, | |
length_function=len, | |
) | |
basename = os.path.basename(input_file) | |
basename = os.path.splitext(basename)[0] | |
texts = get_content(input_file=input_file) | |
metadatas = {'source': basename} | |
docs = text_splitter.create_documents(texts=[texts], metadatas=[metadatas]) | |
return docs | |
def get_similar_docs(query, index, k=5): | |
similar_docs = index.similarity_search(query=query, k=k) | |
result = [(d.summary, d.metadata) for d in similar_docs] | |
return result | |
def convert_to_html(similar_docs): | |
result = [] | |
for summary, metadata in similar_docs: | |
record = '<tr><td>' + summary + '</td><td>' + \ | |
metadata['source'] + '</td></tr>' | |
result.append(record) | |
html = '<table><thead><th>Page Content</th><th>Source</th></thead><tbody>' + \ | |
'\n'.join(result) + '</tbody></table>' | |
return html | |
def create_similarity_plot(embeddings, labels, query, n_clusters=3): | |
# Only include embeddings that have corresponding labels | |
embeddings_with_labels = [ | |
embedding for i, embedding in enumerate(embeddings) if i < len(labels)] | |
# Reduce the dimensionality of the embeddings using PCA | |
pca = PCA(n_components=3) | |
pca_embeddings = pca.fit_transform(embeddings_with_labels) | |
# Cluster the embeddings using k-means | |
kmeans = KMeans(n_clusters=n_clusters) | |
kmeans.fit(embeddings_with_labels) | |
# Create a trace for the query point | |
query_trace = go.Scatter3d( | |
x=[pca_embeddings[-1, 0]], | |
y=[pca_embeddings[-1, 1]], | |
z=[pca_embeddings[-1, 2]], | |
mode='markers', | |
marker=dict( | |
color='black', | |
symbol='diamond', | |
size=10 | |
), | |
name=f"Query: '{query}'" | |
) | |
# Create a trace for the other points | |
points_trace = go.Scatter3d( | |
x=pca_embeddings[:, 0], | |
y=pca_embeddings[:, 1], | |
z=pca_embeddings[:, 2], | |
mode='markers', | |
marker=dict( | |
color=kmeans.labels_, | |
colorscale=px.colors.qualitative.Alphabet, | |
size=5 | |
), | |
text=labels, | |
name='Points' | |
) | |
# Create the figure | |
fig = go.Figure(data=[query_trace, points_trace]) | |
# Add a title and legend | |
fig.update_layout( | |
title="3D Similarity Plot", | |
legend_title_text="Cluster" | |
) | |
# Show the plot | |
fig.show() | |
def plot_similarities(query, index, embeddings=HuggingFaceEmbeddings(), k=5): | |
query_embeddings = embeddings.embed_query(text=query) | |
similar_docs = get_similar_docs(query=query, index=index, k=k) | |
texts = [] | |
for d in similar_docs: | |
texts.append(d[0]) | |
embeddings_array = embeddings.embed_documents(texts=texts) | |
# Get the index of the query point | |
query_index = len(embeddings_array) - 1 | |
create_similarity_plot( | |
embeddings=embeddings_array, | |
labels=texts, | |
query_index=query_index, | |
n_clusters=3 | |
) | |
def start_ui(index): | |
def query_index(query): | |
similar_docs = get_similar_docs(query=query, index=index) | |
formatted_output = convert_to_html(similar_docs=similar_docs) | |
return formatted_output | |
# Define input and output types | |
input = gr.inputs.Textbox(lines=2) | |
output = gr.outputs.HTML() | |
# Create interface object | |
iface = gr.Interface(fn=query_index, | |
inputs=input, | |
outputs=output) | |
# Launch interface | |
iface.launch() |