import os import shutil import subprocess import time from pathlib import Path from typing import List, Tuple, Union import gradio as gr import numpy as np import pandas as pd import requests from preprocessing import pretty_print from symptoms_categories import SYMPTOMS_LIST from concrete.ml.common.serialization.loaders import load from concrete.ml.deployment import FHEModelClient, FHEModelDev, FHEModelServer from concrete.ml.sklearn import XGBClassifier as ConcreteXGBoostClassifier INPUT_BROWSER_LIMIT = 635 SERVER_URL = "http://localhost:8000/" # This repository's main necessary folders REPO_DIR = Path(__file__).parent MODEL_PATH = REPO_DIR / "client_folder" KEYS_PATH = REPO_DIR / ".fhe_keys" CLIENT_TMP_PATH = REPO_DIR / "client_tmp" SERVER_TMP_PATH = REPO_DIR / "server_tmp" # Create the necessary folders KEYS_PATH.mkdir(exist_ok=True) CLIENT_TMP_PATH.mkdir(exist_ok=True) SERVER_TMP_PATH.mkdir(exist_ok=True) subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR) time.sleep(3) def clean_directory(): target_dir = ".fhe_keys" if os.path.exists(target_dir) and os.path.isdir(target_dir): shutil.rmtree(target_dir) print("The .fhe_keys directory and its contents have been successfully removed.") else: print("The .keys directory does not exist.") def load_data(): # Load data df_train = pd.read_csv("./data/Training_preprocessed.csv") df_test = pd.read_csv("./data/Testing_preprocessed.csv") # Separate the traget from the training set # df['prognosis] contains the name of the disease # df['y] contains the numeric label of the disease y_train = df_train["y"] X_train = df_train.drop(columns=["y", "prognosis"], axis=1, errors="ignore") y_test = df_train["y"] X_test = df_test.drop(columns=["y", "prognosis"], axis=1, errors="ignore") return (df_train, X_train, X_test), (df_test, y_train, y_test) def load_model(X_train, y_train): concrete_args = {"max_depth": 1, "n_bits": 3, "n_estimators": 3, "n_jobs": -1} classifier = ConcreteXGBoostClassifier(**concrete_args) classifier.fit(X_train, y_train) circuit = classifier.compile(X_train) return classifier, circuit def get_user_vect_symptoms_from_checkboxgroup(*user_symptoms) -> np.array: symptoms_vector = {key: 0 for key in VALID_COLUMNS} for symptom_box in user_symptoms: for pretty_symptom in symptom_box: symptom = "_".join((pretty_symptom.lower().split(" "))) if symptom not in symptoms_vector.keys(): raise KeyError( f"The symptom '{symptom}' you provided is not recognized as a valid " f"symptom.\nHere is the list of valid symptoms: {symptoms_vector}" ) symptoms_vector[symptom] = 1.0 user_symptoms_vect = np.fromiter(symptoms_vector.values(), dtype=float)[np.newaxis, :] assert all(value == 0 or value == 1 for value in user_symptoms_vect.flatten()) return user_symptoms_vect def get_user_vector_from_default_disease(disease): user_symptom_vector = df_test[df_test["prognosis"] == disease].iloc[0].values user_symptoms_vect = np.fromiter(user_symptom_vector[:-2], dtype=float)[np.newaxis, :] assert all(value == 0 or value == 1 for value in user_symptoms_vect.flatten()) return user_symptoms_vect def get_user_symptoms_from_default_disease(disease): df_filtred = df_test[df_test["prognosis"] == disease] columns_with_1 = df_filtred.columns[df_filtred.eq(1).any()].to_list() return pretty_print(columns_with_1) def get_user_symptoms_vector_fn(selected_default_disease, *selected_symptoms): # Display an error box, if: # 1. The user has already selected a default disease and added more symptoms, or # 2. The the user has not selected a default disease or symptoms if ( any(lst for lst in selected_symptoms if lst) and (selected_default_disease is not None and len(selected_default_disease) > 0) and set(pretty_print(selected_symptoms)) - set(get_user_symptoms_from_default_disease(selected_default_disease)) ) or ( not any(lst for lst in selected_symptoms if lst) and ( selected_default_disease is None or (selected_default_disease is not None and len(selected_default_disease) < 1) ) ): return { error_box_1: gr.update( visible=True, value="Enter a default disease or select your own symptoms" ), } # Case 1: The user has checked his own symptoms if any(lst for lst in selected_symptoms if lst): return { error_box_1: gr.update(visible=False), user_vector_textbox: get_user_vect_symptoms_from_checkboxgroup(*selected_symptoms), } # Case 2: The user has selected a default disease if selected_default_disease is not None and len(selected_default_disease) > 0: return { user_vector_textbox: get_user_vector_from_default_disease(selected_default_disease), error_box_1: gr.update(visible=False), **{ box: get_user_symptoms_from_default_disease(selected_default_disease) for box in check_boxes }, } def key_gen_fn(user_symptoms): print("Cleaning directory ...") clean_directory() if user_symptoms is None or (user_symptoms is not None and len(user_symptoms) < 1): print("Please submit your symptoms first") return { error_box_2: gr.update(visible=True, value="Please submit your symptoms first"), } # Key serialization user_id = np.random.randint(0, 2**32) client = FHEModelClient(path_dir=MODEL_PATH, key_dir=KEYS_PATH / f"{user_id}") client.load() # The client first need to create the private and evaluation keys. client.generate_private_and_evaluation_keys() # Get the serialized evaluation keys serialized_evaluation_keys = client.get_serialized_evaluation_keys() assert isinstance(serialized_evaluation_keys, bytes) # np.save(f".fhe_keys/{user_id}/eval_key.npy", serialized_evaluation_keys) evaluation_key_path = KEYS_PATH / f"{user_id}/evaluation_key" with evaluation_key_path.open("wb") as f: f.write(serialized_evaluation_keys) serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[:INPUT_BROWSER_LIMIT] return { error_box_2: gr.update(visible=False), eval_key_textbox: serialized_evaluation_keys_shorten_hex, user_id_textbox: user_id, eval_key_len_textbox: f"{len(serialized_evaluation_keys) / (10**6):.2f} MB", } def encrypt_fn(user_symptoms, user_id): if not user_symptoms or not user_symptoms: return { error_box_3: gr.update( visible=True, value="Please ensure that the evaluation key has been generated!" ) } # Retrieve the client API client = FHEModelClient(path_dir=MODEL_PATH, key_dir=KEYS_PATH / f"{user_id}") client.load() user_symptoms = np.fromstring(user_symptoms[2:-2], dtype=int, sep=".").reshape(1, -1) quant_user_symptoms = client.model.quantize_input(user_symptoms) encrypted_quantized_user_symptoms = client.quantize_encrypt_serialize(user_symptoms) assert isinstance(encrypted_quantized_user_symptoms, bytes) encrypted_input_path = KEYS_PATH / f"{user_id}/encrypted_symptoms" with encrypted_input_path.open("wb") as f: f.write(encrypted_quantized_user_symptoms) # print(client.model.predict(vect_x, fhe="simulate"), client.model.predict(vect_x, fhe="execute")) # pred_s = client.model.fhe_circuit.simulate(quant_vect) # pred_fhe = client.model.fhe_circuit.encrypt_run_decrypt(quant_vect) # # non alpha -> \X1124, base64 ou en exa # Compute size # np.save(f".fhe_keys/{user_id}/encrypted_quant_vect.npy", encrypted_quantized_user_symptoms) encrypted_quantized_user_symptoms_shorten_hex = encrypted_quantized_user_symptoms.hex()[ :INPUT_BROWSER_LIMIT ] return { error_box_3: gr.update(visible=False), vect_textbox: user_symptoms, quant_vect_textbox: quant_user_symptoms, encrypted_vect_textbox: encrypted_quantized_user_symptoms_shorten_hex, } def is_nan(input): return input is None or (input is not None and len(input) < 1) def send_input_fn(user_id, user_symptoms): """Send the encrypted input image as well as the evaluation key to the server. Args: user_id (int): The current user's ID. filter_name (str): The current filter to consider. """ # Get the evaluation key path if is_nan(user_id) or is_nan(user_symptoms): return { error_box_4: gr.update( visible=True, value="Please ensure that the evaluation key has been generated " "and the symptoms have been submitted before sending the data to the server", ) } evaluation_key_path = KEYS_PATH / f"{user_id}/evaluation_key" encrypted_input_path = KEYS_PATH / f"{user_id}/encrypted_symptoms" if not evaluation_key_path.is_file(): print(f"Please generate the private key, first.{evaluation_key_path.is_file()=}") return { error_box_4: gr.update(visible=True, value="Please generate the private key first.") } if not encrypted_input_path.is_file(): print(f"Please submit your symptoms, first.{encrypted_input_path.is_file()=}") return { error_box_4: gr.update( visible=True, value="Please generate the private key and then encrypt an image first.", ) } # Define the data and files to post data = { "user_id": user_id, "filter": user_symptoms, } files = [ ("files", open(encrypted_input_path, "rb")), ("files", open(evaluation_key_path, "rb")), ] # Send the encrypted input image and evaluation key to the server url = SERVER_URL + "send_input" with requests.post( url=url, data=data, files=files, ) as response: print(f"response.ok: {response.ok}") return {error_box_4: gr.update(visible=False), server_response_box: gr.update(visible=True)} # def decrypt_prediction(encrypted_quantized_vect, user_id): # fhe_api = FHEModelClient(path_dir=REPO_DIR, key_dir=f".fhe_keys/{user_id}") # fhe_api.load() # fhe_api.generate_private_and_evaluation_keys(force=False) # predictions = fhe_api.deserialize_decrypt_dequantize(encrypted_quantized_vect) # return predictions def clear_all_btn(): clean_directory() return { box_default: None, vect_textbox: None, user_id_textbox: None, eval_key_textbox: None, quant_vect_textbox: None, user_vector_textbox: None, eval_key_len_textbox: None, encrypted_vect_textbox: None, error_box_1: gr.update(visible=False), error_box_2: gr.update(visible=False), error_box_3: gr.update(visible=False), error_box_4: gr.update(visible=False), server_response_box: gr.update(visible=False), **{box: None for box in check_boxes}, } if __name__ == "__main__": print("Starting demo ...") (df_train, X_train, X_test), (df_test, y_train, y_test) = load_data() VALID_COLUMNS = X_train.columns.to_list() # Load the model with open("ConcreteXGBoostClassifier.pkl", "r", encoding="utf-8") as file: concrete_classifier = load(file) with gr.Blocks() as demo: # Link + images gr.Markdown( """
Concrete-ML — Documentation — Community — @zama_fhe
""" ) # Gentle introduction gr.Markdown("## Introduction") gr.Markdown("""Blablabla""") # User symptoms gr.Markdown("# Step 1: Provide your symptoms") gr.Markdown("Client side") # Default disease, picked from the dataframe with gr.Row(): default_diseases = list(set(df_test["prognosis"])) box_default = gr.Dropdown(default_diseases, label="Disease") # Box symptoms check_boxes = [] for i, category in enumerate(SYMPTOMS_LIST): check_box = gr.CheckboxGroup( pretty_print(category.values()), label=pretty_print(category.keys()), info=f"Symptoms related to `{pretty_print(category.values())}`", max_batch_size=45, ) check_boxes.append(check_box) error_box_1 = gr.Textbox(label="Error", visible=False) # User symptom vector with gr.Row(): user_vector_textbox = gr.Textbox( label="User symptoms (vector)", interactive=False, max_lines=100, ) with gr.Row(): # Submit botton with gr.Column(): submit_button = gr.Button("Submit") # Clear botton with gr.Column(): clear_button = gr.Button("Clear") # Click submit botton submit_button.click( fn=get_user_symptoms_vector_fn, inputs=[box_default, *check_boxes], outputs=[user_vector_textbox, error_box_1, *check_boxes], ) gr.Markdown("# Step 2: Generate the keys") gr.Markdown("Client side") gen_key_btn = gr.Button("Generate the keys and send public part to server") error_box_2 = gr.Textbox(label="Error", visible=False) with gr.Row(): # User ID with gr.Column(scale=1, min_width=600): user_id_textbox = gr.Textbox( label="User ID:", max_lines=4, interactive=False, ) # Evaluation key size with gr.Column(scale=1, min_width=600): eval_key_len_textbox = gr.Textbox( label="Evaluation key size:", max_lines=4, interactive=False ) with gr.Row(): # Evaluation key (truncated) with gr.Column(scale=2, min_width=600): eval_key_textbox = gr.Textbox( label="Evaluation key (truncated):", max_lines=4, interactive=False, ) gen_key_btn.click( key_gen_fn, inputs=user_vector_textbox, outputs=[eval_key_textbox, user_id_textbox, eval_key_len_textbox, error_box_2], ) gr.Markdown("# Step 3: Encode the message with the private key") gr.Markdown("Client side") encrypt_btn = gr.Button("Encode the message with the private key") error_box_3 = gr.Textbox(label="Error", visible=False) with gr.Row(): with gr.Column(scale=1, min_width=600): vect_textbox = gr.Textbox( label="Vector:", max_lines=4, interactive=False, ) with gr.Column(scale=1, min_width=600): quant_vect_textbox = gr.Textbox( label="Quant vector:", max_lines=4, interactive=False ) with gr.Column(scale=1, min_width=600): encrypted_vect_textbox = gr.Textbox( label="Encrypted vector:", max_lines=4, interactive=False ) encrypt_btn.click( encrypt_fn, inputs=[user_vector_textbox, user_id_textbox], outputs=[vect_textbox, quant_vect_textbox, encrypted_vect_textbox, error_box_3], ) gr.Markdown("# Step 4: Send the encrypted data to the server.") gr.Markdown("Client side") send_input_btn = gr.Button("Send the encrypted data to the server..") error_box_4 = gr.Textbox(label="Error", visible=False) server_response_box = gr.Textbox(value="Data sent", visible=False, show_label=False) send_input_btn.click( send_input_fn, inputs=[user_id_textbox, user_vector_textbox], outputs=[error_box_4, server_response_box], ) gr.Markdown("# Step 5: Run the FHE evaluation") gr.Markdown("Server side") run_fhe = gr.Button("Run the FHE evaluation") gr.Markdown("# Step 6: Decrypt the sentiment") gr.Markdown("Server side") decrypt_target_botton = gr.Button("Decrypt the sentiment") decrypt_target_textbox = gr.Textbox( label="Encrypted vector:", max_lines=4, interactive=False ) # decrypt_target_botton.click( # decrypt_prediction, # inputs=[encrypted_vect_textbox, user_id_textbox], # outputs=[decrypt_target_textbox], # ) clear_button.click( clear_all_btn, outputs=[ box_default, error_box_1, error_box_2, error_box_3, error_box_4, vect_textbox, user_id_textbox, eval_key_textbox, quant_vect_textbox, user_vector_textbox, server_response_box, eval_key_len_textbox, encrypted_vect_textbox, *check_boxes, ], ) demo.launch()