File size: 6,124 Bytes
8824f88
 
 
 
 
 
 
 
 
5ac471c
192a9de
8824f88
f2e9284
 
 
 
 
 
 
 
 
 
 
 
8824f88
 
 
 
 
14acdab
8824f88
 
 
faf8f3f
36ff73f
9a7bc17
36ff73f
 
8824f88
 
 
 
 
 
 
8ef0569
 
 
b2015f4
8824f88
 
faf8f3f
8824f88
a5f97a2
 
faf8f3f
151d4c2
faf8f3f
f2e9284
faf8f3f
5b4300b
1ba36bf
73d0fad
8824f88
 
 
 
 
96b060f
99ab088
 
 
 
8ef0569
 
 
 
 
b2015f4
99ab088
 
 
8824f88
99ab088
 
 
5f9f635
 
f473751
 
5f9f635
f473751
 
8824f88
 
 
 
 
 
 
 
 
 
 
 
8ef0569
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2015f4
 
 
 
 
 
 
8824f88
 
 
f2e9284
 
 
8824f88
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python

import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from peft import PeftModel, PeftConfig

DESCRIPTION = "This is a conversational interface powered by the MergeLlama-7b model, a finetune of CodeLlama-7b designed to assist developers in resolving merge conflicts in their code. "
DESCRIPTION += "It leverages the capabilities of deep learning to provide suggestions for reconciling code differences, presenting potential resolutions for highlighted changes\n"
DESCRIPTION += "The feedback from this space will help develop future versions including more powerful 13b and 34b versions."

DESCRIPTION += "\n# How to use: \n"
DESCRIPTION += "1. Input your merge conflict in the chat in the following format:\n```\n<<<<<<<\n[change]\n=======\n[base]\n>>>>>>>\n```\n"
DESCRIPTION += "The model will generate the merge resolution. Context can be added before the conflict and multiple conflicts/resolutions can be chained together for context.\n"
DESCRIPTION += "**Additional Information:**\n"
DESCRIPTION += "- The model behind this tool is based on the MergeLlama dataset, which can be found [here](https://huggingface.co/datasets/codys12/MergeLlama).\n"
DESCRIPTION += "- For more information about the MergeLlama-7b model, visit [here](https://huggingface.co/codys12/MergeLlama-7b).\n"
DESCRIPTION += "- If you are interested in supporting the larger versions of this model, such as the 13b and 34b variants, you can check them out [here](https://www.dreamcatcher.co/ProjectPage?projectId=uibaxk4sfzetpkg7ch71ui).\n"
DESCRIPTION += "- This model was trained on [DreamcatcherAI](https://www.dreamcatcher.co/Discover)\n"

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 256
MAX_INPUT_TOKEN_LENGTH = 4096

if torch.cuda.is_available():
    model_id = "codys12/MergeLlama-7b"
    model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float16, device_map=0, cache_dir="/data")
    tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf", trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"


@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    #temperature: float = 0.6,
    #top_p: float = 0.9,
    #top_k: int = 50,
    #repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    current_input = ""
    for user, assistant in chat_history:
        current_input += user
        current_input += assistant

    history = current_input
    current_input += message
    current_input += "\n"
    
    device = "cuda:0"
    input_ids = tokenizer(current_input, return_tensors="pt").input_ids.to(device)


    if len(input_ids) > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[-MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        #do_sample=True,
        #top_p=top_p,
        #top_k=top_k,
        #temperature=temperature,
        #num_beams=1,
        #repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        combined_text = "".join(outputs)
        if "<<<<<<<" in combined_text:
            combined_text = combined_text.replace("<<<<<<<", "")  # Remove the unwanted string
            yield combined_text
            break
        else:
            yield combined_text


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        # gr.Slider(
        #     label="Temperature",
        #     minimum=0.1,
        #     maximum=4.0,
        #     step=0.1,
        #     value=0.6,
        # ),
        # gr.Slider(
        #     label="Top-p (nucleus sampling)",
        #     minimum=0.05,
        #     maximum=1.0,
        #     step=0.05,
        #     value=0.9,
        # ),
        # gr.Slider(
        #     label="Top-k",
        #     minimum=1,
        #     maximum=1000,
        #     step=1,
        #     value=50,
        # ),
        # gr.Slider(
        #     label="Repetition penalty",
        #     minimum=0.1,
        #     maximum=2.0,
        #     step=0.05,
        #     value=1.2,
        # ),
    ],
    stop_btn=None,
    examples=[
        ["<<<<<<<\n            var visibleSets = beatmapSets.Where(s => !s.Filtered).ToList();\n            if (!visibleSets.Any())\n                return;\n\n=======\n\n            var visible = beatmapSets.Where(s => !s.Filtered).ToList();\n            if (!visible.Any())\n                return false;\n\n>>>>>>>"],
        ["<<<<<<<\n// Related to JDK7\nimport java.nio.channels.FileChannel;\n\n=======\n\n// Branch-dependent imports\nimport java.nio.channels.SeekableByteChannel;\n\n>>>>>>>"],
        ["<<<<<<<\n    bind(BlobDirectoryAccess.class, DefaultBlobDirectoryAccess.class);\n\n=======\n\n    bind(new TypeLiteral<UpdateStepRepositoryMetadataAccess<Path>>() {}).to(new TypeLiteral<MetadataStore>() {});\n\n>>>>>>>"],
    ],
)

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(
        value="Duplicate Space for private use",
        elem_id="duplicate-button",
        visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
    )
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()