Spaces:
Sleeping
Sleeping
File size: 8,757 Bytes
6875735 25f7388 6875735 1dce944 7fd01f5 1dce944 25f7388 1dce944 25f7388 1dce944 50f543c 1dce944 7cdbd19 cecdb11 1dce944 7cdbd19 cecdb11 bc8a52d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import subprocess
import sys
##Lines 1-8 are necessary because the normal requirements.txt path for installing a package from disk doesn't work on HF spaces, thank you to Omar Sanseviero for the help!
import numpy as np
import pandas as pd
import shap
import streamlit as st
import streamlit.components.v1 as components
from datasets import load_dataset
from transformers import (AutoModelForCausalLM, AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoTokenizer,
pipeline)
st.set_page_config(page_title="HF-SHAP")
st.title("HF-SHAP: A front end for SHAP")
st.caption("By Allen Roush")
st.caption("github: https://github.com/Hellisotherpeople")
st.caption("Linkedin: https://www.linkedin.com/in/allen-roush-27721011b/")
st.title("SHAP (SHapley Additive exPlanations)")
st.image("https://shap.readthedocs.io/en/latest/_images/shap_header.png", width = 700)
st.caption("By Lundberg, Scott M and Lee, Su-In")
st.caption("Slightly modified by Allen Roush to fix a bug with text plotting not working outside of Jupyter Notebooks")
st.caption("Full Citation: https://raw.githubusercontent.com/slundberg/shap/master/docs/references/shap_nips.bib")
st.caption("See on github:: https://github.com/slundberg/shap")
st.caption("More details of how SHAP works: https://christophm.github.io/interpretable-ml-book/shap.html")
form = st.sidebar.form("Main Settings")
form.header("Main Settings")
task_done = form.selectbox("Which NLP task do you want to solve?", ["Text Generation", "Sentiment Analysis", "Translation", "Summarization"])
custom_doc = form.checkbox("Use a document from an existing dataset?", value = False)
if custom_doc:
dataset_name = form.text_area("Enter the name of the huggingface Dataset to do analysis of:", value = "Hellisotherpeople/DebateSum")
dataset_name_2 = form.text_area("Enter the name of the config for the dataset if it has one", value = "")
split_name = form.text_area("Enter the name of the split of the dataset that you want to use", value = "train")
number_of_records = form.number_input("Enter the number of documents that you want to analyze from the dataset", value = 200)
column_name = form.text_area("Enter the name of the column that we are doing analysis on (the X value)", value = "Full-Document")
index_to_analyze_start = form.number_input("Enter the index start of the document that you want to analyze of the dataset", value = 1)
index_to_analyze_end = form.number_input("Enter the index end of the document that you want to analyze of the dataset", value = 2)
form.caption("Multiple documents may not work on certain tasks")
else:
doc = st.text_area("Enter a custom document", value = "This is an example custom document")
if task_done == "Text Generation":
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text Generation", value = "gpt2")
form.caption("This will download a new model, so it may take awhile or even break if the model is too large")
decoder = form.checkbox("Is this a decoder model?", value = True)
form.caption("This should be true for models like GPT-2, and false for models like BERT")
max_length = form.number_input("What's the max length of the text?", value = 50)
min_length = form.number_input("What's the min length of the text?", value = 20, max_value = max_length)
penalize_repetion = form.number_input("How strongly do we want to penalize repetition in the text generation?", value = 2)
sample = form.checkbox("Shall we use top-k and top-p decoding?", value = True)
form.caption("Setting this to false makes it greedy")
if sample:
top_k = form.number_input("What value of K should we use for Top-K sampling? Set to zero to disable", value = 50)
form.caption("In Top-K sampling, the K most likely next words are filtered and the probability mass is redistributed among only those K next words. ")
top_p = form.number_input("What value of P should we use for Top-p sampling? Set to zero to disable", value = 0.95, max_value = 1.0, min_value = 0.0)
form.caption("Top-p sampling chooses from the smallest possible set of words whose cumulative probability exceeds the probability p. The probability mass is then redistributed among this set of words.")
temperature = form.number_input("How spicy/interesting do we want our models output to be", value = 1.05, min_value = 0.0)
form.caption("Setting this higher decreases the likelihood of high probability words and increases the likelihood of low probability (and presumably more interesting) words")
form.caption("For more details on what these settings mean, see here: https://huggingface.co/blog/how-to-generate")
elif task_done == "Sentiment Analysis":
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Sentiment Analysis", value = "nateraw/bert-base-uncased-emotion")
rescale_logits = form.checkbox("Do we rescale the probabilities in terms of log odds?", value = False)
elif task_done == "Translation":
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Translation", value = "Helsinki-NLP/opus-mt-en-es")
elif task_done == "Summarization":
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Translation", value = "sshleifer/distilbart-xsum-12-1")
else:
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Question Answering", value = "deepset/roberta-base-squad2")
form.header("Model Explanation Display Settings")
output_width = form.number_input("Enter the number of pixels for width of model explanation html display", value = 800)
output_height = form.number_input("Enter the number of pixels for height of model explanation html display", value = 1000)
form.form_submit_button("Submit")
@st.cache
def load_and_process_data(path, name, streaming, split_name, number_of_records):
dataset = load_dataset(path = path, name = name, streaming=streaming)
#return list(dataset)
dataset_head = dataset[split_name].take(number_of_records)
df = pd.DataFrame.from_dict(dataset_head)
return df[column_name]
@st.cache(allow_output_mutation=True)
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
if task_done == "Text Generation":
model = AutoModelForCausalLM.from_pretrained(model_name)
model.config.is_decoder=decoder
if sample == True:
model.config.task_specific_params["text-generation"] = {"do_sample": sample, "max_length": max_length, "min_length": min_length, "temperature": temperature, "top_k": top_k, "top_p" : top_p, "no_repeat_ngram_size": penalize_repetion}
else:
model.config.task_specific_params["text-generation"] = {"do_sample": sample, "max_length": max_length, "min_length": min_length, "no_repeat_ngram_size": penalize_repetion}
elif task_done == "Sentiment Analysis":
model = AutoModelForSequenceClassification.from_pretrained(model_name)
elif task_done == "Translation":
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
elif task_done == "Summarization":
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
elif task_done == "Question Answering":
#TODO: This one is going to be harder...
# https://shap.readthedocs.io/en/latest/example_notebooks/text_examples/question_answering/Explaining%20a%20Question%20Answering%20Transformers%20Model.html
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
return tokenizer, model
tokenizer, model = load_model(model_name)
if custom_doc:
df = load_and_process_data(dataset_name, dataset_name_2, True, split_name, number_of_records)
doc = list(df[index_to_analyze_start:index_to_analyze_end])
st.write(doc)
if task_done == "Sentiment Analysis":
pred = pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
else:
explainer = shap.Explainer(model, tokenizer)
if custom_doc:
shap_values = explainer(doc)
else:
shap_values = explainer([doc])
the_plot = shap.plots.text(shap_values, display = False)
st.caption("The plot is interactive! Try Hovering over or clicking on the input or output text")
components.html(the_plot, height = output_height, width = output_width, scrolling = True)
|