add readme
Browse files- README.md +138 -0
- data/example/example_error.jsonl +1 -0
- data/example/example_fix.jsonl +1 -0
- data/example/example_fix.jsonl_results.jsonl +1 -0
- data/example/example_problem.jsonl +1 -0
- evaluation/fixeval.jsonl +0 -0
- evaluation/fixeval_problems.jsonl +0 -0
- figures/evaluation.png +0 -0
README.md
CHANGED
@@ -1,3 +1,141 @@
|
|
1 |
---
|
2 |
license: gpl-3.0
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: gpl-3.0
|
3 |
+
base_model: deepseek-ai/deepseek-coder-7b-instruct-v1.5
|
4 |
+
library_name: peft
|
5 |
+
|
6 |
---
|
7 |
+
|
8 |
+
# OriGen: Enhancing RTL Code Generation with Code-to-Code Augmentation and Self-Reflection
|
9 |
+
|
10 |
+
### Introduction
|
11 |
+
OriGen is a fine-tuned lora model designed for Verilog code generation. It is trained on top of DeepSeek Coder 7B using datasets generated from code-to-code augmentation and self-reflection.
|
12 |
+
|
13 |
+
OriGen_Fix is a fine-tuned lora model designed for fixing syntax errors in Verilog code. It is trained based on OriGen.
|
14 |
+
|
15 |
+
|
16 |
+
The models have been uploaded to Hugging Face, and the repository contains the inference scripts. The dataset and data generation flow will be released soon.
|
17 |
+
|
18 |
+
- **Huggingface**:
|
19 |
+
- https://huggingface.co/henryen/OriGen
|
20 |
+
- https://huggingface.co/henryen/OriGen_Fix
|
21 |
+
- **Repository**: https://github.com/pku-liang/OriGen
|
22 |
+
|
23 |
+
### Evaluation Results
|
24 |
+
<img src="figures/evaluation.png" alt="evaluation" width="1000"/>
|
25 |
+
|
26 |
+
### Quick Start
|
27 |
+
|
28 |
+
Before running the following code, please install the required packages:
|
29 |
+
|
30 |
+
```bash
|
31 |
+
conda create -n origen python=3.11
|
32 |
+
conda activate origen
|
33 |
+
pip install -r requirements.txt
|
34 |
+
```
|
35 |
+
|
36 |
+
Here is an example of how to use the model. Please note that the base model, DeepSeek Coder 7B, is loaded in float16 precision, even though its default precision is bfloat16. This choice was made because our experiments showed that Lora trained in float16 outperforms those trained in bfloat16.
|
37 |
+
|
38 |
+
```python
|
39 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
40 |
+
import torch
|
41 |
+
from peft import PeftModel
|
42 |
+
import json
|
43 |
+
|
44 |
+
model_name = "deepseek-ai/deepseek-coder-7b-instruct-v1.5"
|
45 |
+
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
47 |
+
|
48 |
+
model = AutoModelForCausalLM.from_pretrained(
|
49 |
+
model_name,
|
50 |
+
low_cpu_mem_usage=True,
|
51 |
+
torch_dtype=torch.float16,
|
52 |
+
attn_implementation="flash_attention_2",
|
53 |
+
device_map="auto",
|
54 |
+
).to("cuda")
|
55 |
+
|
56 |
+
model = PeftModel.from_pretrained(model, model_id="henryen/OriGen_Fix")
|
57 |
+
model.eval()
|
58 |
+
|
59 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
60 |
+
|
61 |
+
prompt_template = """
|
62 |
+
### Instruction: Please act as a professional Verilog designer. Your task is to debug a Verilog module.\nYou will receive a task, an original verilog code with syntax and function errors, and the corresponding error messages. \nYou should generate a corrected code based on the original code and the error messages
|
63 |
+
Your task:
|
64 |
+
{description}
|
65 |
+
Original code:
|
66 |
+
{original_code}
|
67 |
+
Error message:
|
68 |
+
{error}
|
69 |
+
You should now generate a correct code.
|
70 |
+
### Response:{header}
|
71 |
+
"""
|
72 |
+
|
73 |
+
def generate_code(data):
|
74 |
+
description = data["description"]
|
75 |
+
original_code = data["original_code"]
|
76 |
+
error = data["error"]
|
77 |
+
header = data["module_header"]
|
78 |
+
prompt = prompt_template.format(description=description, original_code=original_code, error=error, header=header)
|
79 |
+
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
80 |
+
|
81 |
+
outputs = model.generate(
|
82 |
+
**inputs,
|
83 |
+
max_new_tokens=1000,
|
84 |
+
do_sample=False,
|
85 |
+
temperature=0,
|
86 |
+
eos_token_id=tokenizer.eos_token_id,
|
87 |
+
pad_token_id=tokenizer.pad_token_id,
|
88 |
+
streamer=streamer
|
89 |
+
)
|
90 |
+
|
91 |
+
input_length = len(inputs[0])
|
92 |
+
completion = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
|
93 |
+
return completion
|
94 |
+
|
95 |
+
input_file = "./data/example/example_error.jsonl"
|
96 |
+
output_file = "./data/example/example_fix.jsonl"
|
97 |
+
|
98 |
+
with open(input_file, "r") as f, open(output_file, "w") as f2:
|
99 |
+
for line in f:
|
100 |
+
data = json.loads(line)
|
101 |
+
completion = generate_code(data)
|
102 |
+
json.dump({"task_id": data["task_id"], "completion": completion}, f2)
|
103 |
+
f2.write("\n")
|
104 |
+
```
|
105 |
+
|
106 |
+
|
107 |
+
The output will be:
|
108 |
+
```verilog
|
109 |
+
wire and0_out;
|
110 |
+
wire and1_out;
|
111 |
+
wire or0_out;
|
112 |
+
|
113 |
+
assign and0_out = a & b;
|
114 |
+
assign and1_out = c & d;
|
115 |
+
assign or0_out = and0_out | and1_out;
|
116 |
+
assign out = or0_out;
|
117 |
+
assign out_n = ~or0_out;
|
118 |
+
|
119 |
+
endmodule
|
120 |
+
```
|
121 |
+
|
122 |
+
You can check its correctness using testbench provided under the folder `./data/example/`.
|
123 |
+
```bash
|
124 |
+
cd ./data/example/
|
125 |
+
evaluate_functional_correctness example_fix.jsonl --problem_file example_problem.jsonl
|
126 |
+
```
|
127 |
+
|
128 |
+
|
129 |
+
### Paper
|
130 |
+
**Arxiv:** https://arxiv.org/abs/2407.16237
|
131 |
+
|
132 |
+
Please cite our paper if you find this model useful.
|
133 |
+
|
134 |
+
```
|
135 |
+
@article{2024origen,
|
136 |
+
title={OriGen: Enhancing RTL Code Generation with Code-to-Code Augmentation and Self-Reflection},
|
137 |
+
author={Cui, Fan and Yin, Chenyang and Zhou, Kexing and Xiao, Youwei and Sun, Guangyu and Xu, Qiang and Guo, Qipeng and Song, Demin and Lin, Dahua and Zhang, Xingcheng and others},
|
138 |
+
journal={arXiv preprint arXiv:2407.16237},
|
139 |
+
year={2024}
|
140 |
+
}
|
141 |
+
```
|
data/example/example_error.jsonl
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"task_id": "wire_decl", "description": "Implement the following circuit. Create two intermediate wires (named anything you want) to connect the AND and OR gates together. Note that the wire that feeds the NOT gate is really wire `out`, so you do not necessarily need to declare a third wire here. Notice how wires are driven by exactly one source (output of a gate), but can feed multiple inputs.\n\n// The circuit is composed of two layers. The first layer, counting from the input, is two AND gates: one whose input is connected to a and b, and the second is connected to c and d. The second layer there is an OR gate to OR the two AND outputs, connected the output 'out'. Additionally, there is an inverted output 'out_n'.", "module_header": "module top_module (\n\tinput a,\n\tinput b,\n\tinput c,\n\tinput d,\n\toutput out,\n\toutput out_n );\n", "original_code": "module top_module (\n\tinput a,\n\tinput b,\n\tinput c,\n\tinput d,\n\toutput out,\n\toutput out_n );\n wire and0_out;\n wire and1_out;\n wire or0_out;\n\n and and0 (\n .a(a),\n .b(b),\n .o(and0_out)\n );\n and and1 (\n .a(c),\n .b(d),\n .o(and1_out)\n );\n or or0 (\n .a(and0_out),\n .b(and1_out),\n .o(or0_out)\n );\n assign out = or0_out;\n assign out_n = ~or0_out;\n endmodule", "error": "17: error: Gates do not have port names.\n22: error: Gates do not have port names.\n12: error: Gates do not have port names.\n"}
|
data/example/example_fix.jsonl
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"task_id": "wire_decl", "completion": " wire and0_out;\n wire and1_out;\n wire or0_out;\n\n assign and0_out = a & b;\n assign and1_out = c & d;\n assign or0_out = and0_out | and1_out;\n assign out = or0_out;\n assign out_n = ~or0_out;\n\nendmodule\n"}
|
data/example/example_fix.jsonl_results.jsonl
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"task_id": "wire_decl", "completion": " wire and0_out;\n wire and1_out;\n wire or0_out;\n\n assign and0_out = a & b;\n assign and1_out = c & d;\n assign or0_out = and0_out | and1_out;\n assign out = or0_out;\n assign out_n = ~or0_out;\n\nendmodule\n", "result": "passed", "passed": true}
|
data/example/example_problem.jsonl
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"task_id": "wire_decl", "prompt": "module top_module (\n\tinput a,\n\tinput b,\n\tinput c,\n\tinput d,\n\toutput out,\n\toutput out_n );\n", "canonical_solution": "\t\n\twire w1, w2;\n\tassign w1 = a&b;\n\tassign w2 = c&d;\n\tassign out = w1|w2;\n\tassign out_n = ~out;\n\t\nendmodule\n", "test": "`timescale 1 ps/1 ps\n`define OK 12\n`define INCORRECT 13\n// hdlbits_prop {len: 5}\nmodule reference_module (\n\tinput a,\n\tinput b,\n\tinput c,\n\tinput d,\n\toutput out,\n\toutput out_n );\n\t\n\twire w1, w2;\n\tassign w1 = a&b;\n\tassign w2 = c&d;\n\tassign out = w1|w2;\n\tassign out_n = ~out;\n\t\nendmodule\n\n\nmodule stimulus_gen (\n\tinput clk,\n\toutput reg a,b,c,d,\n\toutput reg[511:0] wavedrom_title,\n\toutput reg wavedrom_enable\n);\n\n\n// Add two ports to module stimulus_gen:\n// output [511:0] wavedrom_title\n// output reg wavedrom_enable\n\n\ttask wavedrom_start(input[511:0] title = \"\");\n\tendtask\n\t\n\ttask wavedrom_stop;\n\t\t#1;\n\tendtask\t\n\n\n\n\tinitial begin\n\t\t{a,b,c,d} = 4'h0;\n\t\t@(negedge clk);\n\t\twavedrom_start(\"Exhaustive test\");\n\t\trepeat(20) @(posedge clk, negedge clk)\n\t\t\t{d,c,b,a} <= {d,c,b,a} + 1'b1;\n\t\twavedrom_stop();\n\t\trepeat(100) @(posedge clk, negedge clk) begin\n\t\t\t{a,b,c,d} <= $random;\n\t\tend\n\t\t\n\t\t#1 $finish;\n\tend\n\t\nendmodule\n\nmodule tb();\n\n\ttypedef struct packed {\n\t\tint errors;\n\t\tint errortime;\n\t\tint errors_out;\n\t\tint errortime_out;\n\t\tint errors_out_n;\n\t\tint errortime_out_n;\n\n\t\tint clocks;\n\t} stats;\n\t\n\tstats stats1;\n\t\n\t\n\twire[511:0] wavedrom_title;\n\twire wavedrom_enable;\n\tint wavedrom_hide_after_time;\n\t\n\treg clk=0;\n\tinitial forever\n\t\t#5 clk = ~clk;\n\n\tlogic a;\n\tlogic b;\n\tlogic c;\n\tlogic d;\n\tlogic out_ref;\n\tlogic out_dut;\n\tlogic out_n_ref;\n\tlogic out_n_dut;\n\n\tinitial begin \n\t\t$dumpfile(\"wave.vcd\");\n\t\t$dumpvars(1, stim1.clk, tb_mismatch ,a,b,c,d,out_ref,out_dut,out_n_ref,out_n_dut );\n\tend\n\n\n\twire tb_match;\t\t// Verification\n\twire tb_mismatch = ~tb_match;\n\t\n\tstimulus_gen stim1 (\n\t\t.clk,\n\t\t.* ,\n\t\t.a,\n\t\t.b,\n\t\t.c,\n\t\t.d );\n\treference_module good1 (\n\t\t.a,\n\t\t.b,\n\t\t.c,\n\t\t.d,\n\t\t.out(out_ref),\n\t\t.out_n(out_n_ref) );\n\t\t\n\ttop_module top_module1 (\n\t\t.a,\n\t\t.b,\n\t\t.c,\n\t\t.d,\n\t\t.out(out_dut),\n\t\t.out_n(out_n_dut) );\n\n\t\n\tbit strobe = 0;\n\ttask wait_for_end_of_timestep;\n\t\trepeat(5) begin\n\t\t\tstrobe <= !strobe; // Try to delay until the very end of the time step.\n\t\t\t@(strobe);\n\t\tend\n\tendtask\t\n\n\t\n\tfinal begin\n\t\tif (stats1.errors_out) $display(\"Hint: Output '%s' has %0d mismatches. First mismatch occurred at time %0d.\", \"out\", stats1.errors_out, stats1.errortime_out);\n\t\telse $display(\"Hint: Output '%s' has no mismatches.\", \"out\");\n\t\tif (stats1.errors_out_n) $display(\"Hint: Output '%s' has %0d mismatches. First mismatch occurred at time %0d.\", \"out_n\", stats1.errors_out_n, stats1.errortime_out_n);\n\t\telse $display(\"Hint: Output '%s' has no mismatches.\", \"out_n\");\n\n\t\t$display(\"Hint: Total mismatched samples is %1d out of %1d samples\\n\", stats1.errors, stats1.clocks);\n\t\t$display(\"Simulation finished at %0d ps\", $time);\n\t\t$display(\"Mismatches: %1d in %1d samples\", stats1.errors, stats1.clocks);\n\tend\n\t\n\t// Verification: XORs on the right makes any X in good_vector match anything, but X in dut_vector will only match X.\n\tassign tb_match = ( { out_ref, out_n_ref } === ( { out_ref, out_n_ref } ^ { out_dut, out_n_dut } ^ { out_ref, out_n_ref } ) );\n\t// Use explicit sensitivity list here. @(*) causes NetProc::nex_input() to be called when trying to compute\n\t// the sensitivity list of the @(strobe) process, which isn't implemented.\n\talways @(posedge clk, negedge clk) begin\n\n\t\tstats1.clocks++;\n\t\tif (!tb_match) begin\n\t\t\tif (stats1.errors == 0) stats1.errortime = $time;\n\t\t\tstats1.errors++;\n\t\tend\n\t\tif (out_ref !== ( out_ref ^ out_dut ^ out_ref ))\n\t\tbegin if (stats1.errors_out == 0) stats1.errortime_out = $time;\n\t\t\tstats1.errors_out = stats1.errors_out+1'b1; end\n\t\tif (out_n_ref !== ( out_n_ref ^ out_n_dut ^ out_n_ref ))\n\t\tbegin if (stats1.errors_out_n == 0) stats1.errortime_out_n = $time;\n\t\t\tstats1.errors_out_n = stats1.errors_out_n+1'b1; end\n\n\tend\nendmodule\n"}
|
evaluation/fixeval.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
evaluation/fixeval_problems.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
figures/evaluation.png
ADDED