Ali-Forootani commited on
Commit
7d44fe9
1 Parent(s): 1957d40

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +1166 -99
README.md CHANGED
@@ -8,192 +8,1259 @@ tags: []
8
  <!-- Provide a quick summary of what the model is/does. -->
9
 
10
 
 
11
 
12
- ## Model Details
13
 
14
- ### Model Description
15
 
16
- <!-- Provide a longer summary of what this model is. -->
17
 
18
- This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
 
20
- - **Developed by:** [More Information Needed]
21
- - **Funded by [optional]:** [More Information Needed]
22
- - **Shared by [optional]:** [More Information Needed]
23
- - **Model type:** [More Information Needed]
24
- - **Language(s) (NLP):** [More Information Needed]
25
- - **License:** [More Information Needed]
26
- - **Finetuned from model [optional]:** [More Information Needed]
27
 
28
- ### Model Sources [optional]
29
 
30
- <!-- Provide the basic links for the model. -->
31
 
32
- - **Repository:** [More Information Needed]
33
- - **Paper [optional]:** [More Information Needed]
34
- - **Demo [optional]:** [More Information Needed]
35
 
36
- ## Uses
 
 
 
37
 
38
- <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
 
40
- ### Direct Use
41
 
42
- <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
 
44
- [More Information Needed]
45
 
46
- ### Downstream Use [optional]
 
47
 
48
- <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
 
49
 
50
- [More Information Needed]
 
51
 
52
- ### Out-of-Scope Use
 
 
 
 
53
 
54
- <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
 
56
- [More Information Needed]
 
 
 
 
 
 
57
 
58
- ## Bias, Risks, and Limitations
 
59
 
60
- <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
 
62
- [More Information Needed]
 
 
63
 
64
- ### Recommendations
65
 
66
- <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
 
67
 
68
- Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
 
70
- ## How to Get Started with the Model
71
 
72
- Use the code below to get started with the model.
 
73
 
74
- [More Information Needed]
 
75
 
76
- ## Training Details
 
77
 
78
- ### Training Data
 
 
 
 
 
 
79
 
80
- <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
 
81
 
82
- [More Information Needed]
 
83
 
84
- ### Training Procedure
85
 
86
- <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
 
88
- #### Preprocessing [optional]
 
89
 
90
- [More Information Needed]
91
 
92
 
93
- #### Training Hyperparameters
94
 
95
- - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
 
 
 
 
 
 
96
 
97
- #### Speeds, Sizes, Times [optional]
98
 
99
- <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
 
 
 
 
 
100
 
