zhtwbloomdemo / app.py
jeffeux's picture
Update app.py
2bf069f
raw
history blame contribute delete
No virus
4.34 kB
# ------------------- LIBRARIES -------------------- #
import os, logging, torch, streamlit as st
from transformers import (
AutoTokenizer, AutoModelForCausalLM)
# --------------------- HELPER --------------------- #
def C(text, color="yellow"):
color_dict: dict = dict(
red="\033[01;31m",
green="\033[01;32m",
yellow="\033[01;33m",
blue="\033[01;34m",
magenta="\033[01;35m",
cyan="\033[01;36m",
)
color_dict[None] = "\033[0m"
return (
f"{color_dict.get(color, None)}"
f"{text}{color_dict[None]}")
def stcache():
from packaging import version
if version.parse(st.__version__) < version.parse("1.18"):
return lambda f: st.cache(suppress_st_warning=True)(f)
return lambda f: st.cache_resource()(f)
st.title("`ckip-joint/bloom-1b1-zh` demo")
# ------------------ ENVIORNMENT ------------------- #
os.environ["HF_ENDPOINT"] = "https://huggingface.co"
device = ("cuda"
if torch.cuda.is_available() else "cpu")
logging.info(C("[INFO] "f"device = {device}"))
# ------------------ INITITALIZE ------------------- #
stdec = stcache()
@stdec
def model_init():
logging.info(C("[INFO] "f"Model init start!"))
from transformers import GenerationConfig
# generation_config, unused_kwargs = GenerationConfig.from_pretrained(
# "ckip-joint/bloom-1b1-zh",
# max_new_tokens=200,
# return_unused_kwargs=True)
tokenizer = AutoTokenizer.from_pretrained(
"ckip-joint/bloom-1b1-zh")
model = AutoModelForCausalLM.from_pretrained(
"ckip-joint/bloom-1b1-zh",
# Ref.: Eric, Thanks!
# torch_dtype="auto",
# device_map="auto",
# Ref. for `half`: Chan-Jan, Thanks!
).eval().to(device)
st.balloons()
logging.info(C("[INFO] "f"Model init success!"))
return tokenizer, model
tokenizer, model = model_init()
if 1:
try:
# ===================== INPUT ====================== #
prompt = st.text_input("Prompt: ")
# =================== INFERENCE ==================== #
if prompt:
# placeholder = st.empty()
# st.title(prompt)
with st.container():
st.markdown(f""
f":violet[{prompt}]⋯⋯"
)
# st.empty()
with torch.no_grad():
[texts_out] = model.generate(
**tokenizer(
prompt, return_tensors="pt",
).to(device),
min_new_tokens=0,
max_new_tokens=100,
)
output_text = tokenizer.decode(texts_out,
skip_special_tokens=True,
)
st.empty()
if output_text.startswith(prompt):
out_gens = output_text[len(prompt):]
assert prompt + out_gens == output_text
else:
out_gens = output_text
prompt = ""
st.balloons()
out_gens = out_gens.split('\n')[0]
def multiline(string):
lines = string.split('\n')
return '\\\n'.join([f"**:red[{l}]**"
for l in lines])
# st.empty()
st.caption("Result: ")
st.markdown(f""
f":blue[{prompt}]**:red[{multiline(out_gens)}]**"
)
# st.text(repr(out_gens0))
except Exception as err:
st.write(str(err))
st.snow()
# import streamlit as st
# st.markdown('Streamlit is **_really_ cool**.')
# st.markdown("This text is :red[colored red], and this is **:blue[colored]** and bold.")
# st.markdown(":green[$\sqrt{x^2+y^2}=1$] is a Pythagorean identity. :pencil:")
# def multiline(string):
# lines = string.split('\n')
# return '\\\n'.join([f"**:red[{l}]**"
# for l in lines])
# st.markdown(multiline("1234 \n5616"))
# st.markdown("1234\\\n5616")
# https://docs.streamlit.io/library/api-reference/status/st.spinner
# https://stackoverflow.com/questions/32402502/how-to-change-the-time-zone-in-python-logging