import os from transformers import TextGenerationPipeline from transformers.pipelines.text_generation import ReturnType class H2OTextGenerationPipeline(TextGenerationPipeline): def __init__(self, *args, debug=False, chat=False, stream_output=False, sanitize_bot_response=False, use_prompter=True, prompter=None, prompt_type=None, prompt_dict=None, max_input_tokens=2048 - 256, **kwargs): """ HF-like pipeline, but handle instruction prompting and stopping (for some models) :param args: :param debug: :param chat: :param stream_output: :param sanitize_bot_response: :param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter :param prompter: prompter, can pass if have already :param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in If use_prompter, then will make prompter and use it. :param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom :param max_input_tokens: :param kwargs: """ super().__init__(*args, **kwargs) self.prompt_text = None self.use_prompter = use_prompter self.prompt_type = prompt_type self.prompt_dict = prompt_dict self.prompter = prompter if self.use_prompter: if self.prompter is not None: assert self.prompter.prompt_type is not None else: self.prompter = Prompter(self.prompt_type, self.prompt_dict, debug=debug, chat=chat, stream_output=stream_output) self.human = self.prompter.humanstr self.bot = self.prompter.botstr self.can_stop = True else: self.prompter = None self.human = None self.bot = None self.can_stop = False self.sanitize_bot_response = sanitize_bot_response self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs @staticmethod def limit_prompt(prompt_text, tokenizer, max_prompt_length=None): verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0'))) if hasattr(tokenizer, 'model_max_length'): # model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py model_max_length = tokenizer.model_max_length if max_prompt_length is not None: model_max_length = min(model_max_length, max_prompt_length) # cut at some upper likely limit to avoid excessive tokenization etc # upper bound of 10 chars/token, e.g. special chars sometimes are long if len(prompt_text) > model_max_length * 10: len0 = len(prompt_text) prompt_text = prompt_text[-model_max_length * 10:] if verbose: print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True) else: # unknown model_max_length = None if model_max_length is not None: num_prompt_tokens = None # can't wait for "hole" if not plain prompt_type, since would lose prefix like : # For https://github.com/h2oai/h2ogpt/issues/192 for trial in range(0, 3): prompt_tokens = tokenizer(prompt_text)['input_ids'] num_prompt_tokens = len(prompt_tokens) if num_prompt_tokens > model_max_length: # conservative by using int() chars_per_token = int(len(prompt_text) / num_prompt_tokens) # keep tail, where question is if using langchain prompt_text = prompt_text[-model_max_length * chars_per_token:] if verbose: print("reducing %s tokens, assuming average of %s chars/token for %s characters" % ( num_prompt_tokens, chars_per_token, len(prompt_text)), flush=True) else: if verbose: print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True) break # Why Below False: don't limit max_new_tokens more, just rely upon stopping to reach limit of model if False: # if input prompt is some number of tokens, despite user request, can't have max_new_tokens more # assert num_prompt_tokens is not None if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]: # then give room for prompt fudge = 20 else: fudge = 0 max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'], model_max_length - (num_prompt_tokens + fudge))) if max_new_tokens < generate_kwargs['max_new_tokens']: if verbose: print("Reduced max_new_tokens from %s -> %s" % ( generate_kwargs['max_new_tokens'], max_new_tokens)) generate_kwargs['max_new_tokens'] = max_new_tokens return prompt_text def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs): prompt_text = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer) data_point = dict(context='', instruction=prompt_text, input='') if self.prompter is not None: prompt_text = self.prompter.generate_prompt(data_point) self.prompt_text = prompt_text if handle_long_generation is None: # forces truncation of inputs to avoid critical failure handle_long_generation = None # disable with new approaches return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation, **generate_kwargs) def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True): records = super().postprocess(model_outputs, return_type=return_type, clean_up_tokenization_spaces=clean_up_tokenization_spaces) for rec in records: if self.use_prompter: outputs = rec['generated_text'] outputs = self.prompter.get_response(outputs, prompt=self.prompt_text, sanitize_bot_response=self.sanitize_bot_response) elif self.bot and self.human: outputs = rec['generated_text'].split(self.bot)[1].strip().split(self.human)[0].strip() else: outputs = rec['generated_text'] rec['generated_text'] = outputs return records def _forward(self, model_inputs, **generate_kwargs): if self.can_stop: stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict, self.tokenizer, self.device, human=self.human, bot=self.bot, model_max_length=self.tokenizer.model_max_length) generate_kwargs['stopping_criteria'] = stopping_criteria # return super()._forward(model_inputs, **generate_kwargs) return self.__forward(model_inputs, **generate_kwargs) # FIXME: Copy-paste of original _forward, but removed copy.deepcopy() # FIXME: https://github.com/h2oai/h2ogpt/issues/172 def __forward(self, model_inputs, **generate_kwargs): input_ids = model_inputs["input_ids"] attention_mask = model_inputs.get("attention_mask", None) # Allow empty prompts if input_ids.shape[1] == 0: input_ids = None attention_mask = None in_b = 1 else: in_b = input_ids.shape[0] prompt_text = model_inputs.pop("prompt_text") ## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying ## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline. # generate_kwargs = copy.deepcopy(generate_kwargs) prefix_length = generate_kwargs.pop("prefix_length", 0) if prefix_length > 0: has_max_new_tokens = "max_new_tokens" in generate_kwargs or ( "generation_config" in generate_kwargs and generate_kwargs["generation_config"].max_new_tokens is not None ) if not has_max_new_tokens: generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length generate_kwargs["max_length"] += prefix_length has_min_new_tokens = "min_new_tokens" in generate_kwargs or ( "generation_config" in generate_kwargs and generate_kwargs["generation_config"].min_new_tokens is not None ) if not has_min_new_tokens and "min_length" in generate_kwargs: generate_kwargs["min_length"] += prefix_length # BS x SL generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) out_b = generated_sequence.shape[0] if self.framework == "pt": generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) elif self.framework == "tf": from transformers import is_tf_available if is_tf_available(): import tensorflow as tf generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:])) else: raise ValueError("TF not avaialble.") return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text} import torch from transformers import StoppingCriteria, StoppingCriteriaList class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops=[], encounters=[], device="cuda", model_max_length=None): super().__init__() assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match" self.encounters = encounters self.stops = [stop.to(device) for stop in stops] self.num_stops = [0] * len(stops) self.model_max_length = model_max_length def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for stopi, stop in enumerate(self.stops): if torch.all((stop == input_ids[0][-len(stop):])).item(): self.num_stops[stopi] += 1 if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]: # print("Stopped", flush=True) return True if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length: # critical limit return True # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True) # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True) return False def get_stopping(prompt_type, prompt_dict, tokenizer, device, human=':', bot=":", model_max_length=None): # FIXME: prompt_dict unused currently if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]: if prompt_type == PromptType.human_bot.name: # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1] # stopping only starts once output is beyond prompt # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added stop_words = [human, bot, '\n' + human, '\n' + bot] encounters = [1, 2] elif prompt_type == PromptType.instruct_vicuna.name: # even below is not enough, generic strings and many ways to encode stop_words = [ '### Human:', """ ### Human:""", """ ### Human: """, '### Assistant:', """ ### Assistant:""", """ ### Assistant: """, ] encounters = [1, 2] else: # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise stop_words = ['### End'] encounters = [1] stop_words_ids = [ tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words] # handle single token case stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids] stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0] # avoid padding in front of tokens if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids] # handle fake \n added stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)] # build stopper stopping_criteria = StoppingCriteriaList( [StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device, model_max_length=model_max_length)]) else: stopping_criteria = StoppingCriteriaList() return stopping_criteria from enum import Enum class PromptType(Enum): custom = -1 plain = 0 instruct = 1 quality = 2 human_bot = 3 dai_faq = 4 summarize = 5 simple_instruct = 6 instruct_vicuna = 7 instruct_with_end = 8 human_bot_orig = 9 prompt_answer = 10 open_assistant = 11 wizard_lm = 12 wizard_mega = 13 instruct_vicuna2 = 14 instruct_vicuna3 = 15 wizard2 = 16 wizard3 = 17 instruct_simple = 18 class DocumentChoices(Enum): All_Relevant = 0 All_Relevant_Only_Sources = 1 Only_All_Sources = 2 Just_LLM = 3 class LangChainMode(Enum): """LangChain mode""" DISABLED = "Disabled" CHAT_LLM = "ChatLLM" LLM = "LLM" ALL = "All" WIKI = "wiki" WIKI_FULL = "wiki_full" USER_DATA = "UserData" MY_DATA = "MyData" GITHUB_H2OGPT = "github h2oGPT" H2O_DAI_DOCS = "DriverlessAI docs" import ast import time from enums import PromptType # also supports imports from this file from other files non_hf_types = ['gpt4all_llama', 'llama', 'gptj'] prompt_type_to_model_name = { 'plain': [ 'EleutherAI/gpt-j-6B', 'EleutherAI/pythia-6.9b', 'EleutherAI/pythia-12b', 'EleutherAI/pythia-12b-deduped', 'EleutherAI/gpt-neox-20b', 'openlm-research/open_llama_7b_700bt_preview', 'decapoda-research/llama-7b-hf', 'decapoda-research/llama-13b-hf', 'decapoda-research/llama-30b-hf', 'decapoda-research/llama-65b-hf', 'facebook/mbart-large-50-many-to-many-mmt', 'philschmid/bart-large-cnn-samsum', 'philschmid/flan-t5-base-samsum', 'gpt2', 'distilgpt2', 'mosaicml/mpt-7b-storywriter', 'mosaicml/mpt-7b-instruct', # internal code handles instruct 'mosaicml/mpt-7b-chat', # NC, internal code handles instruct 'gptj', # internally handles prompting 'llama', # plain, or need to choose prompt_type for given TheBloke model 'gpt4all_llama', # internally handles prompting ], 'prompt_answer': [ 'h2oai/h2ogpt-gm-oasst1-en-1024-20b', 'h2oai/h2ogpt-gm-oasst1-en-1024-12b', 'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b', 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt', 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2', 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt', 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b', 'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b', 'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2', 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b', 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2', 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1', ], 'instruct': [], 'instruct_with_end': ['databricks/dolly-v2-12b'], 'quality': [], 'human_bot': [ 'h2oai/h2ogpt-oasst1-512-12b', 'h2oai/h2ogpt-oasst1-512-20b', 'h2oai/h2ogpt-oig-oasst1-256-6_9b', 'h2oai/h2ogpt-oig-oasst1-512-6_9b', 'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy 'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy 'h2oai/h2ogpt-research-oasst1-512-30b', 'h2oai/h2ogpt-oasst1-falcon-40b', 'h2oai/h2ogpt-oig-oasst1-falcon-40b', ], 'dai_faq': [], 'summarize': [], 'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'], 'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'], 'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'], "open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'], "wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'], "wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'], "instruct_simple": ['JosephusCheung/Guanaco'], } inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l} inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l} prompt_types_strings = [] for p in PromptType: prompt_types_strings.extend([p.name]) prompt_types = [] for p in PromptType: prompt_types.extend([p.name, p.value, str(p.value)]) def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=False): prompt_dict_error = '' if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict): try: prompt_dict = ast.literal_eval(prompt_dict) except BaseException as e: prompt_dict_error = str(e) if prompt_dict_error: return dict(), prompt_dict_error if prompt_type in [PromptType.custom.value, str(PromptType.custom.value), PromptType.custom.name]: promptA = prompt_dict.get('promptA', '') promptB = prompt_dict('promptB', '') PreInstruct = prompt_dict.get('PreInstruct', '') PreInput = prompt_dict.get('PreInput', '') PreResponse = prompt_dict.get('PreResponse', '') terminate_response = prompt_dict.get('terminate_response', None) chat_sep = prompt_dict.get('chat_sep', '\n') humanstr = prompt_dict.get('humanstr', '') botstr = prompt_dict.get('botstr', '') elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value), PromptType.plain.name]: promptA = promptB = PreInstruct = PreInput = PreResponse = '' terminate_response = [] chat_sep = '' humanstr = '' botstr = '' elif prompt_type == 'simple_instruct': promptA = promptB = PreInstruct = PreInput = PreResponse = None terminate_response = [] chat_sep = '\n' humanstr = '' botstr = '' elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value), PromptType.instruct.name] + [PromptType.instruct_with_end.value, str(PromptType.instruct_with_end.value), PromptType.instruct_with_end.name]: promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not ( chat and reduced) else '' promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not ( chat and reduced) else '' PreInstruct = """ ### Instruction: """ PreInput = """ ### Input: """ PreResponse = """ ### Response: """ if prompt_type in [PromptType.instruct_with_end.value, str(PromptType.instruct_with_end.value), PromptType.instruct_with_end.name]: terminate_response = ['### End'] else: terminate_response = None chat_sep = '\n' humanstr = PreInstruct botstr = PreResponse elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value), PromptType.quality.name]: promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not ( chat and reduced) else '' promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not ( chat and reduced) else '' PreInstruct = """ ### Instruction: """ PreInput = """ ### Input: """ PreResponse = """ ### Response: """ terminate_response = None chat_sep = '\n' humanstr = PreInstruct # first thing human says botstr = PreResponse # first thing bot says elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value), PromptType.human_bot.name] + [PromptType.human_bot_orig.value, str(PromptType.human_bot_orig.value), PromptType.human_bot_orig.name]: human = ':' bot = ":" if reduced or context or prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value), PromptType.human_bot.name]: preprompt = '' else: cur_date = time.strftime('%Y-%m-%d') cur_time = time.strftime('%H:%M:%S %p %Z') PRE_PROMPT = """\ Current Date: {} Current Time: {} """ preprompt = PRE_PROMPT.format(cur_date, cur_time) start = human promptB = promptA = '%s%s ' % (preprompt, start) PreInstruct = "" PreInput = None if reduced: # when making context, want it to appear as-if LLM generated, which starts with space after : PreResponse = bot + ' ' else: # normally LLM adds space after this, because was how trained. # if add space here, non-unique tokenization will often make LLM produce wrong output PreResponse = bot terminate_response = [start, PreResponse] chat_sep = '\n' humanstr = human # tag before human talks botstr = bot # tag before bot talks elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value), PromptType.dai_faq.name]: promptA = '' promptB = 'Answer the following Driverless AI question.\n' PreInstruct = """ ### Driverless AI frequently asked question: """ PreInput = None PreResponse = """ ### Driverless AI documentation answer: """ terminate_response = ['\n\n'] chat_sep = terminate_response humanstr = PreInstruct botstr = PreResponse elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value), PromptType.summarize.name]: promptA = promptB = PreInput = '' PreInstruct = '## Main Text\n\n' PreResponse = '\n\n## Summary\n\n' terminate_response = None chat_sep = '\n' humanstr = PreInstruct botstr = PreResponse elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value), PromptType.instruct_vicuna.name]: promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \ "The assistant gives helpful, detailed, and polite answers to the human's questions." if not ( chat and reduced) else '' PreInstruct = """ ### Human: """ PreInput = None PreResponse = """ ### Assistant: """ terminate_response = [ '### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate chat_sep = '\n' humanstr = PreInstruct botstr = PreResponse elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value), PromptType.prompt_answer.name]: preprompt = '' prompt_tokens = "<|prompt|>" answer_tokens = "<|answer|>" start = prompt_tokens promptB = promptA = '%s%s' % (preprompt, start) PreInstruct = "" PreInput = None PreResponse = answer_tokens eos = '<|endoftext|>' # neox eos terminate_response = [start, PreResponse, eos] chat_sep = eos humanstr = prompt_tokens botstr = answer_tokens elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value), PromptType.open_assistant.name]: # From added_tokens.json preprompt = '' prompt_tokens = "<|prompter|>" answer_tokens = "<|assistant|>" start = prompt_tokens promptB = promptA = '%s%s' % (preprompt, start) PreInstruct = "" PreInput = None PreResponse = answer_tokens pend = "<|prefix_end|>" eos = "" terminate_response = [start, PreResponse, pend, eos] chat_sep = eos humanstr = prompt_tokens botstr = answer_tokens elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value), PromptType.wizard_lm.name]: # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py preprompt = '' start = '' promptB = promptA = '%s%s' % (preprompt, start) PreInstruct = "" PreInput = None PreResponse = "\n\n### Response\n" eos = "" terminate_response = [PreResponse, eos] chat_sep = eos humanstr = promptA botstr = PreResponse elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value), PromptType.wizard_mega.name]: preprompt = '' start = '' promptB = promptA = '%s%s' % (preprompt, start) PreInstruct = """ ### Instruction: """ PreInput = None PreResponse = """ ### Assistant: """ terminate_response = [PreResponse] chat_sep = '\n' humanstr = PreInstruct botstr = PreResponse elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value), PromptType.instruct_vicuna2.name]: promptA = promptB = "" if not ( chat and reduced) else '' PreInstruct = """ HUMAN: """ PreInput = None PreResponse = """ ASSISTANT: """ terminate_response = [ 'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate chat_sep = '\n' humanstr = PreInstruct botstr = PreResponse elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value), PromptType.instruct_vicuna3.name]: promptA = promptB = "" if not ( chat and reduced) else '' PreInstruct = """ ### User: """ PreInput = None PreResponse = """ ### Assistant: """ terminate_response = [ '### User:'] # but only allow terminate after prompt is found correctly, else can't terminate chat_sep = '\n' humanstr = PreInstruct botstr = PreResponse elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value), PromptType.wizard2.name]: # https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.""" start = '' promptB = promptA = '%s%s' % (preprompt, start) PreInstruct = """ ### Instruction: """ PreInput = None PreResponse = """ ### Response: """ terminate_response = [PreResponse] chat_sep = '\n' humanstr = PreInstruct botstr = PreResponse elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value), PromptType.wizard3.name]: # https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""" start = '' promptB = promptA = '%s%s' % (preprompt, start) PreInstruct = """USER: """ PreInput = None PreResponse = """ASSISTANT: """ terminate_response = [PreResponse] chat_sep = '\n' humanstr = PreInstruct botstr = PreResponse elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value), PromptType.instruct_simple.name]: promptA = '' if not (chat and reduced) else '' promptB = '' if not (chat and reduced) else '' PreInstruct = """ ### Instruction: """ PreInput = """ ### Input: """ PreResponse = """ ### Response: """ terminate_response = None chat_sep = '\n' humanstr = PreInstruct botstr = PreResponse else: raise RuntimeError("No such prompt_type=%s" % prompt_type) if return_dict: return dict(promptA=promptA, promptB=promptB, PreInstruct=PreInstruct, PreInput=PreInput, PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep, humanstr=humanstr, botstr=botstr), '' else: return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep, humanstr, botstr def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced): context = data_point.get('context') if context is None: context = '' instruction = data_point.get('instruction') input = data_point.get('input') output = data_point.get('output') prompt_type = data_point.get('prompt_type', prompt_type) prompt_dict = data_point.get('prompt_dict', prompt_dict) assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type promptA, promptB, PreInstruct, PreInput, PreResponse, \ terminate_response, chat_sep, humanstr, botstr = get_prompt(prompt_type, prompt_dict, chat, context, reduced) prompt = context if not reduced else '' if input and promptA: prompt += f"""{promptA}""" elif promptB: prompt += f"""{promptB}""" if instruction and PreInstruct is not None and input and PreInput is not None: prompt += f"""{PreInstruct}{instruction}{PreInput}{input}""" prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) elif instruction and input and PreInstruct is None and PreInput is not None: prompt += f"""{PreInput}{instruction} {input}""" prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) elif input and instruction and PreInput is None and PreInstruct is not None: prompt += f"""{PreInstruct}{instruction} {input}""" prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) elif instruction and PreInstruct is not None: prompt += f"""{PreInstruct}{instruction}""" prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) elif input and PreInput is not None: prompt += f"""{PreInput}{input}""" prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) elif input and instruction and PreInput is not None: prompt += f"""{PreInput}{instruction}{input}""" prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) elif input and instruction and PreInstruct is not None: prompt += f"""{PreInstruct}{instruction}{input}""" prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) elif input and instruction: # i.e. for simple_instruct prompt += f"""{instruction}: {input}""" prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) elif input: prompt += f"""{input}""" prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) elif instruction: prompt += f"""{instruction}""" prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) if PreResponse is not None: prompt += f"""{PreResponse}""" pre_response = PreResponse # Don't use strip else: pre_response = '' if output: prompt += f"""{output}""" return prompt, pre_response, terminate_response, chat_sep def inject_chatsep(prompt_type, prompt, chat_sep=None): if chat_sep: # only add new line if structured prompt, while 'plain' is just generation of next tokens from input prompt += chat_sep return prompt class Prompter(object): def __init__(self, prompt_type, prompt_dict, debug=False, chat=False, stream_output=False, repeat_penalty=True, allowed_repeat_line_length=10): self.prompt_type = prompt_type self.prompt_dict = prompt_dict data_point = dict(instruction='', input='', output='') _, self.pre_response, self.terminate_response, self.chat_sep = \ generate_prompt(data_point, self.prompt_type, self.prompt_dict, chat, False) self.debug = debug self.chat = chat self.stream_output = stream_output self.repeat_penalty = repeat_penalty self.allowed_repeat_line_length = allowed_repeat_line_length self.prompt = None context = "" # not for chat context reduced = False # not for chat context self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \ self.terminate_response, self.chat_sep, self.humanstr, self.botstr = \ get_prompt(self.prompt_type, self.prompt_dict, chat, context, reduced) def generate_prompt(self, data_point): reduced = False prompt, _, _, _ = generate_prompt(data_point, self.prompt_type, self.prompt_dict, self.chat, reduced) if self.debug: print("prompt: %s" % prompt, flush=True) self.prompt = prompt return prompt def get_response(self, outputs, prompt=None, sanitize_bot_response=False): if isinstance(outputs, str): outputs = [outputs] if self.debug: print("output:\n%s" % '\n\n'.join(outputs), flush=True) if prompt is not None: self.prompt = prompt def clean_response(response): meaningless_words = ['', '', '<|endoftext|>'] for word in meaningless_words: response = response.replace(word, "") if sanitize_bot_response: from better_profanity import profanity response = profanity.censor(response) response = response.strip("\n") return response def clean_repeats(response): lines = response.split('\n') new_lines = [] [new_lines.append(line) for line in lines if line not in new_lines or len(line) < self.allowed_repeat_line_length] if self.debug and len(lines) != len(new_lines): print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True) response = '\n'.join(new_lines) return response multi_output = len(outputs) > 1 for oi, output in enumerate(outputs): if self.prompt_type in [PromptType.plain.value, str(PromptType.plain.value), PromptType.plain.name]: output = clean_response(output) elif prompt is None: # then use most basic parsing like pipeline if self.botstr in output: if self.humanstr: output = clean_response(output.split(self.botstr)[1].strip().split(self.humanstr)[0].strip()) else: # i.e. use after bot but only up to next bot output = clean_response(output.split(self.botstr)[1].strip().split(self.botstr)[0].strip()) else: # output = clean_response(output.strip()) # assume just not printed yet output = "" else: # find first instance of prereponse # prompt sometimes has odd characters, that mutate length, # so can't go by length alone if self.pre_response: outputi = output.find(prompt) if outputi >= 0: output = output[outputi + len(prompt):] allow_terminate = True else: # subtraction is risky due to space offsets sometimes, so only do if necessary output = output[len(prompt) - len(self.pre_response):] # [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat) if self.pre_response in output: output = output.split(self.pre_response)[1] allow_terminate = True else: if output: print("Failure of parsing or not enough output yet: %s" % output, flush=True) allow_terminate = False else: allow_terminate = True output = output[len(prompt):] # clean after subtract prompt out, so correct removal of pre_response output = clean_response(output).strip() if self.repeat_penalty: output = clean_repeats(output).strip() if self.terminate_response and allow_terminate: finds = [] for term in self.terminate_response: finds.append(output.find(term)) finds = [x for x in finds if x >= 0] if len(finds) > 0: termi = finds[0] output = output[:termi].strip() else: output = output.strip() else: output = output.strip() if multi_output: # prefix with output counter output = "\n=========== Output %d\n\n" % (1 + oi) + output if oi > 0: # post fix outputs with seperator output += '\n' outputs[oi] = output # join all outputs, only one extra new line between outputs output = '\n'.join(outputs) if self.debug: print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True) return output