from logging import getLogger from pathlib import Path import pandas as pd import plotly.express as px import streamlit as st from st_aggrid import AgGrid, ColumnsAutoSizeMode, GridOptionsBuilder from streamlit_plotly_events import plotly_events from utilities import initialization initialization() # @st.cache(show_spinner=False) # def initialize_state(): # with st.spinner("Loading app..."): # if 'model' not in st.session_state: # model = Top2Vec.load('models/model.pkl') # model._check_model_status() # model.hierarchical_topic_reduction(num_topics=20) # # st.session_state.model = model # st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav') # logger.info("loading data...") # # if 'data' not in st.session_state: # logger.info("loading data...") # data = pd.read_csv(proj_dir / 'data' / 'data.csv') # data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}') # st.session_state.data = data # st.session_state.selected_data = data # st.session_state.all_topics = list(data.topic_id.unique()) # # if 'topics' not in st.session_state: # logger.info("loading topics...") # topics = pd.read_csv(proj_dir / 'data' / 'topics.csv') # topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}') # st.session_state.topics = topics def reset(): logger.info("Resetting...") st.session_state.selected_data = st.session_state.data st.session_state.selected_points = [] def filter_df(): if st.session_state.selected_points: points_df = pd.DataFrame(st.session_state.selected_points).loc[:, ['x', 'y']] st.session_state.selected_data = st.session_state.data.merge(points_df, on=['x', 'y']) logger.info(f"Updates selected_data: {len(st.session_state.selected_data)}") else: logger.info(f"Lame") def reset(): st.session_state.selected_data = st.session_state.data st.session_state.selected_points = [] def main(): st.write(""" # Topic Modeling This shows a 2d representation of documents embeded in a semantic space. Each dot is a document and the dots close represent documents that are close in meaning. Zoom in and explore a topic of your choice. You can see the documents you select with the `lasso` or `box` tool below in the corresponding tabs.""" ) st.button("Reset", help="Will Reset the selected points and the selected topics", on_click=reset) data_to_model = st.session_state.data.sort_values(by='topic_id', ascending=True) # to make legend sorted https://bioinformatics.stackexchange.com/a/18847 data_to_model['topic_id'].replace(st.session_state.topic_str_to_word, inplace=True) fig = px.scatter(data_to_model, x='x', y='y', color='topic_id', template='plotly_dark', hover_data=['id', 'topic_id', 'x', 'y']) st.session_state.selected_points = plotly_events(fig, select_event=True, click_event=False) filter_df() tab1, tab2 = st.tabs(["Docs", "Topics"]) with tab1: if st.session_state.selected_points: filter_df() cols = ['id', 'topic_id', 'documents'] data = st.session_state.selected_data[cols] data['topic_word'] = data.topic_id.replace(st.session_state.topic_str_to_word) ordered_cols = ['id', 'topic_id', 'topic_word', 'documents'] builder = GridOptionsBuilder.from_dataframe(data[ordered_cols]) builder.configure_pagination() go = builder.build() AgGrid(st.session_state.selected_data[cols], theme='streamlit', gridOptions=go, columns_auto_size_mode=ColumnsAutoSizeMode.FIT_CONTENTS) else: st.markdown('Select points in the graph with the `lasso` or `box` select tools to populate this table.') def get_topics_counts() -> pd.DataFrame: topic_counts = st.session_state.selected_data["topic_id"].value_counts().to_frame() merged = topic_counts.merge(st.session_state.topics, left_index=True, right_on='topic_id') cleaned = merged.drop(['topic_id_y'], axis=1).rename({'topic_id_x': 'topic_count'}, axis=1) cols = ['topic_id'] + [col for col in cleaned.columns if col != 'topic_id'] return cleaned[cols] with tab2: if st.session_state.selected_points: filter_df() cols = ['topic_id', 'topic_count', 'topic_0'] topic_counts = get_topics_counts() # st.write(topic_counts.columns) builder = GridOptionsBuilder.from_dataframe(topic_counts[cols]) builder.configure_pagination() builder.configure_column('topic_0', header_name='Topic Word', wrap_text=True) go = builder.build() AgGrid(topic_counts.loc[:, cols], theme='streamlit', gridOptions=go, columns_auto_size_mode=ColumnsAutoSizeMode.FIT_ALL_COLUMNS_TO_VIEW) else: st.markdown('Select points in the graph with the `lasso` or `box` select tools to populate this table.') if __name__ == "__main__": # Setting up Logger and proj_dir logger = getLogger(__name__) proj_dir = Path(__file__).parents[2] # For max width tables pd.set_option('display.max_colwidth', 0) # Streamlit settings # st.set_page_config(layout="wide") md_title = "# Document Explorer 📖" st.markdown(md_title) st.sidebar.markdown(md_title) # initialize_state() main()