|
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModel |
|
from transformers.pipelines import PIPELINE_REGISTRY |
|
from huggingface_hub import hf_hub_download |
|
|
|
import onnxruntime as ort |
|
import torch |
|
import os |
|
|
|
|
|
class ONNXBaseConfig(PretrainedConfig): |
|
model_type = 'onnx-base' |
|
|
|
AutoConfig.register('onnx-base', ONNXBaseConfig) |
|
|
|
|
|
class ONNXBaseModel(PreTrainedModel): |
|
config_class = ONNXBaseConfig |
|
def __init__(self, config, base_path=None): |
|
super().__init__(config) |
|
if base_path: |
|
model_path = base_path + '/' + config.model_path |
|
if os.path.exists(model_path): |
|
self.session = ort.InferenceSession(model_path) |
|
|
|
def forward(self, input=None, **kwargs): |
|
outs = self.session.run(None, {'input': input}) |
|
return outs |
|
|
|
def save_pretrained(self, save_directory: str, **kwargs): |
|
super().save_pretrained(save_directory=save_directory, **kwargs) |
|
onnx_file_path = save_directory + '/model.onnx' |
|
dummy_input = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) |
|
torch.onnx.export(self, dummy_input, onnx_file_path, |
|
input_names=['input'], output_names=['output'], |
|
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
if config.model_path is None: |
|
config.model_path = 'model.onnx' |
|
is_local = os.path.isdir(pretrained_model_name_or_path) |
|
if is_local: |
|
base_path = pretrained_model_name_or_path |
|
else: |
|
config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename='config.json') |
|
base_path = os.path.dirname(config_path) |
|
hf_hub_download(repo_id=pretrained_model_name_or_path, filename=config.model_path) |
|
return cls(config, base_path=base_path) |
|
|
|
@property |
|
def device(self): |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
return torch.device(device) |
|
|
|
|
|
AutoModel.register(ONNXBaseConfig, ONNXBaseModel) |
|
|
|
|
|
from transformers.pipelines import Pipeline |
|
|
|
class ONNXBasePipeline(Pipeline): |
|
def __init__(self, model, **kwargs): |
|
self.device_id = kwargs['device'] |
|
super().__init__(model=model, **kwargs) |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
return {}, {}, {} |
|
|
|
def preprocess(self, input): |
|
return {'input': input} |
|
|
|
def _forward(self, model_input): |
|
with torch.no_grad(): |
|
outputs = self.model(**model_input) |
|
return outputs |
|
|
|
def postprocess(self, model_outputs): |
|
return model_outputs |
|
|
|
PIPELINE_REGISTRY.register_pipeline( |
|
task='onnx-base', |
|
pipeline_class=ONNXBasePipeline, |
|
pt_model=ONNXBaseModel |
|
) |
|
|