|
--- |
|
base_model: Qwen/Qwen2-1.5B-Instruct |
|
datasets: |
|
- devanshamin/gem-viggo-function-calling |
|
library_name: peft |
|
license: apache-2.0 |
|
pipeline_tag: text-generation |
|
tags: |
|
- trl |
|
- sft |
|
- generated_from_trainer |
|
model-index: |
|
- name: Qwen2-1.5B-Instruct-Function-Calling-v1 |
|
results: [] |
|
--- |
|
|
|
<!-- This model card has been generated automatically according to the information the Trainer had access to. You |
|
should probably proofread and complete it, then remove this comment. --> |
|
|
|
# Qwen2-1.5B-Instruct-Function-Calling-v1 |
|
|
|
This model is a fine-tuned version of [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) on [devanshamin/gem-viggo-function-calling](https://huggingface.co/datasets/devanshamin/gem-viggo-function-calling) dataset. |
|
|
|
## Updated Chat Template |
|
> Note: The template supports multiple tools but the model is fine-tuned on a dataset consisting of examples with a single tool. |
|
|
|
- The chat template has been added to the [tokenizer_config.json](https://huggingface.co/devanshamin/Qwen2-1.5B-Instruct-Function-Calling-v1/blob/7ee7c020cefdb0101939469de608acc2afa7809e/tokenizer_config.json#L34). |
|
- Supports prompts with and without tools. |
|
|
|
```python |
|
chat_template = ( |
|
"{% for message in messages %}" |
|
"{% if loop.first and messages[0]['role'] != 'system' %}" |
|
"{% if tools %}" |
|
"<|im_start|>system\nYou are a helpful assistant with access to the following tools. Use them if required - \n" |
|
"```json\n{{ tools | tojson }}\n```<|im_end|>\n" |
|
"{% else %}" |
|
"<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n" |
|
"{% endif %}" |
|
"{% endif %}" |
|
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}" |
|
"{% endfor %}" |
|
"{% if add_generation_prompt %}" |
|
"{{ '<|im_start|>assistant\n' }}" |
|
"{% endif %}" |
|
) |
|
``` |
|
|
|
## Basic Usage |
|
|
|
```python |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
model_id = "Qwen2-1.5B-Instruct-Function-Calling-v1" |
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="auto") |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
def inference(prompt: str) -> str: |
|
model_inputs = tokenizer([prompt], return_tensors="pt").to('cuda') |
|
generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512) |
|
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)] |
|
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return response |
|
|
|
messages = [{"role": "user", "content": "What is the speed of light?"}] |
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
response = inference(prompt) |
|
print(response) |
|
``` |
|
|
|
## Tool Usage |
|
|
|
### Basic |
|
|
|
```python |
|
import json |
|
from typing import List, Dict |
|
|
|
def get_prompt(user_input: str, tools: List[Dict] | None = None): |
|
prompt = 'Extract the information from the following - \n{}'.format(user_input) |
|
messages = [{"role": "user", "content": prompt}] |
|
prompt = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True, |
|
tools=tools |
|
) |
|
return prompt |
|
|
|
tool = { |
|
"type": "function", |
|
"function": { |
|
"name": "get_company_info", |
|
"description": "Correctly extracted company information with all the required parameters with correct types", |
|
"parameters": { |
|
"properties": { |
|
"name": {"title": "Name", "type": "string"}, |
|
"investors": { |
|
"items": {"type": "string"}, |
|
"title": "Investors", |
|
"type": "array" |
|
}, |
|
"valuation": {"title": "Valuation", "type": "string"}, |
|
"source": {"title": "Source", "type": "string"} |
|
}, |
|
"required": ["investors", "name", "source", "valuation"], |
|
"type": "object" |
|
} |
|
} |
|
} |
|
input_text = "Founded in 2021, Pluto raised $4 million across multiple seed funding rounds, valuing the company at $12 million (pre-money), according to PitchBook. The startup was backed by investors including Switch Ventures, Caffeinated Capital and Maxime Seguineau." |
|
prompt = get_prompt(input_text, tools=[tool]) |
|
response = inference(prompt) |
|
print(response) |
|
# ```json |
|
# { |
|
# "name": "get_company_info", |
|
# "arguments": { |
|
# "name": "Pluto", |
|
# "investors": [ |
|
# "Switch Ventures", |
|
# "Caffeinated Capital", |
|
# "Maxime Seguineau" |
|
# ], |
|
# "valuation": "$12 million", |
|
# "source": "PitchBook" |
|
# } |
|
# } |
|
# ``` |
|
``` |
|
|
|
### Advanced |
|
```python |
|
import re |
|
from enum import Enum |
|
|
|
from pydantic import BaseModel, Field # pip install pydantic |
|
from instructor.function_calls import openai_schema # pip install instructor |
|
|
|
# Define functions using pydantic classes |
|
class PaperCategory(str, Enum): |
|
TYPE_1_DIABETES = 'Type 1 Diabetes' |
|
TYPE_2_DIABETES = 'Type 2 Diabetes' |
|
|
|
class Classification(BaseModel): |
|
label: PaperCategory = Field(..., description='Provide the most likely category') |
|
reason: str = Field(..., description='Give a detailed explanation with quotes from the abstract explaining why the paper is related to the chosen label.') |
|
|
|
function_definition = openai_schema(Classification).openai_schema |
|
tool = dict(type='function', function=function_definition) |
|
input_text = "1,25-dihydroxyvitamin D(3) (1,25(OH)(2)D(3)), the biologically active form of vitamin D, is widely recognized as a modulator of the immune system as well as a regulator of mineral metabolism. The objective of this study was to determine the effects of vitamin D status and treatment with 1,25(OH)(2)D(3) on diabetes onset in non-obese diabetic (NOD) mice, a murine model of human type I diabetes. We have found that vitamin D-deficiency increases the incidence of diabetes in female mice from 46% (n=13) to 88% (n=8) and from 0% (n=10) to 44% (n=9) in male mice as of 200 days of age when compared to vitamin D-sufficient animals. Addition of 50 ng of 1,25(OH)(2)D(3)/day to the diet prevented disease onset as of 200 days and caused a significant rise in serum calcium levels, regardless of gender or vitamin D status. Our results indicate that vitamin D status is a determining factor of disease susceptibility and oral administration of 1,25(OH)(2)D(3) prevents diabetes onset in NOD mice through 200 days of age." |
|
prompt = get_prompt(input_text, tools=[tool]) |
|
output = inference(prompt) |
|
print(output) |
|
# ```json |
|
# { |
|
# "name": "Classification", |
|
# "arguments": { |
|
# "label": "Type 1 Diabetes", |
|
# "reason": "The study investigated the effect of vitamin D status and treatment with 1,25(OH)(2)D(3) on diabetes onset in non-obese diabetic (NOD) mice. It also concluded that vitamin D deficiency leads to an increase in diabetes incidence and that the addition of 1,25(OH)(2)D(3) can prevent diabetes onset in NOD mice." |
|
# } |
|
# } |
|
# ``` |
|
# Extract JSON string using regex |
|
output = re.search(r'```json\s*(\{.*?\})\s*```', output).group(1) |
|
output = Classification(**json.loads(_output)['arguments']) |
|
print(output) |
|
# Classification(label=<PaperCategory.TYPE_1_DIABETES: 'Type 1 Diabetes'>, reason='The study investigated the effect of vitamin D status and treatment with 1,25(OH)(2)D(3) on diabetes onset in non-obese diabetic (NOD) mice. It also concluded that vitamin D deficiency leads to an increase in diabetes incidence and that the addition of 1,25(OH)(2)D(3) can prevent diabetes onset in NOD mice.') |
|
``` |
|
|
|
## Training procedure |
|
|
|
### Training hyperparameters |
|
|
|
The following hyperparameters were used during training: |
|
- learning_rate: 0.0001 |
|
- train_batch_size: 4 |
|
- eval_batch_size: 4 |
|
- seed: 42 |
|
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08 |
|
- lr_scheduler_type: cosine |
|
- lr_scheduler_warmup_steps: 10 |
|
- training_steps: 200 |
|
|
|
### Training results |
|
|
|
| Training Loss | Epoch | Step | Validation Loss | |
|
|:-------------:|:------:|:----:|:---------------:| |
|
| 0.4004 | 0.0101 | 20 | 0.4852 | |
|
| 0.3624 | 0.0201 | 40 | 0.3221 | |
|
| 0.2855 | 0.0302 | 60 | 0.2818 | |
|
| 0.2652 | 0.0402 | 80 | 0.2592 | |
|
| 0.2214 | 0.0503 | 100 | 0.2463 | |
|
| 0.2471 | 0.0603 | 120 | 0.2358 | |
|
| 0.2122 | 0.0704 | 140 | 0.2310 | |
|
| 0.2048 | 0.0804 | 160 | 0.2275 | |
|
| 0.2406 | 0.0905 | 180 | 0.2251 | |
|
| 0.2445 | 0.1006 | 200 | 0.2248 | |
|
|
|
|
|
### Framework versions |
|
|
|
```text |
|
peft==0.11.1 |
|
transformers==4.42.3 |
|
torch==2.3.1+cu121 |
|
datasets==2.20.0 |
|
tokenizers==0.19.1 |
|
``` |