Antoine Chaffin
Adding link to the ArXiv paper
c06b8a2
raw
history blame
4.31 kB
import torch
import argparse
import os
import numpy as np
from watermark import Watermarker
import time
import gradio as gr
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
parser = argparse.ArgumentParser(description='Generative Text Watermarking demo')
parser.add_argument('--model', '-m', type=str, default="facebook/opt-350m", help='Language model')
# parser.add_argument('--model', '-m', type=str, default="meta-llama/Llama-2-7b-chat-hf", help='Language model')
parser.add_argument('--key', '-k', type=int, default=42,
help='The seed of the pseudo random number generator')
args = parser.parse_args()
USERS = ['Alice', 'Bob', 'Charlie', 'Dan']
EMBED_METHODS = [ 'aaronson', 'kirchenbauer', 'sampling', 'greedy' ]
DETECT_METHODS = [ 'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson', 'kirchenbauer']
PAYLOAD_BITS = 2
def embed(user, max_length, window_size, method, prompt):
uid = USERS.index(user)
watermarker = Watermarker(modelname=args.model,
window_size=window_size, payload_bits=PAYLOAD_BITS)
watermarked_texts = watermarker.embed(key=args.key, messages=[ uid ],
max_length=max_length, method=method, prompt=prompt)
print("watermarked_texts: ", watermarked_texts)
return watermarked_texts[0]
def detect(attacked_text, window_size, method, prompt):
watermarker = Watermarker(modelname=args.model,
window_size=window_size, payload_bits=PAYLOAD_BITS)
pvalues, messages = watermarker.detect([ attacked_text ], key=args.key, method=method, prompts=[prompt])
print("messages: ", messages)
print("p-values: ", pvalues)
user = USERS[messages[0]]
pf = pvalues[0]
label = 'The user detected is {:s} with pvalue of {:.3e}'.format(user, pf)
return label
with gr.Blocks() as demo:
gr.Markdown("""# LLM generation watermarking
This spaces let you to try different watermarking scheme for LLM generation.\n
It leverages the upgrades introduced in the paper **[Three Bricks to Consolidate Watermarks for Large Language Models](https://arxiv.org/abs/2308.00113)**, reducing the gap between empirical and theoretical false positive detection rate and give the ability to embed a message (of n bits). Here we use this capacity to embed the identity of the user generating the text, but it could also be used to identify different version of a model or just convey a secret message.\n
Simply select an user name, set the maximum text length, the watermarking window size and the prompt. Aaronson and Kirchenbauer watermarking scheme are proposed, along traditional sampling and greedy search without watermarking.\n
Once the text is generated, you can eventually apply some attacks to it (e.g, remove words), select the associated detection method and run the detection. Please note that the detection is non-blind, and require the original prompt to be known and so left untouched.\n
For Aaronson, the original detection function, along the Neyman-Pearson and Simplified Score version are available.""")
with gr.Row():
user = gr.Dropdown(choices=USERS, value=USERS[0], label="User")
text_length = gr.Number(minimum=1, maximum=512, value=256, step=1, precision=0, label="Max text length")
window_size = gr.Number(minimum=0, maximum=10, value=0, step=1, precision=0, label="Watermarking window size")
embed_method = gr.Dropdown(choices=EMBED_METHODS, value=EMBED_METHODS[0], label="Sampling method")
prompt = gr.Textbox(label="prompt")
with gr.Row():
btn1 = gr.Button("Embed")
with gr.Row():
watermarked_text = gr.Textbox(label="Generated text")
detect_method = gr.Dropdown(choices=DETECT_METHODS, value=DETECT_METHODS[0], label="Detection method")
with gr.Row():
btn2 = gr.Button("Detect")
with gr.Row():
detection_label = gr.Label(label="Detection result")
btn1.click(fn=embed, inputs=[user, text_length, window_size, embed_method, prompt], outputs=[watermarked_text], api_name="watermark")
btn2.click(fn=detect, inputs=[watermarked_text, window_size, detect_method, prompt], outputs=[detection_label], api_name="detect")
demo.launch()