Spaces:
Paused
Paused
""" | |
Supports saving and restoring webui and extensions from a known working set of commits | |
""" | |
import os | |
import json | |
import time | |
import tqdm | |
from datetime import datetime | |
import git | |
from modules import shared, extensions, errors | |
from modules.paths_internal import script_path, config_states_dir | |
all_config_states = {} | |
def list_config_states(): | |
global all_config_states | |
all_config_states.clear() | |
os.makedirs(config_states_dir, exist_ok=True) | |
config_states = [] | |
for filename in os.listdir(config_states_dir): | |
if filename.endswith(".json"): | |
path = os.path.join(config_states_dir, filename) | |
try: | |
with open(path, "r", encoding="utf-8") as f: | |
j = json.load(f) | |
assert "created_at" in j, '"created_at" does not exist' | |
j["filepath"] = path | |
config_states.append(j) | |
except Exception as e: | |
print(f'[ERROR]: Config states {path}, {e}') | |
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True) | |
for cs in config_states: | |
timestamp = time.asctime(time.gmtime(cs["created_at"])) | |
name = cs.get("name", "Config") | |
full_name = f"{name}: {timestamp}" | |
all_config_states[full_name] = cs | |
return all_config_states | |
def get_webui_config(): | |
webui_repo = None | |
try: | |
if os.path.exists(os.path.join(script_path, ".git")): | |
webui_repo = git.Repo(script_path) | |
except Exception: | |
errors.report(f"Error reading webui git info from {script_path}", exc_info=True) | |
webui_remote = None | |
webui_commit_hash = None | |
webui_commit_date = None | |
webui_branch = None | |
if webui_repo and not webui_repo.bare: | |
try: | |
webui_remote = next(webui_repo.remote().urls, None) | |
head = webui_repo.head.commit | |
webui_commit_date = webui_repo.head.commit.committed_date | |
webui_commit_hash = head.hexsha | |
webui_branch = webui_repo.active_branch.name | |
except Exception: | |
webui_remote = None | |
return { | |
"remote": webui_remote, | |
"commit_hash": webui_commit_hash, | |
"commit_date": webui_commit_date, | |
"branch": webui_branch, | |
} | |
def get_extension_config(): | |
ext_config = {} | |
for ext in extensions.extensions: | |
ext.read_info_from_repo() | |
entry = { | |
"name": ext.name, | |
"path": ext.path, | |
"enabled": ext.enabled, | |
"is_builtin": ext.is_builtin, | |
"remote": ext.remote, | |
"commit_hash": ext.commit_hash, | |
"commit_date": ext.commit_date, | |
"branch": ext.branch, | |
"have_info_from_repo": ext.have_info_from_repo | |
} | |
ext_config[ext.name] = entry | |
return ext_config | |
def get_config(): | |
creation_time = datetime.now().timestamp() | |
webui_config = get_webui_config() | |
ext_config = get_extension_config() | |
return { | |
"created_at": creation_time, | |
"webui": webui_config, | |
"extensions": ext_config | |
} | |
def restore_webui_config(config): | |
print("* Restoring webui state...") | |
if "webui" not in config: | |
print("Error: No webui data saved to config") | |
return | |
webui_config = config["webui"] | |
if "commit_hash" not in webui_config: | |
print("Error: No commit saved to webui config") | |
return | |
webui_commit_hash = webui_config.get("commit_hash", None) | |
webui_repo = None | |
try: | |
if os.path.exists(os.path.join(script_path, ".git")): | |
webui_repo = git.Repo(script_path) | |
except Exception: | |
errors.report(f"Error reading webui git info from {script_path}", exc_info=True) | |
return | |
try: | |
webui_repo.git.fetch(all=True) | |
webui_repo.git.reset(webui_commit_hash, hard=True) | |
print(f"* Restored webui to commit {webui_commit_hash}.") | |
except Exception: | |
errors.report(f"Error restoring webui to commit{webui_commit_hash}") | |
def restore_extension_config(config): | |
print("* Restoring extension state...") | |
if "extensions" not in config: | |
print("Error: No extension data saved to config") | |
return | |
ext_config = config["extensions"] | |
results = [] | |
disabled = [] | |
for ext in tqdm.tqdm(extensions.extensions): | |
if ext.is_builtin: | |
continue | |
ext.read_info_from_repo() | |
current_commit = ext.commit_hash | |
if ext.name not in ext_config: | |
ext.disabled = True | |
disabled.append(ext.name) | |
results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled")) | |
continue | |
entry = ext_config[ext.name] | |
if "commit_hash" in entry and entry["commit_hash"]: | |
try: | |
ext.fetch_and_reset_hard(entry["commit_hash"]) | |
ext.read_info_from_repo() | |
if current_commit != entry["commit_hash"]: | |
results.append((ext, current_commit[:8], True, entry["commit_hash"][:8])) | |
except Exception as ex: | |
results.append((ext, current_commit[:8], False, ex)) | |
else: | |
results.append((ext, current_commit[:8], False, "No commit hash found in config")) | |
if not entry.get("enabled", False): | |
ext.disabled = True | |
disabled.append(ext.name) | |
else: | |
ext.disabled = False | |
shared.opts.disabled_extensions = disabled | |
shared.opts.save(shared.config_filename) | |
print("* Finished restoring extensions. Results:") | |
for ext, prev_commit, success, result in results: | |
if success: | |
print(f" + {ext.name}: {prev_commit} -> {result}") | |
else: | |
print(f" ! {ext.name}: FAILURE ({result})") | |