|
--- |
|
library_name: transformers |
|
datasets: |
|
- kaist-ai/CoT-Collection |
|
--- |
|
|
|
# Model Card for b1ade-1b |
|
|
|
|
|
Instruction fine tuned 1B parameter model; pass in: |
|
|
|
1. `context: <...>` |
|
2. `question: <...>` |
|
|
|
and expect an `answer: <...>` |
|
|
|
See implemetation example below (also see https://huggingface.co/spaces/w601sxs/b1ade-1b): |
|
|
|
``` |
|
import torch |
|
import transformers |
|
import os, time |
|
import tempfile |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
BASE_MODEL = "w601sxs/b1ade-1b-bf16" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
offload_folder="offload") |
|
|
|
|
|
model.eval() |
|
|
|
from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList |
|
|
|
class KeywordsStoppingCriteria(StoppingCriteria): |
|
def __init__(self, keywords_ids:list): |
|
self.keywords = keywords_ids |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
if input_ids[0][-1] in self.keywords: |
|
return True |
|
return False |
|
|
|
|
|
stop_words = ['>', ' >','> '] |
|
stop_ids = [tokenizer.encode(w)[0] for w in stop_words] |
|
stop_criteria = StoppingCriteriaList([KeywordsStoppingCriteria(keywords_ids = stop_ids)]) |
|
|
|
def predict(text): |
|
inputs = tokenizer(text, return_tensors="pt").to('cuda') |
|
with torch.no_grad(): |
|
outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=128, stopping_criteria=stop_criteria) |
|
out_text = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0].split("answer:")[-1] |
|
|
|
return print(out_text.split(text)[-1]) |
|
|
|
|
|
|
|
predict("context: <The center contact of the bulb typically connects to the medium-power filament, and the ring connects to the low-power filament. Thus, if a 3-way bulb is screwed into a standard light socket that has only a center contact, only the medium-power filament operates. In the case of the 50 W / 100 W / 150 W bulb, putting this bulb in a regular lamp socket will result in it behaving like a normal 100W bulb.>\n question: <Question: Do 3 way light bulbs work in any lamp?>\n") |
|
``` |