tori29umai commited on
Commit
3d8332e
1 Parent(s): 0b5ed8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- # os.environ['CUDA_VISIBLE_DEVICES'] = ''
3
  import spaces
4
  import sys
5
  import time
@@ -289,12 +289,10 @@ class LlamaAdapter:
289
  repeat_penalty=repeat_penalty
290
  )
291
 
292
- @spaces.GPU(duration=120)
293
  def load_model_gpu(model_type, model_path, n_gpu_layers, params):
294
  llama = LlamaAdapter(model_path, params, n_gpu_layers)
295
  print(f"{model_type} モデル {model_path} のロードが完了しました。(n_gpu_layers: {n_gpu_layers})")
296
  return llama
297
-
298
 
299
  class CharacterMaker:
300
  def __init__(self):
@@ -336,7 +334,7 @@ class CharacterMaker:
336
 
337
  try:
338
  # 新しいモデルをロード
339
- self.llama = load_model_gpu(model_type, model_path, n_gpu_layers, params)
340
  self.current_model = model_type
341
  self.model_loaded.set()
342
  print(f"{model_type} モデルをロードしました。モデルパス: {model_path}、GPUレイヤー数: {n_gpu_layers}")
@@ -344,6 +342,17 @@ class CharacterMaker:
344
  print(f"{model_type} モデルのロード中にエラーが発生しました: {e}")
345
  self.model_loaded.set()
346
 
 
 
 
 
 
 
 
 
 
 
 
347
  def generate_response(self, input_str):
348
  self.load_model('CHAT')
349
  if not self.model_loaded.wait(timeout=30) or not self.llama:
@@ -470,14 +479,14 @@ def chat_with_character(message, history):
470
  character_maker.chat_history = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(sum(history, []))]
471
  else:
472
  character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
473
- return character_maker.generate_response(message)
474
 
475
  def chat_with_character_stream(message, history):
476
  if character_maker.use_chat_format:
477
  character_maker.chat_history = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(sum(history, []))]
478
  else:
479
  character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
480
- response = character_maker.generate_response(message)
481
  for i in range(len(response)):
482
  time.sleep(0.05) # 各文字の表示間隔を調整
483
  yield response[:i+1]
@@ -702,7 +711,7 @@ def build_gradio_interface():
702
  generated_output = gr.Textbox(label="生成された文章")
703
 
704
  generate_button.click(
705
- character_maker.generate_text,
706
  inputs=[gen_input_text, gen_characters, gen_token_multiplier, gen_instruction],
707
  outputs=[generated_output]
708
  )
 
1
  import os
2
+ os.environ['CUDA_VISIBLE_DEVICES'] = ''
3
  import spaces
4
  import sys
5
  import time
 
289
  repeat_penalty=repeat_penalty
290
  )
291
 
 
292
  def load_model_gpu(model_type, model_path, n_gpu_layers, params):
293
  llama = LlamaAdapter(model_path, params, n_gpu_layers)
294
  print(f"{model_type} モデル {model_path} のロードが完了しました。(n_gpu_layers: {n_gpu_layers})")
295
  return llama
 
296
 
297
  class CharacterMaker:
298
  def __init__(self):
 
334
 
335
  try:
336
  # 新しいモデルをロード
337
+ self.llama = LlamaAdapter(model_path, params, n_gpu_layers)
338
  self.current_model = model_type
339
  self.model_loaded.set()
340
  print(f"{model_type} モデルをロードしました。モデルパス: {model_path}、GPUレイヤー数: {n_gpu_layers}")
 
342
  print(f"{model_type} モデルのロード中にエラーが発生しました: {e}")
343
  self.model_loaded.set()
344
 
345
+ @spaces.GPU(duration=120)
346
+ def chat_or_gen(self, text, gen_characters, gen_token_multiplier, instruction, mode):
347
+ if mode == "chat":
348
+ return self.generate_response(text)
349
+ elif mode == "gen":
350
+ return self.generate_text(text, gen_characters, gen_token_multiplier, instruction)
351
+
352
+
353
+ def generate_text_gen_pre(self, text, gen_characters, gen_token_multiplier, instruction):
354
+ return self.chat_or_gen(self, text, gen_characters, gen_token_multiplier, instruction, mode="gen")
355
+
356
  def generate_response(self, input_str):
357
  self.load_model('CHAT')
358
  if not self.model_loaded.wait(timeout=30) or not self.llama:
 
479
  character_maker.chat_history = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(sum(history, []))]
480
  else:
481
  character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
482
+ return character_maker.chat_or_gen(text=message,gen_characters=None, gen_token_multiplier=None, instruction=None, mode="chat")
483
 
484
  def chat_with_character_stream(message, history):
485
  if character_maker.use_chat_format:
486
  character_maker.chat_history = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(sum(history, []))]
487
  else:
488
  character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
489
+ response = character_maker.chat_or_gen(text=message,gen_characters=None, gen_token_multiplier=None, instruction=None, mode="chat")
490
  for i in range(len(response)):
491
  time.sleep(0.05) # 各文字の表示間隔を調整
492
  yield response[:i+1]
 
711
  generated_output = gr.Textbox(label="生成された文章")
712
 
713
  generate_button.click(
714
+ character_maker.generate_text_pre,
715
  inputs=[gen_input_text, gen_characters, gen_token_multiplier, gen_instruction],
716
  outputs=[generated_output]
717
  )