Spaces:
Paused
Paused
import argparse | |
import unittest | |
import os | |
import sys | |
import time | |
import datetime | |
from enum import Enum | |
from typing import List, Tuple | |
import cv2 | |
import requests | |
import numpy as np | |
from selenium import webdriver | |
from selenium.webdriver.common.by import By | |
from selenium.webdriver.support.ui import WebDriverWait | |
from selenium.webdriver.common.action_chains import ActionChains | |
from selenium.webdriver.support import expected_conditions as EC | |
from webdriver_manager.chrome import ChromeDriverManager | |
TIMEOUT = 20 # seconds | |
CWD = os.getcwd() | |
SKI_IMAGE = os.path.join(CWD, "images/ski.jpg") | |
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") | |
test_result_dir = os.path.join("results", f"test_result_{timestamp}") | |
test_expectation_dir = "expectations" | |
os.makedirs(test_result_dir, exist_ok=True) | |
os.makedirs(test_expectation_dir, exist_ok=True) | |
driver_path = ChromeDriverManager().install() | |
class GenType(Enum): | |
txt2img = "txt2img" | |
img2img = "img2img" | |
def _find_by_xpath(self, driver: webdriver.Chrome, xpath: str) -> "WebElement": | |
return driver.find_element(By.XPATH, xpath) | |
def tab(self, driver: webdriver.Chrome) -> "WebElement": | |
return self._find_by_xpath( | |
driver, | |
f"//*[@id='tabs']/*[contains(@class, 'tab-nav')]//button[text()='{self.value}']", | |
) | |
def controlnet_panel(self, driver: webdriver.Chrome) -> "WebElement": | |
return self._find_by_xpath( | |
driver, f"//*[@id='tab_{self.value}']//*[@id='controlnet']" | |
) | |
def generate_button(self, driver: webdriver.Chrome) -> "WebElement": | |
return self._find_by_xpath(driver, f"//*[@id='{self.value}_generate_box']") | |
def prompt_textarea(self, driver: webdriver.Chrome) -> "WebElement": | |
return self._find_by_xpath(driver, f"//*[@id='{self.value}_prompt']//textarea") | |
class SeleniumTestCase(unittest.TestCase): | |
def __init__(self, methodName: str = "runTest") -> None: | |
super().__init__(methodName) | |
self.driver = None | |
self.gen_type = None | |
def setUp(self) -> None: | |
super().setUp() | |
self.driver = webdriver.Chrome(driver_path) | |
self.driver.get(webui_url) | |
wait = WebDriverWait(self.driver, TIMEOUT) | |
wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "#controlnet"))) | |
self.gen_type = GenType.txt2img | |
def tearDown(self) -> None: | |
self.driver.quit() | |
super().tearDown() | |
def select_gen_type(self, gen_type: GenType): | |
gen_type.tab(self.driver).click() | |
self.gen_type = gen_type | |
def set_prompt(self, prompt: str): | |
textarea = self.gen_type.prompt_textarea(self.driver) | |
textarea.clear() | |
textarea.send_keys(prompt) | |
def expand_controlnet_panel(self): | |
controlnet_panel = self.gen_type.controlnet_panel(self.driver) | |
input_image_group = controlnet_panel.find_element( | |
By.CSS_SELECTOR, ".cnet-input-image-group" | |
) | |
if not input_image_group.is_displayed(): | |
controlnet_panel.click() | |
def enable_controlnet_unit(self): | |
controlnet_panel = self.gen_type.controlnet_panel(self.driver) | |
enable_checkbox = controlnet_panel.find_element( | |
By.CSS_SELECTOR, ".cnet-unit-enabled input[type='checkbox']" | |
) | |
if not enable_checkbox.is_selected(): | |
enable_checkbox.click() | |
def iterate_preprocessor_types(self, ignore_none: bool = True): | |
dropdown = self.gen_type.controlnet_panel(self.driver).find_element( | |
By.CSS_SELECTOR, | |
f"#{self.gen_type.value}_controlnet_ControlNet-0_controlnet_preprocessor_dropdown", | |
) | |
index = 0 | |
while True: | |
dropdown.click() | |
options = dropdown.find_elements( | |
By.XPATH, "//ul[contains(@class, 'options')]/li" | |
) | |
input_element = dropdown.find_element(By.CSS_SELECTOR, "input") | |
if index >= len(options): | |
return | |
option = options[index] | |
index += 1 | |
if "none" in option.text and ignore_none: | |
continue | |
option_text = option.text | |
option.click() | |
yield option_text | |
def select_control_type(self, control_type: str): | |
controlnet_panel = self.gen_type.controlnet_panel(self.driver) | |
control_type_radio = controlnet_panel.find_element( | |
By.CSS_SELECTOR, f'.controlnet_control_type input[value="{control_type}"]' | |
) | |
control_type_radio.click() | |
time.sleep(3) # Wait for gradio backend to update model/module | |
def set_seed(self, seed: int): | |
seed_input = self.driver.find_element( | |
By.CSS_SELECTOR, f"#{self.gen_type.value}_seed input[type='number']" | |
) | |
seed_input.clear() | |
seed_input.send_keys(seed) | |
def set_subseed(self, seed: int): | |
show_button = self.driver.find_element( | |
By.CSS_SELECTOR, | |
f"#{self.gen_type.value}_subseed_show input[type='checkbox']", | |
) | |
if not show_button.is_selected(): | |
show_button.click() | |
subseed_locator = ( | |
By.CSS_SELECTOR, | |
f"#{self.gen_type.value}_subseed input[type='number']", | |
) | |
WebDriverWait(self.driver, TIMEOUT).until( | |
EC.visibility_of_element_located(subseed_locator) | |
) | |
subseed_input = self.driver.find_element(*subseed_locator) | |
subseed_input.clear() | |
subseed_input.send_keys(seed) | |
def upload_controlnet_input(self, img_path: str): | |
controlnet_panel = self.gen_type.controlnet_panel(self.driver) | |
image_input = controlnet_panel.find_element( | |
By.CSS_SELECTOR, '.cnet-input-image-group .cnet-image input[type="file"]' | |
) | |
image_input.send_keys(img_path) | |
def upload_img2img_input(self, img_path: str): | |
image_input = self.driver.find_element( | |
By.CSS_SELECTOR, '#img2img_image input[type="file"]' | |
) | |
image_input.send_keys(img_path) | |
def generate_image(self, name: str): | |
self.gen_type.generate_button(self.driver).click() | |
progress_bar_locator_visible = EC.visibility_of_element_located( | |
(By.CSS_SELECTOR, f"#{self.gen_type.value}_results .progress") | |
) | |
WebDriverWait(self.driver, TIMEOUT).until(progress_bar_locator_visible) | |
WebDriverWait(self.driver, TIMEOUT * 10).until_not(progress_bar_locator_visible) | |
generated_imgs = self.driver.find_elements( | |
By.CSS_SELECTOR, | |
f"#{self.gen_type.value}_results #{self.gen_type.value}_gallery img", | |
) | |
for i, generated_img in enumerate(generated_imgs): | |
# Use requests to get the image content | |
img_content = requests.get(generated_img.get_attribute("src")).content | |
# Save the image content to a file | |
global overwrite_expectation | |
dest_dir = ( | |
test_expectation_dir if overwrite_expectation else test_result_dir | |
) | |
img_file_name = f"{self.__class__.__name__}_{name}_{i}.png" | |
with open( | |
os.path.join(dest_dir, img_file_name), | |
"wb", | |
) as img_file: | |
img_file.write(img_content) | |
if not overwrite_expectation: | |
try: | |
img1 = cv2.imread(os.path.join(test_expectation_dir, img_file_name)) | |
img2 = cv2.imread(os.path.join(test_result_dir, img_file_name)) | |
except Exception as e: | |
self.assertTrue(False, f"Get exception reading imgs: {e}") | |
continue | |
self.expect_same_image( | |
img1, | |
img2, | |
diff_img_path=os.path.join( | |
test_result_dir, img_file_name.replace(".png", "_diff.png") | |
), | |
) | |
def expect_same_image(self, img1, img2, diff_img_path: str): | |
# Calculate the difference between the two images | |
diff = cv2.absdiff(img1, img2) | |
# Set a threshold to highlight the different pixels | |
threshold = 30 | |
diff_highlighted = np.where(diff > threshold, 255, 0).astype(np.uint8) | |
# Assert that the two images are similar within a tolerance | |
similar = np.allclose(img1, img2, rtol=0.5, atol=1) | |
if not similar: | |
# Save the diff_highlighted image to inspect the differences | |
cv2.imwrite(diff_img_path, diff_highlighted) | |
self.assertTrue(similar) | |
simple_control_types = { | |
"Canny": "canny", | |
"Depth": "depth_midas", | |
"Normal": "normal_bae", | |
"OpenPose": "openpose_full", | |
"MLSD": "mlsd", | |
"Lineart": "lineart_standard (from white bg & black line)", | |
"SoftEdge": "softedge_pidinet", | |
"Scribble": "scribble_pidinet", | |
"Seg": "seg_ofade20k", | |
"Tile": "tile_resample", | |
# Shuffle and Reference are not stable, and expected to fail. | |
# The majority of pixels are same, but some outlier pixels can have big diff. | |
"Shuffle": "shuffle", | |
"Reference": "reference_only", | |
}.keys() | |
class SeleniumTxt2ImgTest(SeleniumTestCase): | |
def setUp(self) -> None: | |
super().setUp() | |
self.select_gen_type(GenType.txt2img) | |
self.set_seed(100) | |
self.set_subseed(1000) | |
def test_simple_control_types(self): | |
"""Test simple control types that only requires input image.""" | |
for control_type in simple_control_types: | |
with self.subTest(control_type=control_type): | |
self.expand_controlnet_panel() | |
self.select_control_type(control_type) | |
self.upload_controlnet_input(SKI_IMAGE) | |
self.generate_image(f"{control_type}_ski") | |
class SeleniumImg2ImgTest(SeleniumTestCase): | |
def setUp(self) -> None: | |
super().setUp() | |
self.select_gen_type(GenType.img2img) | |
self.set_seed(100) | |
self.set_subseed(1000) | |
def test_simple_control_types(self): | |
"""Test simple control types that only requires input image.""" | |
for control_type in simple_control_types: | |
with self.subTest(control_type=control_type): | |
self.expand_controlnet_panel() | |
self.select_control_type(control_type) | |
self.upload_img2img_input(SKI_IMAGE) | |
self.upload_controlnet_input(SKI_IMAGE) | |
self.generate_image(f"img2img_{control_type}_ski") | |
class SeleniumInpaintTest(SeleniumTestCase): | |
def setUp(self) -> None: | |
super().setUp() | |
def draw_inpaint_mask(self, target_canvas): | |
size = target_canvas.size | |
width = size["width"] | |
height = size["height"] | |
brush_radius = 5 | |
repeat = int(width * 0.1 / brush_radius) | |
trace: List[Tuple[int, int]] = [ | |
(brush_radius, 0), | |
(0, height * 0.2), | |
(brush_radius, 0), | |
(0, -height * 0.2), | |
] * repeat | |
actions = ActionChains(self.driver) | |
actions.move_to_element(target_canvas) # move to the canvas | |
actions.move_by_offset(*trace[0]) | |
actions.click_and_hold() # click and hold the left mouse button down | |
for stop_point in trace[1:]: | |
actions.move_by_offset(*stop_point) | |
actions.release() # release the left mouse button | |
actions.perform() # perform the action chain | |
def draw_cn_mask(self): | |
canvas = self.gen_type.controlnet_panel(self.driver).find_element( | |
By.CSS_SELECTOR, ".cnet-input-image-group .cnet-image canvas" | |
) | |
self.draw_inpaint_mask(canvas) | |
def draw_a1111_mask(self): | |
canvas = self.driver.find_element(By.CSS_SELECTOR, "#img2maskimg canvas") | |
self.draw_inpaint_mask(canvas) | |
def test_txt2img_inpaint(self): | |
self.select_gen_type(GenType.txt2img) | |
self.expand_controlnet_panel() | |
self.select_control_type("Inpaint") | |
self.upload_controlnet_input(SKI_IMAGE) | |
self.draw_cn_mask() | |
self.set_seed(100) | |
self.set_subseed(1000) | |
for option in self.iterate_preprocessor_types(): | |
with self.subTest(option=option): | |
self.generate_image(f"{option}_txt2img_ski") | |
def test_img2img_inpaint(self): | |
# Note: img2img inpaint can only use A1111 mask. | |
# ControlNet input is disabled in img2img inpaint. | |
self._test_img2img_inpaint(use_cn_mask=False, use_a1111_mask=True) | |
def _test_img2img_inpaint(self, use_cn_mask: bool, use_a1111_mask: bool): | |
self.select_gen_type(GenType.img2img) | |
self.expand_controlnet_panel() | |
self.select_control_type("Inpaint") | |
self.upload_img2img_input(SKI_IMAGE) | |
# Send to inpaint | |
self.driver.find_element( | |
By.XPATH, f"//*[@id='img2img_copy_to_img2img']//button[text()='inpaint']" | |
).click() | |
time.sleep(3) | |
# Select latent noise to make inpaint effect more visible. | |
self.driver.find_element( | |
By.XPATH, | |
f"//input[@name='radio-img2img_inpainting_fill' and @value='latent noise']", | |
).click() | |
self.set_prompt("(coca-cola:2.0)") | |
self.enable_controlnet_unit() | |
self.upload_controlnet_input(SKI_IMAGE) | |
self.set_seed(100) | |
self.set_subseed(1000) | |
prefix = "" | |
if use_cn_mask: | |
self.draw_cn_mask() | |
prefix += "controlnet" | |
if use_a1111_mask: | |
self.draw_a1111_mask() | |
prefix += "A1111" | |
for option in self.iterate_preprocessor_types(): | |
with self.subTest(option=option, mask_prefix=prefix): | |
self.generate_image(f"{option}_{prefix}_img2img_ski") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Your script description.") | |
parser.add_argument( | |
"--overwrite_expectation", action="store_true", help="overwrite expectation" | |
) | |
parser.add_argument( | |
"--target_url", type=str, default="http://localhost:7860", help="WebUI URL" | |
) | |
args, unknown_args = parser.parse_known_args() | |
overwrite_expectation = args.overwrite_expectation | |
webui_url = args.target_url | |
sys.argv = sys.argv[:1] + unknown_args | |
unittest.main() | |