Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from threading import Thread | |
from typing import Iterator | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline | |
MAX_MAX_NEW_TOKENS = 1024 | |
DEFAULT_MAX_NEW_TOKENS = 512 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
DESCRIPTION = """\ | |
# Chat with Patched Coder | |
""" | |
LICENSE = """\ | |
--- | |
This space was created by [patched](https://patched.codes). | |
""" | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
if torch.cuda.is_available(): | |
model_id = "Qwen/Qwen1.5-7B-Chat" | |
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
tokenizer.padding_side = 'right' | |
# pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
# tokenizer.use_default_system_prompt = False | |
def generate( | |
message: str, | |
chat_history: list[tuple[str, str]], | |
system_prompt: str, | |
max_new_tokens: int = 1024, | |
temperature: float = 0.2, | |
top_p: float = 0.95, | |
# top_k: int = 50, | |
# repetition_penalty: float = 1.2, | |
) -> Iterator[str]: | |
conversation = [] | |
if system_prompt: | |
conversation.append({"role": "system", "content": system_prompt}) | |
for user, assistant in chat_history: | |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) | |
conversation.append({"role": "user", "content": message}) | |
# prompt = pipe.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) | |
# outputs = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, | |
# eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id) | |
# return outputs[0]['generated_text'][len(prompt):].strip() | |
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
input_ids = input_ids.to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
{"input_ids": input_ids}, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
#top_k=top_k, | |
temperature=temperature, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.pad_token_id, | |
#num_beams=1, | |
#repetition_penalty=1.2, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
example1='''Fix vulnerability CWE-327: Use of a Broken or Risky Cryptographic Algorithm in the following code snippet. | |
def md5_hash(path): | |
with open(path, "rb") as f: | |
content = f.read() | |
return hashlib.md5(content).hexdigest() | |
''' | |
example2='''You are a software engineer who is best in the world at summarizing code changes. | |
Carefully analyze the given old code and new code and generate a summary of the changes. | |
Old Code: | |
#include <stdio.h> | |
#include <stdlib.h> | |
typedef struct Node { | |
int data; | |
struct Node *next; | |
} Node; | |
void processList() { | |
Node *head = (Node*)malloc(sizeof(Node)); | |
head->data = 1; | |
head->next = (Node*)malloc(sizeof(Node)); | |
head->next->data = 2; | |
printf("First element: %d\n", head->data); | |
free(head->next); | |
free(head); | |
printf("Accessing freed list: %d\n", head->next->data); | |
} | |
New Code: | |
#include <stdio.h> | |
#include <stdlib.h> | |
typedef struct Node { | |
int data; | |
struct Node *next; | |
} Node; | |
void processList() { | |
Node *head = (Node*)malloc(sizeof(Node)); | |
if (head == NULL) { | |
perror("Failed to allocate memory for head"); | |
return; | |
} | |
head->data = 1; | |
head->next = (Node*)malloc(sizeof(Node)); | |
if (head->next == NULL) { | |
free(head); | |
perror("Failed to allocate memory for next node"); | |
return; | |
} | |
head->next->data = 2; | |
printf("First element: %d\n", head->data); | |
free(head->next); | |
head->next = NULL; | |
free(head); | |
head = NULL; | |
if (head != NULL && head->next != NULL) { | |
printf("Accessing freed list: %d\n", head->next->data); | |
} | |
} | |
''' | |
example3='''Is the following code prone to CWE-117: Improper Output Neutralization for Logs. Respond only with YES or NO. | |
from flask import Flask, request, jsonify | |
import logging | |
app = Flask(__name__) | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
@app.route('/api/data', methods=['GET']) | |
def get_data(): | |
api_key = request.args.get('api_key') | |
logger.info("Received request with API Key: %s", api_key) | |
data = {"message": "Data processed"} | |
return jsonify(data) | |
''' | |
example4='''Fix vulnerability CWE-78: Improper Neutralization of Special Elements used in an OS Command ('OS Command Injection') in the following code snippet. | |
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str: | |
if desc is not None: | |
print(desc) | |
run_kwargs = {{ | |
"args": command, | |
"shell": True, | |
"env": os.environ if custom_env is None else custom_env, | |
"encoding": 'utf8', | |
"errors": 'ignore', | |
}} | |
if not live: | |
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE | |
result = subprocess.run(**run_kwargs) ##here | |
if result.returncode != 0: | |
error_bits = [ | |
f"{{errdesc or 'Error running command'}}.", | |
f"Command: {{command}}", | |
f"Error code: {{result.returncode}}", | |
] | |
if result.stdout: | |
error_bits.append(f"stdout: {{result.stdout}}") | |
if result.stderr: | |
error_bits.append(f"stderr: {{result.stderr}}") | |
raise RuntimeError("\n".join(error_bits)) | |
return (result.stdout or "") | |
''' | |
chat_interface = gr.ChatInterface( | |
fn=generate, | |
chatbot=gr.Chatbot(height="480px"), | |
additional_inputs=[ | |
gr.Textbox(label="System prompt", lines=4), | |
gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
), | |
gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=4.0, | |
step=0.1, | |
value=0.2, | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=0.95, | |
), | |
], | |
stop_btn=None, | |
examples=[ | |
["You are a helpful coding assistant. Create a snake game in Python."], | |
[example1], | |
[example2], | |
[example3], | |
[example4], | |
], | |
) | |
with gr.Blocks(css="style.css",) as demo: | |
gr.Markdown(DESCRIPTION) | |
chat_interface.render() | |
gr.Markdown(LICENSE, elem_classes="contain") | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() | |