stupidog04
commited on
Commit
•
45c38e5
1
Parent(s):
84fed5c
commit files to HF hub
Browse files- README.md +0 -1
- config.json +9 -8
- pair_classification.py → pair_classification_pipeline.py +27 -0
- pipeline.py +0 -28
- pytorch_model.bin +1 -1
- runs/events.out.tfevents.1666972137.sa103.11178.0 +3 -0
README.md
CHANGED
@@ -2,7 +2,6 @@
|
|
2 |
tags:
|
3 |
- image-classification
|
4 |
- pytorch
|
5 |
-
- huggingpics
|
6 |
library_name: generic
|
7 |
metrics:
|
8 |
- accuracy
|
|
|
2 |
tags:
|
3 |
- image-classification
|
4 |
- pytorch
|
|
|
5 |
library_name: generic
|
6 |
metrics:
|
7 |
- accuracy
|
config.json
CHANGED
@@ -6,13 +6,14 @@
|
|
6 |
"attention_probs_dropout_prob": 0.0,
|
7 |
"custom_pipelines": {
|
8 |
"pair-classification": {
|
9 |
-
"impl": "
|
10 |
"pt": [
|
11 |
"ViTForImageClassification"
|
12 |
],
|
13 |
"tf": [
|
14 |
"TFViTForImageClassification"
|
15 |
-
]
|
|
|
16 |
}
|
17 |
},
|
18 |
"encoder_stride": 16,
|
@@ -31,12 +32,12 @@
|
|
31 |
"initializer_range": 0.02,
|
32 |
"intermediate_size": 3072,
|
33 |
"label2id": {
|
34 |
-
"chk1_fail": 0,
|
35 |
-
"chk1_pass": 1,
|
36 |
-
"chk2_fail": 2,
|
37 |
-
"chk2_pass": 3,
|
38 |
-
"chk3_fail": 4,
|
39 |
-
"chk3_pass": 5
|
40 |
},
|
41 |
"layer_norm_eps": 1e-12,
|
42 |
"model_type": "vit",
|
|
|
6 |
"attention_probs_dropout_prob": 0.0,
|
7 |
"custom_pipelines": {
|
8 |
"pair-classification": {
|
9 |
+
"impl": "pair_classification_pipeline.PairClassificationPipeline",
|
10 |
"pt": [
|
11 |
"ViTForImageClassification"
|
12 |
],
|
13 |
"tf": [
|
14 |
"TFViTForImageClassification"
|
15 |
+
],
|
16 |
+
"type": "image"
|
17 |
}
|
18 |
},
|
19 |
"encoder_stride": 16,
|
|
|
32 |
"initializer_range": 0.02,
|
33 |
"intermediate_size": 3072,
|
34 |
"label2id": {
|
35 |
+
"chk1_fail": "0",
|
36 |
+
"chk1_pass": "1",
|
37 |
+
"chk2_fail": "2",
|
38 |
+
"chk2_pass": "3",
|
39 |
+
"chk3_fail": "4",
|
40 |
+
"chk3_pass": "5"
|
41 |
},
|
42 |
"layer_norm_eps": 1e-12,
|
43 |
"model_type": "vit",
|
pair_classification.py → pair_classification_pipeline.py
RENAMED
@@ -1,8 +1,35 @@
|
|
1 |
from torchvision import transforms
|
|
|
2 |
from transformers import ImageClassificationPipeline
|
3 |
import torch
|
4 |
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
class PairClassificationPipeline(ImageClassificationPipeline):
|
7 |
pipe_to_tensor = transforms.ToTensor()
|
8 |
pipe_to_pil = transforms.ToPILImage()
|
|
|
1 |
from torchvision import transforms
|
2 |
+
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
3 |
from transformers import ImageClassificationPipeline
|
4 |
import torch
|
5 |
|
6 |
|
7 |
+
class PreTrainedPipeline():
|
8 |
+
def __init__(self, path):
|
9 |
+
"""
|
10 |
+
Initialize model
|
11 |
+
"""
|
12 |
+
# self.processor = feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
|
13 |
+
model_flag = 'google/vit-base-patch16-224-in21k'
|
14 |
+
# model_flag = 'google/vit-base-patch16-384'
|
15 |
+
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
|
16 |
+
self.model = ViTForImageClassification.from_pretrained(path)
|
17 |
+
self.pipe = PairClassificationPipeline(self.model, feature_extractor=self.feature_extractor)
|
18 |
+
|
19 |
+
def __call__(self, inputs):
|
20 |
+
"""
|
21 |
+
Args:
|
22 |
+
inputs (:obj:`np.array`):
|
23 |
+
The raw waveform of audio received. By default at 16KHz.
|
24 |
+
Return:
|
25 |
+
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
|
26 |
+
the detected text from the input audio.
|
27 |
+
"""
|
28 |
+
# input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1
|
29 |
+
# logits = self.model(input_values).logits.cpu().detach().numpy()[0]
|
30 |
+
return self.pipe(inputs)
|
31 |
+
|
32 |
+
|
33 |
class PairClassificationPipeline(ImageClassificationPipeline):
|
34 |
pipe_to_tensor = transforms.ToTensor()
|
35 |
pipe_to_pil = transforms.ToPILImage()
|
pipeline.py
CHANGED
@@ -1,36 +1,8 @@
|
|
1 |
from torchvision import transforms
|
2 |
-
from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTConfig
|
3 |
from transformers import ImageClassificationPipeline
|
4 |
import torch
|
5 |
|
6 |
|
7 |
-
|
8 |
-
class PreTrainedPipeline():
|
9 |
-
def __init__(self, path):
|
10 |
-
"""
|
11 |
-
Initialize model
|
12 |
-
"""
|
13 |
-
# self.processor = feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
|
14 |
-
model_flag = 'google/vit-base-patch16-224-in21k'
|
15 |
-
# model_flag = 'google/vit-base-patch16-384'
|
16 |
-
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
|
17 |
-
self.model = ViTForImageClassification.from_pretrained(path)
|
18 |
-
self.pipe = PairClassificationPipeline(self.model, feature_extractor=self.feature_extractor)
|
19 |
-
|
20 |
-
def __call__(self, inputs):
|
21 |
-
"""
|
22 |
-
Args:
|
23 |
-
inputs (:obj:`np.array`):
|
24 |
-
The raw waveform of audio received. By default at 16KHz.
|
25 |
-
Return:
|
26 |
-
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
|
27 |
-
the detected text from the input audio.
|
28 |
-
"""
|
29 |
-
# input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1
|
30 |
-
# logits = self.model(input_values).logits.cpu().detach().numpy()[0]
|
31 |
-
return self.pipe(inputs)
|
32 |
-
|
33 |
-
|
34 |
class PairClassificationPipeline(ImageClassificationPipeline):
|
35 |
pipe_to_tensor = transforms.ToTensor()
|
36 |
pipe_to_pil = transforms.ToPILImage()
|
|
|
1 |
from torchvision import transforms
|
|
|
2 |
from transformers import ImageClassificationPipeline
|
3 |
import torch
|
4 |
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
class PairClassificationPipeline(ImageClassificationPipeline):
|
7 |
pipe_to_tensor = transforms.ToTensor()
|
8 |
pipe_to_pil = transforms.ToPILImage()
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 345635761
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a709dd954e0bcc97bd8ed5652becb0709ac5ebc4d36f6e1ae7b15de08019dc01
|
3 |
size 345635761
|
runs/events.out.tfevents.1666972137.sa103.11178.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:93ee6a8b55c58e1a4f81c0e7b2484cc7986e796c8cad905ef2417c30f5469136
|
3 |
+
size 551
|