|
import os |
|
import csv |
|
|
|
import torch |
|
from diffusers import AutoPipelineForText2Image |
|
|
|
|
|
def load_prompts(path): |
|
if os.path.basename(path) == 'ViLG-300.csv': |
|
def csv_to_dict(file_path): |
|
result_dict = {} |
|
with open(file_path, 'r', encoding='utf-8') as csv_file: |
|
csv_reader = csv.DictReader(csv_file, delimiter=',') |
|
for row in csv_reader: |
|
prompt = row['\ufeffPrompt'] |
|
text = row['文本'] |
|
category = row['类别'] |
|
source = row['来源'] |
|
result_dict[prompt] = {'prompt': prompt, 'text': text, 'category': category, 'source': source} |
|
return result_dict |
|
data = csv_to_dict(path).keys() |
|
else: |
|
return NotImplementedError |
|
return data |
|
|
|
|
|
def main( |
|
model_id="runwayml/stable-diffusion-v1-5", |
|
prompt_path="assets/ViLG-300.csv", |
|
save_path=None, |
|
dtype='fp16', |
|
variant=None, |
|
): |
|
if save_path is None: |
|
save_path = os.path.join('saved', model_id.replace('/', '_')) |
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
prompts = load_prompts(prompt_path) |
|
pipeline = AutoPipelineForText2Image.from_pretrained( |
|
model_id, |
|
variant=variant, |
|
torch_dtype=torch.float32 if dtype == 'fp32' else torch.float16 |
|
) |
|
pipeline.to(device='cuda') |
|
pipeline.safety_checker = None |
|
for i, prompt in enumerate(prompts): |
|
print(f'{i}|{len(prompts)}: {prompt}') |
|
image = pipeline(prompt).images[0] |
|
image.save(os.path.join(save_path, f'{i}.jpg')) |
|
|
|
|
|
if __name__ == '__main__': |
|
import fire |
|
fire.Fire(main) |
|
|