Open-Sora / app.py
kadirnar's picture
Update app.py
b38913b verified
raw
history blame
No virus
1.49 kB
import gradio as gr
import subprocess
import tempfile
import shutil
def run_inference(config_path, ckpt_path, prompt_path):
with open(config_path, 'r') as file:
config_content = file.read()
config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_path}"')
with tempfile.NamedTemporaryFile('w', delete=False) as temp_file:
temp_file.write(config_content)
temp_config_path = temp_file.name
cmd = [
"torchrun", "--standalone", "--nproc_per_node", "1",
"scripts/inference.py", temp_config_path,
"--ckpt-path", ckpt_path
]
result = subprocess.run(cmd, capture_output=True, text=True)
shutil.rmtree(temp_config_path)
if result.returncode == 0:
return "Inference completed successfully.", result.stdout
else:
return "Error occurred:", result.stderr
def main():
gr.Interface(
fn=run_inference,
inputs=[
gr.Textbox(label="Configuration Path"),
gr.Dropdown(choices=["./path/to/model1.ckpt", "./path/to/model2.ckpt", "./path/to/model3.ckpt"], label="Checkpoint Path"),
gr.Textbox(label="Prompt Path")
],
outputs=[
gr.Text(label="Status"),
gr.Text(label="Output")
],
title="Open-Sora Inference",
description="Run Open-Sora Inference with Custom Parameters"
).launch()
if __name__ == "__main__":
main()