File size: 7,039 Bytes
44121f9
 
 
 
dd4f9fa
 
 
 
 
44121f9
 
dd4f9fa
44121f9
dd4f9fa
 
 
44121f9
 
 
 
dd4f9fa
 
 
 
 
 
 
 
 
e5350c1
7d416d9
dd4f9fa
 
651ff88
d01757d
4c9bd11
dd4f9fa
4c9bd11
 
 
 
 
 
e5350c1
4c9bd11
 
 
dd4f9fa
4c9bd11
dd4f9fa
 
 
 
 
44121f9
 
dd4f9fa
44121f9
7d416d9
dd4f9fa
 
44121f9
dd4f9fa
44121f9
dd4f9fa
 
44121f9
dd4f9fa
 
 
 
 
44121f9
dd4f9fa
 
44121f9
7d416d9
44121f9
7d416d9
dd4f9fa
 
 
 
44121f9
dd4f9fa
 
 
 
 
 
44121f9
327b523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44121f9
 
 
 
dd4f9fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44121f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
---
library_name: transformers
tags:
- image-to-text
- text-generation-inference
license: gemma
datasets:
- ucsahin/pubtables-detection-1500-samples
pipeline_tag: image-text-to-text
---

# paligemma-3b-mix-448-ft-TableDetection

This model is a mixed precision fine-tuned version of [google/paligemma-3b-mix-448](https://huggingface.co/google/paligemma-3b-mix-448) on [ucsahin/pubtables-detection-1500-samples](https://huggingface.co/datasets/ucsahin/pubtables-detection-1500-samples) dataset.
It achieves the following results on the evaluation set:
- Loss: 1.3544


## Model Details

- This model is a multimodal language model fine-tuned for the task of detecting tables in images given textual prompts. The model utilizes a combination of image and text inputs to predict bounding boxes around tables within the provided images. 
- The primary purpose of this model is to assist in automating the process of table detection within images. It can be utilized in various applications such as document processing, data extraction, and image analysis, where identifying tables within images is essential.

**Inputs:**
- **Image:** The model requires an image containing one or more tables as input. The image should be in a standard format such as JPEG or PNG.
- **Text Prompt:** Additionally, a text prompt is required to guide the model's attention towards the task of table detection. The prompt should clearly indicate the desired action. Please use **"detect table"** as your text prompt.

**Outputs:**
- **Bounding Boxes:** The model outputs the location for the bounding box coordinates in the form of special <loc[value]> tokens, where value is a number that represents a normalized coordinate. Each detection is represented by four location coordinates in the order y_min, x_min, y_max, x_max, followed by the label that was detected in that box. To convert values to coordinates, you first need to divide the numbers by 1024, then multiply y by the image height and x by its width. This will give you the coordinates of the bounding boxes, relative to the original image size.
If everything goes smoothly, the model will output a text similar to "<loc[value]><loc[value]><loc[value]><loc[value]> table; <loc[value]><loc[value]><loc[value]><loc[value]> table" depending on the number of tables detected in the image. Then, you can use the following script to convert the text output into PASCAL VOC formatted bounding boxes.
```python
import re

def post_process(bbox_text, image_width, image_height):
    loc_values_str = [bbox.strip() for bbox in bbox_text.split(";")]
    
    converted_bboxes = []
    for loc_value_str in loc_values_str:
        loc_values = re.findall(r'<loc(\d+)>', loc_value_str)
        loc_values = [int(x) for x in loc_values]
        loc_values = loc_values[:4] 
        
        loc_values = [value/1024 for value in loc_values]
        # convert to (xmin, ymin, xmax, ymax)
        loc_values = [
            int(loc_values[1]*image_width), int(loc_values[0]*image_height), 
            int(loc_values[3]*image_width), int(loc_values[2]*image_height), 
        ]
        converted_bboxes.append(loc_values)
    
    return converted_bboxes
```
  
## How to Get Started with the Model

<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
In Transformers, you can load the model as follows:

```python
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch

model_id = "ucsahin/paligemma-3b-mix-448-ft-TableDetection"

device = "cuda:0"
dtype = torch.bfloat16

model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=dtype,
    device_map=device
)

processor = PaliGemmaProcessor.from_pretrained(model_id)
```

For inference, you can use the following:

```python
# # Instruct the model to detect tables
prompt = "detect table"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=128, do_sample=False)
    generation = generation[0][input_len:]
    bbox_text = processor.decode(generation, skip_special_tokens=True)
    print(bbox_text)
```

**Warning:** You can also load a quantized 4-bit or 8-bit model using `bitsandbytes`. Beware though that the model can generate outputs that can require further post-processing for example five locations tags "<loc[value]>" instead of four, and different labels other than "table". The provided post-processing script should handle the first case.  

Use the following to load the 4-bit quantized model:

```python
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, BitsAndBytesConfig
import torch

model_id = "ucsahin/paligemma-3b-mix-448-ft-TableDetection"

device = "cuda:0"
dtype = torch.bfloat16

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=dtype
)

model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=dtype,
    device_map=device,
    quantization_config=bnb_config
)

processor = PaliGemmaProcessor.from_pretrained(model_id)
```


## Bias, Risks, and Limitations

<!-- This section is meant to convey both technical and sociotechnical limitations. -->

Please refer to [google/paligemma-3b-mix-448](https://huggingface.co/google/paligemma-3b-mix-448) for bias, risks and limitations.


### Training hyperparameters

The following hyperparameters were used during training:
- learning_rate: 0.0001
- train_batch_size: 4
- eval_batch_size: 4
- seed: 42
- gradient_accumulation_steps: 4
- bf16: True mixed precision
- total_train_batch_size: 16
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- lr_scheduler_warmup_steps: 5
- num_epochs: 3

### Training results

| Training Loss | Epoch  | Step | Validation Loss |
|:-------------:|:------:|:----:|:---------------:|
| 2.957         | 0.1775 | 15   | 2.1300          |
| 1.9656        | 0.3550 | 30   | 1.8421          |
| 1.6716        | 0.5325 | 45   | 1.6898          |
| 1.5514        | 0.7101 | 60   | 1.5803          |
| 1.5851        | 0.8876 | 75   | 1.5271          |
| 1.4134        | 1.0651 | 90   | 1.4771          |
| 1.3566        | 1.2426 | 105  | 1.4528          |
| 1.3093        | 1.4201 | 120  | 1.4227          |
| 1.2897        | 1.5976 | 135  | 1.4115          |
| 1.256         | 1.7751 | 150  | 1.4007          |
| 1.2666        | 1.9527 | 165  | 1.3678          |
| 1.2213        | 2.1302 | 180  | 1.3744          |
| 1.0999        | 2.3077 | 195  | 1.3633          |
| 1.1931        | 2.4852 | 210  | 1.3606          |
| 1.0722        | 2.6627 | 225  | 1.3619          |
| 1.1485        | 2.8402 | 240  | 1.3544          |


### Framework versions

- PEFT 0.11.1
- Transformers 4.42.0.dev0
- Pytorch 2.3.0+cu121
- Datasets 2.19.1
- Tokenizers 0.19.1