Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import shutil | |
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
import torchaudio | |
from einops import rearrange | |
import psutil | |
import humanize | |
import spaces | |
from transformers import ( | |
AutoProcessor, | |
AutoModelForVision2Seq, | |
pipeline | |
) | |
from huggingface_hub import scan_cache_dir | |
from stable_audio_tools import get_pretrained_model | |
from stable_audio_tools.inference.generation import generate_diffusion_cond | |
# Cache setup code remains same | |
CACHE_ROOT = '/tmp' | |
os.environ['HF_HOME'] = CACHE_ROOT | |
os.environ['HUGGINGFACE_HUB_CACHE'] = os.path.join(CACHE_ROOT, 'hub') | |
os.environ['XDG_CACHE_HOME'] = os.path.join(CACHE_ROOT, 'cache') | |
# Global model variables | |
kosmos_model = None | |
kosmos_processor = None | |
zephyr_pipe = None | |
audio_model = None | |
audio_config = None | |
def initialize_models(): | |
global kosmos_model, kosmos_processor, zephyr_pipe, audio_model, audio_config | |
try: | |
print("Loading Kosmos-2...") | |
kosmos_model = AutoModelForVision2Seq.from_pretrained( | |
"microsoft/kosmos-2-patch14-224", | |
device_map="auto", | |
torch_dtype=torch.float16 | |
) | |
kosmos_processor = AutoProcessor.from_pretrained( | |
"microsoft/kosmos-2-patch14-224") | |
if torch.cuda.is_available(): | |
kosmos_model = kosmos_model.to("cuda") | |
except Exception as e: | |
print(f"Error loading Kosmos-2: {e}") | |
raise | |
try: | |
print("Loading Zephyr...") | |
zephyr_pipe = pipeline( | |
"text-generation", | |
model="HuggingFaceH4/zephyr-7b-beta", | |
torch_dtype=torch.bfloat16, | |
device_map="auto" | |
) | |
except Exception as e: | |
print(f"Error loading Zephyr: {e}") | |
raise | |
try: | |
print("Loading Stable Audio...") | |
audio_model, audio_config = get_pretrained_model("stabilityai/stable-audio-open-1.0") | |
if torch.cuda.is_available(): | |
audio_model = audio_model.to("cuda") | |
except Exception as e: | |
print(f"Error loading Stable Audio: {e}") | |
raise | |
def get_caption(image_in): | |
if not image_in: | |
raise gr.Error("Please provide an image") | |
try: | |
# Convert image to PIL if needed | |
if isinstance(image_in, str): | |
image = Image.open(image_in) | |
elif isinstance(image_in, np.ndarray): | |
image = Image.fromarray(image_in) | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
prompt = "<grounding>Describe this image in detail without names:" | |
inputs = kosmos_processor(text=prompt, images=image, return_tensors="pt") | |
device = next(kosmos_model.parameters()).device | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
generated_ids = kosmos_model.generate( | |
pixel_values=inputs["pixel_values"], | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
image_embeds_position_mask=inputs["image_embeds_position_mask"], | |
max_new_tokens=128, | |
) | |
generated_text = kosmos_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
processed_text, _ = kosmos_processor.post_process_generation(generated_text) | |
# Clean up output | |
for prefix in ["Describe this image in detail without names", "An image of", "<grounding>"]: | |
processed_text = processed_text.replace(prefix, "").strip() | |
return processed_text | |
except Exception as e: | |
raise gr.Error(f"Image caption generation failed: {str(e)}") | |
# Continuing from previous code... | |
def get_musical_prompt(user_prompt, chosen_model): | |
if not user_prompt: | |
raise gr.Error("No image caption provided") | |
try: | |
standard_sys = """ | |
You are a musician AI who specializes in translating architectural spaces into musical experiences. Your job is to create concise musical descriptions that capture the essence of architectural photographs. | |
Consider these elements in your composition: | |
- Spatial Experience: expansive/intimate spaces, layered forms, acoustical qualities | |
- Materials & Textures: metallic, glass, concrete translated into instrumental textures | |
- Musical Elements: blend of classical structure and jazz improvisation | |
- Orchestration: symphonic layers, solo instruments, or ensemble variations | |
- Soundscapes: environmental depth and spatial audio qualities | |
Respond immediately with a single musical prompt. No explanation, just the musical description. | |
""" | |
instruction = f""" | |
<|system|> | |
{standard_sys}</s> | |
<|user|> | |
{user_prompt}</s> | |
""" | |
outputs = zephyr_pipe( | |
instruction.strip(), | |
max_new_tokens=256, | |
do_sample=True, | |
temperature=0.75, | |
top_k=50, | |
top_p=0.92 | |
) | |
musical_prompt = outputs[0]["generated_text"] | |
# Clean system message and tokens | |
cleaned_prompt = musical_prompt.replace("<|system|>", "").replace("</s>", "").replace("<|user|>", "").replace("<|assistant|>", "") | |
lines = cleaned_prompt.split('\n') | |
relevant_lines = [line.strip() for line in lines | |
if line.strip() and | |
not line.startswith('-') and | |
not line.startswith('Example') and | |
not line.startswith('Instructions') and | |
not line.startswith('Consider') and | |
not line.startswith('Incorporate')] | |
if relevant_lines: | |
final_prompt = relevant_lines[-1].strip() | |
if len(final_prompt) >= 10: | |
return final_prompt | |
raise ValueError("Could not extract valid musical prompt") | |
except Exception as e: | |
print(f"Error in get_musical_prompt: {str(e)}") | |
return "Ambient orchestral composition with piano and strings, creating a contemplative atmosphere" | |
def get_stable_audio_open(prompt, seconds_total=47, steps=100, cfg_scale=7): | |
try: | |
torch.cuda.empty_cache() # Clear GPU memory before generation | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
sample_rate = audio_config["sample_rate"] | |
sample_size = audio_config["sample_size"] | |
# Set up conditioning | |
conditioning = [{ | |
"prompt": prompt, | |
"seconds_start": 0, | |
"seconds_total": seconds_total | |
}] | |
# Generate audio | |
output = generate_diffusion_cond( | |
audio_model, | |
steps=steps, | |
cfg_scale=cfg_scale, | |
conditioning=conditioning, | |
sample_size=sample_size, | |
sigma_min=0.3, | |
sigma_max=500, | |
sampler_type="dpmpp-3m-sde", | |
device=device | |
) | |
output = rearrange(output, "b d n -> d (b n)") | |
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() | |
# Save to temporary file | |
output_path = os.path.join(CACHE_ROOT, f"output_{os.urandom(8).hex()}.wav") | |
torchaudio.save(output_path, output, sample_rate) | |
return output_path | |
except Exception as e: | |
torch.cuda.empty_cache() # Clear GPU memory on error | |
raise gr.Error(f"Music generation failed: {str(e)}") | |
def check_api(): | |
try: | |
if all([kosmos_model, kosmos_processor, zephyr_pipe, audio_model, audio_config]): | |
return "Orchestra ready. 🎹 👁️ 🎼" | |
return "Orchestra is tuning..." | |
except Exception: | |
return "Orchestra is tuning..." | |
# Rest of the utility functions remain the same | |
def get_storage_info(): | |
disk_usage = psutil.disk_usage('/tmp') | |
used = humanize.naturalsize(disk_usage.used) | |
total = humanize.naturalsize(disk_usage.total) | |
percent = disk_usage.percent | |
return f"Storage: {used}/{total} ({percent}% used)" | |
def smart_cleanup(): | |
try: | |
cache_info = scan_cache_dir() | |
seen_models = {} | |
for repo in cache_info.repos: | |
model_id = repo.repo_id | |
if model_id not in seen_models: | |
seen_models[model_id] = [] | |
seen_models[model_id].append(repo) | |
for model_id, repos in seen_models.items(): | |
if len(repos) > 1: | |
repos.sort(key=lambda x: x.last_modified, reverse=True) | |
for repo in repos[1:]: | |
shutil.rmtree(repo.repo_path) | |
print(f"Removed duplicate cache for {model_id}") | |
return get_storage_info() | |
except Exception as e: | |
print(f"Error during cleanup: {e}") | |
return "Cleanup error occurred" | |
def get_image_examples(): | |
image_dir = "images" | |
image_extensions = ['.jpg', '.jpeg', '.png'] | |
examples = [] | |
if not os.path.exists(image_dir): | |
print(f"Warning: Image directory '{image_dir}' not found") | |
return [] | |
for filename in os.listdir(image_dir): | |
if any(filename.lower().endswith(ext) for ext in image_extensions): | |
examples.append([os.path.join(image_dir, filename)]) | |
return examples | |
def infer(image_in, api_status): | |
if image_in is None: | |
raise gr.Error("Please provide an image of architecture") | |
if api_status == "Orchestra is tuning...": | |
raise gr.Error("The model is still tuning, please try again later") | |
try: | |
gr.Info("🎭 Finding a poetry in form and light...") | |
user_prompt = get_caption(image_in) | |
gr.Info("🎼 Weaving into melody...") | |
musical_prompt = get_musical_prompt(user_prompt, "Stable Audio Open") | |
gr.Info("🎻 Breathing life into notes...") | |
music_o = get_stable_audio_open(musical_prompt) | |
torch.cuda.empty_cache() # Clear GPU memory after generation | |
return gr.update(value=musical_prompt, interactive=True), gr.update(visible=True), music_o | |
except Exception as e: | |
torch.cuda.empty_cache() | |
raise gr.Error(f"Generation failed: {str(e)}") | |
def retry(caption): | |
musical_prompt = caption | |
gr.Info("🎹 Refreshing with a new vibe...") | |
music_o = get_stable_audio_open(musical_prompt) | |
return music_o | |
# UI Definition | |
demo_title = "Musical Toy for Frank" | |
description = "A humble attempt to hear Architecture through Music" | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 980px; | |
text-align: left; | |
} | |
#inspi-prompt textarea { | |
font-size: 20px; | |
line-height: 24px; | |
font-weight: 600; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
# UI layout remains exactly the same as in your original code | |
with gr.Column(elem_id="col-container"): | |
gr.HTML(f""" | |
<h2 style="text-align: center;">{demo_title}</h2> | |
<p style="text-align: center;">{description}</p> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
image_in = gr.Image( | |
label="Inspire us:", | |
type="filepath", | |
elem_id="image-in" | |
) | |
gr.Examples( | |
examples=get_image_examples(), | |
fn=infer, | |
inputs=[image_in], | |
examples_per_page=5, | |
label="♪ ♪ ..." | |
) | |
submit_btn = gr.Button("Listen to it...") | |
with gr.Column(): | |
check_status = gr.Textbox( | |
label="Status", | |
interactive=False, | |
value=check_api() | |
) | |
caption = gr.Textbox( | |
label="Explanation & Inspiration...", | |
interactive=False, | |
elem_id="inspi-prompt" | |
) | |
retry_btn = gr.Button("🎲", visible=False) | |
result = gr.Audio( | |
label="Music" | |
) | |
# Credits section remains the same | |
gr.HTML(""" | |
<div style="margin-top: 40px; padding: 20px; border-top: 1px solid #ddd;"> | |
<!-- Your existing credits HTML --> | |
</div> | |
""") | |
# Event handlers | |
demo.load( | |
fn=check_api, | |
outputs=check_status, | |
) | |
retry_btn.click( | |
fn=retry, | |
inputs=[caption], | |
outputs=[result] | |
) | |
submit_btn.click( | |
fn=infer, | |
inputs=[ | |
image_in, | |
check_status | |
], | |
outputs=[ | |
caption, | |
retry_btn, | |
result | |
] | |
) | |
with gr.Column(): | |
storage_info = gr.Textbox(label="Storage Info", value=get_storage_info()) | |
cleanup_btn = gr.Button("Smart Cleanup") | |
cleanup_btn.click( | |
fn=smart_cleanup, | |
outputs=storage_info | |
) | |
if __name__ == "__main__": | |
print("Initializing models...") | |
initialize_models() | |
print("Models initialized successfully") | |
demo.queue(max_size=16).launch( | |
show_api=False, | |
show_error=True, | |
server_name="0.0.0.0", | |
server_port=7860, | |
) |