RamAnanth1 commited on
Commit
a12dda5
1 Parent(s): 70bf707

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -30
app.py CHANGED
@@ -39,8 +39,8 @@ def get_conference_notes(venue, blind_submission=False):
39
 
40
  raw_notes = get_conference_notes(venue, blind_submission=True)
41
 
42
- st.set_page_config(page_title="ICLR2023 Papers Visualization", page_icon="🐞", layout="centered")
43
 
 
44
  st.write("Number of submissions at ICLR 2023:", len(raw_notes))
45
 
46
  df_raw = pd.json_normalize(raw_notes)
@@ -50,42 +50,42 @@ accepted_venues = ['ICLR 2023 poster', 'ICLR 2023 notable top 5%', 'ICLR 2023 no
50
  df = df_raw[df_raw["content.venue"].isin(accepted_venues)]
51
  st.write("Number of submissions accepted at ICLR 2023:", len(df))
52
 
53
- # df_filtered = df[['id', 'content.title', 'content.keywords', 'content.abstract']]
54
- # df = df_filtered
55
- # if "CO_API_KEY" not in os.environ:
56
- # raise KeyError("CO_API_KEY not found in st.secrets or os.environ. Please set it in "
57
- # ".streamlit/secrets.toml or as an environment variable.")
58
 
59
- # co = cohere.Client(os.environ["CO_API_KEY"])
60
 
61
- # def get_visualizations():
62
- # list_of_titles = list(df["content.title"].values)
63
- # embeds = co.embed(texts=list_of_titles,
64
- # model="small").embeddings
65
 
66
- # embeds_npy = np.array(embeds)
67
 
68
- # # Load and initialize BERTopic to use KMeans clustering with 8 clusters only.
69
- # cluster_model = KMeans(n_clusters=8)
70
- # topic_model = BERTopic(hdbscan_model=cluster_model)
71
 
72
- # # df is a dataframe. df['title'] is the column of text we're modeling
73
- # df['topic'], probabilities = topic_model.fit_transform(df['content.title'], embeds_npy)
74
 
75
- # app = Topically(os.environ["CO_API_KEY"])
76
 
77
- # df['topic_name'], topic_names = app.name_topics((df['content.title'], df['topic']), num_generations=5)
78
 
79
- # #st.write("Topics extracted are:", topic_names)
80
 
81
- # topic_model.set_topic_labels(topic_names)
82
- # fig1 = topic_model.visualize_documents(df['content.title'].values,
83
- # embeddings=embeds_npy,
84
- # topics = list(range(8)),
85
- # custom_labels=True)
86
- # topic_model.set_topic_labels(topic_names)
87
- # fig2 = topic_model.visualize_barchart(custom_labels=True)
88
- # st.plotly_chart(fig1)
89
- # st.plotly_chart(fig2)
90
 
91
- # st.button("Run Visualization", on_click=get_visualizations)
 
39
 
40
  raw_notes = get_conference_notes(venue, blind_submission=True)
41
 
 
42
 
43
+ st.title("ICLR2023 Papers Visualization")
44
  st.write("Number of submissions at ICLR 2023:", len(raw_notes))
45
 
46
  df_raw = pd.json_normalize(raw_notes)
 
50
  df = df_raw[df_raw["content.venue"].isin(accepted_venues)]
51
  st.write("Number of submissions accepted at ICLR 2023:", len(df))
52
 
53
+ df_filtered = df[['id', 'content.title', 'content.keywords', 'content.abstract']]
54
+ df = df_filtered
55
+ if "CO_API_KEY" not in os.environ:
56
+ raise KeyError("CO_API_KEY not found in st.secrets or os.environ. Please set it in "
57
+ ".streamlit/secrets.toml or as an environment variable.")
58
 
59
+ co = cohere.Client(os.environ["CO_API_KEY"])
60
 
61
+ def get_visualizations():
62
+ list_of_titles = list(df["content.title"].values)
63
+ embeds = co.embed(texts=list_of_titles,
64
+ model="small").embeddings
65
 
66
+ embeds_npy = np.array(embeds)
67
 
68
+ # Load and initialize BERTopic to use KMeans clustering with 8 clusters only.
69
+ cluster_model = KMeans(n_clusters=8)
70
+ topic_model = BERTopic(hdbscan_model=cluster_model)
71
 
72
+ # df is a dataframe. df['title'] is the column of text we're modeling
73
+ df['topic'], probabilities = topic_model.fit_transform(df['content.title'], embeds_npy)
74
 
75
+ app = Topically(os.environ["CO_API_KEY"])
76
 
77
+ df['topic_name'], topic_names = app.name_topics((df['content.title'], df['topic']), num_generations=5)
78
 
79
+ #st.write("Topics extracted are:", topic_names)
80
 
81
+ topic_model.set_topic_labels(topic_names)
82
+ fig1 = topic_model.visualize_documents(df['content.title'].values,
83
+ embeddings=embeds_npy,
84
+ topics = list(range(8)),
85
+ custom_labels=True)
86
+ topic_model.set_topic_labels(topic_names)
87
+ fig2 = topic_model.visualize_barchart(custom_labels=True)
88
+ st.plotly_chart(fig1)
89
+ st.plotly_chart(fig2)
90
 
91
+ st.button("Run Visualization", on_click=get_visualizations)