101
- [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- ## Evaluation
 
 
104
 
105
- <!-- This section describes the evaluation protocols and provides the results. -->
 
 
 
 
106
 
107
- ### Testing Data, Factors & Metrics
108
 
109
- #### Testing Data
 
110
 
111
- <!-- This should link to a Dataset Card if possible. -->
 
112
 
113
- [More Information Needed]
114
 
115
- #### Factors
116
 
117
- <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
 
119
- [More Information Needed]
120
 
121
- #### Metrics
122
 
123
- <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
 
125
- [More Information Needed]
 
126
 
127
- ### Results
128
 
129
- [More Information Needed]
130
 
131
- #### Summary
132
 
133
 
 
134
 
135
- ## Model Examination [optional]
 
136
 
137
- <!-- Relevant interpretability work for the model goes here -->
 
 
138
 
139
- [More Information Needed]
 
 
 
 
140
 
141
- ## Environmental Impact
 
 
 
 
 
 
 
142
 
143
- <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
 
144
 
145
- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
 
147
- - **Hardware Type:** [More Information Needed]
148
- - **Hours used:** [More Information Needed]
149
- - **Cloud Provider:** [More Information Needed]
150
- - **Compute Region:** [More Information Needed]
151
- - **Carbon Emitted:** [More Information Needed]
152
 
153
- ## Technical Specifications [optional]
154
 
155
- ### Model Architecture and Objective
 
 
 
 
156
 
157
- [More Information Needed]
158
 
159
- ### Compute Infrastructure
160
 
161
- [More Information Needed]
162
 
163
- #### Hardware
164
 
165
- [More Information Needed]
 
 
 
 
166
 
167
- #### Software
168
 
169
- [More Information Needed]
170
 
171
- ## Citation [optional]
172
 
173
- <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
 
175
- **BibTeX:**
176
 
177
- [More Information Needed]
 
 
178
 
179
- **APA:**
 
 
 
 
 
 
 
180
 
181
- [More Information Needed]
182
 
183
- ## Glossary [optional]
184
 
185
- <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
 
187
- [More Information Needed]
188
 
189
- ## More Information [optional]
190
 
191
- [More Information Needed]
192
 
193
- ## Model Card Authors [optional]
194
 
195
- [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- ## Model Card Contact
198
 
199
- [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  <!-- Provide a quick summary of what the model is/does. -->
9
 
10
 
11
+ In this repositoty we fine tuned Llava [link](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)
12
 
13
+ LLaVA (Large Language and Vision Assistant) models are a type of artificial intelligence that combines language understanding with visual perception. These models are designed to process and understand both text and images, allowing them to perform tasks that require interpreting visual information and responding in natural language.
14
 
15
+ Key features of LLaVA models include:
16
 
17
+ 1. Multimodal capabilities: They can analyze images and respond to questions or prompts about them in natural language.
18
 
19
+ 2. Visual grounding: LLaVA models can connect language concepts to visual elements in images.
20
 
21
+ 3. Task versatility: They can be used for various tasks like visual question answering, image captioning, and visual reasoning.
 
 
 
 
 
 
22
 
23
+ 4. Foundation model integration: LLaVA builds upon large language models, extending their capabilities to include visual understanding.
24
 
25
+ LLaVA models represent an important step in developing AI systems that can interact with the world more comprehensively, bridging the gap between language and visual perception.
26
 
27
+ Would you like me to elaborate on any specific aspect of LLaVA models, such as their architecture, training process, or potential applications?
 
 
28
 
29
+ ## what do you find in this README?
30
+ 1. how to use this fine tuned model
31
+ 2. how I trained the Llave model of the dataset
32
+ 3. how I tested it locally and pushed it into huggingface
33
 
 
34
 
 
35
 
36
+ ## Dataset
37
 
38
+ The dataset that we consider to fine tune themodel is [link](https://huggingface.co/datasets/naver-clova-ix/cord-v1)"naver-clova-ix/cord-v1" that you can find it in the dataset huggingface.
39
 
40
+ # 1. How to use the fine tunned model
41
+ ```python
42
 
43
+ from transformers import AutoProcessor, BitsAndBytesConfig, LlavaNextForConditionalGeneration
44
+ import torch
45
 
46
+ import sys
47
+ import os
48
 
49
+ import lightning as L
50
+ from torch.utils.data import DataLoader
51
+ import re
52
+ from nltk import edit_distance
53
+ import numpy as np
54
 
 
55
 
56
+ def setting_directory(depth):
57
+ current_dir = os.path.abspath(os.getcwd())
58
+ root_dir = current_dir
59
+ for i in range(depth):
60
+ root_dir = os.path.abspath(os.path.join(root_dir, os.pardir))
61
+ sys.path.append(os.path.dirname(root_dir))
62
+ return root_dir
63
 
64
+ root_dir = setting_directory(1)
65
+ epochs = 100
66
 
 
67
 
68
+ model_name = "Ali-Forootani/llava-v1.6-mistral-7b-hf_100epochs_fine_tune"
69
+ processor = AutoProcessor.from_pretrained(model_name)
70
+ model = LlavaNextForConditionalGeneration.from_pretrained(model_name)
71
 
72
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
 
74
+ model.eval()
75
+ model = model.to(device)
76
 
 
77
 
 
78
 
79
+ from datasets import load_dataset
80
+ dataset = load_dataset("naver-clova-ix/cord-v2")
81
 
82
+ #You can save the model in the local directory as well
83
+ dataset.save_to_disk("/data/bio-eng-llm/llm_repo/naver-clova-ix/cord-v2")
84
 
85
+ test_example = dataset["test"][3]
86
+ test_image = test_example["image"]
87
 
88
+ MAX_LENGTH = 256 # or any other suitable value
89
+ #prepare image and prompt for the model
90
+ #To do this can be replaced by apply_chat_template when the processor supports this
91
+ prompt = f"[INST] <image>\nExtract JSON [\INST]"
92
+ inputs = processor(text=prompt, images=[test_image], return_tensors="pt").to("cuda")
93
+ for k,v in inputs.items():
94
+ print(k,v.shape)
95
 
96
+ # Generate token IDs
97
+ generated_ids = model.generate(**inputs, max_new_tokens=MAX_LENGTH)
98
 
99
+ # Decode back into text
100
+ generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
101
 
102
+ print(generated_texts)
103
 
 
104
 
105
+ #######################################
106
+ ####################################### You can make the output nicer
107
 
 
108
 
109
 
110
+ import re
111
 
112
+ # let's turn that into JSON
113
+ def token2json(tokens, is_inner_value=False, added_vocab=None):
114
+ """
115
+ Convert a (generated) token sequence into an ordered JSON format.
116
+ """
117
+ if added_vocab is None:
118
+ added_vocab = processor.tokenizer.get_added_vocab()
119
 
120
+ output = {}
121
 
122
+ while tokens:
123
+ start_token = re.search(r"<s_(.*?)>", tokens, re.IGNORECASE)
124
+ if start_token is None:
125
+ break
126
+ key = start_token.group(1)
127
+ key_escaped = re.escape(key)
128
 
129
+ end_token = re.search(rf"</s_{key_escaped}>", tokens, re.IGNORECASE)
130
+ start_token = start_token.group()
131
+ if end_token is None:
132
+ tokens = tokens.replace(start_token, "")
133
+ else:
134
+ end_token = end_token.group()
135
+ start_token_escaped = re.escape(start_token)
136
+ end_token_escaped = re.escape(end_token)
137
+ content = re.search(
138
+ f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE | re.DOTALL
139
+ )
140
+ if content is not None:
141
+ content = content.group(1).strip()
142
+ if r"<s_" in content and r"</s_" in content: # non-leaf node
143
+ value = token2json(content, is_inner_value=True, added_vocab=added_vocab)
144
+ if value:
145
+ if len(value) == 1:
146
+ value = value[0]
147
+ output[key] = value
148
+ else: # leaf nodes
149
+ output[key] = []
150
+ for leaf in content.split(r"<sep/>"):
151
+ leaf = leaf.strip()
152
+ if leaf in added_vocab and leaf[0] == "<" and leaf[-2:] == "/>":
153
+ leaf = leaf[1:-2] # for categorical special tokens
154
+ output[key].append(leaf)
155
+ if len(output[key]) == 1:
156
+ output[key] = output[key][0]
157
 
158
+ tokens = tokens[tokens.find(end_token) + len(end_token) :].strip()
159
+ if tokens[:6] == r"<sep/>": # non-leaf nodes
160
+ return [output] + token2json(tokens[6:], is_inner_value=True, added_vocab=added_vocab)
161
 
162
+ if len(output):
163
+ return [output] if is_inner_value else output
164
+ else:
165
+ return [] if is_inner_value else {"text_sequence": tokens}
166
+
167
 
 
168
 
169
+ generated_json = token2json(generated_texts[0])
170
+ print(generated_json)
171
 
172
+ for key, value in generated_json.items():
173
+ print(key, value)
174
 
175
+ ```
176
 
 
177
 
178
+ # 2. How to fine-tune LLaVa for document parsing (PDF -> JSON)
179
 
180
+ In this notebook, we are going to fine-tune the [LLaVa](https://huggingface.co/docs/transformers/main/en/model_doc/llava) model for a document AI use case. LLaVa is one of the better open-source multimodal models at the time of writing (there's already a successor called [LLaVa-NeXT](https://huggingface.co/docs/transformers/main/en/model_doc/llava_next)). As we'll see, fine-tuning these various models is pretty similar as their API is mostly the same.
181
 
182
+ The goal for the model in this notebook is to generate a JSON that contains key fields (like food items and their corresponding prices) from receipts. We will fine-tune LLaVa on the [CORD](https://huggingface.co/datasets/naver-clova-ix/cord-v2) dataset, which contains (receipt image, ground truth JSON) pairs.
183
 
184
+ Sources:
185
 
186
+ * LLaVa [documentation](https://huggingface.co/docs/transformers/main/en/model_doc/llava)
187
+ * LLaVa [models on the hub](https://huggingface.co/llava-hf)
188
 
 
189
 
190
+ ## Define variables and importing moduls
191
 
192
+ We'll first set some variables useful througout this tutorial.
193
 
194
 
195
+ ```python
196
 
197
+ from transformers import AutoProcessor, BitsAndBytesConfig, LlavaNextForConditionalGeneration
198
+ from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
199
 
200
+ import torch
201
+ import sys
202
+ import os
203
 
204
+ import lightning as L
205
+ from torch.utils.data import DataLoader
206
+ import re
207
+ from nltk import edit_distance
208
+ import numpy as np
209
 
210
+ # if you would like to set the directory you can use this piece of code
211
+ def setting_directory(depth):
212
+ current_dir = os.path.abspath(os.getcwd())
213
+ root_dir = current_dir
214
+ for i in range(depth):
215
+ root_dir = os.path.abspath(os.path.join(root_dir, os.pardir))
216
+ sys.path.append(os.path.dirname(root_dir))
217
+ return root_dir
218
 
219
+ root_dir = setting_directory(1)
220
+ epochs = 100
221
 
 
222
 
 
 
 
 
 
223
 
 
224
 
225
+ import lightning as L
226
+ from torch.utils.data import DataLoader
227
+ import re
228
+ from nltk import edit_distance
229
+ import numpy as np
230
 
231
+ ##############################
232
 
 
233
 
234
+ MAX_LENGTH = 256
235
 
236
+ # MODEL_ID = "llava-hf/llava-v1.6-mistral-7b-hf"
237
 
238
+ MODEL_ID = "/data/bio-eng-llm/llm_repo/llava-hf/llava-v1.6-mistral-7b-hf"
239
+ REPO_ID = "YOUR-HUB-REPO-TO-PUSH"
240
+ WANDB_PROJECT = "LLaVaNeXT"
241
+ WANDB_NAME = "llava-next-demo-cord"
242
+ ```
243
 
 
244
 
245
+ ## Load dataset
246
 
247
+ Let's start by loading the dataset from the hub. Here we use the [CORD](https://huggingface.co/datasets/naver-clova-ix/cord-v2) dataset, created by the [Donut](https://huggingface.co/docs/transformers/en/model_doc/donut) authors (Donut is another powerful - but slightly undertrained document AI model available in the Transformers library). CORD is an important benchmark for receipt understanding. The Donut authors have prepared it in a format that suits vision-language models: we're going to fine-tune it to generate the JSON given the image.
248
 
249
+ If you want to load your own custom dataset, check out this guide: https://huggingface.co/docs/datasets/image_dataset.
250
 
 
251
 
252
+ ```python
253
+ from datasets import load_dataset
254
+ dataset = load_dataset("naver-clova-ix/cord-v2")
255
 
256
+ #see one image as an example
257
+ example = dataset['train'][0]
258
+ image = example["image"]
259
+ # resize image for smaller displaying
260
+ width, height = image.size
261
+ image = image.resize((int(0.3*width), int(0.3*height)))
262
+ print(image)
263
+ ```
264
 
 
265
 
266
+ ## Load processor
267
 
268
+ Next, we'll load the processor which is used to prepare the data in the format that the model expects. Neural networks like LLaVa don't directly take images and text as input, but rather `pixel_values` (which is a resized, rescaled, normalized and optionally splitted version of the receipt images), `input_ids` (which are text token indices in the vocabulary of the model), etc. This is handled by the processor.
269
 
270
+ ### Image resolution
271
 
272
+ The image resolution at which multimodal models are trained greatly has an impact on performance. One of the shortcomings of LLaVa is that it uses a fairly low image resolution (336x336). Newer models like LLaVa-NeXT and Idefics2 use a much higher image resolution enabling the model to "see" a lot more details in the image (which improves its OCR performance among other things). On the other hand, using a bigger image resolution comes at a cost of much higher memory requirements and longer training times. This is less of an issue with LLaVa due to its relatively small image resolution.
273
 
274
+ ## Load model
275
 
276
+ Next, we're going to load the LLaVa model from the [hub](https://huggingface.co/llava-hf/llava-1.5-7b-hf). This is a model with about 7 billion trainable parameters (as it combines a LLaMa-7B language model with a relatively low-parameter vision encoder). Do note that we load a model here which already has undergone supervised fine-tuning (SFT) on the [LLaVa-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) instruction dataset. We can benefit from the fine-tuning that the model already has undergone.
277
 
278
+ ### Full fine-tuning, LoRa and Q-LoRa
279
+
280
+ As this model has 7 billion trainable parameters, that's going to have quite an impact on the amount of memory used. For reference, fine-tuning a model using the [AdamW optimizer](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html#torch.optim.AdamW) (which is often used to optimize neural networks) with mixed precision, you need about 18 times the amount of parameters in GB of GPU RAM. So in this case, we would need 18x7 billion bytes = 126 GB of GPU RAM if we want to update all the parameters of the model!! That's huge right? And for most people infeasible.
281
+
282
+ Luckily, some clever people came up with the [LoRa](https://huggingface.co/docs/peft/main/en/conceptual_guides/lora) method (LoRa is short for low-rank adapation). It allows to just freeze the existing weights and only train a couple of adapter layers on top of the base model. Hugging Face offers the separate [PEFT library](https://huggingface.co/docs/peft/main/en/index) for easy use of LoRa, along with other Parameter-Efficient Fine-Tuning methods (that's where the name PEFT comes from).
283
+
284
+ Moreover, one can not only freeze the existing base model but also quantize it (which means, shrinking down its size). A neural network's parameters are typically saved in either float32 (which means, 32 bits or 4 bytes are used to store each parameter value) or float16 (which means, 16 bits or half a byte - also called half precision). However, with some clever algorithms one can shrink each parameter to just 8 or 4 bits (half a byte!), without significant effect on final performance. Read all about it here: https://huggingface.co/blog/4bit-transformers-bitsandbytes.
285
+
286
+ This means that we're going to shrink the size of the base Idefics2-8b model considerably using 4-bit quantization, and then only train a couple of adapter layers on top using LoRa (in float16). This idea of combining LoRa with quantization is called Q-LoRa and is the most memory friendly version.
287
+
288
+ Of course, if you have the memory available, feel free to use full fine-tuning or LoRa without quantization! In case of full fine-tuning, the code snippet below instantiates the model with Flash Attention which considerably speeds up computations.
289
+
290
+ There exist many forms of quantization, here we leverage the [BitsAndBytes](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig) integration.
291
+
292
+
293
+ ```python
294
+
295
+ from transformers import BitsAndBytesConfig, LlavaNextForConditionalGeneration
296
+ import torch
297
+
298
+ USE_LORA = False
299
+ USE_QLORA = True
300
+
301
+
302
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
303
+
304
+
305
+ ## Load model
306
+
307
+ # Three options for training, from the lowest precision training to the highest precision training:
308
+ # - QLora
309
+ # - Standard Lora
310
+ # - Full fine-tuning
311
+ if USE_QLORA or USE_LORA:
312
+ if USE_QLORA:
313
+ bnb_config = BitsAndBytesConfig(
314
+ load_in_4bit= True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, device = device,
315
+ )
316
+ model = LlavaNextForConditionalGeneration.from_pretrained(
317
+ MODEL_ID,
318
+ torch_dtype=torch.float16,
319
+ quantization_config=bnb_config,
320
+ )
321
+ else:
322
+ # for full fine-tuning, we can speed up the model using Flash Attention
323
+ # only available on certain devices, see https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features
324
+ model = LlavaNextForConditionalGeneration.from_pretrained(
325
+ MODEL_ID,
326
+ torch_dtype=torch.float16,
327
+ _attn_implementation="flash_attention_2",
328
+ )
329
+
330
+ ```
331
+
332
+ ## Apply PEFT
333
+
334
+ After loading the base model, we're going to add LoRa adapter layers. We're going to only train these adapter layers (the base model is kept frozen).
335
+
336
+ The difference here with other models are the layers at which we're going to add adapters (in PEFT this is called `target_modules`). This typically depends a bit on the model.
337
+
338
+ Here, I based myself off the original `find_all_linear_names` [function](https://github.com/haotian-liu/LLaVA/blob/ec3a32ddea47d8739cb6523fb2661b635c15827e/llava/train/train.py#L169) found in the original LLaVa repository. It means that we're going to add adapters to all linear layers of the model (`nn.Linear`), except for the ones present in the vision encoder and multimodal projector.
339
+ This means that we're mostly going to adapt the language model part of LLaVa for our use case.
340
+
341
+
342
+ ```python
343
+ from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
344
+
345
+
346
+ def find_all_linear_names(model):
347
+ cls = torch.nn.Linear
348
+ lora_module_names = set()
349
+ multimodal_keywords = ['multi_modal_projector', 'vision_model']
350
+ for name, module in model.named_modules():
351
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
352
+ continue
353
+ if isinstance(module, cls):
354
+ names = name.split('.')
355
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
356
+
357
+ if 'lm_head' in lora_module_names: # needed for 16-bit
358
+ lora_module_names.remove('lm_head')
359
+ return list(lora_module_names)
360
+
361
+
362
+ lora_config = LoraConfig(
363
+ r=8,
364
+ lora_alpha=8,
365
+ lora_dropout=0.1,
366
+ target_modules=find_all_linear_names(model),
367
+ init_lora_weights="gaussian",
368
+ )
369
+
370
+ model = prepare_model_for_kbit_training(model)
371
+ model = get_peft_model(model, lora_config)
372
+ ```
373
+
374
+ ## Create PyTorch dataset
375
+
376
+ Next we'll create a regular [PyTorch dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) which defines the individual items of the dataset. For that, one needs to implement 3 methods: an `init` method, a `len` method (which returns the length of the dataset) and a `getitem` method (which returns items of the dataset).
377
+
378
+ The `init` method goes over all the ground truth JSON sequences and turns them into token sequences (which we want the model to generate) using the `json2token` method. Unlike in my Donut and Idefics2 notebooks, we're not going to add special tokens to the model's vocabulary to omit complexity. Feel free to check them out, I haven't ablated whether adding special tokens gives a big boost in performance.
379
+
380
+ Typically, one uses the processor in the `getitem` method to prepare the data in the format that the model expects, but we'll postpone that here for a reason we'll explain later. In our case we're just going to return 2 things: the image and a corresponding ground truth token sequence.
381
+
382
+
383
+
384
+
385
+ ```python
386
+ from torch.utils.data import Dataset
387
+ from typing import Any, Dict
388
+ import random
389
+
390
+ class LlavaDataset(Dataset):
391
+ """
392
+ PyTorch Dataset for LLaVa. This class takes a HuggingFace Dataset as input.
393
+
394
+ Each row, consists of image path(png/jpg/jpeg) and ground truth data (json/jsonl/txt).
395
+ """
396
+
397
+ def __init__(
398
+ self,
399
+ dataset_name_or_path: str,
400
+ split: str = "train",
401
+ sort_json_key: bool = True,
402
+ ):
403
+ super().__init__()
404
+
405
+ self.split = split
406
+ self.sort_json_key = sort_json_key
407
+
408
+ self.dataset = load_dataset(dataset_name_or_path, split=self.split)
409
+ self.dataset_length = len(self.dataset)
410
+
411
+ self.gt_token_sequences = []
412
+ for sample in self.dataset:
413
+ ground_truth = json.loads(sample["ground_truth"])
414
+ if "gt_parses" in ground_truth: # when multiple ground truths are available, e.g., docvqa
415
+ assert isinstance(ground_truth["gt_parses"], list)
416
+ gt_jsons = ground_truth["gt_parses"]
417
+ else:
418
+ assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict)
419
+ gt_jsons = [ground_truth["gt_parse"]]
420
+
421
+ self.gt_token_sequences.append(
422
+ [
423
+ self.json2token(
424
+ gt_json,
425
+ sort_json_key=self.sort_json_key,
426
+ )
427
+ for gt_json in gt_jsons # load json from list of json
428
+ ]
429
+ )
430
+
431
+ def json2token(self, obj: Any, sort_json_key: bool = True):
432
+ """
433
+ Convert an ordered JSON object into a token sequence
434
+ """
435
+ if type(obj) == dict:
436
+ if len(obj) == 1 and "text_sequence" in obj:
437
+ return obj["text_sequence"]
438
+ else:
439
+ output = ""
440
+ if sort_json_key:
441
+ keys = sorted(obj.keys(), reverse=True)
442
+ else:
443
+ keys = obj.keys()
444
+ for k in keys:
445
+ output += (
446
+ fr"<s_{k}>"
447
+ + self.json2token(obj[k], sort_json_key)
448
+ + fr"</s_{k}>"
449
+ )
450
+ return output
451
+ elif type(obj) == list:
452
+ return r"<sep/>".join(
453
+ [self.json2token(item, sort_json_key) for item in obj]
454
+ )
455
+ else:
456
+ obj = str(obj)
457
+ return obj
458
+
459
+ def __len__(self) -> int:
460
+ return self.dataset_length
461
+
462
+ def __getitem__(self, idx: int) -> Dict:
463
+ """
464
+ Returns one item of the dataset.
465
+
466
+ Returns:
467
+ image : the original Receipt image
468
+ target_sequence : tokenized ground truth sequence
469
+ """
470
+ sample = self.dataset[idx]
471
+
472
+ # inputs
473
+ image = sample["image"]
474
+ target_sequence = random.choice(self.gt_token_sequences[idx]) # can be more than one, e.g., DocVQA Task 1
475
+
476
+ return image, target_sequence
477
+
478
+ ########################################
479
+
480
+ ##################### If you want to choose a few number of dataset! ##################
481
+ class LlavaDataset2(Dataset):
482
+ """
483
+ PyTorch Dataset for LLaVa. This class takes a HuggingFace Dataset as input.
484
+
485
+ Each row, consists of image path(png/jpg/jpeg) and ground truth data (json/jsonl/txt).
486
+ """
487
+
488
+ def __init__(
489
+ self,
490
+ dataset_name_or_path: str,
491
+ split: str = "train",
492
+ sort_json_key: bool = True,
493
+ num_samples: int = None
494
+ ):
495
+ super().__init__()
496
+
497
+ self.split = split
498
+ self.sort_json_key = sort_json_key
499
+
500
+ self.dataset = load_dataset(dataset_name_or_path, split=self.split)
501
+ self.dataset_length = len(self.dataset)
502
+
503
+ # If num_samples is specified and is less than the dataset length, select a subset
504
+ if num_samples is not None and num_samples < self.dataset_length:
505
+ indices = random.sample(range(self.dataset_length), num_samples)
506
+ self.dataset = self.dataset.select(indices)
507
+ self.dataset_length = num_samples
508
+
509
+ self.gt_token_sequences = []
510
+ for sample in self.dataset:
511
+ ground_truth = json.loads(sample["ground_truth"])
512
+ if "gt_parses" in ground_truth: # when multiple ground truths are available, e.g., docvqa
513
+ assert isinstance(ground_truth["gt_parses"], list)
514
+ gt_jsons = ground_truth["gt_parses"]
515
+ else:
516
+ assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict)
517
+ gt_jsons = [ground_truth["gt_parse"]]
518
+
519
+ self.gt_token_sequences.append(
520
+ [
521
+ self.json2token(
522
+ gt_json,
523
+ sort_json_key=self.sort_json_key,
524
+ )
525
+ for gt_json in gt_jsons # load json from list of json
526
+ ]
527
+ )
528
+
529
+ def json2token(self, obj: Any, sort_json_key: bool = True):
530
+ """
531
+ Convert an ordered JSON object into a token sequence
532
+ """
533
+ if isinstance(obj, dict):
534
+ if len(obj) == 1 and "text_sequence" in obj:
535
+ return obj["text_sequence"]
536
+ else:
537
+ output = ""
538
+ keys = sorted(obj.keys(), reverse=True) if sort_json_key else obj.keys()
539
+ for k in keys:
540
+ output += (
541
+ fr"<s_{k}>"
542
+ + self.json2token(obj[k], sort_json_key)
543
+ + fr"</s_{k}>"
544
+ )
545
+ return output
546
+ elif isinstance(obj, list):
547
+ return r"<sep/>".join(
548
+ [self.json2token(item, sort_json_key) for item in obj]
549
+ )
550
+ else:
551
+ return str(obj)
552
+
553
+ def __len__(self) -> int:
554
+ return self.dataset_length
555
+
556
+ def __getitem__(self, idx: int) -> Dict:
557
+ """
558
+ Returns one item of the dataset.
559
+
560
+ Returns:
561
+ image : the original Receipt image
562
+ target_sequence : tokenized ground truth sequence
563
+ """
564
+ sample = self.dataset[idx]
565
+
566
+ # inputs
567
+ image = sample["image"]
568
+ target_sequence = random.choice(self.gt_token_sequences[idx]) # can be more than one, e.g., DocVQA Task 1
569
+
570
+ return image, target_sequence
571
+
572
+
573
+
574
+ train_dataset = LlavaDataset2("naver-clova-ix/cord-v2", split="train",
575
+ sort_json_key=False,
576
+ num_samples=100
577
+ )
578
+
579
+ val_dataset = LlavaDataset2("naver-clova-ix/cord-v2", split="validation",
580
+ sort_json_key=False,
581
+ num_samples=100
582
+ )
583
+
584
+ ########################################
585
+
586
+
587
+
588
+ train_dataset = LlavaDataset("naver-clova-ix/cord-v2", split="train", sort_json_key=False)
589
+ val_dataset = LlavaDataset("naver-clova-ix/cord-v2", split="validation", sort_json_key=False)
590
+
591
+
592
+
593
+ train_example = train_dataset[0]
594
+ image, target_sequence = train_example
595
+ print(target_sequence)
596
+ ```
597
+
598
+
599
+ ## Define collate functions
600
+
601
+ Now that we have PyTorch datasets, we'll define a so-called collators which define how items of the dataset should be batched together. This is because we typically train neural networks on batches of data (i.e. various images/target sequences combined) rather than one-by-one, using a variant of stochastic-gradient descent or SGD (like Adam, AdamW, etc.).
602
+
603
+ It's only here that we're going to use the processor to turn the (image, target token sequence) into the format that the model expects (which is `pixel_values`, `input_ids` etc.). The reason we do that here is because it allows for **dynamic padding** of the batches: each batch contains ground truth sequences of varying lengths. By only using the processor here, we will pad the `input_ids` up to the largest sequence in the batch.
604
+
605
+ We also decide to limit the length of the text tokens (`input_ids`) to a max length due to memory constraints, feel free to expand if your target token sequences are longer (I'd recommend plotting the average token length of your dataset to determine the optimal value).
606
+
607
+ The formatting of the `input_ids` is super important: we need to respect a so-called [chat template](https://huggingface.co/docs/transformers/main/en/chat_templating). As of now, LLaVa does not yet support chat templates, so we manually write down the prompt in the correct format (which starts with USER and ends with ASSISTANT). I'll update my notebook when it is supported. We use the text prompt "Extract JSON", this is just a deliberate choice, you could also omit this and just train the model on (image, JSON) pairs without text prompt.
608
+
609
+ Labels are created for the model by simply copying the inputs to the LLM (`input_ids`), but with padding tokens replaced by the ignore index of the loss function. This ensures that the model doesn't need to learn to predict padding tokens (used to batch examples together).
610
+
611
+ Why are the labels a copy of the model inputs, you may ask? The model will internally shift the labels one position to the right so that the model will learn to predict the next token. This can be seen [here](https://github.com/huggingface/transformers/blob/6f465d45d98f9eaeef83cfdfe79aecc7193b0f1f/src/transformers/models/idefics2/modeling_idefics2.py#L1851-L1855).
612
+
613
+ The collate function for evaluation is different, since there we only need to feed the prompt to the model, as we'll use the `generate()` method to autoregressively generate a completion.
614
+
615
+
616
+
617
+
618
+ ```python
619
+ def train_collate_fn(examples):
620
+ images = []
621
+ texts = []
622
+ for example in examples:
623
+ image, ground_truth = example
624
+ images.append(image)
625
+ # TODO: in the future we can replace this by processor.apply_chat_template
626
+ prompt = f"[INST] <image>\nExtract JSON [\INST] {ground_truth}"
627
+ texts.append(prompt)
628
+
629
+ batch = processor(text=texts, images=images, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
630
+
631
+ labels = batch["input_ids"].clone()
632
+ labels[labels == processor.tokenizer.pad_token_id] = -100
633
+ batch["labels"] = labels
634
+
635
+ input_ids = batch["input_ids"]
636
+ attention_mask = batch["attention_mask"]
637
+ pixel_values = batch["pixel_values"]
638
+ image_sizes = batch["image_sizes"]
639
+ labels = batch["labels"]
640
+
641
+ return input_ids, attention_mask, pixel_values, image_sizes, labels
642
+
643
+
644
+ def eval_collate_fn(examples):
645
+ # we only feed the prompt to the model
646
+ images = []
647
+ texts = []
648
+ answers = []
649
+ for example in examples:
650
+ image, ground_truth = example
651
+ images.append(image)
652
+ # TODO: in the future we can replace this by processor.apply_chat_template
653
+ prompt = f"[INST] <image>\nExtract JSON [\INST]"
654
+ texts.append(prompt)
655
+ answers.append(ground_truth)
656
+
657
+ batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
658
+
659
+ input_ids = batch["input_ids"]
660
+ attention_mask = batch["attention_mask"]
661
+ pixel_values = batch["pixel_values"]
662
+ image_sizes = batch["image_sizes"]
663
+
664
+ return input_ids, attention_mask, pixel_values, image_sizes, answers
665
+ ```
666
+
667
+ ## Define PyTorch LightningModule
668
+
669
+ There are various ways to train a PyTorch model: one could just use native PyTorch, use the [Trainer API](https://huggingface.co/docs/transformers/en/main_classes/trainer) or frameworks like [Accelerate](https://huggingface.co/docs/accelerate/en/index). In this notebook, I'll use PyTorch Lightning as it allows to easily compute evaluation metrics during training.
670
+
671
+ Below, we define a [LightningModule](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html), which is the standard way to train a model in PyTorch Lightning. A LightningModule is an `nn.Module` with some additional functionality.
672
+
673
+ Basically, PyTorch Lightning will take care of all device placements (`.to(device)`) for us, as well as the backward pass, putting the model in training mode, etc.
674
+
675
+ Notice the difference between a training step and an evaluation step:
676
+
677
+ - a training step only consists of a forward pass, in which we compute the cross-entropy loss between the model's next token predictions and the ground truth (in parallel for all tokens, this technique is known as "teacher forcing"). The backward pass is handled by PyTorch Lightning.
678
+ - an evaluation step consists of making the model autoregressively complete the prompt using the [`generate()`](https://huggingface.co/docs/transformers/v4.40.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) method. After that, we compute an evaluation metric between the predicted sequences and the ground truth ones. This allows us to see how the model is improving over the course of training. The metric we use here is the so-called [Levenhstein edit distance](https://en.wikipedia.org/wiki/Levenshtein_distance). This quantifies how much we would need to edit the predicted token sequence to get the target sequence (the fewer edits the better!). Its optimal value is 0 (which means, no edits need to be made).
679
+
680
+ Besides that, we define the optimizer to use ([AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html) is a good default choice) and the data loaders, which use the collate functions defined above to batch together items of the PyTorch datasets. Do note that AdamW is a pretty heavy optimizer in terms of memory requirements, but as we're training with QLoRa we only need to store optimizer states for the adapter layers. For full fine-tuning, one could take a look at more memory friendly optimizers such as [8-bit Adam](https://huggingface.co/docs/bitsandbytes/main/en/optimizers).
681
+
682
+
683
+ ```python
684
+ import lightning as L
685
+ from torch.utils.data import DataLoader
686
+ import re
687
+ from nltk import edit_distance
688
+ import numpy as np
689
+
690
+
691
+ class LlavaModelPLModule(L.LightningModule):
692
+ def __init__(self, config, processor, model):
693
+ super().__init__()
694
+ self.config = config
695
+ self.processor = processor
696
+ self.model = model
697
+
698
+ self.batch_size = config.get("batch_size")
699
+
700
+ def training_step(self, batch, batch_idx):
701
+
702
+ input_ids, attention_mask, pixel_values, image_sizes, labels = batch
703
+
704
+ outputs = self.model(input_ids=input_ids,
705
+ attention_mask=attention_mask,
706
+ pixel_values=pixel_values,
707
+ image_sizes=image_sizes,
708
+ labels=labels
709
+ )
710
+ loss = outputs.loss
711
+
712
+ self.log("train_loss", loss)
713
+
714
+ return loss
715
+
716
+ def validation_step(self, batch, batch_idx, dataset_idx=0):
717
+
718
+ input_ids, attention_mask, pixel_values, image_sizes, answers = batch
719
+
720
+ # autoregressively generate token IDs
721
+ generated_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_mask,
722
+ pixel_values=pixel_values, image_sizes=image_sizes, max_new_tokens=MAX_LENGTH)
723
+ # turn them back into text, chopping of the prompt
724
+ # important: we don't skip special tokens here, because we want to see them in the output
725
+ predictions = self.processor.batch_decode(generated_ids[:, input_ids.size(1):], skip_special_tokens=True)
726
+
727
+ scores = []
728
+ for pred, answer in zip(predictions, answers):
729
+ pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
730
+ scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))
731
+
732
+ if self.config.get("verbose", False) and len(scores) == 1:
733
+ print(f"Prediction: {pred}")
734
+ print(f" Answer: {answer}")
735
+ print(f" Normed ED: {scores[0]}")
736
+
737
+ self.log("val_edit_distance", np.mean(scores))
738
+
739
+ return scores
740
+
741
+ def configure_optimizers(self):
742
+ # you could also add a learning rate scheduler if you want
743
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.get("lr"))
744
+
745
+ return optimizer
746
+
747
+ def train_dataloader(self):
748
+ return DataLoader(train_dataset, collate_fn=train_collate_fn, batch_size=self.batch_size, shuffle=True, num_workers=4)
749
+
750
+ def val_dataloader(self):
751
+ return DataLoader(val_dataset, collate_fn=eval_collate_fn, batch_size=self.batch_size, shuffle=False, num_workers=4)
752
+
753
+
754
+ epochs = 100
755
+
756
+
757
+ config = {"max_epochs": epochs ,
758
+ # "val_check_interval": 0.2, # how many times we want to validate during an epoch
759
+ "check_val_every_n_epoch": 1,
760
+ "gradient_clip_val": 1.0,
761
+ "accumulate_grad_batches": 8,
762
+ "lr": 1e-4,
763
+ "batch_size": 1,
764
+ # "seed":2022,
765
+ "num_nodes": 1,
766
+ "warmup_steps": 50,
767
+ "result_path": "./result",
768
+ "verbose": True,
769
+ }
770
+
771
+ model_module = LlavaModelPLModule(config, processor, model)
772
+ ```
773
+
774
+ ## Define callbacks
775
+
776
+ Optionally, Lightning allows to define so-called [callbacks](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html), which are arbitrary pieces of code that can be executed during training.
777
+
778
+ Here I'm adding a `PushToHubCallback` which will push the model to the [hub](https://huggingface.co/) at the end of every epoch as well as at the end of training. Do note that you could of course also pass the `private=True` flag when pushing to the hub, if you wish to keep your model private. Hugging Face also offers the [Enterprise Hub](https://huggingface.co/enterprise) so that you can easily share models with your colleagues privately in a secure way.
779
+
780
+ We'll also use the EarlyStopping callback of Lightning, which will automatically stop training once the evaluation metric (edit distance in our case) doesn't improve after 3 epochs.
781
+
782
+
783
+ ```python
784
+ from lightning.pytorch.callbacks import Callback
785
+ from lightning.pytorch.callbacks.early_stopping import EarlyStopping
786
+
787
+ from huggingface_hub import HfApi
788
+
789
+ api = HfApi()
790
+
791
+ class PushToHubCallback(Callback):
792
+ def on_train_epoch_end(self, trainer, pl_module):
793
+ print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
794
+ pl_module.model.push_to_hub(REPO_ID,
795
+ commit_message=f"Training in progress, epoch {trainer.current_epoch}")
796
+
797
+ def on_train_end(self, trainer, pl_module):
798
+ print(f"Pushing model to the hub after training")
799
+ pl_module.processor.push_to_hub(REPO_ID,
800
+ commit_message=f"Training done")
801
+ pl_module.model.push_to_hub(REPO_ID,
802
+ commit_message=f"Training done")
803
+
804
+ early_stop_callback = EarlyStopping(monitor="val_edit_distance", patience=3, verbose=False, mode="min")
805
+ ```
806
+
807
+ ## Train!
808
+
809
+ Alright, we're set to start training! We will also pass the Weights and Biases logger so that we get see some pretty plots of our loss and evaluation metric during training (do note that you may need to log in the first time you run this, see the [docs](https://docs.wandb.ai/guides/integrations/lightning)).
810
+
811
+ Do note that this Trainer class supports many more flags! See the docs: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.trainer.trainer.Trainer.html#lightning.pytorch.trainer.trainer.Trainer.
812
+
813
+
814
+ _hint_: you may track the training on wandb, but first you should create an account and login! then use it! I did not use it so I commented in the code!
815
+ ```bash
816
+ pip install -U wandb>=0.12.10
817
+ ```
818
+
819
+ ```python
820
+ trainer = L.Trainer(
821
+ accelerator="gpu",
822
+ devices=[0],
823
+ max_epochs=config.get("max_epochs"),
824
+ accumulate_grad_batches=config.get("accumulate_grad_batches"),
825
+ check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
826
+ gradient_clip_val=config.get("gradient_clip_val"),
827
+ precision="16-mixed",
828
+ limit_val_batches=5,
829
+ num_sanity_val_steps=0,
830
+ logger=None,
831
+ #callbacks=[PushToHubCallback(), early_stop_callback],
832
+ )
833
 
834
+ trainer.fit(model_module)
835
 
836
+
837
+
838
+ ##############################################
839
+
840
+ # You can save the model in your local directory as you wish
841
+ save_dir = root_dir + f"models/fine_tuned_models/llava-v1.6-mistral-7b-hf_{epochs}e_qa_qa"
842
+ #trainer.save_model(save_dir)
843
+
844
+ trainer.save_checkpoint(f"{save_dir}/checkpoint.ckpt")
845
+
846
+ print("Saved model to:", save_dir)
847
+
848
+ ```
849
+ # 3. How to test the model locally by loading the saved checkpoint:
850
+
851
+ ```python
852
+
853
+ from transformers import AutoProcessor, BitsAndBytesConfig, LlavaNextForConditionalGeneration
854
+ import torch
855
+
856
+ import sys
857
+ import os
858
+
859
+ import lightning as L
860
+ from torch.utils.data import DataLoader
861
+ import re
862
+ from nltk import edit_distance
863
+ import numpy as np
864
+
865
+
866
+ def setting_directory(depth):
867
+ current_dir = os.path.abspath(os.getcwd())
868
+ root_dir = current_dir
869
+ for i in range(depth):
870
+ root_dir = os.path.abspath(os.path.join(root_dir, os.pardir))
871
+ sys.path.append(os.path.dirname(root_dir))
872
+ return root_dir
873
+
874
+ root_dir = setting_directory(1)
875
+ epochs = 100
876
+
877
+
878
+
879
+
880
+ import lightning as L
881
+ from torch.utils.data import DataLoader
882
+ import re
883
+ from nltk import edit_distance
884
+ import numpy as np
885
+
886
+ ##############################
887
+
888
+
889
+ MAX_LENGTH = 256
890
+
891
+ # MODEL_ID = "llava-hf/llava-v1.6-mistral-7b-hf"
892
+
893
+ MODEL_ID = "/data/bio-eng-llm/llm_repo/llava-hf/llava-v1.6-mistral-7b-hf"
894
+ REPO_ID = "YOUR-HUB-REPO-TO-PUSH"
895
+ WANDB_PROJECT = "LLaVaNeXT"
896
+ WANDB_NAME = "llava-next-demo-cord"
897
+
898
+
899
+ from transformers import AutoProcessor
900
+
901
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
902
+ processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
903
+
904
+ from transformers import BitsAndBytesConfig, LlavaNextForConditionalGeneration
905
+ import torch
906
+
907
+
908
+ from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
909
+
910
+
911
+ USE_LORA = False
912
+ USE_QLORA = True
913
+
914
+
915
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
916
+
917
+
918
+ ## Load model
919
+
920
+ # Three options for training, from the lowest precision training to the highest precision training:
921
+ # - QLora
922
+ # - Standard Lora
923
+ # - Full fine-tuning
924
+ if USE_QLORA or USE_LORA:
925
+ if USE_QLORA:
926
+ bnb_config = BitsAndBytesConfig(
927
+ load_in_4bit= True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, device = device,
928
+ )
929
+ model = LlavaNextForConditionalGeneration.from_pretrained(
930
+ MODEL_ID,
931
+ torch_dtype=torch.float16,
932
+ quantization_config=bnb_config,
933
+ )
934
+ else:
935
+ # for full fine-tuning, we can speed up the model using Flash Attention
936
+ # only available on certain devices, see https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features
937
+ model = LlavaNextForConditionalGeneration.from_pretrained(
938
+ MODEL_ID,
939
+ torch_dtype=torch.float16,
940
+ _attn_implementation="flash_attention_2",
941
+ )
942
+
943
+
944
+ def find_all_linear_names(model):
945
+ cls = torch.nn.Linear
946
+ lora_module_names = set()
947
+ multimodal_keywords = ['multi_modal_projector', 'vision_model']
948
+ for name, module in model.named_modules():
949
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
950
+ continue
951
+ if isinstance(module, cls):
952
+ names = name.split('.')
953
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
954
+
955
+ if 'lm_head' in lora_module_names: # needed for 16-bit
956
+ lora_module_names.remove('lm_head')
957
+ return list(lora_module_names)
958
+
959
+
960
+ lora_config = LoraConfig(
961
+ r=8,
962
+ lora_alpha=8,
963
+ lora_dropout=0.1,
964
+ target_modules=find_all_linear_names(model),
965
+ init_lora_weights="gaussian",
966
+ )
967
+
968
+
969
+ base_model = model
970
+
971
+ model = prepare_model_for_kbit_training(model)
972
+ model = get_peft_model(model, lora_config)
973
+
974
+
975
+ from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
976
+
977
+
978
+
979
+
980
+
981
+ ##############################
982
+
983
+
984
+ class LlavaModelPLModule(L.LightningModule):
985
+ def __init__(self, config, processor, model):
986
+ super().__init__()
987
+ self.config = config
988
+ self.processor = processor
989
+ self.model = model
990
+
991
+ self.batch_size = config.get("batch_size")
992
+
993
+ def training_step(self, batch, batch_idx):
994
+
995
+ input_ids, attention_mask, pixel_values, image_sizes, labels = batch
996
+
997
+ outputs = self.model(input_ids=input_ids,
998
+ attention_mask=attention_mask,
999
+ pixel_values=pixel_values,
1000
+ image_sizes=image_sizes,
1001
+ labels=labels
1002
+ )
1003
+ loss = outputs.loss
1004
+
1005
+ self.log("train_loss", loss)
1006
+
1007
+ return loss
1008
+
1009
+ def validation_step(self, batch, batch_idx, dataset_idx=0):
1010
+
1011
+ input_ids, attention_mask, pixel_values, image_sizes, answers = batch
1012
+
1013
+ # autoregressively generate token IDs
1014
+ generated_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_mask,
1015
+ pixel_values=pixel_values, image_sizes=image_sizes, max_new_tokens=MAX_LENGTH)
1016
+ # turn them back into text, chopping of the prompt
1017
+ # important: we don't skip special tokens here, because we want to see them in the output
1018
+ predictions = self.processor.batch_decode(generated_ids[:, input_ids.size(1):], skip_special_tokens=True)
1019
+
1020
+ scores = []
1021
+ for pred, answer in zip(predictions, answers):
1022
+ pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
1023
+ scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))
1024
+
1025
+ if self.config.get("verbose", False) and len(scores) == 1:
1026
+ print(f"Prediction: {pred}")
1027
+ print(f" Answer: {answer}")
1028
+ print(f" Normed ED: {scores[0]}")
1029
+
1030
+ self.log("val_edit_distance", np.mean(scores))
1031
+
1032
+ return scores
1033
+
1034
+ def configure_optimizers(self):
1035
+ # you could also add a learning rate scheduler if you want
1036
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.get("lr"))
1037
+
1038
+ return optimizer
1039
+
1040
+ def train_dataloader(self):
1041
+ return DataLoader(train_dataset, collate_fn=train_collate_fn, batch_size=self.batch_size, shuffle=True, num_workers=4)
1042
+
1043
+ def val_dataloader(self):
1044
+ return DataLoader(val_dataset, collate_fn=eval_collate_fn, batch_size=self.batch_size, shuffle=False, num_workers=4)
1045
+
1046
+
1047
+ from pytorch_lightning import Trainer
1048
+
1049
+
1050
+
1051
+ config = {"max_epochs": epochs ,
1052
+ # "val_check_interval": 0.2, # how many times we want to validate during an epoch
1053
+ "check_val_every_n_epoch": 1,
1054
+ "gradient_clip_val": 1.0,
1055
+ "accumulate_grad_batches": 8,
1056
+ "lr": 1e-4,
1057
+ "batch_size": 1,
1058
+ # "seed":2022,
1059
+ "num_nodes": 1,
1060
+ "warmup_steps": 50,
1061
+ "result_path": "./result",
1062
+ "verbose": True,}
1063
+
1064
+
1065
+
1066
+
1067
+ #model = LlavaModelPLModule(config, processor, model)
1068
+
1069
+
1070
+
1071
+ model_path = root_dir + f"/testing_eve_jobmodels/fine_tuned_models/llava-v1.6-mistral-7b-hf_{epochs}e_qa_qa"
1072
+
1073
+
1074
+ #checkpoint = torch.load('model_path/checkpoint.ckpt')
1075
+
1076
+ """
1077
+ model = LlavaModelPLModule.load_from_checkpoint(f"{model_path}/checkpoint.ckpt",
1078
+ config,
1079
+ processor,
1080
+ model)
1081
+ """
1082
+
1083
+
1084
+ #loading the model with the checkpoint!
1085
+
1086
+ model = LlavaModelPLModule.load_from_checkpoint(
1087
+ f"{model_path}/checkpoint.ckpt",
1088
+ hparams_file=None,
1089
+ config=config,
1090
+ processor=processor,
1091
+ model= model
1092
+ )
1093
+
1094
+
1095
+
1096
+ #model.load_state_dict(checkpoint['model_state_dict'])
1097
+
1098
+
1099
+ print(model)
1100
+
1101
+
1102
+ model.eval()
1103
+
1104
+
1105
+ model = model.to(device)
1106
+
1107
+
1108
+
1109
+ from datasets import load_dataset
1110
+
1111
+ dataset = load_dataset("naver-clova-ix/cord-v2")
1112
+
1113
+
1114
+ test_example = dataset["test"][3]
1115
+ test_image = test_example["image"]
1116
+
1117
+
1118
+ #prepare image and prompt for the model
1119
+ #To do this can be replaced by apply_chat_template when the processor supports this
1120
+ prompt = f"[INST] <image>\nExtract JSON [\INST]"
1121
+ inputs = processor(text=prompt, images=[test_image], return_tensors="pt").to("cuda")
1122
+ for k,v in inputs.items():
1123
+ print(k,v.shape)
1124
+
1125
+ # Generate token IDs
1126
+ generated_ids = model.model.generate(**inputs, max_new_tokens=MAX_LENGTH)
1127
+
1128
+ # Decode back into text
1129
+ generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
1130
+
1131
+ print(generated_texts)
1132
+
1133
+ #processor = AutoProcessor.from_pretrained(model_path)
1134
+
1135
+
1136
+ import re
1137
+
1138
+ # let's turn that into JSON
1139
+ def token2json(tokens, is_inner_value=False, added_vocab=None):
1140
+ """
1141
+ Convert a (generated) token sequence into an ordered JSON format.
1142
+ """
1143
+ if added_vocab is None:
1144
+ added_vocab = processor.tokenizer.get_added_vocab()
1145
+
1146
+ output = {}
1147
+
1148
+ while tokens:
1149
+ start_token = re.search(r"<s_(.*?)>", tokens, re.IGNORECASE)
1150
+ if start_token is None:
1151
+ break
1152
+ key = start_token.group(1)
1153
+ key_escaped = re.escape(key)
1154
+
1155
+ end_token = re.search(rf"</s_{key_escaped}>", tokens, re.IGNORECASE)
1156
+ start_token = start_token.group()
1157
+ if end_token is None:
1158
+ tokens = tokens.replace(start_token, "")
1159
+ else:
1160
+ end_token = end_token.group()
1161
+ start_token_escaped = re.escape(start_token)
1162
+ end_token_escaped = re.escape(end_token)
1163
+ content = re.search(
1164
+ f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE | re.DOTALL
1165
+ )
1166
+ if content is not None:
1167
+ content = content.group(1).strip()
1168
+ if r"<s_" in content and r"</s_" in content: # non-leaf node
1169
+ value = token2json(content, is_inner_value=True, added_vocab=added_vocab)
1170
+ if value:
1171
+ if len(value) == 1:
1172
+ value = value[0]
1173
+ output[key] = value
1174
+ else: # leaf nodes
1175
+ output[key] = []
1176
+ for leaf in content.split(r"<sep/>"):
1177
+ leaf = leaf.strip()
1178
+ if leaf in added_vocab and leaf[0] == "<" and leaf[-2:] == "/>":
1179
+ leaf = leaf[1:-2] # for categorical special tokens
1180
+ output[key].append(leaf)
1181
+ if len(output[key]) == 1:
1182
+ output[key] = output[key][0]
1183
+
1184
+ tokens = tokens[tokens.find(end_token) + len(end_token) :].strip()
1185
+ if tokens[:6] == r"<sep/>": # non-leaf nodes
1186
+ return [output] + token2json(tokens[6:], is_inner_value=True, added_vocab=added_vocab)
1187
+
1188
+ if len(output):
1189
+ return [output] if is_inner_value else output
1190
+ else:
1191
+ return [] if is_inner_value else {"text_sequence": tokens}
1192
+
1193
+
1194
+
1195
+ generated_json = token2json(generated_texts[0])
1196
+ print(generated_json)
1197
+
1198
+ for key, value in generated_json.items():
1199
+ print(key, value)
1200
+
1201
+
1202
+ ###################################################################
1203
+ ###################################################################
1204
+ ###################################################################
1205
+ """
1206
+
1207
+ # Pushing the model into the Huggingface hub
1208
+
1209
+ #Ali-Forootani/llava-v1.6-mistral-7b-hf_20epochs_fine_tune
1210
+
1211
+ # Specify the directory where the model and processor will be saved
1212
+ model_save_path = model_path + "./saved_model"
1213
+
1214
+ # Save the processor
1215
+ processor.save_pretrained(model_save_path)
1216
+
1217
+ # Save the model
1218
+ model.model.save_pretrained(model_save_path)
1219
+
1220
+ from transformers import AutoProcessor, LlavaNextForConditionalGeneration
1221
+
1222
+ # Load the saved processor and model
1223
+ processor = AutoProcessor.from_pretrained(model_save_path)
1224
+ model = LlavaNextForConditionalGeneration.from_pretrained(model_save_path)
1225
+
1226
+ # Push the processor and model to the Hugging Face Hub
1227
+ from huggingface_hub import HfApi, login
1228
+ login(token="your_huggingface_token")
1229
+ processor.push_to_hub("Ali-Forootani/llava-v1.6-mistral-7b-hf_100epochs_fine_tune", use_auth_token=True)
1230
+ model.push_to_hub("Ali-Forootani/llava-v1.6-mistral-7b-hf_100epochs_fine_tune", use_auth_token=True)
1231
+
1232
+
1233
+ from huggingface_hub import HfApi, login
1234
+ from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
1235
+
1236
+ """
1237
+
1238
+
1239
+ #############################
1240
+ ############################# Second way to push to huggingface
1241
+
1242
+ from huggingface_hub import HfApi, login
1243
+ # Login to Hugging Face
1244
+ login(token="your_huggingface_token")
1245
+
1246
+ # Define your Hugging Face repository name
1247
+ # repo_name = "Ali-Forootani/llava-v1.6-mistral-7b_fine_tune_20epochs"
1248
+ repo_name = "Ali-Forootani/llava-v1.6-mistral-7b-hf_100epochs_fine_tune"
1249
+
1250
+ #######
1251
+
1252
+
1253
+
1254
+ # Save the model and processor locally model_path
1255
+ #output_dir = model_path + "/model_to_push"
1256
+ #model.model.save_pretrained(output_dir)
1257
+ #processor.save_pretrained(output_dir)
1258
+
1259
+ # Push to Hugging Face Hub
1260
+ model.model.push_to_hub(repo_name, use_auth_token=True)
1261
+ processor.push_to_hub(repo_name, use_auth_token=True)
1262
+ ```
1263
+
1264
+
1265
+
1266
+ [More Information Needed]