import glob import subprocess import sys from typing import List sys.path.append(".") from benchmark_text_to_image import ALL_T2I_CKPTS # noqa: E402 PATTERN = "benchmark_*.py" class SubprocessCallException(Exception): pass # Taken from `test_examples_utils.py` def run_command(command: List[str], return_stdout=False): """ Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture if an error occurred while running `command` """ try: output = subprocess.check_output(command, stderr=subprocess.STDOUT) if return_stdout: if hasattr(output, "decode"): output = output.decode("utf-8") return output except subprocess.CalledProcessError as e: raise SubprocessCallException( f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" ) from e def main(): python_files = glob.glob(PATTERN) for file in python_files: print(f"****** Running file: {file} ******") # Run with canonical settings. if file != "benchmark_text_to_image.py" and file != "benchmark_ip_adapters.py": command = f"python {file}" run_command(command.split()) command += " --run_compile" run_command(command.split()) # Run variants. for file in python_files: # See: https://github.com/pytorch/pytorch/issues/129637 if file == "benchmark_ip_adapters.py": continue if file == "benchmark_text_to_image.py": for ckpt in ALL_T2I_CKPTS: command = f"python {file} --ckpt {ckpt}" if "turbo" in ckpt: command += " --num_inference_steps 1" run_command(command.split()) command += " --run_compile" run_command(command.split()) elif file == "benchmark_sd_img.py": for ckpt in ["stabilityai/stable-diffusion-xl-refiner-1.0", "stabilityai/sdxl-turbo"]: command = f"python {file} --ckpt {ckpt}" if ckpt == "stabilityai/sdxl-turbo": command += " --num_inference_steps 2" run_command(command.split()) command += " --run_compile" run_command(command.split()) elif file in ["benchmark_sd_inpainting.py", "benchmark_ip_adapters.py"]: sdxl_ckpt = "stabilityai/stable-diffusion-xl-base-1.0" command = f"python {file} --ckpt {sdxl_ckpt}" run_command(command.split()) command += " --run_compile" run_command(command.split()) elif file in ["benchmark_controlnet.py", "benchmark_t2i_adapter.py"]: sdxl_ckpt = ( "diffusers/controlnet-canny-sdxl-1.0" if "controlnet" in file else "TencentARC/t2i-adapter-canny-sdxl-1.0" ) command = f"python {file} --ckpt {sdxl_ckpt}" run_command(command.split()) command += " --run_compile" run_command(command.split()) if __name__ == "__main__": main()