tori29umai commited on
Commit
e703a54
1 Parent(s): 3d8332e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -16
app.py CHANGED
@@ -289,10 +289,12 @@ class LlamaAdapter:
289
  repeat_penalty=repeat_penalty
290
  )
291
 
292
- def load_model_gpu(model_type, model_path, n_gpu_layers, params):
293
- llama = LlamaAdapter(model_path, params, n_gpu_layers)
294
- print(f"{model_type} モデル {model_path} のロードが完了しました。(n_gpu_layers: {n_gpu_layers})")
295
- return llama
 
 
296
 
297
  class CharacterMaker:
298
  def __init__(self):
@@ -342,12 +344,6 @@ class CharacterMaker:
342
  print(f"{model_type} モデルのロード中にエラーが発生しました: {e}")
343
  self.model_loaded.set()
344
 
345
- @spaces.GPU(duration=120)
346
- def chat_or_gen(self, text, gen_characters, gen_token_multiplier, instruction, mode):
347
- if mode == "chat":
348
- return self.generate_response(text)
349
- elif mode == "gen":
350
- return self.generate_text(text, gen_characters, gen_token_multiplier, instruction)
351
 
352
 
353
  def generate_text_gen_pre(self, text, gen_characters, gen_token_multiplier, instruction):
@@ -479,20 +475,22 @@ def chat_with_character(message, history):
479
  character_maker.chat_history = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(sum(history, []))]
480
  else:
481
  character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
482
- return character_maker.chat_or_gen(text=message,gen_characters=None, gen_token_multiplier=None, instruction=None, mode="chat")
483
 
484
  def chat_with_character_stream(message, history):
485
  if character_maker.use_chat_format:
486
  character_maker.chat_history = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(sum(history, []))]
487
  else:
488
  character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
489
- response = character_maker.chat_or_gen(text=message,gen_characters=None, gen_token_multiplier=None, instruction=None, mode="chat")
490
  for i in range(len(response)):
491
  time.sleep(0.05) # 各文字の表示間隔を調整
492
  yield response[:i+1]
493
- def clear_chat():
494
- character_maker.reset()
495
- return []
 
 
496
 
497
  # ログ関連関数
498
  def list_log_files():
@@ -711,7 +709,7 @@ def build_gradio_interface():
711
  generated_output = gr.Textbox(label="生成された文章")
712
 
713
  generate_button.click(
714
- character_maker.generate_text_pre,
715
  inputs=[gen_input_text, gen_characters, gen_token_multiplier, gen_instruction],
716
  outputs=[generated_output]
717
  )
 
289
  repeat_penalty=repeat_penalty
290
  )
291
 
292
+ @spaces.GPU(duration=120)
293
+ def chat_or_gen(text, gen_characters, gen_token_multiplier, instruction, mode):
294
+ if mode == "chat":
295
+ return character_maker.generate_response(text)
296
+ elif mode == "gen":
297
+ return character_maker.generate_text(text, gen_characters, gen_token_multiplier, instruction)
298
 
299
  class CharacterMaker:
300
  def __init__(self):
 
344
  print(f"{model_type} モデルのロード中にエラーが発生しました: {e}")
345
  self.model_loaded.set()
346
 
 
 
 
 
 
 
347
 
348
 
349
  def generate_text_gen_pre(self, text, gen_characters, gen_token_multiplier, instruction):
 
475
  character_maker.chat_history = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(sum(history, []))]
476
  else:
477
  character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
478
+ return chat_or_gen(text=message, gen_characters=None, gen_token_multiplier=None, instruction=None, mode="chat")
479
 
480
  def chat_with_character_stream(message, history):
481
  if character_maker.use_chat_format:
482
  character_maker.chat_history = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(sum(history, []))]
483
  else:
484
  character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
485
+ response = chat_or_gen(text=message, gen_characters=None, gen_token_multiplier=None, instruction=None, mode="chat")
486
  for i in range(len(response)):
487
  time.sleep(0.05) # 各文字の表示間隔を調整
488
  yield response[:i+1]
489
+
490
+ # 文章生成関連関数
491
+ def generate_text_wrapper(text, gen_characters, gen_token_multiplier, instruction):
492
+ return chat_or_gen(text=text, gen_characters=gen_characters, gen_token_multiplier=gen_token_multiplier, instruction=instruction, mode="gen")
493
+
494
 
495
  # ログ関連関数
496
  def list_log_files():
 
709
  generated_output = gr.Textbox(label="生成された文章")
710
 
711
  generate_button.click(
712
+ generate_text_wrapper,
713
  inputs=[gen_input_text, gen_characters, gen_token_multiplier, gen_instruction],
714
  outputs=[generated_output]
715
  )