Spaces:
Running
Running
from pathlib import Path | |
from typing import Any, Dict, Hashable | |
import streamlit as st | |
import streamlit.components.v1 as components | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding, GPT2LMHeadModel, PreTrainedTokenizer | |
from transformers import LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper | |
from context_probing import estimate_unigram_logprobs | |
from context_probing.core import nll_score, kl_div_score | |
from context_probing.utils import columns_to_diagonals, get_windows, ids_to_readable_tokens | |
root_dir = Path(__file__).resolve().parent | |
highlighted_text_component = components.declare_component( | |
"highlighted_text", path=root_dir / "highlighted_text" / "build" | |
) | |
compact_layout = st.experimental_get_query_params().get("compact", ["false"]) == ["true"] | |
if not compact_layout: | |
st.title("Context length probing") | |
st.markdown( | |
"""[📃 Paper](https://arxiv.org/abs/2212.14815) | | |
[🌍 Website](https://cifkao.github.io/context-probing) | | |
[🧑💻 Code](https://github.com/cifkao/context-probing) | |
""" | |
) | |
generation_mode = st.radio( | |
"Mode", ["Basic mode", "Generation mode"], | |
horizontal=True, label_visibility="collapsed" | |
) == "Generation mode" | |
st.caption( | |
"In basic mode, we analyze the model's one-step-ahead predictions on the input text. " | |
"In generation mode, we generate a continuation of the input text (prompt) " | |
"and analyze the model's predictions influencing the generated tokens." | |
) | |
model_name = st.selectbox( | |
"Model", | |
[ | |
"distilgpt2", | |
"gpt2", | |
"EleutherAI/gpt-neo-125m", | |
"roneneldan/TinyStories-8M", | |
"roneneldan/TinyStories-33M", | |
] | |
) | |
metric_name = st.radio( | |
"Metric", | |
(["KL divergence"] if not generation_mode else []) + ["NLL loss"], | |
index=0, | |
horizontal=True, | |
help="**KL divergence** is computed between the predictions with the reduced context " | |
"(corresponding to the highlighted token) and the predictions with the full context " | |
"($c_\\text{max}$ tokens). \n" | |
"**NLL loss** is the negative log-likelihood loss (a.k.a. cross entropy) for the target " | |
"token." | |
) | |
tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name, use_fast=False) | |
# Make sure the logprobs do not use up more than ~4 GB of memory | |
MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8) | |
# Select window lengths such that we are allowed to fill the whole window without running out of memory | |
# (otherwise the window length is irrelevant); if using NLL, memory is not a consideration, but we want | |
# to limit runtime | |
multiplier = tokenizer.vocab_size if metric_name == "KL divergence" else 16384 # arbitrary number | |
window_len_options = [ | |
w for w in [8, 16, 32, 64, 128, 256, 512, 1024] | |
if w == 8 or w * (2 * w) * multiplier <= MAX_MEM | |
] | |
window_len = st.select_slider( | |
r"Window size ($c_\text{max}$)", | |
options=window_len_options, | |
value=min(128, window_len_options[-1]), | |
help="The maximum context length $c_\\text{max}$ for which we compute the scores. Smaller " | |
"windows are less computationally intensive, allowing for longer inputs." | |
) | |
# Now figure out how many tokens we are allowed to use: | |
# window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM | |
max_tokens = int(MAX_MEM / (multiplier * window_len) - window_len) | |
max_tokens = min(max_tokens, 4096) | |
enable_null_context = st.checkbox( | |
"Enable length-1 context", | |
value=True, | |
help="This enables computing scores for context length 1 (i.e. the previous token), which " | |
"involves using an estimate of the model's unigram distribution. This is not originally " | |
"proposed in the paper." | |
) | |
generate_kwargs = {} | |
with st.empty(): | |
if generation_mode: | |
with st.expander("Generation options", expanded=False): | |
generate_kwargs["max_new_tokens"] = st.slider( | |
"Max. number of generated tokens", | |
min_value=8, max_value=min(1024, max_tokens), step=8, value=min(128, max_tokens) | |
) | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
generate_kwargs["temperature"] = st.number_input( | |
min_value=0.01, value=0.9, step=0.05, label="`temperature`" | |
) | |
with col2: | |
generate_kwargs["top_p"] = st.number_input( | |
min_value=0., value=0.95, max_value=1., step=0.05, label="`top_p`" | |
) | |
with col3: | |
generate_kwargs["typical_p"] = st.number_input( | |
min_value=0., value=1., max_value=1., step=0.05, label="`typical_p`" | |
) | |
with col4: | |
generate_kwargs["repetition_penalty"] = st.number_input( | |
min_value=1., value=1., step=0.05, label="`repetition_penalty`" | |
) | |
DEFAULT_TEXT = """ | |
We present context length probing, a novel explanation technique for causal | |
language models, based on tracking the predictions of a model as a function of the length of | |
available context, and allowing to assign differential importance scores to different contexts. | |
The technique is model-agnostic and does not rely on access to model internals beyond computing | |
token-level probabilities. We apply context length probing to large pre-trained language models | |
and offer some initial analyses and insights, including the potential for studying long-range | |
dependencies. | |
""".replace("\n", " ").strip() | |
with st.expander( | |
f"Prompt" if generation_mode else f"Input text (≤\u2009{max_tokens} tokens)", expanded=True | |
): | |
text = st.text_area( | |
"Input text", | |
st.session_state.get("input_text", DEFAULT_TEXT), | |
key="input_text", label_visibility="collapsed" | |
) | |
inputs = tokenizer([text]) | |
[input_ids] = inputs["input_ids"] | |
label_ids = [*input_ids[1:], tokenizer.eos_token_id] | |
inputs["labels"] = [label_ids] | |
num_user_tokens = len(input_ids) | |
if num_user_tokens < 1: | |
st.error("Please enter at least one token.", icon="🚨") | |
st.stop() | |
if not generation_mode and num_user_tokens > max_tokens: | |
st.error( | |
f"Your input has {num_user_tokens} tokens. Please enter at most {max_tokens} tokens " | |
f"or try reducing the window size.", | |
icon="🚨" | |
) | |
st.stop() | |
with st.spinner("Loading model…"): | |
model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name) | |
def get_unigram_logprobs( | |
_model: GPT2LMHeadModel, | |
_tokenizer: PreTrainedTokenizer, | |
model_name: str | |
): | |
path = Path("data") / "unigram_logprobs" / f'{model_name.replace("/", "_")}.npy' | |
if path.exists(): | |
return torch.as_tensor(np.load(path, allow_pickle=False)) | |
else: | |
return estimate_unigram_logprobs(_model, _tokenizer) | |
if enable_null_context: | |
with st.spinner("Obtaining unigram probabilities…"): | |
unigram_logprobs = get_unigram_logprobs(model, tokenizer, model_name=model_name) | |
else: | |
unigram_logprobs = torch.full((tokenizer.vocab_size,), torch.nan) | |
unigram_logprobs = tuple(unigram_logprobs.tolist()) | |
def get_logprobs(model, inputs, metric): | |
logprobs = [] | |
batch_size = 8 | |
num_items = len(inputs["input_ids"]) | |
pbar = st.progress(0) | |
for i in range(0, num_items, batch_size): | |
pbar.progress(i / num_items, f"{i}/{num_items}") | |
batch = {k: v[i:i + batch_size] for k, v in inputs.items()} | |
batch_logprobs = model(**batch).logits.log_softmax(dim=-1).to(torch.float16) | |
if metric != "KL divergence": | |
batch_logprobs = torch.gather( | |
batch_logprobs, dim=-1, index=batch["labels"][..., None] | |
) | |
logprobs.append(batch_logprobs) | |
logprobs = torch.cat(logprobs, dim=0) | |
pbar.empty() | |
return logprobs | |
def get_logits_processor(temperature, top_p, typical_p, repetition_penalty) -> LogitsProcessorList: | |
processor = LogitsProcessorList() | |
if repetition_penalty != 1.0: | |
processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) | |
if temperature != 1.0: | |
processor.append(TemperatureLogitsWarper(temperature)) | |
if top_p < 1.0: | |
processor.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1)) | |
if typical_p < 1.0: | |
processor.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=1)) | |
return processor | |
def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs): | |
assert metric == "NLL loss" | |
start = max(0, inputs["input_ids"].shape[1] - window_len + 1) | |
input_ids = inputs["input_ids"][:, start:] | |
logits_warper = get_logits_processor(**kwargs) | |
new_ids, logprobs = [], [] | |
eos_idx = None | |
pbar = st.progress(0) | |
max_steps = max_new_tokens + window_len - 1 | |
model_kwargs = dict(use_cache=True) | |
for i in range(max_steps): | |
pbar.progress(i / max_steps, f"{i}/{max_steps}") | |
if input_ids.shape[1] == window_len: | |
model_kwargs.update(use_cache=False) | |
if "past_key_values" in model_kwargs: | |
del model_kwargs["past_key_values"] | |
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
model_outputs = model(**model_inputs) | |
model_kwargs = model._update_model_kwargs_for_generation(model_outputs, model_kwargs, is_encoder_decoder=False) | |
logits_window = model_outputs.logits.squeeze(0) | |
logprobs_window = logits_window.log_softmax(dim=-1) | |
if eos_idx is None: | |
probs_next = logits_warper(input_ids, logits_window[[-1]]).softmax(dim=-1) | |
next_token = torch.multinomial(probs_next, num_samples=1).item() | |
if next_token == tokenizer.eos_token_id or i >= max_new_tokens - 1: | |
eos_idx = i | |
else: | |
next_token = tokenizer.eos_token_id | |
new_ids.append(next_token) | |
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1) | |
if input_ids.shape[1] > window_len: | |
input_ids = input_ids[:, 1:] | |
if logprobs_window.shape[0] == window_len: | |
logprobs.append( | |
logprobs_window[torch.arange(window_len), input_ids.squeeze(0)] | |
) | |
if eos_idx is not None and i - eos_idx >= window_len - 1: | |
break | |
pbar.empty() | |
[input_ids] = input_ids.tolist() | |
new_ids = new_ids[:eos_idx + 1] | |
label_ids = [*input_ids, *new_ids][1:] | |
return torch.as_tensor(new_ids), torch.as_tensor(label_ids), torch.stack(logprobs)[:, :, None] | |
def run_context_length_probing( | |
_model: GPT2LMHeadModel, | |
_tokenizer: PreTrainedTokenizer, | |
_inputs: Dict[str, torch.Tensor], | |
window_len: int, | |
unigram_logprobs: tuple, | |
metric: str, | |
generation_mode: bool, | |
generate_kwargs: Dict[str, Any], | |
cache_key: Hashable | |
): | |
del cache_key | |
[input_ids] = _inputs["input_ids"] | |
[label_ids] = _inputs["labels"] | |
with st.spinner("Running model…"): | |
if generation_mode: | |
new_ids, label_ids, logprobs = generate( | |
model=_model, | |
inputs=_inputs.convert_to_tensors("pt"), | |
metric=metric, | |
window_len=window_len, | |
**generate_kwargs | |
) | |
output_ids = [*input_ids, *new_ids] | |
window_len = logprobs.shape[1] | |
else: | |
window_len = min(window_len, len(input_ids)) | |
inputs_sliding = get_windows( | |
_inputs, | |
window_len=window_len, | |
start=0, | |
pad_id=_tokenizer.eos_token_id | |
).convert_to_tensors("pt") | |
logprobs = get_logprobs(model=_model, inputs=inputs_sliding, metric=metric) | |
output_ids = [*input_ids, label_ids[-1]] | |
num_tgt_tokens = logprobs.shape[0] | |
with st.spinner("Computing scores…"): | |
logprobs = logprobs.transpose(0, 1) | |
logprobs = columns_to_diagonals(logprobs) | |
logprobs = logprobs[:, :num_tgt_tokens] | |
label_ids = label_ids[-num_tgt_tokens:] | |
unigram_logprobs = torch.as_tensor(unigram_logprobs) | |
unigram_logprobs[~torch.isfinite(unigram_logprobs)] = torch.nan | |
if logprobs.shape[-1] == 1: | |
unigram_logprobs = unigram_logprobs[label_ids].unsqueeze(-1) | |
else: | |
unigram_logprobs = unigram_logprobs.unsqueeze(0).repeat(num_tgt_tokens, 1) | |
logprobs = torch.cat([unigram_logprobs.unsqueeze(0), logprobs], dim=0) | |
if metric == "NLL loss": | |
scores = nll_score(logprobs=logprobs, labels=label_ids, allow_overwrite=True) | |
elif metric == "KL divergence": | |
scores = kl_div_score(logprobs, labels=label_ids, allow_overwrite=True) | |
del logprobs # possibly overwritten by the score computation to save memory | |
scores = (-scores).diff(dim=0).transpose(0, 1) | |
scores = scores.nan_to_num() | |
scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-6 | |
scores = scores.to(torch.float16) | |
if generation_mode: | |
scores = F.pad(scores, (0, 0, max(0, len(input_ids) - window_len + 1), 0), value=0.) | |
return output_ids, scores | |
if not generation_mode: | |
run_context_length_probing = st.cache_data(run_context_length_probing, show_spinner=False) | |
if generation_mode: | |
st.button("Rerun", type="primary") | |
output_ids, scores = run_context_length_probing( | |
_model=model, | |
_tokenizer=tokenizer, | |
_inputs=inputs, | |
window_len=window_len, | |
unigram_logprobs=unigram_logprobs, | |
metric=metric_name, | |
generation_mode=generation_mode, | |
generate_kwargs=generate_kwargs, | |
cache_key=(model_name, text), | |
) | |
tokens = ids_to_readable_tokens(tokenizer, output_ids, strip_whitespace=False) | |
st.markdown('<label style="font-size: 14px;">Output</label>', unsafe_allow_html=True) | |
highlighted_text_component( | |
tokens=tokens, | |
scores=scores.tolist(), | |
prefix_len=len(input_ids) if generation_mode else 0 | |
) | |