Tool Use
Collection
6 items
•
Updated
This model is a fine-tuned version of Qwen/Qwen2-1.5B-Instruct on devanshamin/gem-viggo-function-calling dataset.
Note: The template supports multiple tools but the model is fine-tuned on a dataset consisting of examples with a single tool.
chat_template = (
"{% for message in messages %}"
"{% if loop.first and messages[0]['role'] != 'system' %}"
"{% if tools %}"
"<|im_start|>system\nYou are a helpful assistant with access to the following tools. Use them if required - \n"
"```json\n{{ tools | tojson }}\n```<|im_end|>\n"
"{% else %}"
"<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n"
"{% endif %}"
"{% endif %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\n' }}"
"{% endif %}"
)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "Qwen2-1.5B-Instruct-Function-Calling-v1"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
def inference(prompt: str) -> str:
model_inputs = tokenizer([prompt], return_tensors="pt").to('cuda')
generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
messages = [{"role": "user", "content": "What is the speed of light?"}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
response = inference(prompt)
print(response)
import json
from typing import List, Dict
def get_prompt(user_input: str, tools: List[Dict] | None = None):
prompt = 'Extract the information from the following - \n{}'.format(user_input)
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
tools=tools
)
return prompt
tool = {
"type": "function",
"function": {
"name": "get_company_info",
"description": "Correctly extracted company information with all the required parameters with correct types",
"parameters": {
"properties": {
"name": {"title": "Name", "type": "string"},
"investors": {
"items": {"type": "string"},
"title": "Investors",
"type": "array"
},
"valuation": {"title": "Valuation", "type": "string"},
"source": {"title": "Source", "type": "string"}
},
"required": ["investors", "name", "source", "valuation"],
"type": "object"
}
}
}
input_text = "Founded in 2021, Pluto raised $4 million across multiple seed funding rounds, valuing the company at $12 million (pre-money), according to PitchBook. The startup was backed by investors including Switch Ventures, Caffeinated Capital and Maxime Seguineau."
prompt = get_prompt(input_text, tools=[tool])
response = inference(prompt)
print(response)
# ```json
# {
# "name": "get_company_info",
# "arguments": {
# "name": "Pluto",
# "investors": [
# "Switch Ventures",
# "Caffeinated Capital",
# "Maxime Seguineau"
# ],
# "valuation": "$12 million",
# "source": "PitchBook"
# }
# }
# ```
import re
from enum import Enum
from pydantic import BaseModel, Field # pip install pydantic
from instructor.function_calls import openai_schema # pip install instructor
# Define functions using pydantic classes
class PaperCategory(str, Enum):
TYPE_1_DIABETES = 'Type 1 Diabetes'
TYPE_2_DIABETES = 'Type 2 Diabetes'
class Classification(BaseModel):
label: PaperCategory = Field(..., description='Provide the most likely category')
reason: str = Field(..., description='Give a detailed explanation with quotes from the abstract explaining why the paper is related to the chosen label.')
function_definition = openai_schema(Classification).openai_schema
tool = dict(type='function', function=function_definition)
input_text = "1,25-dihydroxyvitamin D(3) (1,25(OH)(2)D(3)), the biologically active form of vitamin D, is widely recognized as a modulator of the immune system as well as a regulator of mineral metabolism. The objective of this study was to determine the effects of vitamin D status and treatment with 1,25(OH)(2)D(3) on diabetes onset in non-obese diabetic (NOD) mice, a murine model of human type I diabetes. We have found that vitamin D-deficiency increases the incidence of diabetes in female mice from 46% (n=13) to 88% (n=8) and from 0% (n=10) to 44% (n=9) in male mice as of 200 days of age when compared to vitamin D-sufficient animals. Addition of 50 ng of 1,25(OH)(2)D(3)/day to the diet prevented disease onset as of 200 days and caused a significant rise in serum calcium levels, regardless of gender or vitamin D status. Our results indicate that vitamin D status is a determining factor of disease susceptibility and oral administration of 1,25(OH)(2)D(3) prevents diabetes onset in NOD mice through 200 days of age."
prompt = get_prompt(input_text, tools=[tool])
output = inference(prompt)
print(output)
# ```json
# {
# "name": "Classification",
# "arguments": {
# "label": "Type 1 Diabetes",
# "reason": "The study investigated the effect of vitamin D status and treatment with 1,25(OH)(2)D(3) on diabetes onset in non-obese diabetic (NOD) mice. It also concluded that vitamin D deficiency leads to an increase in diabetes incidence and that the addition of 1,25(OH)(2)D(3) can prevent diabetes onset in NOD mice."
# }
# }
# ```
# Extract JSON string using regex
output = re.search(r'```json\s*(\{.*?\})\s*```', output).group(1)
output = Classification(**json.loads(_output)['arguments'])
print(output)
# Classification(label=<PaperCategory.TYPE_1_DIABETES: 'Type 1 Diabetes'>, reason='The study investigated the effect of vitamin D status and treatment with 1,25(OH)(2)D(3) on diabetes onset in non-obese diabetic (NOD) mice. It also concluded that vitamin D deficiency leads to an increase in diabetes incidence and that the addition of 1,25(OH)(2)D(3) can prevent diabetes onset in NOD mice.')
The following hyperparameters were used during training:
Training Loss | Epoch | Step | Validation Loss |
---|---|---|---|
0.4004 | 0.0101 | 20 | 0.4852 |
0.3624 | 0.0201 | 40 | 0.3221 |
0.2855 | 0.0302 | 60 | 0.2818 |
0.2652 | 0.0402 | 80 | 0.2592 |
0.2214 | 0.0503 | 100 | 0.2463 |
0.2471 | 0.0603 | 120 | 0.2358 |
0.2122 | 0.0704 | 140 | 0.2310 |
0.2048 | 0.0804 | 160 | 0.2275 |
0.2406 | 0.0905 | 180 | 0.2251 |
0.2445 | 0.1006 | 200 | 0.2248 |
peft==0.11.1
transformers==4.42.3
torch==2.3.1+cu121
datasets==2.20.0
tokenizers==0.19.1
Base model
Qwen/Qwen2-1.5B-Instruct