|
import tensorflow as tf |
|
from PIL import Image |
|
from tensorflow import keras |
|
import numpy as np |
|
import os |
|
import random |
|
import logging |
|
from tensorflow.keras.preprocessing import image as keras_image |
|
from huggingface_hub import from_pretrained_keras |
|
from openai import AzureOpenAI |
|
import gradio as gr |
|
from zipfile import ZipFile |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
class DiseaseDetectionApp: |
|
def __init__(self): |
|
|
|
|
|
self.class_names =['Normal', 'Tuberculosis'] |
|
self.model =tf.keras.models.load_model("chest_xray_tuberclosis_prediction_model.keras") |
|
self.client=AzureOpenAI() |
|
|
|
|
|
def predict_disease(self, image_path): |
|
""" |
|
Predict the disease present in the X-Ray image. |
|
|
|
Args: |
|
- image_data: PIL image data |
|
|
|
Returns: |
|
- predicted_disease: string |
|
""" |
|
try: |
|
|
|
|
|
img = keras_image.load_img(image_path, target_size=(256, 256)) |
|
|
|
|
|
img_array = keras_image.img_to_array(img) |
|
|
|
|
|
img_array = tf.expand_dims(img_array, 0) |
|
|
|
|
|
predictions = self.model.predict(img_array) |
|
|
|
|
|
predict_class =self.class_names[np.argmax(predictions[0])] |
|
confidence = round(100 * np.max(predictions[0]), 2) |
|
return predict_class |
|
|
|
except Exception as e: |
|
logging.error(f"Error predicting disease: {str(e)}") |
|
return None |
|
|
|
def classify_disease(self,image_path): |
|
|
|
disease_name=self.predict_disease(image_path) |
|
print(disease_name) |
|
if disease_name=="Tuberculosis": |
|
conversation = [ |
|
{"role": "system", "content": "You are a medical assistant"}, |
|
{"role": "user", "content": f""" your task describe(classify) about the given disease as a summary only in 3 lines. |
|
```{disease_name}``` |
|
"""} |
|
] |
|
|
|
response = self.client.chat.completions.create( |
|
model="ChatGPT", |
|
messages=conversation, |
|
temperature=0, |
|
max_tokens=1000 |
|
) |
|
|
|
|
|
result = response.choices[0].message.content |
|
return disease_name,result |
|
|
|
elif disease_name=="Normal": |
|
result="No problem in your xray image" |
|
return disease_name,result |
|
|
|
|
|
|
|
def unzip_image_data(self,filespath): |
|
""" |
|
Unzips an image dataset into a specified directory. |
|
|
|
Returns: |
|
str: The path to the directory containing the extracted image files. |
|
""" |
|
try: |
|
with ZipFile(filespath,"r") as extract: |
|
directory_path = random.randrange(100) |
|
extract.extractall(f"{directory_path}") |
|
return f"{directory_path}" |
|
|
|
except Exception as e: |
|
logging.error(f"An error occurred during extraction: {e}") |
|
return "" |
|
|
|
def example_images(self,filespath): |
|
""" |
|
Unzips the image dataset and generates a list of paths to the individual image files and use image for showing example |
|
|
|
Returns: |
|
List[str]: A list of file paths to each image in the dataset. |
|
""" |
|
image_dataset_folder = self.unzip_image_data(filespath) |
|
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'] |
|
image_count = len([name for name in os.listdir(image_dataset_folder) if os.path.isfile(os.path.join(image_dataset_folder, name)) and os.path.splitext(name)[1].lower() in image_extensions]) |
|
example=[] |
|
for i in range(image_count): |
|
for name in os.listdir(image_dataset_folder): |
|
path=(os.path.join(os.path.dirname(image_dataset_folder),os.path.join(image_dataset_folder,name))) |
|
example.append(path) |
|
|
|
return example |
|
|
|
def get_example_image(self): |
|
normal_image="Normal_dataset.zip" |
|
tuberclosis_image="Tuberculosis_dataset.zip" |
|
|
|
normal_image_unziped=self.example_images(normal_image) |
|
tuberclosis_image_unziped=self.example_images(tuberclosis_image) |
|
|
|
return normal_image_unziped,tuberclosis_image_unziped |
|
|
|
def gradio_interface(self): |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.HTML("""<center><h1>Tuberculosis Disease Detection</h1></center>""") |
|
|
|
normal_image,tuberclosis_image=self.get_example_image() |
|
|
|
with gr.Row(): |
|
input_image =gr.Image(type="filepath",sources="upload") |
|
with gr.Column(): |
|
output=gr.Label(label="Disease Name") |
|
with gr.Row(): |
|
classify_disease_=gr.Textbox(label="About disease") |
|
with gr.Row(): |
|
button =gr.Button(value="Detect The Disease") |
|
|
|
button.click(self.classify_disease,[input_image],[output,classify_disease_]) |
|
|
|
gr.Examples( |
|
examples=normal_image, |
|
label="Normal X-ray Images", |
|
inputs=[input_image], |
|
outputs=[output,classify_disease_], |
|
fn=self.classify_disease, |
|
examples_per_page=5, |
|
cache_examples=False) |
|
|
|
gr.Examples( |
|
examples=tuberclosis_image, |
|
label="Tuberclosis X-ray Images", |
|
inputs=[input_image], |
|
outputs=[output,classify_disease_], |
|
examples_per_page=5, |
|
fn=self.classify_disease, |
|
cache_examples=False) |
|
|
|
|
|
demo.launch(debug=True) |
|
|
|
if __name__ == "__main__": |
|
app = DiseaseDetectionApp() |
|
result=app.gradio_interface() |
|
|