File size: 8,966 Bytes
a1f367e
009291b
17c1e65
 
 
2d85a77
17c1e65
 
 
 
 
b52c17b
17c1e65
 
a1f367e
17c1e65
 
 
 
 
 
 
 
a1f367e
 
 
 
 
 
 
 
 
17c1e65
 
 
 
 
 
 
a1f367e
 
17c1e65
a1f367e
17c1e65
 
a1f367e
17c1e65
a1f367e
17c1e65
 
 
a1f367e
 
 
17c1e65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1f367e
17c1e65
 
a1f367e
17c1e65
 
 
 
a1f367e
 
17c1e65
a1f367e
 
 
17c1e65
 
 
a1f367e
17c1e65
a1f367e
17c1e65
 
 
 
 
 
 
 
 
 
 
a1f367e
17c1e65
 
 
 
 
 
a1f367e
17c1e65
 
 
 
 
 
a1f367e
17c1e65
 
 
 
 
 
 
 
 
 
 
 
 
 
d8f32a4
 
17c1e65
 
 
 
 
 
 
 
 
 
 
a1f367e
17c1e65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1f367e
17c1e65
 
 
 
 
 
 
a1f367e
17c1e65
a1f367e
17c1e65
 
 
 
a1f367e
17c1e65
 
 
 
a1f367e
17c1e65
a1f367e
17c1e65
 
 
 
 
a1f367e
 
 
 
 
17c1e65
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from typing import Tuple, List
from my_model.utilities.gen_utilities import is_pycharm
import seaborn as sns
from transformers import AutoTokenizer
from datasets import Dataset, load_dataset
import my_model.config.fine_tuning_config as config
from my_model.LLAMA2.LLAMA2_model import Llama2ModelManager


