#!/usr/bin/env python import os from threading import Thread from typing import Iterator import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from peft import PeftModel, PeftConfig DESCRIPTION = "# MergeLlama-7b\nThis is a conversational interface powered by the MergeLlama-7b model, a finetune of CodeLlama-7b designed to assist developers in resolving merge conflicts in their code. " DESCRIPTION += "It leverages the capabilities of deep learning to provide suggestions for reconciling code differences, presenting potential resolutions for highlighted changes\n" DESCRIPTION += "The feedback from this space will help develop future versions including more powerful 13b and 34b versions." DESCRIPTION += "\n# How to use: \n" DESCRIPTION += "1. Input your merge conflict in the chat in the following format:\n```\n<<<<<<<\n[Current change]\n=======\n[Incoming change]\n>>>>>>>\n```\n" DESCRIPTION += "The model will generate the merge resolution. Context can be added before the conflict and multiple conflicts/resolutions can be chained together for context.\n" DESCRIPTION += "**Additional Information:**\n" DESCRIPTION += "- The model behind this tool is based on the MergeLlama dataset, which can be found [here](https://huggingface.co/datasets/codys12/MergeLlama).\n" DESCRIPTION += "- For more information about the MergeLlama-7b model, visit [here](https://huggingface.co/codys12/MergeLlama-7b).\n" DESCRIPTION += "- If you are interested in supporting the larger versions of this model, such as the 13b and 34b variants, you can check them out [here](https://www.dreamcatcher.co/ProjectPage?projectId=uibaxk4sfzetpkg7ch71ui).\n" DESCRIPTION += "- This model was trained on [DreamcatcherAI](https://www.dreamcatcher.co/Discover)\n" if not torch.cuda.is_available(): DESCRIPTION += "\n
Running on CPU 🥶 This demo does not work on CPU.
" MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 256 MAX_INPUT_TOKEN_LENGTH = 4096 if torch.cuda.is_available(): model_id = "codys12/MergeLlama-7b" model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float16, device_map=0, cache_dir="/data") tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf", trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" @spaces.GPU def generate( message: str, chat_history: list[tuple[str, str]], max_new_tokens: int = 1024, #temperature: float = 0.6, #top_p: float = 0.9, #top_k: int = 50, #repetition_penalty: float = 1.2, ) -> Iterator[str]: conversation = [] current_input = "" for user, assistant in chat_history: current_input += user current_input += assistant history = current_input current_input += message current_input += "\n" device = "cuda:0" input_ids = tokenizer(current_input, return_tensors="pt").input_ids.to(device) if len(input_ids) > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[-MAX_INPUT_TOKEN_LENGTH:] gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids}, streamer=streamer, max_new_tokens=max_new_tokens, #do_sample=True, #top_p=top_p, #top_k=top_k, #temperature=temperature, #num_beams=1, #repetition_penalty=repetition_penalty, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) combined_text = "".join(outputs) if "<<<<<<<" in combined_text: combined_text = combined_text.replace("<<<<<<<", "") # Remove the unwanted string yield combined_text break else: yield combined_text chat_interface = gr.ChatInterface( fn=generate, additional_inputs=[ gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ), # gr.Slider( # label="Temperature", # minimum=0.1, # maximum=4.0, # step=0.1, # value=0.6, # ), # gr.Slider( # label="Top-p (nucleus sampling)", # minimum=0.05, # maximum=1.0, # step=0.05, # value=0.9, # ), # gr.Slider( # label="Top-k", # minimum=1, # maximum=1000, # step=1, # value=50, # ), # gr.Slider( # label="Repetition penalty", # minimum=0.1, # maximum=2.0, # step=0.05, # value=1.2, # ), ], stop_btn=None, examples=[ ["<<<<<<<\nlet x = max(y, 11)\n=======\nvar x = max(y, 12, z)\n>>>>>>>"], ["<<<<<<<\nclass Calculator { \nadd(a, b) {\n return a + b;\n }\n}\n=======\nclass Calculator {\n subtract(a, b) {\n return a - b;\n }\n}\n>>>>>>>"], ["<<<<<<<\nfunction greet(name) {\n return `Hello, ${name}! Have a good day.`;\n}\n=======\nfunction greet(name, time) {\n return `Good ${time}, ${name}!`;\n}\n>>>>>>>"], ["<<<<<<<\nconst user = {\n name: 'John',\n age: 30\n}\n=======\nconst user = {\n name: 'John',\n email: 'john@example.com'\n}\n>>>>>>>"], ["<<<<<<<\n.btn {\n background-color: blue;\n padding: 10px 20px;\n}\n=======\n.btn {\n border: 1px solid black;\n font-size: 16px;\n}\n>>>>>>>"], ["<<<<<<<\n var visibleSets = beatmapSets.Where(s => !s.Filtered).ToList();\n if (!visibleSets.Any())\n return;\n\n=======\n\n var visible = beatmapSets.Where(s => !s.Filtered).ToList();\n if (!visible.Any())\n return false;\n\n>>>>>>>"], ], ) with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) gr.DuplicateButton( value="Duplicate Space for private use", elem_id="duplicate-button", visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", ) chat_interface.render() if __name__ == "__main__": demo.queue(max_size=20).launch()