AINovelChat / app.py
tori29umai's picture
Upload 5 files
20e3524 verified
raw
history blame
10 kB
import gradio as gr
from jinja2 import Template
from llama_cpp import Llama
import os
import configparser
from utils.dl_utils import dl_guff_model
# モデルディレクトリが存在しない場合は作成
if not os.path.exists("models"):
os.makedirs("models")
# 使用するモデルのファイル名を指定
model_filename = "Llama-3.1-70B-EZO-1.1-it-Q4_K_M.gguf"
model_path = os.path.join("models", model_filename)
# モデルファイルが存在しない場合はダウンロード
if not os.path.exists(model_path):
dl_guff_model("models", f"https://huggingface.co/mmnga/Llama-3.1-70B-EZO-1.1-it-gguf/resolve/main/{model_filename}")
# 設定をINIファイルに保存する関数
def save_settings_to_ini(settings, filename='character_settings.ini'):
config = configparser.ConfigParser()
config['Settings'] = {
'name': settings['name'],
'gender': settings['gender'],
'situation': '\n'.join(settings['situation']),
'orders': '\n'.join(settings['orders']),
'dirty_talk_list': '\n'.join(settings['dirty_talk_list']),
'example_quotes': '\n'.join(settings['example_quotes'])
}
with open(filename, 'w', encoding='utf-8') as configfile:
config.write(configfile)
# INIファイルから設定を読み込む関数
def load_settings_from_ini(filename='character_settings.ini'):
if not os.path.exists(filename):
return None
config = configparser.ConfigParser()
config.read(filename, encoding='utf-8')
if 'Settings' not in config:
return None
try:
settings = {
'name': config['Settings']['name'],
'gender': config['Settings']['gender'],
'situation': config['Settings']['situation'].split('\n'),
'orders': config['Settings']['orders'].split('\n'),
'dirty_talk_list': config['Settings']['dirty_talk_list'].split('\n'),
'example_quotes': config['Settings']['example_quotes'].split('\n')
}
return settings
except KeyError:
return None
# LlamaCppのラッパークラス
class LlamaCppAdapter:
def __init__(self, model_path, n_ctx=4096):
print(f"モデルの初期化: {model_path}")
self.llama = Llama(model_path=model_path, n_ctx=n_ctx, n_gpu_layers=-1)
def generate(self, prompt, max_new_tokens=4096, temperature=0.5, top_p=0.7, top_k=80, stop=["<END>"]):
return self._generate(prompt, max_new_tokens, temperature, top_p, top_k, stop)
def _generate(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int, stop: list):
return self.llama(
prompt,
temperature=temperature,
max_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
stop=stop,
repeat_penalty=1.2,
)
# キャラクターメーカークラス
class CharacterMaker:
def __init__(self):
self.llama = LlamaCppAdapter(model_path)
self.history = []
self.settings = load_settings_from_ini()
if not self.settings:
self.settings = {
"name": "ナツ",
"gender": "女性",
"situation": [
"あなたは人工知能アシスタントです。",
"ユーザーの日常生活をサポートし、より良い生活を送るお手伝いをします。",
"AIアシスタント『ナツ』として、ユーザーの健康と幸福をケアし、様々な質問に答えたり課題解決を手伝ったりします。"
],
"orders": [
"丁寧な言葉遣いを心がけてください。",
"ユーザーとの対話を通じてサポートを提供します。",
"ユーザーのことは『ユーザー様』と呼んでください。"
],
"conversation_topics": [
"健康管理",
"目標設定",
"時間管理"
],
"example_quotes": [
"ユーザー様の健康と幸福が何より大切です。どのようなサポートが必要でしょうか?",
"私はユーザー様の生活をより良いものにするためのアシスタントです。お手伝いできることがありましたらお申し付けください。",
"目標達成に向けて一緒に頑張りましょう。具体的な計画を立てるお手伝いをさせていただきます。",
"効率的な時間管理のコツをお教えします。まずは1日のスケジュールを確認してみましょう。",
"ストレス解消法についてアドバイスいたします。リラックスするための簡単な呼吸法から始めてみませんか?"
]
}
save_settings_to_ini(self.settings)
def make(self, input_str: str):
prompt = self._generate_aki(input_str)
print(prompt)
print("-----------------")
res = self.llama.generate(prompt, max_new_tokens=1000, stop=["<END>", "\n"])
res_text = res["choices"][0]["text"]
self.history.append({"user": input_str, "assistant": res_text})
return res_text
def make_prompt(self, name: str, gender: str, situation: list, orders: list, dirty_talk_list: list, example_quotes: list, input_str: str):
with open('test_prompt.jinja2', 'r', encoding='utf-8') as f:
prompt = f.readlines()
fix_example_quotes = [quote+"<END>" for quote in example_quotes]
prompt = "".join(prompt)
prompt = Template(prompt).render(name=name, gender=gender, situation=situation, orders=orders, dirty_talk_list=dirty_talk_list, example_quotes=fix_example_quotes, histories=self.history, input_str=input_str)
return prompt
def _generate_aki(self, input_str: str):
prompt = self.make_prompt(
self.settings["name"],
self.settings["gender"],
self.settings["situation"],
self.settings["orders"],
self.settings["dirty_talk_list"],
self.settings["example_quotes"],
input_str
)
print(prompt)
return prompt
def update_settings(self, new_settings):
self.settings.update(new_settings)
save_settings_to_ini(self.settings)
def reset(self):
self.history = []
self.llama = LlamaCppAdapter(model_path)
character_maker = CharacterMaker()
# 設定を更新する関数
def update_settings(name, gender, situation, orders, dirty_talk_list, example_quotes):
new_settings = {
"name": name,
"gender": gender,
"situation": [s.strip() for s in situation.split('\n') if s.strip()],
"orders": [o.strip() for o in orders.split('\n') if o.strip()],
"dirty_talk_list": [d.strip() for d in dirty_talk_list.split('\n') if d.strip()],
"example_quotes": [e.strip() for e in example_quotes.split('\n') if e.strip()]
}
character_maker.update_settings(new_settings)
return "設定が更新されました。"
# チャット機能の関数
def chat_with_character(message, history):
character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
response = character_maker.make(message)
return response
# チャットをクリアする関数
def clear_chat():
character_maker.reset()
return []
# カスタムCSS
custom_css = """
#chatbot {
height: 60vh !important;
overflow-y: auto;
}
"""
# カスタムJavaScript(HTML内に埋め込む)
custom_js = """
<script>
function adjustChatbotHeight() {
var chatbot = document.querySelector('#chatbot');
if (chatbot) {
chatbot.style.height = window.innerHeight * 0.6 + 'px';
}
}
// ページ読み込み時と画面サイズ変更時にチャットボットの高さを調整
window.addEventListener('load', adjustChatbotHeight);
window.addEventListener('resize', adjustChatbotHeight);
</script>
"""
# Gradioインターフェースの設定
with gr.Blocks(css=custom_css) as iface:
chatbot = gr.Chatbot(elem_id="chatbot")
with gr.Tab("チャット"):
gr.ChatInterface(
chat_with_character,
chatbot=chatbot,
textbox=gr.Textbox(placeholder="メッセージを入力してください...", container=False, scale=7),
theme="soft",
retry_btn="もう一度生成",
undo_btn="前のメッセージを取り消す",
clear_btn="チャットをクリア",
)
with gr.Tab("設定"):
gr.Markdown("## キャラクター設定")
name_input = gr.Textbox(label="名前", value=character_maker.settings["name"])
gender_input = gr.Textbox(label="性別", value=character_maker.settings["gender"])
situation_input = gr.Textbox(label="状況設定", value="\n".join(character_maker.settings["situation"]), lines=5)
orders_input = gr.Textbox(label="指示", value="\n".join(character_maker.settings["orders"]), lines=5)
dirty_talk_input = gr.Textbox(label="淫語リスト", value="\n".join(character_maker.settings["dirty_talk_list"]), lines=5)
example_quotes_input = gr.Textbox(label="例文", value="\n".join(character_maker.settings["example_quotes"]), lines=5)
update_button = gr.Button("設定を更新")
update_output = gr.Textbox(label="更新状態")
update_button.click(
update_settings,
inputs=[name_input, gender_input, situation_input, orders_input, dirty_talk_input, example_quotes_input],
outputs=[update_output]
)
# Gradioアプリの起動
if __name__ == "__main__":
iface.launch(
share=True,
allowed_paths=["models"],
favicon_path="custom.html"
)