File size: 1,487 Bytes
b38913b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gradio as gr
import subprocess
import tempfile
import shutil

def run_inference(config_path, ckpt_path, prompt_path):
    with open(config_path, 'r') as file:
        config_content = file.read()
    config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_path}"')
    
    with tempfile.NamedTemporaryFile('w', delete=False) as temp_file:
        temp_file.write(config_content)
        temp_config_path = temp_file.name

    cmd = [
        "torchrun", "--standalone", "--nproc_per_node", "1",
        "scripts/inference.py", temp_config_path,
        "--ckpt-path", ckpt_path
    ]
    result = subprocess.run(cmd, capture_output=True, text=True)

    shutil.rmtree(temp_config_path)

    if result.returncode == 0:
        return "Inference completed successfully.", result.stdout
    else:
        return "Error occurred:", result.stderr

def main():
    gr.Interface(
        fn=run_inference,
        inputs=[
            gr.Textbox(label="Configuration Path"),
            gr.Dropdown(choices=["./path/to/model1.ckpt", "./path/to/model2.ckpt", "./path/to/model3.ckpt"], label="Checkpoint Path"),
            gr.Textbox(label="Prompt Path")
        ],
        outputs=[
            gr.Text(label="Status"),
            gr.Text(label="Output")
        ],
        title="Open-Sora Inference",
        description="Run Open-Sora Inference with Custom Parameters"
    ).launch()

if __name__ == "__main__":
    main()