import numpy as np import os import IPython import traceback import json from gensim.utils import ( save_text, add_to_txt, extract_dict, format_dict_prompt, generate_feedback, ) import copy import random class Critic: """ class that reflects and criticizes new task for improvement """ def __init__(self, cfg, memory): self.prompt_folder = f"prompts/{cfg['prompt_folder']}" self.memory = memory self.chat_log = self.memory.chat_log self.cfg = cfg self.model_output_dir = cfg["model_output_dir"] def error_review(self, new_task): """ commonly made error review """ if os.path.exists(f"{self.prompt_folder}/cliport_prompt_common_errors_template.txt") and "task-name" in new_task: self.chat_log = add_to_txt(self.chat_log, "================= Error Book Preview!", with_print=True) errorbook_prompt_text = open(f'{self.prompt_folder}/cliport_prompt_common_errors_template.txt').read() errorbook_prompt_text = errorbook_prompt_text.replace("TASK_NAME_TEMPLATE", new_task["task-name"]) res = generate_feedback(errorbook_prompt_text, temperature=0., interaction_txt=self.chat_log) # cfg['gpt_temperature'] def reflection(self, new_task, new_code, current_tasks=None): """ reflect on if the new task needs to be added """ all_add_to_the_task_list_flag = True if os.path.exists(f"{self.prompt_folder}/cliport_prompt_task_reflection.txt"): # only consider successful task self.chat_log = add_to_txt(self.chat_log, "================= Code Reflect!", with_print=True) total_tasks = copy.deepcopy(self.memory.online_task_buffer) if current_tasks is not None: # adding all the tasks in the current run. at least should not overlap with those for t in current_tasks: total_tasks[t['task-name']] = t # need to load more total_tasks = self.memory.online_task_buffer MAX_NUM = 40 if len(total_tasks) > MAX_NUM: total_tasks = dict(random.sample(total_tasks.items(), MAX_NUM)) print("reflection history task num:", len(total_tasks)) task_descriptions_replacement_str = format_dict_prompt(total_tasks, -1) # append current new task code_reflection_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_task_reflection.txt").read() code_reflection_prompt_text = code_reflection_prompt_text.replace("CURRENT_TASK_NAME_TEMPLATE", str(task_descriptions_replacement_str)) code_reflection_prompt_text = code_reflection_prompt_text.replace("TASK_STRING_TEMPLATE", str(new_task)) code_reflection_prompt_text = code_reflection_prompt_text.replace("TASK_CODE_TEMPLATE", str(new_code)) if len(self.cfg['target_task_name']) > 0: code_reflection_prompt_text = code_reflection_prompt_text.replace("TARGET_TASK_NAME", self.cfg['target_task_name']) # no matter total_tasks[new_task["task-name"].replace("-", "_")] = str(new_task) res = generate_feedback(code_reflection_prompt_text, temperature=0.4, interaction_txt=self.chat_log, n=int(self.cfg['reflection_agreement_num'])) # cfg['gpt_temperature'] all_add_to_the_task_list_flag = True for idx, r in enumerate(res): # iterate through for agreement reflection_def_cmd = extract_dict(r, prefix='task_reflection') exec(reflection_def_cmd, globals()) try: print(f"critic {idx}:", task_reflection) if task_reflection["add_to_the_task_list"] == 'False': all_add_to_the_task_list_flag = False print(f"critic {idx} suggests not adding this task to the buffer! ") except: IPython.embed() save_text(self.model_output_dir, new_task['task-name'] + "_reflection_output", str(task_reflection)) return all_add_to_the_task_list_flag