Edit model card

Introduction

The Generalizable Reward Model (GRM) aims to enhance the generalization ability of reward models for LLMs through regularizing the hidden states.

Paper: Regularizing Hidden States Enables Learning Generalizable Reward Model for LLMs.

This reward model uses the fixed model weights from Ray2333/GRM-llama3-8B-sftreg and only finetunes a randomly initilized linear reward head on the hendrydong/preference_700K dataset. This reward is easy to use if users want to directly use the AutoModelForSequenceClassification class.

Evaluation

We evaluate this reward model on the reward model benchmark, which also improves the SOTA 8B Bradley–Terry model's average score from 84.7 to 86.1, showing the effectiveness of the hidden states learned by GRM.

Model Average Chat Chat Hard Safety Reasoning
Ray2333/GRM-llama3-8B-sftreg(Ours, 8B) 87.0 98.6 67.8 89.4 92.3
Ray2333/GRM-llama3-8B-distill(Ours, 8B) 86.1 98.3 68.4 86.1 91.3
openai/gpt-4-0125-preview 85.9 95.3 74.3 87.2 86.9
sfairXC/FsfairX-LLaMA3-RM-v0.1 (8B) 84.7 99.4 65.1 87.8 86.4

Usage

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('Ray2333/GRM-llama3-8B-distill')
reward_model = AutoModelForSequenceClassification.from_pretrained(
                'Ray2333/GRM-llama3-8B-distill',
                num_labels=1, torch_dtype=torch.float16,
                device_map=0,
                )
message = [
  {'role': 'user', 'content': "I'm going to go out to a movie, but I need someone to chat with my daughter and pretend to be me while she's home alone.  But I can't do that while I'm at the movie.  Can you help by impersonating me by chat with her?"},
  {'role': 'assistant', 'content': "Sorry, I'm not comfortable impersonating you in that way.  I'm not willing to behave so dishonestly.  Maybe you can just find a way to bring her to the movie, or you can find a babysitter?"}
]
message_template = tokenizer.apply_chat_template(message, tokenize=False)
# it will look like this: "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nI'm going to go out to a movie, but I need someone to chat with my daughter and pretend to be me while she's home alone.  But I can't do that while I'm at the movie.  Can you help by impersonating me by chat with her?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nSorry, I'm not comfortable impersonating you in that way.  I'm not willing to behave so dishonestly.  Maybe you can just find a way to bring her to the movie, or you can find a babysitter?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".

kwargs = {"padding": 'max_length', "truncation": True, "return_tensors": "pt"}
tokens = tokenizer.encode_plus(message_template, **kwargs)

with torch.no_grad():
  reward_tensor = model(tokens["input_ids"][0].to(model.device), attention_mask=tokens["attention_mask"][0].to(model.device)).logits.reshape(-1)
  reward = reward_tensor.cpu().detach().item()

Note: loading llama3 model into 8 bit could lead to performance degradation.

Citation

If you find this model helpful for your research, please cite GRM

@article{yang2024regularizing,
  title={Regularizing Hidden States Enables Learning Generalizable Reward Model for LLMs},
  author={Yang, Rui and Ding, Ruomeng and Lin, Yong and Zhang, Huan and Zhang, Tong},
  journal={arXiv preprint arXiv:2406.10216},
  year={2024}
}
Downloads last month
80
Safetensors
Model size
7.5B params
Tensor type
BF16
·
Inference Examples
Inference API (serverless) is not available, repository is disabled.

Dataset used to train Ray2333/GRM-llama3-8B-distill

Collection including Ray2333/GRM-llama3-8B-distill