owner-manual / InnovationHub /llm /vector_store.py
ctankso_americas_corpdir_net
feat: temperature adjustment
e0800e8
import plotly.graph_objs as go
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import plotly.express as px
import numpy as np
import os
import pprint
import codecs
import chardet
import gradio as gr
from langchain.llms import HuggingFacePipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate
from langchain.chains.conversation.memory import ConversationalBufferWindowMemory
from EdgeGPT import Chatbot
def get_content(input_file):
# Read the input file in binary mode
with open(input_file, 'rb') as f:
raw_data = f.read()
# Detect the encoding of the file
result = chardet.detect(raw_data)
encoding = result['encoding']
# Decode the contents using the detected encoding
with codecs.open(input_file, 'r', encoding=encoding) as f:
raw_text = f.read()
# Return the content of the input file
return raw_text
def split_text(input_file, chunk_size=1000, chunk_overlap=0):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
)
basename = os.path.basename(input_file)
basename = os.path.splitext(basename)[0]
raw_text = get_content(input_file=input_file)
texts = text_splitter.split_text(text=raw_text)
metadatas = [{"source": f"{basename}[{i}]"} for i in range(len(texts))]
docs = text_splitter.create_documents(texts=texts, metadatas=metadatas)
return texts, metadatas, docs
def create_docs(input_file):
# Create a text splitter object with a separator character
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=0,
length_function=len,
)
basename = os.path.basename(input_file)
basename = os.path.splitext(basename)[0]
texts = get_content(input_file=input_file)
metadatas = {'source': basename}
docs = text_splitter.create_documents(texts=[texts], metadatas=[metadatas])
return docs
def get_similar_docs(query, index, k=5):
similar_docs = index.similarity_search(query=query, k=k)
result = [(d.summary, d.metadata) for d in similar_docs]
return result
def convert_to_html(similar_docs):
result = []
for summary, metadata in similar_docs:
record = '<tr><td>' + summary + '</td><td>' + \
metadata['source'] + '</td></tr>'
result.append(record)
html = '<table><thead><th>Page Content</th><th>Source</th></thead><tbody>' + \
'\n'.join(result) + '</tbody></table>'
return html
def create_similarity_plot(embeddings, labels, query, n_clusters=3):
# Only include embeddings that have corresponding labels
embeddings_with_labels = [
embedding for i, embedding in enumerate(embeddings) if i < len(labels)]
# Reduce the dimensionality of the embeddings using PCA
pca = PCA(n_components=3)
pca_embeddings = pca.fit_transform(embeddings_with_labels)
# Cluster the embeddings using k-means
kmeans = KMeans(n_clusters=n_clusters)
kmeans.fit(embeddings_with_labels)
# Create a trace for the query point
query_trace = go.Scatter3d(
x=[pca_embeddings[-1, 0]],
y=[pca_embeddings[-1, 1]],
z=[pca_embeddings[-1, 2]],
mode='markers',
marker=dict(
color='black',
symbol='diamond',
size=10
),
name=f"Query: '{query}'"
)
# Create a trace for the other points
points_trace = go.Scatter3d(
x=pca_embeddings[:, 0],
y=pca_embeddings[:, 1],
z=pca_embeddings[:, 2],
mode='markers',
marker=dict(
color=kmeans.labels_,
colorscale=px.colors.qualitative.Alphabet,
size=5
),
text=labels,
name='Points'
)
# Create the figure
fig = go.Figure(data=[query_trace, points_trace])
# Add a title and legend
fig.update_layout(
title="3D Similarity Plot",
legend_title_text="Cluster"
)
# Show the plot
fig.show()
def plot_similarities(query, index, embeddings=HuggingFaceEmbeddings(), k=5):
query_embeddings = embeddings.embed_query(text=query)
similar_docs = get_similar_docs(query=query, index=index, k=k)
texts = []
for d in similar_docs:
texts.append(d[0])
embeddings_array = embeddings.embed_documents(texts=texts)
# Get the index of the query point
query_index = len(embeddings_array) - 1
create_similarity_plot(
embeddings=embeddings_array,
labels=texts,
query_index=query_index,
n_clusters=3
)
def start_ui(index):
def query_index(query):
similar_docs = get_similar_docs(query=query, index=index)
formatted_output = convert_to_html(similar_docs=similar_docs)
return formatted_output
# Define input and output types
input = gr.inputs.Textbox(lines=2)
output = gr.outputs.HTML()
# Create interface object
iface = gr.Interface(fn=query_index,
inputs=input,
outputs=output)
# Launch interface
iface.launch()