dhuynh95's picture
Upload 3 files
09caaea
raw
history blame
7.43 kB
import gradio as gr
import pandas as pd
import os
from huggingface_hub import InferenceClient, login
from transformers import AutoTokenizer
import evaluate
bleu = evaluate.load("bleu")
HF_TOKEN = os.environ.get("HF_TOKEN", None)
print(HF_TOKEN)
client = InferenceClient(model="bigcode/starcoder", token=HF_TOKEN)
login(token=HF_TOKEN)
checkpoint = "bigcode/starcoder"
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_auth_token=True)
df = pd.read_csv("samples.csv")
sample_df = df.loc[~df.prediction_50.isna()]
description = "<h1 style='text-align: center; color: #333333; font-size: 40px;'>StarCoder Memorization Verifier"
high_bleu_examples = {
"Example 1": """from django.contrib import admin
from .models import SearchResult
# Register your models here.
class SearchResultAdmin(admin.ModelAdmin):
fields = ["query", "heading", "url", "text"]
admin.site.register(SearchResult, SearchResultAdmin)""",
"Example 2": """class Solution:
def finalPrices(self, prices: List[int]) -> List[int]:
res = []
for i in range(len(prices)):
for j in range(i+1,len(prices)):
if prices[j]<=prices[i]:
res.append(prices[i]-prices[j])
break
if j==len(prices)-1:
res.append(prices[i])
res.append(prices[-1])
return res""",
"Example 3": """from data_collection.management.commands import BaseXpressDemocracyClubCsvImporter
class Command(BaseXpressDemocracyClubCsvImporter):
council_id = 'E06000027'
addresses_name = 'parl.2017-06-08/Version 1/Torbay Democracy_Club__08June2017.tsv'
stations_name = 'parl.2017-06-08/Version 1/Torbay Democracy_Club__08June2017.tsv'
elections = ['parl.2017-06-08']
csv_delimiter = '\t'
"""
}
low_bleu_examples = {
"Example 1": """from zeit.cms.i18n import MessageFactory as _
import zope.interface
import zope.schema
class IGlobalSettings(zope.interface.Interface):
\"""Global CMS settings.\"""
default_year = zope.schema.Int(
title=_("Default year"),
min=1900,
max=2100)
default_volume = zope.schema.Int(
title=_("Default volume"),
min=1,
max=54)
def get_working_directory(template):
\"""Return the collection which is the main working directory.
template:
Template which will be filled with year and volume. In
``template`` the placeholders $year and $volume will be replaced.
Example: 'online/$year/$volume/foo'
If the respective collection does not exist, it will be created before
returning it.
\"""
""",
"Example 2": """# -*- coding: utf-8 -*-
\"""Context managers implemented for (mostly) internal use\"""
import contextlib
import functools
from io import UnsupportedOperation
import os
import sys
__all__ = ["RedirectStdout", "RedirectStderr"]
@contextlib.contextmanager
def _stdchannel_redirected(stdchannel, dest_filename, mode="w"):
\"""
A context manager to temporarily redirect stdout or stderr
Originally by Marc Abramowitz, 2013
(http://marc-abramowitz.com/archives/2013/07/19/python-context-manager-for-redirected-stdout-and-stderr/)
\"""
oldstdchannel = None
dest_file = None
try:
if stdchannel is None:
yield iter([None])
else:
oldstdchannel = os.dup(stdchannel.fileno())
dest_file = open(dest_filename, mode)
os.dup2(dest_file.fileno(), stdchannel.fileno())
yield
except (UnsupportedOperation, AttributeError):
yield iter([None])
finally:
if oldstdchannel is not None:
os.dup2(oldstdchannel, stdchannel.fileno())
if dest_file is not None:
dest_file.close()
RedirectStdout = functools.partial(_stdchannel_redirected, sys.stdout)
RedirectStderr = functools.partial(_stdchannel_redirected, sys.stderr)
RedirectNoOp = functools.partial(_stdchannel_redirected, None, "")
""",
"Example 3": """\"""Utils for criterion.\"""
import torch
import torch.nn.functional as F
def normalize(x, axis=-1):
\"""Performs L2-Norm.\"""
num = x
denom = torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12
return num / denom
# Source : https://github.com/earhian/Humpback-Whale-Identification-1st-/blob/master/models/triplet_loss.py
def euclidean_dist(x, y):
\"""Computes Euclidean distance.\"""
m, n = x.size(0), y.size(0)
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
yy = torch.pow(x, 2).sum(1, keepdim=True).expand(m, m).t()
dist = xx + yy - 2 * torch.matmul(x, y.t())
dist = dist.clamp(min=1e-12).sqrt()
return dist
def cosine_dist(x, y):
\"""Computes Cosine Distance.\"""
x = F.normalize(x, dim=1)
y = F.normalize(y, dim=1)
dist = 2 - 2 * torch.mm(x, y.t())
return dist
"""
}
def complete(sample, k):
prefix_tokens = tokenizer(sample)["input_ids"][:k]
prefix = tokenizer.decode(prefix_tokens)
output = prefix
for token in client.text_generation(prefix, do_sample=False, max_new_tokens=512, stream=True):
if token == "<|endoftext|>":
bleu_score = {"BLEU": bleu.compute(predictions=[sample],
references=[output])["bleu"]}
return output, gr.Label.update(value=bleu_score)
output += token
bleu_score = {"BLEU": bleu.compute(predictions=[sample],
references=[output])["bleu"]}
yield output, gr.Label.update(value=bleu_score)
bleu_score = {"BLEU": bleu.compute(predictions=[sample],
references=[output])["bleu"]}
return output, gr.Label.update(value=bleu_score)
def high_bleu_mirror(x):
output = high_bleu_examples[x]
return output
def low_bleu_mirror(x):
output = low_bleu_examples[x]
return output
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown(description)
with gr.Row():
with gr.Column():
instruction = gr.Textbox(
placeholder="Enter your code here",
lines=5,
label="Original",
)
with gr.Accordion("Advanced parameters", open=False):
k = gr.Slider(minimum=1, maximum=250, value=50)
submit = gr.Button("Check", variant="primary")
high_bleu_examples = gr.Examples(list(high_bleu_examples.keys()), label="High memorization samples",
inputs=instruction, outputs=instruction,
fn=high_bleu_mirror, cache_examples=True)
low_bleu_examples = gr.Examples(list(low_bleu_examples.keys()), label = "Low memorization samples",
inputs=instruction, outputs=instruction,
fn=low_bleu_mirror, cache_examples=True)
with gr.Column():
output = gr.Textbox(lines=5,
label="Completion", interactive=False)
label = gr.Label(value={"BLEU": 0},
label="Similarity score (BLEU)")
submit.click(
complete,
inputs=[instruction, k],
outputs=[output, label],
)
demo.queue(concurrency_count=16).launch(debug=True)