hometown / utils.py
peteli's picture
Upload 3 files
4f23115
raw
history blame
No virus
4.37 kB
from paddlenlp.utils.serialization import load_torch
import paddle
import safetensors.numpy
import os
import ppdiffusers
from contextlib import contextmanager
@contextmanager
def context_nologging():
ppdiffusers.utils.logging.set_verbosity_error()
try:
yield
finally:
ppdiffusers.utils.logging.set_verbosity_info()
__all__ = ['convert_paddle_lora_to_safetensor_lora', 'convert_pytorch_lora_to_paddle_lora']
def convert_paddle_lora_to_safetensor_lora(paddle_file, safe_file=None):
if not os.path.exists(paddle_file):
print(f"{paddle_file} 文件不存在!")
return
if safe_file is None:
safe_file = paddle_file.replace("paddle_lora_weights.pdparams", "pytorch_lora_weights.safetensors")
tensors = paddle.load(paddle_file)
new_tensors = {}
for k, v in tensors.items():
new_tensors[k] = v.cpu().numpy().T
safetensors.numpy.save_file(new_tensors, safe_file)
print(f"文件已经保存到{safe_file}!")
def convert_pytorch_lora_to_paddle_lora(pytorch_file, paddle_file=None):
if not os.path.exists(pytorch_file):
print(f"{pytorch_file} 文件不存在!")
return
if paddle_file is None:
paddle_file = pytorch_file.replace("pytorch_lora_weights.bin", "paddle_lora_weights.pdparams")
tensors = load_torch(pytorch_file)
new_tensors = {}
for k, v in tensors.items():
new_tensors[k] = v.T
paddle.save(new_tensors, paddle_file)
print(f"文件已经保存到{paddle_file}!")
import time
from typing import Optional, Type
import paddle
import requests
from huggingface_hub import create_repo, upload_folder, get_full_repo_name
# Since HF sometimes timeout, we need to retry uploads
# Credit: https://github.com/huggingface/datasets/blob/06ae3f678651bfbb3ca7dd3274ee2f38e0e0237e/src/datasets/utils/file_utils.py#L265
def _retry(
func,
func_args: Optional[tuple] = None,
func_kwargs: Optional[dict] = None,
exceptions: Type[requests.exceptions.RequestException] = requests.exceptions.RequestException,
max_retries: int = 0,
base_wait_time: float = 0.5,
max_wait_time: float = 2,
):
func_args = func_args or ()
func_kwargs = func_kwargs or {}
retry = 0
while True:
try:
return func(*func_args, **func_kwargs)
except exceptions as err:
if retry >= max_retries:
raise err
else:
sleep_time = min(max_wait_time, base_wait_time * 2**retry) # Exponential backoff
print(f"{func} timed out, retrying in {sleep_time}s... [{retry/max_retries}]")
time.sleep(sleep_time)
retry += 1
def upload_lora_folder(upload_dir, repo_name, pretrained_model_name_or_path, prompt, hub_token=None):
repo_name = get_full_repo_name(repo_name, token=hub_token)
_retry(
create_repo,
func_kwargs={"repo_id": repo_name, "exist_ok": True, "token": hub_token},
base_wait_time=1.0,
max_retries=5,
max_wait_time=10.0,
)
save_model_card(
repo_name,
base_model=pretrained_model_name_or_path,
prompt=prompt,
repo_folder=upload_dir,
)
# Upload model
print(f"Pushing to {repo_name}")
_retry(
upload_folder,
func_kwargs={
"repo_id": repo_name,
"repo_type": "model",
"folder_path": upload_dir,
"commit_message": "submit best ckpt",
"token": hub_token,
"ignore_patterns": ["checkpoint-*/*", "logs/*", "validation_images/*"],
},
base_wait_time=1.0,
max_retries=5,
max_wait_time=20.0,
)
def save_model_card(repo_name, base_model=str, prompt=str, repo_folder=None):
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
instance_prompt: {prompt}
tags:
- stable-diffusion
- stable-diffusion-ppdiffusers
- text-to-image
- ppdiffusers
- lora
inference: false
---
"""
model_card = f"""
# LoRA DreamBooth - {repo_name}
本仓库的 LoRA 权重是基于 {base_model} 训练而来的,我们采用[DreamBooth](https://dreambooth.github.io/)的技术并使用 {prompt} 文本进行了训练。
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)