class FinetuningDataHandler:
    """
    A class dedicated to handling data for fine-tuning LLaMA-2 Chat models. It manages loading,
    inspecting, preparing, and splitting the dataset, specifically designed to filter out
    data samples exceeding a specified token count limit. This is crucial for models with
    token count constraints and it helps control the level of GPU RAM tolerance based on the number of tokens,
    ensuring efficient and effective model fine-tuning.

    Attributes:
        tokenizer (AutoTokenizer): Tokenizer used for tokenizing the dataset.
        dataset_file (str): File path to the dataset.
        max_token_count (int): Maximum allowable token count per data sample.

    Methods:
        load_llm_tokenizer: Loads the LLM tokenizer and adds special tokens, if not already loaded.
        load_dataset: Loads the dataset from a specified file path.
        plot_tokens_count_distribution: Plots the distribution of token counts in the dataset.
        filter_dataset_by_indices: Filters the dataset based on valid indices, removing samples exceeding token limits.
        get_token_counts: Calculates token counts for each sample in the dataset.
        prepare_dataset: Tokenizes and filters the dataset, preparing it for training. Also visualizes token count 
                         distribution before and after filtering.
        split_dataset_for_train_eval: Divides the dataset into training and evaluation sets.
        inspect_prepare_split_data: Coordinates the data preparation and splitting process for fine-tuning.
    """

    def __init__(self, tokenizer: AutoTokenizer = None, dataset_file: str = config.DATASET_FILE) -> None:
        """
        Initializes the FinetuningDataHandler class.

        Args:
            tokenizer (AutoTokenizer, optional): Tokenizer to use for tokenizing the dataset. Defaults to None.
            dataset_file (str): Path to the dataset file. Defaults to config.DATASET_FILE.
        """

        self.tokenizer = tokenizer  # The tokenizer used for processing the dataset.
        self.dataset_file = dataset_file  # Path to the fine-tuning dataset file.
        self.max_token_count = config.MAX_TOKEN_COUNT  # Max token count for filtering set to 1,024.

    def load_llm_tokenizer(self) -> None:
        """
        Loads the LLM tokenizer and adds special tokens, if not already loaded.
        If the tokenizer is already loaded, this method does nothing.

        Returns:
            None
        """

        if self.tokenizer is None:
            llm_manager = Llama2ModelManager()  # Initialize Llama2 model manager.
            # we only need the tokenizer for the data inspection not the model itself.
            self.tokenizer = llm_manager.load_tokenizer()
            llm_manager.add_special_tokens()  # Add special tokens specific to LLAMA2 vocab for efficient tokenization.

    def load_dataset(self) -> Dataset:
        """
        Loads the dataset from the specified file path. The dataset is expected to be in CSV format.

        Returns:
            Dataset: The loaded dataset, ready for processing.
        """

        return load_dataset('csv', data_files=self.dataset_file)

    def plot_tokens_count_distribution(self, token_counts: List[int], title: str = "Token Count Distribution") -> None:
        """
        Plots the distribution of token counts in the dataset for visualization purposes.

        Args:
            token_counts (List[int]): List of token counts, each count representing the number of tokens in a dataset 
                                      sample.
            title (str): Title for the plot, highlighting the nature of the distribution.

        Returns:
            None
        """

        if is_pycharm():  # Ensuring compatibility with PyCharm's environment for interactive plot.
            import matplotlib  # The import is kept here intentionaly. 
            matplotlib.use('TkAgg')  # Set the backend to 'TkAgg'
        import matplotlib.pyplot as plt  # The import is kept here intentionaly. 
        sns.set_style("whitegrid")
        plt.figure(figsize=(15, 6))
        plt.hist(token_counts, bins=50, color='#3498db', edgecolor='black')
        plt.title(title, fontsize=16)
        plt.xlabel("Number of Tokens", fontsize=14)
        plt.ylabel("Number of Samples", fontsize=14)
        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
        plt.tight_layout()
        plt.show()

    def filter_dataset_by_indices(self, dataset: Dataset, valid_indices: List[int]) -> Dataset:
        """
        Filters the dataset based on a list of valid indices. This method is used to exclude
        data samples that have a token count exceeding the specified maximum token count.

        Args:
            dataset (Dataset): The dataset to be filtered.
            valid_indices (List[int]): Indices of samples with token counts within the limit.

        Returns:
            Dataset: Filtered dataset containing only samples with valid indices.
        """
        return dataset['train'].select(valid_indices)  # Select only samples with valid indices based on token count.

    def get_token_counts(self, dataset: Dataset) -> List[int]:
        """
        Calculates and returns the token counts for each sample in the dataset.
        This function assumes the dataset has a 'train' split and a 'text' field.

        Args:
            dataset (Dataset): The dataset for which to count tokens.

        Returns:
            List[int]: List of token counts per sample in the dataset.
        """

        if 'train' in dataset:
            return [len(self.tokenizer.tokenize(s)) for s in dataset["train"]["text"]]
        else:
            # After filtering the samples with unacceptable token count, the dataset is 
            # already `dataset = dataset['train']`. 
            return [len(self.tokenizer.tokenize(s)) for s in dataset["text"]]

    def prepare_dataset(self) -> Tuple[Dataset, Dataset]:
        """
        Prepares the dataset for fine-tuning by tokenizing the data and filtering out samples
        that exceed the maximum used context window (configurable through max_token_count).
        It also visualizes the token count distribution before and after filtering.

        Returns:
            Tuple[Dataset, Dataset]: The train and evaluate datasets, post-filtering.
        """

        dataset = self.load_dataset()
        self.load_llm_tokenizer()

        # Count tokens in each dataset sample before filtering
        token_counts_before_filtering = self.get_token_counts(dataset)
        # Plot token count distribution before filtering for visualization.
        self.plot_tokens_count_distribution(token_counts_before_filtering, "Token Count Distribution Before Filtration")
        # Identify valid indices based on max token count.
        valid_indices = [i for i, count in enumerate(token_counts_before_filtering) if count <= self.max_token_count]
        # Filter the dataset to exclude samples with excessive token counts.
        filtered_dataset = self.filter_dataset_by_indices(dataset, valid_indices)

        token_counts_after_filtering = self.get_token_counts(filtered_dataset)
        self.plot_tokens_count_distribution(token_counts_after_filtering, "Token Count Distribution After Filtration")

        return self.split_dataset_for_train_eval(filtered_dataset)  # split the dataset into training and evaluation.

    def split_dataset_for_train_eval(self, dataset: Dataset) -> Tuple[Dataset, Dataset]:
        """
        Splits the dataset into training and evaluation datasets.

        Args:
            dataset (Dataset): The dataset to split.

        Returns:
            Tuple[Dataset, Dataset]: The split training and evaluation datasets.
        """

        split_data = dataset.train_test_split(test_size=config.TEST_SIZE, shuffle=True, seed=config.SEED)
        train_data, eval_data = split_data['train'], split_data['test']
        return train_data, eval_data

    def inspect_prepare_split_data(self) -> Tuple[Dataset, Dataset]:
        """
        Orchestrates the process of inspecting, preparing, and splitting the dataset for fine-tuning.

        Returns:
            Tuple[Dataset, Dataset]: The prepared training and evaluation datasets.
        """

        return self.prepare_dataset()


# Example usage
if __name__ == "__main__":
    
    # Please uncomment the below lines to test the data prep.
    # data_handler = FinetuningDataHandler()
    # fine_tuning_data_train, fine_tuning_data_eval = data_handler.inspect_prepare_split_data()
    # print(fine_tuning_data_train, fine_tuning_data_eval)
    pass