top2vec / app /pages /01_Topic_Explorer_πŸ“š.py
derek-thomas's picture
derek-thomas HF staff
Updating topic_word
356174d
raw
history blame
2.62 kB
from logging import getLogger
from pathlib import Path
import pandas as pd
import plotly.graph_objects as go
import streamlit as st
from utilities import initialization
initialization()
# @st.cache(show_spinner=False)
# def initialize_state():
# with st.spinner("Loading app..."):
# if 'model' not in st.session_state:
# model = Top2Vec.load('models/model.pkl')
# model._check_model_status()
# model.hierarchical_topic_reduction(num_topics=20)
#
# st.session_state.model = model
# st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
# logger.info("loading data...")
#
# if 'data' not in st.session_state:
# logger.info("loading data...")
# data = pd.read_csv(proj_dir / 'data' / 'data.csv')
# data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
# st.session_state.data = data
# st.session_state.selected_data = data
# st.session_state.all_topics = list(data.topic_id.unique())
#
# if 'topics' not in st.session_state:
# logger.info("loading topics...")
# topics = pd.read_csv(proj_dir / 'data' / 'topics.csv')
# topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
# st.session_state.topics = topics
def main():
st.write("""
A way to dive into each topic. Use the slider on the left to choose the topic.
The `y` axis shows which words are closest to a topic centroid. The `x` axis shows how correlated they are.""")
topic_num = st.sidebar.slider("Topic Number", 0, 19, value=0)
topic_num_str = f"{topic_num:02}"
fig = go.Figure(go.Bar(
x=st.session_state.model.topic_word_scores_reduced[topic_num][::-1],
y=st.session_state.model.topic_words_reduced[topic_num][::-1],
orientation='h'))
fig.update_layout(
title=f'Words for Topic {topic_num_str}: {st.session_state.topic_str_to_word[topic_num_str]}',
yaxis_title='Top 20 topic words',
xaxis_title='Distance to topic centroid'
)
st.plotly_chart(fig, True)
if __name__ == "__main__":
# Setting up Logger and proj_dir
logger = getLogger(__name__)
proj_dir = Path(__file__).parents[2]
# For max width tables
pd.set_option('display.max_colwidth', 0)
# Streamlit settings
# st.set_page_config(layout="wide")
md_title = "# Topic Explorer πŸ“š"
st.markdown(md_title)
st.sidebar.markdown(md_title)
# initialize_state()
main()