|
|
|
|
|
class FondantInferenceModel: |
|
"""FondantInferenceModel class that abstracts the model loading and inference. |
|
User needs to implement an inference, pre/postprocess step and pass the class to the FondantInferenceComponent. |
|
The FondantInferenceComponent will then load the model and prepare it for inference. |
|
The examples folder can then show examples for a pytorch / huggingface / tensorflow / ... model. |
|
""" |
|
def __init__(self, device: str = "cpu"): |
|
self.device = device |
|
|
|
self.model = self.load_model() |
|
|
|
self.eval() |
|
|
|
def load_model(self): |
|
|
|
... |
|
|
|
def eval(self): |
|
|
|
self.model = self.model.eval() |
|
self.model = self.model.to(self.device) |
|
|
|
def preprocess(self, input): |
|
|
|
... |
|
|
|
def postprocess(self, output): |
|
|
|
... |
|
|
|
def __call__(self, *args, **kwargs): |
|
processed_inputs = self.preprocess(*args, **kwargs) |
|
outputs = self.model(*processed_inputs) |
|
processed_outputs = self.postprocess(outputs) |
|
return processed_outputs |
|
|
|
|
|
class FondantInferenceComponent(FondantTransformComponent, FondantInferenceModel): |
|
|
|
|
|
def transform( |
|
self, args: argparse.Namespace, dataframe: dd.DataFrame |
|
) -> dd.DataFrame: |
|
|
|
|
|
|
|
output = self.infer(args.image) |
|
|