Text Generation
English
sft
jordiclive's picture
Update README.md
3ae2899
|
raw
history blame
4.65 kB
metadata
license: mit
datasets:
  - Nebulous/gpt4all_pruned
  - sahil2801/CodeAlpaca-20k
  - yahma/alpaca-cleaned
language:
  - en
tags:
  - sft
pipeline_tag: text-generation
widget:
  - text: >-
      <|prompter|>What is a meme, and what's the history behind this
      word?</s><|assistant|>
  - text: <|prompter|>What's the Earth total population</s><|assistant|>
  - text: <|prompter|>Write a story about future of AI development</s><|assistant|>

This repo contains a low-rank adapter for LLaMA-7b fit on

  • Nebulous/gpt4all_pruned
  • sahil2801/CodeAlpaca-20k
  • yahma/alpaca-cleaned
  • datasets part of the OpenAssistant project.

This version of the weights was trained with the following hyperparameters:

  • Epochs: 2
  • Batch size: 128
  • Max Length: 2048
  • Learning rate: 4e-6
  • Lora r: 8
  • Lora Alpha: 32
  • Lora target modules: q_proj, k_proj, v_proj, o_proj

The model was trained with flash attention and gradient checkpointing.


license: apache-2.0

Open-Assistant SFT-1 12B Model

This is the first iteration English supervised-fine-tuning (SFT) model of the Open-Assistant project. It is based on a Pythia 12B that was fine-tuned on ~22k human demonstrations of assistant conversations collected through the https://open-assistant.io/ human feedback web app before March 7, 2023.

Model Details

  • Developed as part of the OpenAssistant Project
  • Model type: Transformer-based Language Model
  • Language: English

Prompting

Two special tokens are used to mark the beginning of user and assistant turns: <|prompter|> and <|assistant|>. Each turn ends with a <|endoftext|> token.

Input prompt example:

<|prompter|>What is a meme, and what's the history behind this word?</s><|assistant|>

The input ends with the <|assistant|> token to signal that the model should start generating the assistant reply.

Example Code (Note several embeddings need to be loaded along with the LoRA weights):

from typing import List, NamedTuple

import torch
import transformers
from huggingface_hub import hf_hub_download
from peft import PeftModel
from transformers import GenerationConfig

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = transformers.AutoTokenizer.from_pretrained("jordiclive/gpt4all-alpaca-oa-codealpaca-lora-7b")


model = transformers.AutoModelForCausalLM.from_pretrained(
    "decapoda-research/llama-7b-hf", torch_dtype=torch.float16
)  # Load Base Model
model.resize_token_embeddings(
    32016
)  # This model repo also contains several embeddings for special tokens that need to be loaded.

model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

lora_weights = "jordiclive/gpt4all-alpaca-oa-codealpaca-lora-7b"
model = PeftModel.from_pretrained(
    model,
    lora_weights,
    torch_dtype=torch.float16,
)  # Load Lora model

model.eos_token_id = tokenizer.eos_token_id
filename = hf_hub_download("jordiclive/gpt4all-alpaca-oa-codealpaca-lora-7b", "extra_embeddings.pt")
embed_weights = torch.load(
    filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)  # Load embeddings for special tokens
model.base_model.model.model.embed_tokens.weight[32000:, :] = embed_weights.to(
    model.base_model.model.model.embed_tokens.weight.dtype
).to(
    device
)  # Add special token embeddings


model = model.half().to(device)
generation_config = GenerationConfig(
    temperature=0.1,
    top_p=0.75,
    top_k=40,
    num_beams=4,
)


def format_system_prompt(prompt, eos_token="</s>"):
    return "{}{}{}".format(
        "<|prompter|>",
        prompt,
        eos_token,
    )


def generate(prompt, generation_config=generation_config, max_new_tokens=2048, device=device):
    prompt = format_system_prompt(prompt)  # OpenAssistant Prompt Format expected
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        generation_output = model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=max_new_tokens,
            eos_token_id=2,
        )
    s = generation_output.sequences[0]
    output = tokenizer.decode(s)
    print("Text generated:")
    print(output)
    return output


generate("What is a meme, and what's the history behind this word?")
generate("What's the Earth total population")
generate("Write a story about future of AI development")