ChenyuRabbitLove commited on
Commit
ebdfacb
1 Parent(s): 8b5f4b6

feat: add NTU LLaMA model

Browse files
app.py CHANGED
@@ -15,7 +15,6 @@ from utils.utils import (
15
  )
16
  from utils.completion_reward import CompletionReward
17
  from utils.completion_reward_utils import (
18
- get_llm_response,
19
  set_player_name,
20
  set_player_selected_character,
21
  create_certificate,
@@ -258,6 +257,10 @@ with gr.Blocks(theme=seafoam, css=get_content("css/style.css")) as demo:
258
  mtk_description = gr.Markdown(
259
  "# 蔚藍", elem_id="mtk_description", visible=False
260
  )
 
 
 
 
261
 
262
  with gr.Row():
263
  openai_img = gr.Image(
@@ -288,6 +291,13 @@ with gr.Blocks(theme=seafoam, css=get_content("css/style.css")) as demo:
288
  interactive=False,
289
  show_download_button=False,
290
  )
 
 
 
 
 
 
 
291
 
292
  with gr.Row():
293
  start_generate_story = gr.Button(
@@ -306,6 +316,7 @@ with gr.Blocks(theme=seafoam, css=get_content("css/style.css")) as demo:
306
  bot2 = gr.Chatbot(visible=False, elem_id="bot2")
307
  bot3 = gr.Chatbot(visible=False, elem_id="bot3")
308
  bot4 = gr.Chatbot(visible=False, elem_id="bot4")
 
309
 
310
  with gr.Row():
311
  select_story = gr.Radio(
@@ -460,31 +471,26 @@ with gr.Blocks(theme=seafoam, css=get_content("css/style.css")) as demo:
460
  ],
461
  queue=False,
462
  ).then(
463
- lambda: create_visibility_updates(True, 11),
464
  None,
465
  [
466
  openai_img,
467
  aws_img,
468
  google_img,
469
  mtk_img,
 
470
  story_title,
471
  story_description,
472
  openai_description,
473
  aws_description,
474
  google_description,
475
  mtk_description,
 
476
  start_generate_story,
477
  ],
478
  queue=False,
479
  )
480
 
481
- get_llm_response_args = dict(
482
- fn=get_llm_response,
483
- inputs=[completion_reward, player_logs],
484
- outputs=[bot1, bot2, bot3, bot4],
485
- queue=False,
486
- )
487
-
488
  get_first_llm_response_args = dict(
489
  fn=get_llm_response_once,
490
  inputs=[completion_reward, player_logs],
@@ -513,12 +519,19 @@ with gr.Blocks(theme=seafoam, css=get_content("css/style.css")) as demo:
513
  queue=False,
514
  )
515
 
 
 
 
 
 
 
 
516
  start_generate_story.click(
517
  lambda: gr.update(visible=False), None, start_generate_story, queue=False
518
  ).then(
519
- lambda: create_visibility_updates(True, 5),
520
  None,
521
- [bot1, bot2, bot3, bot4, weaving],
522
  queue=False,
523
  ).then(
524
  **get_first_llm_response_args
@@ -528,6 +541,8 @@ with gr.Blocks(theme=seafoam, css=get_content("css/style.css")) as demo:
528
  **get_third_llm_response_args
529
  ).then(
530
  **get_fourth_llm_response_args
 
 
531
  ).then(
532
  lambda: gr.update(visible=True), None, [select_story], queue=False
533
  ).then(
@@ -580,23 +595,26 @@ with gr.Blocks(theme=seafoam, css=get_content("css/style.css")) as demo:
580
  )
581
 
582
  start_generate_certificate.click(
583
- lambda: create_visibility_updates(False, 18),
584
  None,
585
  [
586
  openai_img,
587
  aws_img,
588
  google_img,
589
  mtk_img,
 
590
  story_title,
591
  story_description,
592
  openai_description,
593
  aws_description,
594
  google_description,
595
  mtk_description,
 
596
  bot1,
597
  bot2,
598
  bot3,
599
  bot4,
 
600
  select_story,
601
  processing,
602
  cancel_story,
 
15
  )
16
  from utils.completion_reward import CompletionReward
17
  from utils.completion_reward_utils import (
 
18
  set_player_name,
19
  set_player_selected_character,
20
  create_certificate,
 
257
  mtk_description = gr.Markdown(
258
  "# 蔚藍", elem_id="mtk_description", visible=False
259
  )
260
+ ntu_description = gr.Markdown(
261
+ "# 紅寶石", elem_id="ntu_description", visible=False
262
+ )
263
+
264
 
265
  with gr.Row():
266
  openai_img = gr.Image(
 
291
  interactive=False,
292
  show_download_button=False,
293
  )
294
+ ntu_img = gr.Image(
295
+ "medias/ntu.png",
296
+ visible=False,
297
+ elem_id="ntu_img",
298
+ interactive=False,
299
+ show_download_button=False,
300
+ )
301
 
302
  with gr.Row():
303
  start_generate_story = gr.Button(
 
316
  bot2 = gr.Chatbot(visible=False, elem_id="bot2")
317
  bot3 = gr.Chatbot(visible=False, elem_id="bot3")
318
  bot4 = gr.Chatbot(visible=False, elem_id="bot4")
319
+ bot5 = gr.Chatbot(visible=False, elem_id="bot5")
320
 
321
  with gr.Row():
322
  select_story = gr.Radio(
 
471
  ],
472
  queue=False,
473
  ).then(
474
+ lambda: create_visibility_updates(True, 13),
475
  None,
476
  [
477
  openai_img,
478
  aws_img,
479
  google_img,
480
  mtk_img,
481
+ ntu_img,
482
  story_title,
483
  story_description,
484
  openai_description,
485
  aws_description,
486
  google_description,
487
  mtk_description,
488
+ ntu_description,
489
  start_generate_story,
490
  ],
491
  queue=False,
492
  )
493
 
 
 
 
 
 
 
 
494
  get_first_llm_response_args = dict(
495
  fn=get_llm_response_once,
496
  inputs=[completion_reward, player_logs],
 
519
  queue=False,
520
  )
521
 
522
+ get_fifth_llm_response_args = dict(
523
+ fn=get_llm_response_once,
524
+ inputs=[completion_reward, player_logs],
525
+ outputs=bot5,
526
+ queue=False,
527
+ )
528
+
529
  start_generate_story.click(
530
  lambda: gr.update(visible=False), None, start_generate_story, queue=False
531
  ).then(
532
+ lambda: create_visibility_updates(True, 6),
533
  None,
534
+ [bot1, bot2, bot3, bot4, bot5, weaving],
535
  queue=False,
536
  ).then(
537
  **get_first_llm_response_args
 
541
  **get_third_llm_response_args
542
  ).then(
543
  **get_fourth_llm_response_args
544
+ ).then(
545
+ **get_fifth_llm_response_args
546
  ).then(
547
  lambda: gr.update(visible=True), None, [select_story], queue=False
548
  ).then(
 
595
  )
596
 
597
  start_generate_certificate.click(
598
+ lambda: create_visibility_updates(False, 21),
599
  None,
600
  [
601
  openai_img,
602
  aws_img,
603
  google_img,
604
  mtk_img,
605
+ ntu_img,
606
  story_title,
607
  story_description,
608
  openai_description,
609
  aws_description,
610
  google_description,
611
  mtk_description,
612
+ ntu_description,
613
  bot1,
614
  bot2,
615
  bot3,
616
  bot4,
617
+ bot5,
618
  select_story,
619
  processing,
620
  cancel_story,
css/style.css CHANGED
@@ -374,6 +374,11 @@ input[type="range"]::-ms-track {
374
  border: None !important;
375
  }
376
 
 
 
 
 
 
377
  #processing {
378
  margin: 20vh 10vw;
379
  height: 30vh;
@@ -434,4 +439,8 @@ input[type="range"]::-ms-track {
434
 
435
  [data-testid="蔚藍-radio-label"] {
436
  background: rgba(66, 130, 227, 0.4) !important;
 
 
 
 
437
  }
 
374
  border: None !important;
375
  }
376
 
377
+ #bot4 .message {
378
+ background: rgba(227, 66, 104, 0.4) !important;
379
+ border: None !important;
380
+ }
381
+
382
  #processing {
383
  margin: 20vh 10vw;
384
  height: 30vh;
 
439
 
440
  [data-testid="蔚藍-radio-label"] {
441
  background: rgba(66, 130, 227, 0.4) !important;
442
+ }
443
+
444
+ [data-testid="紅寶石-radio-label"] {
445
+ background: rgba(227, 66, 104, 0.4) !important;
446
  }
utils/completion_reward.py CHANGED
@@ -48,17 +48,20 @@ class CompletionReward:
48
  self.paragraph_aws = None
49
  self.paragraph_google = None
50
  self.paragraph_mtk = None
 
51
  self.player_certificate_url = None
52
  self.openai_agent = OpenAIAgent()
53
  self.aws_agent = AWSAgent()
54
  self.google_agent = GoogleAgent()
55
  self.mtk_agent = MTKAgent()
 
56
  self.agents_responses = {}
57
  self.agent_list = [
58
  self.openai_agent,
59
  self.aws_agent,
60
  self.google_agent,
61
  self.mtk_agent,
 
62
  ]
63
  self.shuffled_response_order = {}
64
  self.pop_response_order = []
@@ -67,38 +70,9 @@ class CompletionReward:
67
  "aws": self.paragraph_aws,
68
  "google": self.paragraph_google,
69
  "mtk": self.paragraph_mtk,
 
70
  }
71
 
72
- def get_llm_response(self, player_logs):
73
- openai_story = self.openai_agent.get_story(player_logs)
74
- aws_story = self.aws_agent.get_story(player_logs)
75
- google_story = self.google_agent.get_story(player_logs)
76
- mtk_story = self.mtk_agent.get_story(player_logs)
77
- agents_responses = {
78
- "openai": openai_story,
79
- "aws": aws_story,
80
- "google": google_story,
81
- "mtk": mtk_story,
82
- }
83
- self.paragraph_openai = agents_responses["openai"]
84
- self.paragraph_aws = agents_responses["aws"]
85
- self.paragraph_google = agents_responses["google"]
86
- self.paragraph_mtk = agents_responses["mtk"]
87
- response_items = list(agents_responses.items())
88
- random.shuffle(response_items)
89
-
90
- self.shuffled_response_order = {
91
- str(index): agent for index, (agent, _) in enumerate(response_items)
92
- }
93
-
94
- shuffled_responses = tuple(response for _, response in response_items)
95
- return (
96
- [(None, shuffled_responses[0])],
97
- [(None, shuffled_responses[1])],
98
- [(None, shuffled_responses[2])],
99
- [(None, shuffled_responses[3])],
100
- )
101
-
102
  def get_llm_response_once(self, player_logs):
103
  if self.agent_list:
104
  # Randomly select and remove an agent from the list
@@ -110,7 +84,7 @@ class CompletionReward:
110
  self.agents_responses[agent.name] = story
111
  self.pop_response_order.append(agent.name)
112
 
113
- if len(self.pop_response_order) == 4:
114
  self.shuffled_response_order = {
115
  str(index): agent for index, agent in enumerate(self.pop_response_order)
116
  }
@@ -118,6 +92,7 @@ class CompletionReward:
118
  self.paragraph_aws = self.agents_responses["aws"]
119
  self.paragraph_google = self.agents_responses["google"]
120
  self.paragraph_mtk = self.agents_responses["mtk"]
 
121
 
122
  return [(None, story)]
123
 
@@ -137,6 +112,7 @@ class CompletionReward:
137
  "索拉拉": "1",
138
  "薇丹特": "2",
139
  "蔚藍": "3",
 
140
  }
141
  self.player_selected_character = player_selected_character
142
  self.player_selected_model = self.shuffled_response_order[
@@ -290,6 +266,8 @@ class OpenAIAgent:
290
  logging.error(f"OpenAI Attempt {retry_attempts}: {e}")
291
  time.sleep(1 * retry_attempts)
292
 
 
 
293
  def get_background(self):
294
  client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
295
  image_url = None
@@ -355,6 +333,8 @@ class AWSAgent:
355
  logging.error(f"AWS Attempt {retry_attempts}: {e}")
356
  time.sleep(1 * retry_attempts)
357
 
 
 
358
 
359
  class GoogleAgent:
360
  from google.cloud import aiplatform
@@ -412,6 +392,8 @@ class GoogleAgent:
412
  logging.error(f"Google Attempt {retry_attempts}: {e}")
413
  time.sleep(1 * retry_attempts)
414
 
 
 
415
 
416
  class MTKAgent:
417
  def __init__(self):
@@ -485,7 +467,76 @@ class MTKAgent:
485
  retry_attempts += 1
486
  logging.error(f"MTK Attempt {retry_attempts}: {e}")
487
  time.sleep(1 * retry_attempts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
 
 
489
 
490
  class ImageProcessor:
491
  @staticmethod
 
48
  self.paragraph_aws = None
49
  self.paragraph_google = None
50
  self.paragraph_mtk = None
51
+ self.paragraph_ntu = None
52
  self.player_certificate_url = None
53
  self.openai_agent = OpenAIAgent()
54
  self.aws_agent = AWSAgent()
55
  self.google_agent = GoogleAgent()
56
  self.mtk_agent = MTKAgent()
57
+ self.ntu_agent = NTUAgent()
58
  self.agents_responses = {}
59
  self.agent_list = [
60
  self.openai_agent,
61
  self.aws_agent,
62
  self.google_agent,
63
  self.mtk_agent,
64
+ self.ntu_agent,
65
  ]
66
  self.shuffled_response_order = {}
67
  self.pop_response_order = []
 
70
  "aws": self.paragraph_aws,
71
  "google": self.paragraph_google,
72
  "mtk": self.paragraph_mtk,
73
+ "ntu": self.paragraph_ntu,
74
  }
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def get_llm_response_once(self, player_logs):
77
  if self.agent_list:
78
  # Randomly select and remove an agent from the list
 
84
  self.agents_responses[agent.name] = story
85
  self.pop_response_order.append(agent.name)
86
 
87
+ if len(self.pop_response_order) == 5:
88
  self.shuffled_response_order = {
89
  str(index): agent for index, agent in enumerate(self.pop_response_order)
90
  }
 
92
  self.paragraph_aws = self.agents_responses["aws"]
93
  self.paragraph_google = self.agents_responses["google"]
94
  self.paragraph_mtk = self.agents_responses["mtk"]
95
+ self.paragraph_ntu = self.agents_responses["ntu"]
96
 
97
  return [(None, story)]
98
 
 
112
  "索拉拉": "1",
113
  "薇丹特": "2",
114
  "蔚藍": "3",
115
+ "紅寶石": "4",
116
  }
117
  self.player_selected_character = player_selected_character
118
  self.player_selected_model = self.shuffled_response_order[
 
266
  logging.error(f"OpenAI Attempt {retry_attempts}: {e}")
267
  time.sleep(1 * retry_attempts)
268
 
269
+ return '星際夥伴短時間內寫了太多故事,需要休息一下,請稍後再試,或是選擇其他星際夥伴的故事。'
270
+
271
  def get_background(self):
272
  client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
273
  image_url = None
 
333
  logging.error(f"AWS Attempt {retry_attempts}: {e}")
334
  time.sleep(1 * retry_attempts)
335
 
336
+ return '星際夥伴短時間內寫了太多故事,需要休息一下,請稍後再試,或是選擇其他星際夥伴的故事。'
337
+
338
 
339
  class GoogleAgent:
340
  from google.cloud import aiplatform
 
392
  logging.error(f"Google Attempt {retry_attempts}: {e}")
393
  time.sleep(1 * retry_attempts)
394
 
395
+ return '星際夥伴短時間內寫了太多故事,需要休息一下,請稍後再試,或是選擇其他星際夥伴的故事。'
396
+
397
 
398
  class MTKAgent:
399
  def __init__(self):
 
467
  retry_attempts += 1
468
  logging.error(f"MTK Attempt {retry_attempts}: {e}")
469
  time.sleep(1 * retry_attempts)
470
+
471
+ return '星際夥伴短時間內寫了太多故事,需要休息一下,請稍後再試,或是選擇其他星際夥伴的故事。'
472
+
473
+ class NTUAgent:
474
+ def __init__(self):
475
+ self.name = "ntu"
476
+
477
+ def get_story(self, user_log):
478
+ system_prompt = """
479
+ 我正在舉辦一個學習型的活動,我為學生設計了一個獨特的故事機制,每天每個學生都會收到屬於自己獨特的冒險紀錄,現在我需要你協助我將這些冒險紀錄,製作成一段冒險故事,請
480
+ - 以「你」稱呼學生
481
+ - 可以裁減內容以將內容限制在 1024 個 token 內
482
+ - 試著合併故事記錄成一段連貫、有吸引力的故事
483
+ - 請使用 zh_TW
484
+ - 請直接回覆故事內容,不需要回覆任何訊息
485
+ """
486
+
487
+ user_log = f"""
488
+ ```{user_log}
489
+ ```
490
+ """
491
+
492
+ url = 'http://api.twllm.com:20002/v1/chat/completions'
493
+
494
+ data = {
495
+ "model": "yentinglin/Taiwan-LLM-13B-v2.0-chat",
496
+ "messages": f"{system_prompt}, 以下是我的冒險故事 ```{user_log}```",
497
+ "temperature": 0.7,
498
+ "top_p": 1,
499
+ "n": 1,
500
+ "max_tokens": 2048,
501
+ "stop": ["string"],
502
+ "stream": False,
503
+ "presence_penalty": 0,
504
+ "frequency_penalty": 0,
505
+ "user": "string",
506
+ "best_of": 1,
507
+ "top_k": -1,
508
+ "ignore_eos": False,
509
+ "use_beam_search": False,
510
+ "stop_token_ids": [0],
511
+ "skip_special_tokens": True,
512
+ "spaces_between_special_tokens": True,
513
+ "add_generation_prompt": True,
514
+ "echo": False,
515
+ "repetition_penalty": 1,
516
+ "min_p": 0
517
+ }
518
+
519
+ headers = {
520
+ 'accept': 'application/json',
521
+ 'Content-Type': 'application/json'
522
+ }
523
+
524
+ retry_attempts = 0
525
+ while retry_attempts < 5:
526
+ try:
527
+ response = requests.post(url, headers=headers, data=json.dumps(data)).json()
528
+ response_text = response["choices"][0]["message"]["content"]
529
+
530
+ chinese_converter = OpenCC("s2tw")
531
+
532
+ return chinese_converter.convert(response_text)
533
+
534
+ except Exception as e:
535
+ retry_attempts += 1
536
+ logging.error(f"NTU Attempt {retry_attempts}: {e}")
537
+ time.sleep(1 * retry_attempts)
538
 
539
+ return '星際夥伴短時間內寫了太多故事,需要休息一下,請稍後再試,或是選擇其他星際夥伴的故事。'
540
 
541
  class ImageProcessor:
542
  @staticmethod
utils/completion_reward_utils.py CHANGED
@@ -2,11 +2,6 @@ import json
2
 
3
  import gradio as gr
4
 
5
-
6
- def get_llm_response(completion_reward, *args):
7
- return completion_reward.get_llm_response(*args)
8
-
9
-
10
  def get_llm_response_once(completion_reward, *args):
11
  return completion_reward.get_llm_response_once(*args)
12
 
 
2
 
3
  import gradio as gr
4
 
 
 
 
 
 
5
  def get_llm_response_once(completion_reward, *args):
6
  return completion_reward.get_llm_response_once(*args)
7