import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
MAX_MAX_NEW_TOKENS = 1024
DEFAULT_MAX_NEW_TOKENS = 512
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# Chat with Patched Coder
"""
LICENSE = """\
---
This space was created by [patched](https://patched.codes).
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n
Running on CPU 🥶 This demo does not work on CPU.
"
if torch.cuda.is_available():
model_id = "Qwen/Qwen1.5-7B-Chat"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.padding_side = 'right'
# pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
# tokenizer.use_default_system_prompt = False
@spaces.GPU(duration=60)
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.2,
top_p: float = 0.95,
# top_k: int = 50,
# repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
# prompt = pipe.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
# outputs = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p,
# eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)
# return outputs[0]['generated_text'][len(prompt):].strip()
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
#top_k=top_k,
temperature=temperature,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
#num_beams=1,
#repetition_penalty=1.2,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
example1='''Fix vulnerability CWE-327: Use of a Broken or Risky Cryptographic Algorithm in the following code snippet.
def md5_hash(path):
with open(path, "rb") as f:
content = f.read()
return hashlib.md5(content).hexdigest()
'''
example2='''You are a software engineer who is best in the world at summarizing code changes.
Carefully analyze the given old code and new code and generate a summary of the changes.
Old Code:
#include
#include
typedef struct Node {
int data;
struct Node *next;
} Node;
void processList() {
Node *head = (Node*)malloc(sizeof(Node));
head->data = 1;
head->next = (Node*)malloc(sizeof(Node));
head->next->data = 2;
printf("First element: %d\n", head->data);
free(head->next);
free(head);
printf("Accessing freed list: %d\n", head->next->data);
}
New Code:
#include
#include
typedef struct Node {
int data;
struct Node *next;
} Node;
void processList() {
Node *head = (Node*)malloc(sizeof(Node));
if (head == NULL) {
perror("Failed to allocate memory for head");
return;
}
head->data = 1;
head->next = (Node*)malloc(sizeof(Node));
if (head->next == NULL) {
free(head);
perror("Failed to allocate memory for next node");
return;
}
head->next->data = 2;
printf("First element: %d\n", head->data);
free(head->next);
head->next = NULL;
free(head);
head = NULL;
if (head != NULL && head->next != NULL) {
printf("Accessing freed list: %d\n", head->next->data);
}
}
'''
example3='''Is the following code prone to CWE-117: Improper Output Neutralization for Logs. Respond only with YES or NO.
from flask import Flask, request, jsonify
import logging
app = Flask(__name__)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@app.route('/api/data', methods=['GET'])
def get_data():
api_key = request.args.get('api_key')
logger.info("Received request with API Key: %s", api_key)
data = {"message": "Data processed"}
return jsonify(data)
'''
example4='''Fix vulnerability CWE-78: Improper Neutralization of Special Elements used in an OS Command ('OS Command Injection') in the following code snippet.
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
if desc is not None:
print(desc)
run_kwargs = {{
"args": command,
"shell": True,
"env": os.environ if custom_env is None else custom_env,
"encoding": 'utf8',
"errors": 'ignore',
}}
if not live:
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
result = subprocess.run(**run_kwargs) ##here
if result.returncode != 0:
error_bits = [
f"{{errdesc or 'Error running command'}}.",
f"Command: {{command}}",
f"Error code: {{result.returncode}}",
]
if result.stdout:
error_bits.append(f"stdout: {{result.stdout}}")
if result.stderr:
error_bits.append(f"stderr: {{result.stderr}}")
raise RuntimeError("\n".join(error_bits))
return (result.stdout or "")
'''
chat_interface = gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(height="480px"),
additional_inputs=[
gr.Textbox(label="System prompt", lines=4),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.2,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.95,
),
],
stop_btn=None,
examples=[
["You are a helpful coding assistant. Create a snake game in Python."],
[example1],
[example2],
[example3],
[example4],
],
)
with gr.Blocks(css="style.css",) as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
gr.Markdown(LICENSE, elem_classes="contain")
if __name__ == "__main__":
demo.queue(max_size=20).launch()