Masa-digital-art's picture
Duplicate from Masa-digital-art/Storytelling-AI-test
7f8e8f0
raw
history blame
6.54 kB
import gradio as gr
import openai
import requests
import os
import fileinput
from dotenv import load_dotenv
import io
from PIL import Image
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
title="Sorytelling-AI-test"
inputs_label="あなたが入力に応じてストーリーを生成します"
outputs_label="AIが生成したストーリー"
visual_outputs_label="AIが生成したビジュアルイメージ"
description="""
- 生成には時間がかかります。また失敗する可能性があります。
"""
article = """
<ul>
<li style="font-size: small;">楽しんでいただけたら、Likeのクリックをお願いします。</li>
</ul>
<ul>
<li style="font-size: small;">よかったらフィードバックの収集にご協力お願いします <a href="https://forms.gle/bLxs2h22JvQK4zwP8">https://forms.gle/bLxs2h22JvQK4zwP8</a></li>
</ul>
<h5>リリースノート</h5>
<ul>
<li style="font-size: small;">2023-08-31 v1.0</li>
</ul>
<h5>注意事項</h5>
<ul>
<li style="font-size: small;">当サービスでは、2023/3/14にリリースされたOpenAI社のChatGPT APIのgpt-4と、2022/4/13にリリースされたSability AI社のStable Diffusion XL 'sAPIを使用しております。</li>
<li style="font-size: small;">当サービスで生成されたテキストは、OpenAI が提供する人工知能によるものであり、当サービスやOpenAI がその正確性や信頼性を保証するものではありません。</li>
<li style="font-size: small;">当サービスで生成されたイメージは、Stability AI が提供する人工知能によるものであり、当サービスやStabiliy AI がその信頼性を保証するものではありません。</li>
<li style="font-size: small;"><a href="https://platform.openai.com/docs/usage-policies">OpenAI の利用規約</a>に従い、データ保持しない方針です(ただし諸般の事情によっては変更する可能性はございます)。
<li style="font-size: small;">当サービスで生成されたコンテンツは事実確認をした上で、コンテンツ生成者およびコンテンツ利用者の責任において利用してください。</li>
<li style="font-size: small;">当サービスでの使用により発生したいかなる損害についても、当社は一切の責任を負いません。</li>
<li style="font-size: small;">当サービスはβ版のため、予告なくサービスを終了する場合がございます。</li>
</ul>
"""
load_dotenv()
openai.api_key = os.getenv('OPENAI_API_KEY')
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'
stability_api = client.StabilityInference(
key=os.getenv('STABILITY_KEY'),
engine="stable-diffusion-xl-1024-v1-0",
verbose=True,
)
MODEL = "gpt-4"
def get_filetext(filename, cache={}):
if filename in cache:
# キャッシュに保存されている場合は、キャッシュからファイル内容を取得する
return cache[filename]
else:
if not os.path.exists(filename):
raise ValueError(f"ファイル '{filename}' が見つかりませんでした")
with open(filename, "r") as f:
text = f.read()
# ファイル内容をキャッシュする
cache[filename] = text
return text
class OpenAI:
@classmethod
def chat_completion(cls, prompt, start_with=""):
constraints = get_filetext(filename = "constraints.md")
template = get_filetext(filename = "template.md")
# ChatCompletion APIに渡すデータを定義する
data = {
"model": "gpt-4",
"messages": [
{"role": "system", "content": constraints}
,{"role": "system", "content": template}
,{"role": "assistant", "content": "Sure!"}
,{"role": "user", "content": prompt}
,{"role": "assistant", "content": start_with}
],
}
# ChatCompletion APIを呼び出す
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {openai.api_key}"
},
json=data
)
# ChatCompletion APIから返された結果を取得する
result = response.json()
print(result)
content = result["choices"][0]["message"]["content"].strip()
visualize_prompt = content.split("### Prompt for Visual Expression\n\n")[1]
answers = stability_api.generate(
prompt=("high quality illustlation,Stunning detail, crisp images, high-contrast images, dynamic angles, cinematic lighting, sharp focus, imaginative concept art, Simple colors, impressive shading" + visualize_prompt),
steps=30,
width=768,
height=512,
)
for resp in answers:
for artifact in resp.artifacts:
if artifact.finish_reason == generation.FILTER:
print("NSFW")
if artifact.type == generation.ARTIFACT_IMAGE:
img = Image.open(io.BytesIO(artifact.binary))
return [content, img]
class MasasanAI:
@classmethod
def generate_vision_prompt(cls, user_message):
template = get_filetext(filename="template.md")
prompt = f"""
{user_message}
---
上記を元に、下記テンプレートを埋めてください。
---
{template}
"""
return prompt
@classmethod
def generate_vision(cls, user_message):
prompt = MasasanAI.generate_vision_prompt(user_message);
start_with = ""
result = OpenAI.chat_completion(prompt=prompt, start_with=start_with)
return result
def main():
iface = gr.Interface(fn=MasasanAI.generate_vision,
inputs=gr.Textbox(label=inputs_label),
outputs=[gr.Textbox(label=inputs_label),
gr.Image(label=visual_outputs_label)],
title=title,
description=description,
article=article,
allow_flagging='never'
)
iface.launch()
if __name__ == '__main__':
main()