File size: 956 Bytes
5d526dc
 
 
 
 
 
 
35f403b
 
5d526dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from typing import Dict, List, Any

import torch
from transformers import BertModel, BertTokenizerFast


class EndpointHandler():

    def __init__(self, path_to_model: str = '.'):
        # Preload all the elements you are going to need at inference.
        # pseudo:
        self.tokenizer = BertTokenizerFast.from_pretrained(path_to_model)
        self.model = BertModel.from_pretrained(path_to_model)
        self.model = self.model.eval()

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        This method is called whenever a request is made to the endpoint.
        :param data: { inputs [str]: list of strings to be encoded }
        :return: A :obj:`list` | `dict`: will be serialized and returned
        """

        inputs = self.tokenizer(data['inputs'], return_tensors = "pt", padding = True)

        with torch.no_grad():
            outputs = self.model(**inputs)

        return outputs.pooler_output