Spaces:
Runtime error
Runtime error
import streamlit as st | |
import time | |
import pandas as pd | |
import altair as alt | |
from multipage import MultiPage | |
from transformers import pipeline | |
def app(): | |
st.markdown('## Mask Fill task') | |
st.write('Write a sentence with a [MASK] gap to fill') | |
st.markdown('## ') | |
def get_model(model): | |
return pipeline('fill-mask', model=model) | |
def create_graph(answer): | |
x_bar = [i['token_str'] for i in answer] | |
y_bar = [i['score'] for i in answer] | |
chart_data = pd.DataFrame(y_bar, index=x_bar) | |
data = pd.melt(chart_data.reset_index(), id_vars=["index"]) | |
# Horizontal stacked bar chart | |
chart = ( | |
alt.Chart(data) | |
.mark_bar(color='#d7abf5') | |
.encode( | |
x=alt.X("index", type="nominal", title='',sort=alt.EncodingSortField(field="index", op="count", order='ascending')), | |
y=alt.Y("value", type="quantitative", title="Score", sort='-x'), | |
) | |
) | |
st.altair_chart(chart, use_container_width=True) | |
col1, col2 = st.columns([2,1]) | |
with col1: | |
prompt= st.text_area('Your prompt here', | |
'''Who is Elon [MASK]?''') | |
with col2: | |
select_model = st.radio( | |
"Select the model to use:", | |
('Bert cased', 'Bert Un-cased'), index = 1) | |
if select_model == 'Bert cased': | |
model = 'bert-base-cased' | |
elif select_model == 'Bert Un-cased': | |
model = 'bert-base-uncased' | |
with st.spinner('Loading Model... (This may take a while)'): | |
unmasker = get_model(model) | |
st.success('Model loaded correctly!') | |
gen = st.info('Generating Mask...') | |
answer = unmasker(prompt) | |
gen.empty() | |
with col1: | |
create_graph(answer) | |