File size: 4,190 Bytes
58df7f1
 
 
4fdac74
58df7f1
 
 
 
 
 
 
 
 
 
 
 
 
 
7681953
58df7f1
ce79710
 
58df7f1
 
 
 
 
 
 
 
 
 
 
7681953
58df7f1
7681953
 
 
 
58df7f1
 
 
 
 
 
 
 
 
 
 
 
 
 
7681953
58df7f1
7681953
58df7f1
7681953
58df7f1
7681953
 
 
 
 
58df7f1
 
 
 
 
 
 
 
 
 
 
 
4f6d82f
58df7f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fdac74
58df7f1
 
 
 
 
 
 
4fdac74
58df7f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7681953
 
 
 
 
 
58df7f1
 
 
 
e3e1dc8
58df7f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import shutil
from pathlib import Path
from typing import List, Tuple, Union

import numpy
import pandas

from concrete.ml.sklearn import XGBClassifier as ConcreteXGBoostClassifier

# Max Input to be displayed on the HuggingFace space brower using Gradio
# Too large inputs, slow down the server: https://github.com/gradio-app/gradio/issues/1877
INPUT_BROWSER_LIMIT = 635

# Store the server's URL
SERVER_URL = "http://localhost:8000/"

CURRENT_DIR = Path(__file__).parent
DEPLOYMENT_DIR = CURRENT_DIR / "deployment_logit_11"
KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys"
CLIENT_DIR = DEPLOYMENT_DIR / "client_dir"
SERVER_DIR = DEPLOYMENT_DIR / "server_dir"

ALL_DIRS = [KEYS_DIR, CLIENT_DIR, SERVER_DIR]

# Columns that define the target
TARGET_COLUMNS = ["prognosis_encoded", "prognosis"]

TRAINING_FILENAME = "./data/Training_preprocessed.csv"
TESTING_FILENAME = "./data/Testing_preprocessed.csv"

# pylint: disable=invalid-name

from typing import List, Tuple


def pretty_print(
    inputs, case_conversion=str.title, which_replace: str = "_", to_what: str = " ", delimiter=None
):
    """
    Prettify and sort the input as a list of string.

    Args:
        inputs (Any): The inputs to be prettified.

    Returns:
        List: The prettified and sorted list of inputs.

    """
    # Flatten the list if required
    pretty_list = []
    for item in inputs:
        if isinstance(item, list):
            pretty_list.extend(item)
        else:
            pretty_list.append(item)

    # Sort
    pretty_list = sorted(list(set(pretty_list)))
    # Replace
    pretty_list = [item.replace(which_replace, to_what) for item in pretty_list]
    pretty_list = [case_conversion(item) for item in pretty_list]
    if delimiter:
        pretty_list = f"{delimiter.join(pretty_list)}."

    return pretty_list


def clean_directory() -> None:
    """
    Clear direcgtories
    """
    print("Cleaning...\n")
    for target_dir in ALL_DIRS:
        if os.path.exists(target_dir) and os.path.isdir(target_dir):
            shutil.rmtree(target_dir)
        target_dir.mkdir(exist_ok=True, parents=True)


def get_disease_name(encoded_prediction: int, file_name: str = TRAINING_FILENAME) -> str:
    """Return the disease name given its encoded label.

    Args:
        encoded_prediction (int): The encoded prediction
        file_name (str): The data file path

    Returns:
        str: The according disease name
    """
    df = pandas.read_csv(file_name, usecols=TARGET_COLUMNS).drop_duplicates()
    disease_name, _ = df[df[TARGET_COLUMNS[0]] == encoded_prediction].values.flatten()
    return disease_name


def load_data() -> Union[Tuple[pandas.DataFrame, numpy.ndarray], List]:
    """
    Return the data

    Args:
        None

    Return:
        The train, testing set and valid symptoms.


    """
    # Load data
    df_train = pandas.read_csv(TRAINING_FILENAME)
    df_test = pandas.read_csv(TESTING_FILENAME)

    # Separate the traget from the training / testing set:
    # TARGET_COLUMNS[0] -> "prognosis_encoded" -> contains the numeric label of the disease
    # TARGET_COLUMNS[1] -> "prognosis"         -> contains the name of the disease

    y_train = df_train[TARGET_COLUMNS[0]]
    X_train = df_train.drop(columns=TARGET_COLUMNS, axis=1, errors="ignore")

    y_test = df_test[TARGET_COLUMNS[0]]
    X_test = df_test.drop(columns=TARGET_COLUMNS, axis=1, errors="ignore")

    return (
        (X_train, X_test),
        (y_train, y_test),
        X_train.columns.to_list(),
        df_train[TARGET_COLUMNS[1]].unique().tolist(),
    )


def load_model(X_train: pandas.DataFrame, y_train: numpy.ndarray):
    """
    Load a pre-trained serialized model

    Args:
        X_train (pandas.DataFrame): Training set
        y_train (numpy.ndarray): Targets of the training set

    Return:
        The Concrete ML model and its circuit
    """
    # Parameters
    concrete_args = {"max_depth": 1, "n_bits": 3, "n_estimators": 3, "n_jobs": -1}
    classifier = ConcreteXGBoostClassifier(**concrete_args)
    # Train the model
    classifier.fit(X_train, y_train)
    # Compile the model
    circuit = classifier.compile(X_train)

    return classifier, circuit