rajistics commited on
Commit
374eeee
1 Parent(s): 374500b

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +33 -0
  2. requirements.txt +1 -0
handler.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import pickle
3
+ from typing import Dict, List, Any
4
+ import numpy as np
5
+
6
+
7
+ # set device
8
+ class EndpointHandler():
9
+ def __init__(self, path=""):
10
+ # load the optimized model
11
+ self.pipe = pd.read_pickle(r'churn.pkl')
12
+
13
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
14
+ """
15
+ Args:
16
+ data (:obj:):
17
+ includes the input data and the parameters for the inference.
18
+ Return:
19
+ A :obj:`list`:. A string representing what the label/class is
20
+ """
21
+ inputs = data.pop("inputs", data)
22
+ parameters = data.pop("parameters", None)
23
+ df = pd.DataFrame(data)
24
+
25
+ df["TotalCharges"] = df["TotalCharges"].replace(" ", np.nan, regex=False).astype(float)
26
+ df = df.drop(columns=["customerID"])
27
+ df = df.drop(columns=["Churn"])
28
+
29
+ # run inference pipeline
30
+ pred = self.pipe.predict(df)
31
+
32
+ # postprocess the prediction
33
+ return {"pred": pred}
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pandas