h4rz3rk4s3's picture
Update README.md
0c9aa19 verified
|
raw
history blame
2.14 kB
metadata
license: apache-2.0
tags:
  - TinyLlama
  - QLoRA
  - Politics
  - EU
  - sft
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0

TinyParlaMintLlama-1.1B

TinyParlaMintLlama-1.1B is a SFT fine-tune of TinyLlama/TinyLlama-1.1B-Chat-v1.0 using a sample of a concentrated version of the English [ParlaMint] (https://www.clarin.si/repository/xmlui/handle/11356/1864) Dataset using QLoRA. The model was fine-tuned for ~12h on one A100 40GB on ~100M tokens.

The goal of this project is to study the potential for improving the domain-specific (in this case political) knowledge of small (<3B) LLMs by concentrating the training datasets TF-IDF in respect to the underlying Topics found in the origianl Dataset.

The used training data contains speeches from the Austrian, Danish, French, British, Hungarian, Dutch, Norwegian, Polish, Swedish and Turkish Parliament. The concentrated ParlaMint Dataset as well as more information about the used sample will soon be added.

💻 Usage

!pip install -qU transformers accelerate
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import Accelerator
import transformers
import torch
model = "h4rz3rk4s3/TinyParlaMintLlama-1.1B"
messages = [
    {
        "role": "system",
        "content": "You are a professional writer of political speeches.",
    },
    {"role": "user", "content": "Write a short speech on Brexit and it's impact on the European Union."},
]

tokenizer = AutoTokenizer.from_pretrained(model)
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model = AutoModelForCausalLM.from_pretrained(
    model, trust_remote_code=True, device_map={"": Accelerator().process_index}
)

pipeline = transformers.pipeline(
    "text-generation",
    tokenizer=tokenizer,
    model=model,
    torch_dtype=torch.float16,
    device_map={"": Accelerator().process_index},
)
outputs = pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
print(outputs[0]["generated_text"])