Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.model_selection import train_test_split | |
from tensorflow.keras.models import Sequential | |
from tensorflow.keras.layers import Dense, Dropout, LSTM | |
import tensorflow as tf | |
import json | |
import datetime | |
import os | |
import plotly.express as px | |
import logging | |
from typing import Dict, List, Optional | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
class VRTherapySystem: | |
def __init__(self): | |
"""Initialize the VR Therapy System""" | |
try: | |
self.data_dir = "vr_therapy_data" | |
os.makedirs(self.data_dir, exist_ok=True) | |
self.session_data = self._load_or_create_session_data() | |
self.user_profiles = self._load_or_create_user_profiles() | |
logger.info("VR Therapy System initialized successfully") | |
except Exception as e: | |
logger.error(f"Error initializing VR Therapy System: {str(e)}") | |
raise | |
def _load_or_create_session_data(self) -> pd.DataFrame: | |
"""Load existing session data or create new DataFrame""" | |
try: | |
file_path = os.path.join(self.data_dir, 'session_data.csv') | |
if os.path.exists(file_path): | |
return pd.read_csv(file_path) | |
else: | |
df = pd.DataFrame(columns=[ | |
'user_id', 'timestamp', 'session_duration', | |
'pain_reduction', 'mobility_improvement' | |
]) | |
df.to_csv(file_path, index=False) | |
return df | |
except Exception as e: | |
logger.error(f"Error loading session data: {str(e)}") | |
return pd.DataFrame() | |
def _load_or_create_user_profiles(self) -> pd.DataFrame: | |
"""Load existing user profiles or create new DataFrame""" | |
try: | |
file_path = os.path.join(self.data_dir, 'user_profiles.csv') | |
if os.path.exists(file_path): | |
return pd.read_csv(file_path) | |
else: | |
df = pd.DataFrame(columns=[ | |
'user_id', 'age', 'condition', 'therapy_goals' | |
]) | |
df.to_csv(file_path, index=False) | |
return df | |
except Exception as e: | |
logger.error(f"Error loading user profiles: {str(e)}") | |
return pd.DataFrame() | |
def save_user_profile(self, user_id: str, age: int, condition: str, | |
therapy_goals: str) -> str: | |
"""Save or update user profile""" | |
try: | |
new_profile = pd.DataFrame([{ | |
'user_id': user_id, | |
'age': age, | |
'condition': condition, | |
'therapy_goals': therapy_goals | |
}]) | |
# Update existing or append new | |
if user_id in self.user_profiles['user_id'].values: | |
self.user_profiles.loc[ | |
self.user_profiles['user_id'] == user_id | |
] = new_profile.iloc[0] | |
else: | |
self.user_profiles = pd.concat( | |
[self.user_profiles, new_profile], | |
ignore_index=True | |
) | |
# Save to CSV | |
self.user_profiles.to_csv( | |
os.path.join(self.data_dir, 'user_profiles.csv'), | |
index=False | |
) | |
logger.info(f"Profile saved successfully for user {user_id}") | |
return "Profile saved successfully" | |
except Exception as e: | |
error_msg = f"Error saving user profile: {str(e)}" | |
logger.error(error_msg) | |
return error_msg | |
def generate_therapy_session(self, user_id: str, pain_level: int, | |
mobility_score: int) -> str: | |
"""Generate a personalized therapy session""" | |
try: | |
difficulty = self._calculate_difficulty(pain_level, mobility_score) | |
session = self._create_session_plan(difficulty) | |
logger.info(f"Therapy session generated for user {user_id}") | |
return json.dumps(session, indent=2) | |
except Exception as e: | |
error_msg = f"Error generating therapy session: {str(e)}" | |
logger.error(error_msg) | |
return json.dumps({"error": error_msg}) | |
def _calculate_difficulty(self, pain_level: int, mobility_score: int) -> str: | |
"""Calculate session difficulty""" | |
try: | |
score = (10 - pain_level) * 0.3 + mobility_score * 0.7 | |
if score < 4: | |
return "basic" | |
elif score < 7: | |
return "intermediate" | |
else: | |
return "advanced" | |
except Exception as e: | |
logger.error(f"Error calculating difficulty: {str(e)}") | |
return "basic" | |
def _create_session_plan(self, difficulty: str) -> Dict: | |
"""Create a therapy session plan""" | |
exercises = { | |
"basic": [ | |
"Guided Breathing", | |
"Gentle Stretching", | |
"Simple Range of Motion" | |
], | |
"intermediate": [ | |
"Balance Training", | |
"Strength Exercises", | |
"Coordination Tasks" | |
], | |
"advanced": [ | |
"Complex Movement Patterns", | |
"Endurance Training", | |
"Dynamic Balance" | |
] | |
} | |
return { | |
"difficulty": difficulty, | |
"exercises": exercises.get(difficulty, exercises["basic"]), | |
"duration": 30, | |
"rest_periods": "As needed", | |
"modifications": "Available upon request" | |
} | |
def log_session_progress(self, user_id: str, session_duration: int, | |
pain_reduction: int, mobility_improvement: int) -> str: | |
"""Log therapy session progress""" | |
try: | |
new_session = pd.DataFrame([{ | |
'user_id': user_id, | |
'timestamp': datetime.datetime.now().isoformat(), | |
'session_duration': session_duration, | |
'pain_reduction': pain_reduction, | |
'mobility_improvement': mobility_improvement | |
}]) | |
self.session_data = pd.concat( | |
[self.session_data, new_session], | |
ignore_index=True | |
) | |
# Save to CSV | |
self.session_data.to_csv( | |
os.path.join(self.data_dir, 'session_data.csv'), | |
index=False | |
) | |
logger.info(f"Session progress logged for user {user_id}") | |
return "Session progress logged successfully" | |
except Exception as e: | |
error_msg = f"Error logging session progress: {str(e)}" | |
logger.error(error_msg) | |
return error_msg | |
def get_user_analytics(self, user_id: str) -> str: | |
"""Generate user analytics""" | |
try: | |
user_sessions = self.session_data[ | |
self.session_data['user_id'] == user_id | |
] | |
if len(user_sessions) == 0: | |
return json.dumps({"message": "No sessions found for this user"}) | |
analytics = { | |
"total_sessions": len(user_sessions), | |
"average_duration": user_sessions['session_duration'].mean(), | |
"average_pain_reduction": user_sessions['pain_reduction'].mean(), | |
"average_mobility_improvement": user_sessions['mobility_improvement'].mean(), | |
"progress_trend": user_sessions['mobility_improvement'].tolist() | |
} | |
logger.info(f"Analytics generated for user {user_id}") | |
return json.dumps(analytics, indent=2) | |
except Exception as e: | |
error_msg = f"Error generating analytics: {str(e)}" | |
logger.error(error_msg) | |
return json.dumps({"error": error_msg}) | |
# Create Gradio interface | |
def create_interface(): | |
try: | |
vr_system = VRTherapySystem() | |
with gr.Blocks(title="VR Therapy System") as interface: | |
gr.Markdown("# VR Therapy and Rehabilitation System") | |
with gr.Tab("User Profile"): | |
with gr.Row(): | |
user_id = gr.Textbox(label="User ID") | |
age = gr.Number(label="Age") | |
condition = gr.Textbox(label="Medical Condition") | |
therapy_goals = gr.TextArea(label="Therapy Goals") | |
save_profile_btn = gr.Button("Save Profile") | |
profile_output = gr.Textbox(label="Profile Status") | |
with gr.Tab("Therapy Session"): | |
with gr.Row(): | |
session_user_id = gr.Textbox(label="User ID") | |
pain_level = gr.Slider(1, 10, label="Pain Level") | |
mobility_score = gr.Slider(1, 10, label="Mobility Score") | |
generate_btn = gr.Button("Generate Session") | |
session_output = gr.JSON(label="Session Plan") | |
with gr.Tab("Progress Logging"): | |
with gr.Row(): | |
log_user_id = gr.Textbox(label="User ID") | |
duration = gr.Number(label="Session Duration (minutes)") | |
pain_reduction = gr.Slider(0, 10, label="Pain Reduction") | |
mobility_improvement = gr.Slider(0, 10, label="Mobility Improvement") | |
log_btn = gr.Button("Log Progress") | |
log_output = gr.Textbox(label="Logging Status") | |
with gr.Tab("Analytics"): | |
analytics_user_id = gr.Textbox(label="User ID") | |
analytics_btn = gr.Button("Generate Analytics") | |
analytics_output = gr.JSON(label="User Analytics") | |
# Connect interface functions | |
save_profile_btn.click( | |
vr_system.save_user_profile, | |
inputs=[user_id, age, condition, therapy_goals], | |
outputs=profile_output | |
) | |
generate_btn.click( | |
vr_system.generate_therapy_session, | |
inputs=[session_user_id, pain_level, mobility_score], | |
outputs=session_output | |
) | |
log_btn.click( | |
vr_system.log_session_progress, | |
inputs=[log_user_id, duration, pain_reduction, mobility_improvement], | |
outputs=log_output | |
) | |
analytics_btn.click( | |
vr_system.get_user_analytics, | |
inputs=analytics_user_id, | |
outputs=analytics_output | |
) | |
return interface | |
except Exception as e: | |
logger.error(f"Error creating interface: {str(e)}") | |
raise | |
# Launch the application | |
if __name__ == "__main__": | |
try: | |
interface = create_interface() | |
interface.launch(share=True) | |
logger.info("VR Therapy System launched successfully") | |
except Exception as e: | |
logger.error(f"Error launching application: {str(e)}") | |
print(f"Error: {str(e)}") |