stupidog04 commited on
Commit
164a729
1 Parent(s): be4d4fb

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +33 -0
pipeline.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ from pair_classification import PairClassificationPipeline
3
+ from typing import Dict
4
+
5
+
6
+ class PreTrainedPipeline():
7
+ def __init__(self, path):
8
+ """
9
+ Initialize model
10
+ """
11
+ model_flag = 'google/vit-base-patch16-224-in21k'
12
+ # self.processor = feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
13
+ self.pipe = pipeline("pair-classification", model=model_flag , feature_extractor=model_flag ,
14
+ model_kwargs={'num_labels':len(label2id),
15
+ 'label2id':label2id,
16
+ 'id2label':id2label,
17
+ 'num_channels':6,
18
+ 'ignore_mismatched_sizes': True })
19
+ self.model = self.pipe.model.from_pretrained(path)
20
+
21
+
22
+ def __call__(self, inputs):
23
+ """
24
+ Args:
25
+ inputs (:obj:`np.array`):
26
+ The raw waveform of audio received. By default at 16KHz.
27
+ Return:
28
+ A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
29
+ the detected text from the input audio.
30
+ """
31
+ # input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1
32
+ # logits = self.model(input_values).logits.cpu().detach().numpy()[0]
33
+ return self.pipe(inputs)