Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
import os | |
from huggingface_hub import InferenceClient, login | |
from transformers import AutoTokenizer | |
import evaluate | |
bleu = evaluate.load("bleu") | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
print(HF_TOKEN) | |
client = InferenceClient(model="bigcode/starcoder", token=HF_TOKEN) | |
login(token=HF_TOKEN) | |
checkpoint = "bigcode/starcoder" | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_auth_token=True) | |
df = pd.read_csv("samples.csv") | |
sample_df = df.loc[~df.prediction_50.isna()] | |
description = "<h1 style='text-align: center; color: #333333; font-size: 40px;'>StarCoder Memorization Verifier" | |
high_bleu_examples = { | |
"Example 1": """from django.contrib import admin | |
from .models import SearchResult | |
# Register your models here. | |
class SearchResultAdmin(admin.ModelAdmin): | |
fields = ["query", "heading", "url", "text"] | |
admin.site.register(SearchResult, SearchResultAdmin)""", | |
"Example 2": """class Solution: | |
def finalPrices(self, prices: List[int]) -> List[int]: | |
res = [] | |
for i in range(len(prices)): | |
for j in range(i+1,len(prices)): | |
if prices[j]<=prices[i]: | |
res.append(prices[i]-prices[j]) | |
break | |
if j==len(prices)-1: | |
res.append(prices[i]) | |
res.append(prices[-1]) | |
return res""", | |
"Example 3": """from data_collection.management.commands import BaseXpressDemocracyClubCsvImporter | |
class Command(BaseXpressDemocracyClubCsvImporter): | |
council_id = 'E06000027' | |
addresses_name = 'parl.2017-06-08/Version 1/Torbay Democracy_Club__08June2017.tsv' | |
stations_name = 'parl.2017-06-08/Version 1/Torbay Democracy_Club__08June2017.tsv' | |
elections = ['parl.2017-06-08'] | |
csv_delimiter = '\t' | |
""" | |
} | |
low_bleu_examples = { | |
"Example 1": """from zeit.cms.i18n import MessageFactory as _ | |
import zope.interface | |
import zope.schema | |
class IGlobalSettings(zope.interface.Interface): | |
\"""Global CMS settings.\""" | |
default_year = zope.schema.Int( | |
title=_("Default year"), | |
min=1900, | |
max=2100) | |
default_volume = zope.schema.Int( | |
title=_("Default volume"), | |
min=1, | |
max=54) | |
def get_working_directory(template): | |
\"""Return the collection which is the main working directory. | |
template: | |
Template which will be filled with year and volume. In | |
``template`` the placeholders $year and $volume will be replaced. | |
Example: 'online/$year/$volume/foo' | |
If the respective collection does not exist, it will be created before | |
returning it. | |
\""" | |
""", | |
"Example 2": """# -*- coding: utf-8 -*- | |
\"""Context managers implemented for (mostly) internal use\""" | |
import contextlib | |
import functools | |
from io import UnsupportedOperation | |
import os | |
import sys | |
__all__ = ["RedirectStdout", "RedirectStderr"] | |
@contextlib.contextmanager | |
def _stdchannel_redirected(stdchannel, dest_filename, mode="w"): | |
\""" | |
A context manager to temporarily redirect stdout or stderr | |
Originally by Marc Abramowitz, 2013 | |
(http://marc-abramowitz.com/archives/2013/07/19/python-context-manager-for-redirected-stdout-and-stderr/) | |
\""" | |
oldstdchannel = None | |
dest_file = None | |
try: | |
if stdchannel is None: | |
yield iter([None]) | |
else: | |
oldstdchannel = os.dup(stdchannel.fileno()) | |
dest_file = open(dest_filename, mode) | |
os.dup2(dest_file.fileno(), stdchannel.fileno()) | |
yield | |
except (UnsupportedOperation, AttributeError): | |
yield iter([None]) | |
finally: | |
if oldstdchannel is not None: | |
os.dup2(oldstdchannel, stdchannel.fileno()) | |
if dest_file is not None: | |
dest_file.close() | |
RedirectStdout = functools.partial(_stdchannel_redirected, sys.stdout) | |
RedirectStderr = functools.partial(_stdchannel_redirected, sys.stderr) | |
RedirectNoOp = functools.partial(_stdchannel_redirected, None, "") | |
""", | |
"Example 3": """\"""Utils for criterion.\""" | |
import torch | |
import torch.nn.functional as F | |
def normalize(x, axis=-1): | |
\"""Performs L2-Norm.\""" | |
num = x | |
denom = torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12 | |
return num / denom | |
# Source : https://github.com/earhian/Humpback-Whale-Identification-1st-/blob/master/models/triplet_loss.py | |
def euclidean_dist(x, y): | |
\"""Computes Euclidean distance.\""" | |
m, n = x.size(0), y.size(0) | |
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) | |
yy = torch.pow(x, 2).sum(1, keepdim=True).expand(m, m).t() | |
dist = xx + yy - 2 * torch.matmul(x, y.t()) | |
dist = dist.clamp(min=1e-12).sqrt() | |
return dist | |
def cosine_dist(x, y): | |
\"""Computes Cosine Distance.\""" | |
x = F.normalize(x, dim=1) | |
y = F.normalize(y, dim=1) | |
dist = 2 - 2 * torch.mm(x, y.t()) | |
return dist | |
""" | |
} | |
def complete(sample, k): | |
prefix_tokens = tokenizer(sample)["input_ids"][:k] | |
prefix = tokenizer.decode(prefix_tokens) | |
output = prefix | |
for token in client.text_generation(prefix, do_sample=False, max_new_tokens=512, stream=True): | |
if token == "<|endoftext|>": | |
bleu_score = {"BLEU": bleu.compute(predictions=[sample], | |
references=[output])["bleu"]} | |
return output, gr.Label.update(value=bleu_score) | |
output += token | |
bleu_score = {"BLEU": bleu.compute(predictions=[sample], | |
references=[output])["bleu"]} | |
yield output, gr.Label.update(value=bleu_score) | |
bleu_score = {"BLEU": bleu.compute(predictions=[sample], | |
references=[output])["bleu"]} | |
return output, gr.Label.update(value=bleu_score) | |
def high_bleu_mirror(x): | |
output = high_bleu_examples[x] | |
return output | |
def low_bleu_mirror(x): | |
output = low_bleu_examples[x] | |
return output | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
instruction = gr.Textbox( | |
placeholder="Enter your code here", | |
lines=5, | |
label="Original", | |
) | |
with gr.Accordion("Advanced parameters", open=False): | |
k = gr.Slider(minimum=1, maximum=250, value=50) | |
submit = gr.Button("Check", variant="primary") | |
high_bleu_examples = gr.Examples(list(high_bleu_examples.keys()), label="High memorization samples", | |
inputs=instruction, outputs=instruction, | |
fn=high_bleu_mirror, cache_examples=True) | |
low_bleu_examples = gr.Examples(list(low_bleu_examples.keys()), label = "Low memorization samples", | |
inputs=instruction, outputs=instruction, | |
fn=low_bleu_mirror, cache_examples=True) | |
with gr.Column(): | |
output = gr.Textbox(lines=5, | |
label="Completion", interactive=False) | |
label = gr.Label(value={"BLEU": 0}, | |
label="Similarity score (BLEU)") | |
submit.click( | |
complete, | |
inputs=[instruction, k], | |
outputs=[output, label], | |
) | |
demo.queue(concurrency_count=16).launch(debug=True) |