Spaces:
Sleeping
Sleeping
from logging import getLogger | |
from pathlib import Path | |
import pandas as pd | |
import plotly.graph_objects as go | |
import streamlit as st | |
from utilities import initialization | |
initialization() | |
# @st.cache(show_spinner=False) | |
# def initialize_state(): | |
# with st.spinner("Loading app..."): | |
# if 'model' not in st.session_state: | |
# model = Top2Vec.load('models/model.pkl') | |
# model._check_model_status() | |
# model.hierarchical_topic_reduction(num_topics=20) | |
# | |
# st.session_state.model = model | |
# st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav') | |
# logger.info("loading data...") | |
# | |
# if 'data' not in st.session_state: | |
# logger.info("loading data...") | |
# data = pd.read_csv(proj_dir / 'data' / 'data.csv') | |
# data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}') | |
# st.session_state.data = data | |
# st.session_state.selected_data = data | |
# st.session_state.all_topics = list(data.topic_id.unique()) | |
# | |
# if 'topics' not in st.session_state: | |
# logger.info("loading topics...") | |
# topics = pd.read_csv(proj_dir / 'data' / 'topics.csv') | |
# topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}') | |
# st.session_state.topics = topics | |
def main(): | |
st.write(""" | |
A way to dive into each topic. Use the slider on the left to choose the topic. | |
The `y` axis shows which words are closest to a topic centroid. The `x` axis shows how correlated they are.""") | |
topic_num = st.sidebar.slider("Topic Number", 0, 19, value=0) | |
topic_num_str = f"{topic_num:02}" | |
fig = go.Figure(go.Bar( | |
x=st.session_state.model.topic_word_scores_reduced[topic_num][::-1], | |
y=st.session_state.model.topic_words_reduced[topic_num][::-1], | |
orientation='h')) | |
fig.update_layout( | |
title=f'Words for Topic {topic_num_str}: {st.session_state.topic_str_to_word[topic_num_str]}', | |
yaxis_title='Top 20 topic words', | |
xaxis_title='Distance to topic centroid' | |
) | |
st.plotly_chart(fig, True) | |
if __name__ == "__main__": | |
# Setting up Logger and proj_dir | |
logger = getLogger(__name__) | |
proj_dir = Path(__file__).parents[2] | |
# For max width tables | |
pd.set_option('display.max_colwidth', 0) | |
# Streamlit settings | |
# st.set_page_config(layout="wide") | |
md_title = "# Topic Explorer π" | |
st.markdown(md_title) | |
st.sidebar.markdown(md_title) | |
# initialize_state() | |
main() | |