Spaces:
Running
on
L40S
Running
on
L40S
File size: 6,208 Bytes
d69879c e123fec d69879c e123fec d69879c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import os
import logging
import torch
import asyncio
import aiohttp
import requests
from huggingface_hub import hf_hub_download
# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Configuration
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
MODELS_DIR = os.path.join(DATA_ROOT, "models")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hugging Face repository information
HF_REPO_ID = "jbilcke-hf/model-cocktail"
# Model files to download
MODEL_FILES = [
"dwpose/dw-ll_ucoco_384.pth",
"face-detector/s3fd-619a316812.pth",
"liveportrait/spade_generator.pth",
"liveportrait/warping_module.pth",
"liveportrait/motion_extractor.pth",
"liveportrait/stitching_retargeting_module.pth",
"liveportrait/appearance_feature_extractor.pth",
"liveportrait/landmark.onnx",
# For animal mode πΆπ±
# however they say animal mode doesn't support stitching yet?
# https://github.com/KwaiVGI/LivePortrait/blob/main/assets/docs/changelog/2024-08-02.md#updates-on-animals-mode
#"liveportrait-animals/warping_module.pth",
#"liveportrait-animals/spade_generator.pth",
#"liveportrait-animals/motion_extractor.pth",
#"liveportrait-animals/appearance_feature_extractor.pth",
#"liveportrait-animals/stitching_retargeting_module.pth",
#"liveportrait-animals/xpose.pth",
# this is a hack, instead we should probably try to
# fix liveportrait/utils/dependencies/insightface/utils/storage.py
"insightface/models/buffalo_l.zip",
"insightface/buffalo_l/det_10g.onnx",
"insightface/buffalo_l/2d106det.onnx",
"sd-vae-ft-mse/diffusion_pytorch_model.bin",
"sd-vae-ft-mse/diffusion_pytorch_model.safetensors",
"sd-vae-ft-mse/config.json",
# we don't use those yet
#"flux-dev/flux-dev-fp8.safetensors",
#"flux-dev/flux_dev_quantization_map.json",
#"pulid-flux/pulid_flux_v0.9.0.safetensors",
#"pulid-flux/pulid_v1.bin"
]
def create_directory(directory):
"""Create a directory if it doesn't exist and log its status."""
if not os.path.exists(directory):
os.makedirs(directory)
logger.info(f" Directory created: {directory}")
else:
logger.info(f" Directory already exists: {directory}")
def print_directory_structure(startpath):
"""Print the directory structure starting from the given path."""
for root, dirs, files in os.walk(startpath):
level = root.replace(startpath, '').count(os.sep)
indent = ' ' * 4 * level
logger.info(f"{indent}{os.path.basename(root)}/")
subindent = ' ' * 4 * (level + 1)
for f in files:
logger.info(f"{subindent}{f}")
async def download_hf_file(filename: str) -> None:
"""Download a file from Hugging Face to the models directory."""
dest = os.path.join(MODELS_DIR, filename)
os.makedirs(os.path.dirname(dest), exist_ok=True)
if os.path.exists(dest):
# this is really for debugging purposes only
logger.debug(f" β
{filename}")
return
logger.info(f" β³ Downloading {HF_REPO_ID}/{filename}")
try:
await asyncio.get_event_loop().run_in_executor(
None,
lambda: hf_hub_download(
repo_id=HF_REPO_ID,
filename=filename,
local_dir=MODELS_DIR
)
)
logger.info(f" β
Downloaded {filename}")
except Exception as e:
logger.error(f"π¨ Error downloading file from Hugging Face: {e}")
if os.path.exists(dest):
os.remove(dest)
raise
async def download_all_models():
"""Download all required models from the Hugging Face repository."""
logger.info(" π Looking for models...")
tasks = [download_hf_file(filename) for filename in MODEL_FILES]
await asyncio.gather(*tasks)
logger.info(" β
All models are available")
# are you looking to debug the app and verify that models are downloaded properly?
# then un-comment the two following lines:
#logger.info("π‘ Printing directory structure of models:")
#print_directory_structure(MODELS_DIR)
class ModelLoader:
"""A class responsible for loading and initializing all required models."""
def __init__(self):
self.device = DEVICE
self.models_dir = MODELS_DIR
async def load_live_portrait(self):
"""Load LivePortrait models."""
from liveportrait.config.inference_config import InferenceConfig
from liveportrait.config.crop_config import CropConfig
from liveportrait.live_portrait_pipeline import LivePortraitPipeline
logger.info(" β³ Loading LivePortrait models...")
live_portrait_pipeline = await asyncio.to_thread(
LivePortraitPipeline,
inference_cfg=InferenceConfig(
# default values
flag_stitching=True, # we recommend setting it to True!
flag_relative=True, # whether to use relative motion
flag_pasteback=True, # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
flag_do_crop= True, # whether to crop the source portrait to the face-cropping space
flag_do_rot=True, # whether to conduct the rotation when flag_do_crop is True
),
crop_cfg=CropConfig()
)
logger.info(" β
LivePortrait models loaded successfully.")
return live_portrait_pipeline
async def initialize_models():
"""Initialize and load all required models."""
logger.info("π Starting model initialization...")
# Ensure all required models are downloaded
await download_all_models()
# Initialize the ModelLoader
loader = ModelLoader()
# Load LivePortrait models
live_portrait = await loader.load_live_portrait()
logger.info("β
Model initialization completed.")
return live_portrait
# Initial setup
logger.info("π Setting up storage directories...")
create_directory(MODELS_DIR)
logger.info("β
Storage directories setup completed.")
|