|
import gradio as gr |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
import subprocess |
|
import tempfile, time |
|
import shutil |
|
import os |
|
import spaces |
|
|
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
import os |
|
|
|
print ("starting the app.") |
|
|
|
def download_t5_model(model_id, save_directory): |
|
|
|
if not os.path.exists(save_directory): |
|
os.makedirs(save_directory) |
|
snapshot_download(repo_id="DeepFloyd/t5-v1_1-xxl",local_dir=save_directory, local_dir_use_symlinks=False) |
|
|
|
|
|
model_id = "DeepFloyd/t5-v1_1-xxl" |
|
save_directory = "pretrained_models/t5_ckpts/t5-v1_1-xxl" |
|
|
|
|
|
st_time_t5 = time.time() |
|
download_t5_model(model_id, save_directory) |
|
print(f"T5 Download Time : {st_time_t5-time.time()} seconds") |
|
|
|
def download_model(repo_id, model_name): |
|
model_path = hf_hub_download(repo_id=repo_id, filename=model_name) |
|
return model_path |
|
|
|
import glob |
|
|
|
@spaces.GPU(duration=1500) |
|
def run_model(temp_config_path, ckpt_path): |
|
start_time = time.time() |
|
cmd = [ |
|
"torchrun", "--standalone", "--nproc_per_node", "1", |
|
"scripts/inference.py", temp_config_path, |
|
"--ckpt-path", ckpt_path |
|
] |
|
subprocess.run(cmd) |
|
end_time = time.time() |
|
execution_time = end_time - start_time |
|
print(f"Model Execution time: {execution_time} seconds") |
|
|
|
def run_inference(model_name, prompt_text): |
|
repo_id = "hpcai-tech/Open-Sora" |
|
|
|
|
|
config_mapping = { |
|
"OpenSora-v1-16x256x256.pth": "configs/opensora/inference/16x256x256.py", |
|
"OpenSora-v1-HQ-16x256x256.pth": "configs/opensora/inference/16x256x256.py", |
|
"OpenSora-v1-HQ-16x512x512.pth": "configs/opensora/inference/16x512x512.py" |
|
} |
|
|
|
config_path = config_mapping[model_name] |
|
st_time_sora = time.time() |
|
ckpt_path = download_model(repo_id, model_name) |
|
print(f"Open-Sora Download Time : {st_time_sora-time.time()} seconds") |
|
|
|
|
|
prompt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode='w') |
|
prompt_file.write(prompt_text) |
|
prompt_file.close() |
|
|
|
with open(config_path, 'r') as file: |
|
config_content = file.read() |
|
config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_file.name}"') |
|
|
|
with tempfile.NamedTemporaryFile('w', delete=False, suffix='.py') as temp_file: |
|
temp_file.write(config_content) |
|
temp_config_path = temp_file.name |
|
|
|
run_model(temp_config_path, ckpt_path) |
|
|
|
save_dir = "./outputs/samples/" |
|
list_of_files = glob.glob(f'{save_dir}/*') |
|
if list_of_files: |
|
latest_file = max(list_of_files, key=os.path.getctime) |
|
return latest_file |
|
else: |
|
print("No files found in the output directory.") |
|
return None |
|
|
|
|
|
os.remove(temp_file.name) |
|
os.remove(prompt_file.name) |
|
|
|
|
|
|
|
def main(): |
|
gr.Interface( |
|
fn=run_inference, |
|
inputs=[ |
|
gr.Dropdown(choices=[ |
|
"OpenSora-v1-16x256x256.pth", |
|
"OpenSora-v1-HQ-16x256x256.pth", |
|
"OpenSora-v1-HQ-16x512x512.pth" |
|
], |
|
value="OpenSora-v1-16x256x256.pth", |
|
label="Model Selection"), |
|
gr.Textbox(label="Prompt Text", value="iron man riding a skateboard in new york city") |
|
], |
|
outputs=gr.Video(label="Output Video"), |
|
title="Open-Sora Inference", |
|
description="Run Open-Sora Inference with Custom Parameters", |
|
|
|
|
|
|
|
|
|
|
|
article = """ |
|
# Examples |
|
|
|
| Model | Description | Video Player Embedding | |
|
|------------------------------|----------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------| |
|
| OpenSora-v1-HQ-16x256x256.pth | Iron Man riding a skateboard in New York City | ![ironman](https://github.com/sandeshrajbhandari/open-sora-examples/assets/12326258/8173e37f-6405-44f3-aaaa-fafc88187933) | |
|
| OpenSora-v1-16x256x256.pth | A man is skiing down a snowy mountain. A drone shot from above. An avalanche is chasing him from behind. | ![skiing](https://github.com/sandeshrajbhandari/open-sora-examples/assets/12326258/d2cab73a-a77e-4e0b-a80e-668e252b6b6a) | |
|
| OpenSora-v1-16x256x256.pth | Extreme close-up of a 24-year-old woman’s eye blinking, standing in Marrakech during magic hour, cinematic film shot in 70mm, depth of field, vivid colors, cinematic | ![woman](https://github.com/sandeshrajbhandari/open-sora-examples/assets/12326258/38322939-f7bf-4f72-8a5e-ccc427970afc) | |
|
|
|
""" |
|
).launch() |
|
|
|
if __name__ == "__main__": |
|
main() |