m7mdal7aj commited on
Commit
c996cf4
1 Parent(s): 67b9883

Create dataset_analysis.py

Browse files
Files changed (1) hide show
  1. my_model/tabs/dataset_analysis.py +177 -0
my_model/tabs/dataset_analysis.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ from collections import Counter
4
+ import contractions
5
+ import csv
6
+ import altair as alt
7
+ from typing import Tuple, List, Optional
8
+ from my_model.dataset.dataset_processor import process_okvqa_dataset
9
+ from my_model.config import dataset_config as config
10
+
11
+ class OKVQADatasetAnalyzer:
12
+ """
13
+ Provides tools for analyzing and visualizing distributions of question types within given question datasets.
14
+ It supports operations such as data loading, categorization of questions based on keywords, visualization of q
15
+ uestion distribution, and exporting data to CSV files.
16
+
17
+ Attributes:
18
+ train_file_path (str): Path to the training dataset file.
19
+ test_file_path (str): Path to the testing dataset file.
20
+ data_choice (str): Choice of dataset(s) to analyze; options include 'train', 'test', or 'train_test'.
21
+ questions (List[str]): List of questions aggregated based on the dataset choice.
22
+ question_types (Counter): Counter object tracking the frequency of each question type.
23
+ Qs (Dict[str, List[str]]): Dictionary mapping question types to lists of corresponding questions.
24
+ """
25
+
26
+ def __init__(self, train_file_path: str, test_file_path: str, data_choice: str):
27
+ """
28
+ Initializes the OKVQADatasetAnalyzer with paths to dataset files and a choice of which datasets to analyze.
29
+
30
+ Parameters:
31
+ train_file_path (str): Path to the training dataset JSON file. This file should contain a list of questions.
32
+ test_file_path (str): Path to the testing dataset JSON file. This file should also contain a list of
33
+ questions.
34
+ data_choice (str): Specifies which dataset(s) to load and analyze. Valid options are 'train', 'test', or
35
+ 'train_test'indicating whether to load training data, testing data, or both.
36
+
37
+ The constructor initializes the paths, selects the dataset based on the choice, and loads the initial data by
38
+ calling the `load_data` method.
39
+ It also prepares structures for categorizing questions and storing the results.
40
+ """
41
+
42
+ self.train_file_path = train_file_path
43
+ self.test_file_path = test_file_path
44
+ self.data_choice = data_choice
45
+ self.questions = []
46
+ self.question_types = Counter()
47
+ self.Qs = {keyword: [] for keyword in config.QUESTION_KEYWORDS}
48
+ self.load_data()
49
+
50
+ def load_data(self) -> None:
51
+ """
52
+ Loads the dataset(s) from the specified JSON file(s) based on the user's choice of 'train', 'test', or
53
+ 'train_test'.
54
+ This method updates the internal list of questions depending on the chosen dataset.
55
+ """
56
+
57
+ if self.data_choice in ['train', 'train_test']:
58
+ with open(self.train_file_path, 'r') as file:
59
+ train_data = json.load(file)
60
+ self.questions += [q['question'] for q in train_data['questions']]
61
+
62
+ if self.data_choice in ['test', 'train_test']:
63
+ with open(self.test_file_path, 'r') as file:
64
+ test_data = json.load(file)
65
+ self.questions += [q['question'] for q in test_data['questions']]
66
+
67
+ def categorize_questions(self) -> None:
68
+ """
69
+ Categorizes each question in the loaded data into predefined categories based on keywords.
70
+ This method updates the internal dictionary `self.Qs` and the Counter `self.question_types` with categorized
71
+ questions.
72
+ """
73
+
74
+ question_keywords = config.QUESTION_KEYWORDS
75
+
76
+ for question in self.questions:
77
+ question = contractions.fix(question)
78
+ words = question.lower().split()
79
+ question_keyword = None
80
+ if words[:2] == ['name', 'the']:
81
+ question_keyword = 'name the'
82
+ else:
83
+ for word in words:
84
+ if word in question_keywords:
85
+ question_keyword = word
86
+ break
87
+ if question_keyword:
88
+ self.question_types[question_keyword] += 1
89
+ self.Qs[question_keyword].append(question)
90
+ else:
91
+ self.question_types["others"] += 1
92
+ self.Qs["others"].append(question)
93
+
94
+ def plot_question_distribution(self) -> None:
95
+ """
96
+ Plots an interactive bar chart of question types using Altair and Streamlit, displaying the count and percentage
97
+ of each type.
98
+ The chart sorts question types by count in descending order and includes detailed tooltips for interaction.
99
+ This method is intended for visualization in a Streamlit application.
100
+ """
101
+
102
+ # Prepare data
103
+ total_questions = sum(self.question_types.values())
104
+ items = [(key, value, (value / total_questions) * 100) for key, value in self.question_types.items()]
105
+ df = pd.DataFrame(items, columns=['Question Keyword', 'Count', 'Percentage'])
106
+
107
+ # Sort data and handle 'others' category specifically if present
108
+ df = df[df['Question Keyword'] != 'others'].sort_values('Count', ascending=False)
109
+ if 'others' in self.question_types:
110
+ others_df = pd.DataFrame([('others', self.question_types['others'],
111
+ (self.question_types['others'] / total_questions) * 100)],
112
+ columns=['Question Keyword', 'Count', 'Percentage'])
113
+ df = pd.concat([df, others_df], ignore_index=True)
114
+
115
+ # Explicitly set the order of the x-axis based on the sorted DataFrame
116
+ order = df['Question Keyword'].tolist()
117
+
118
+ # Create the bar chart
119
+ bars = alt.Chart(df).mark_bar().encode(
120
+ x=alt.X('Question Keyword:N', sort=order, title='Question Keyword', axis=alt.Axis(labelAngle=-45)),
121
+ y=alt.Y('Count:Q', title='Frequency'),
122
+ color=alt.Color('Question Keyword:N', scale=alt.Scale(scheme='category20'), legend=None),
123
+ tooltip=[alt.Tooltip('Question Keyword:N', title='Type'),
124
+ alt.Tooltip('Count:Q', title='Count'),
125
+ alt.Tooltip('Percentage:Q', title='Percentage', format='.1f')]
126
+ )
127
+
128
+ # Create text labels for the bars with count and percentage
129
+ text = bars.mark_text(
130
+ align='center',
131
+ baseline='bottom',
132
+ dy=-5 # Nudges text up so it appears above the bar
133
+ ).encode(
134
+ text=alt.Text('PercentageText:N')
135
+ ).transform_calculate(
136
+ PercentageText="datum.Count + ' (' + format(datum.Percentage, '.1f') + '%)'"
137
+ )
138
+
139
+ # Combine the bar and text layers
140
+ chart = (bars + text).properties(
141
+ width=700,
142
+ height=400,
143
+ title='Distribution of Question Keywords'
144
+ ).configure_title(fontSize=20).configure_axis(
145
+ labelFontSize=12,
146
+ titleFontSize=14
147
+ )
148
+
149
+ # Display the chart in Streamlit
150
+ st.altair_chart(chart, use_container_width=True)
151
+
152
+ def export_to_csv(self, qs_filename: str, question_types_filename: str) -> None:
153
+ """
154
+ Exports the categorized questions and their counts to two separate CSV files.
155
+
156
+ Parameters:
157
+ qs_filename (str): The filename or path for exporting the `self.Qs` dictionary data.
158
+ question_types_filename (str): The filename or path for exporting the `self.question_types` Counter data.
159
+
160
+ This method writes the contents of `self.Qs` and `self.question_types` to the specified files in CSV format.
161
+ Each CSV file includes headers for better understanding and use of the exported data.
162
+ """
163
+
164
+ # Export self.Qs dictionary
165
+ with open(qs_filename, mode='w', newline='', encoding='utf-8') as file:
166
+ writer = csv.writer(file)
167
+ writer.writerow(['Question Type', 'Questions'])
168
+ for q_type, questions in self.Qs.items():
169
+ for question in questions:
170
+ writer.writerow([q_type, question])
171
+
172
+ # Export self.question_types Counter
173
+ with open(question_types_filename, mode='w', newline='', encoding='utf-8') as file:
174
+ writer = csv.writer(file)
175
+ writer.writerow(['Question Type', 'Count'])
176
+ for q_type, count in self.question_types.items():
177
+ writer.writerow([q_type, count])