File size: 4,011 Bytes
605e3c4
 
1835df3
 
605e3c4
 
 
 
 
 
bbb0a94
605e3c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1835df3
 
 
 
 
 
 
 
 
 
 
605e3c4
bbb0a94
 
 
1835df3
 
 
 
 
605e3c4
 
5f3bd7f
605e3c4
 
 
 
 
 
1835df3
605e3c4
1835df3
 
 
 
 
 
 
 
 
 
 
 
 
 
bbb0a94
1835df3
 
 
bbb0a94
1835df3
bbb0a94
1835df3
 
 
 
5c0cabb
1835df3
5c0cabb
1835df3
 
 
 
5c0cabb
1835df3
5c0cabb
1835df3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
import numpy as np

# Function to process data and return feature importances and correlation matrix
def calculate_importances(file):
    # Read uploaded file
    heart_df = pd.read_csv(file)
    
    # Set X and y
    X = heart_df.drop('target', axis=1)
    y = heart_df['target']
    
    # Split the data
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
    
    # Initialize models
    rf_model = RandomForestClassifier(random_state=42)
    xgb_model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
    cart_model = DecisionTreeClassifier(random_state=42)
    
    # Train models
    rf_model.fit(X_train, y_train)
    xgb_model.fit(X_train, y_train)
    cart_model.fit(X_train, y_train)
    
    # Get feature importances
    rf_importances = rf_model.feature_importances_
    xgb_importances = xgb_model.feature_importances_
    cart_importances = cart_model.feature_importances_
    
    feature_names = X.columns
    
    # Prepare DataFrame
    rf_importance = {'Feature': feature_names, 'Random Forest': rf_importances}
    xgb_importance = {'Feature': feature_names, 'XGBoost': xgb_importances}
    cart_importance = {'Feature': feature_names, 'CART': cart_importances}
    
    # Create DataFrames
    rf_df = pd.DataFrame(rf_importance)
    xgb_df = pd.DataFrame(xgb_importance)
    cart_df = pd.DataFrame(cart_importance)
    
    # Merge DataFrames
    importance_df = rf_df.merge(xgb_df, on='Feature').merge(cart_df, on='Feature')
    
    # Correlation Matrix
    corr_matrix = heart_df.corr()
    
    # Save to Excel
    file_name = 'feature_importances.xlsx'
    importance_df.to_excel(file_name, index=False)
    
    return file_name, importance_df, corr_matrix, rf_importances, xgb_importances, cart_importances, feature_names

# Streamlit interface
st.title("Ablation Study on Medical Features")

# File upload
uploaded_file = st.file_uploader("Upload heart.csv file", type=['csv'])

if uploaded_file is not None:
    # Process the file and get results
    excel_file, importance_df, corr_matrix, rf_importances, xgb_importances, cart_importances, feature_names = calculate_importances(uploaded_file)
    
    # Display a preview of the DataFrame
    st.write("Feature Importances (Preview):")
    st.dataframe(importance_df.head())
    
    # Provide a link to download the Excel file
    with open(excel_file, "rb") as file:
        btn = st.download_button(
            label="Download Excel File",
            data=file,
            file_name=excel_file,
            mime="application/vnd.ms-excel"
        )

    # Plot and display the Correlation Matrix
    st.write("Correlation Matrix:")
    plt.figure(figsize=(10, 8))
    sns.heatmap(corr_matrix, annot=True, fmt=".2f", cmap="coolwarm", cbar=True)
    st.pyplot(plt)
    
    # Plot and display the Feature Importance (Random Forest)
    st.write("Random Forest Feature Importance:")
    fig_rf, ax_rf = plt.subplots()
    sns.barplot(x=rf_importances, y=feature_names, ax=ax_rf)
    ax_rf.set_title('Random Forest Feature Importances')
    st.pyplot(fig_rf)
    
    # Plot and display the Feature Importance (XGBoost)
    st.write("XGBoost Feature Importance:")
    fig_xgb, ax_xgb = plt.subplots()
    sns.barplot(x=xgb_importances, y=feature_names, ax=ax_xgb)
    ax_xgb.set_title('XGBoost Feature Importances')
    st.pyplot(fig_xgb)
    
    # Plot and display the Feature Importance (Decision Tree - CART)
    st.write("CART (Decision Tree) Feature Importance:")
    fig_cart, ax_cart = plt.subplots()
    sns.barplot(x=cart_importances, y=feature_names, ax=ax_cart)
    ax_cart.set_title('CART (Decision Tree) Feature Importances')
    st.pyplot(fig_cart)