TESTTT / app.py
Roberta2024's picture
Update app.py
5f3bd7f verified
raw
history blame contribute delete
No virus
4.01 kB
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
import numpy as np
# Function to process data and return feature importances and correlation matrix
def calculate_importances(file):
# Read uploaded file
heart_df = pd.read_csv(file)
# Set X and y
X = heart_df.drop('target', axis=1)
y = heart_df['target']
# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
# Initialize models
rf_model = RandomForestClassifier(random_state=42)
xgb_model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
cart_model = DecisionTreeClassifier(random_state=42)
# Train models
rf_model.fit(X_train, y_train)
xgb_model.fit(X_train, y_train)
cart_model.fit(X_train, y_train)
# Get feature importances
rf_importances = rf_model.feature_importances_
xgb_importances = xgb_model.feature_importances_
cart_importances = cart_model.feature_importances_
feature_names = X.columns
# Prepare DataFrame
rf_importance = {'Feature': feature_names, 'Random Forest': rf_importances}
xgb_importance = {'Feature': feature_names, 'XGBoost': xgb_importances}
cart_importance = {'Feature': feature_names, 'CART': cart_importances}
# Create DataFrames
rf_df = pd.DataFrame(rf_importance)
xgb_df = pd.DataFrame(xgb_importance)
cart_df = pd.DataFrame(cart_importance)
# Merge DataFrames
importance_df = rf_df.merge(xgb_df, on='Feature').merge(cart_df, on='Feature')
# Correlation Matrix
corr_matrix = heart_df.corr()
# Save to Excel
file_name = 'feature_importances.xlsx'
importance_df.to_excel(file_name, index=False)
return file_name, importance_df, corr_matrix, rf_importances, xgb_importances, cart_importances, feature_names
# Streamlit interface
st.title("Ablation Study on Medical Features")
# File upload
uploaded_file = st.file_uploader("Upload heart.csv file", type=['csv'])
if uploaded_file is not None:
# Process the file and get results
excel_file, importance_df, corr_matrix, rf_importances, xgb_importances, cart_importances, feature_names = calculate_importances(uploaded_file)
# Display a preview of the DataFrame
st.write("Feature Importances (Preview):")
st.dataframe(importance_df.head())
# Provide a link to download the Excel file
with open(excel_file, "rb") as file:
btn = st.download_button(
label="Download Excel File",
data=file,
file_name=excel_file,
mime="application/vnd.ms-excel"
)
# Plot and display the Correlation Matrix
st.write("Correlation Matrix:")
plt.figure(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=True, fmt=".2f", cmap="coolwarm", cbar=True)
st.pyplot(plt)
# Plot and display the Feature Importance (Random Forest)
st.write("Random Forest Feature Importance:")
fig_rf, ax_rf = plt.subplots()
sns.barplot(x=rf_importances, y=feature_names, ax=ax_rf)
ax_rf.set_title('Random Forest Feature Importances')
st.pyplot(fig_rf)
# Plot and display the Feature Importance (XGBoost)
st.write("XGBoost Feature Importance:")
fig_xgb, ax_xgb = plt.subplots()
sns.barplot(x=xgb_importances, y=feature_names, ax=ax_xgb)
ax_xgb.set_title('XGBoost Feature Importances')
st.pyplot(fig_xgb)
# Plot and display the Feature Importance (Decision Tree - CART)
st.write("CART (Decision Tree) Feature Importance:")
fig_cart, ax_cart = plt.subplots()
sns.barplot(x=cart_importances, y=feature_names, ax=ax_cart)
ax_cart.set_title('CART (Decision Tree) Feature Importances')
st.pyplot(fig_cart)