jeffeux commited on
Commit
d64ab4a
1 Parent(s): 856d35d

ok? hope...

Browse files
Files changed (2) hide show
  1. app.py +76 -22
  2. requirements.txt +2 -1
app.py CHANGED
@@ -18,6 +18,14 @@ def C(text, color="yellow"):
18
  f"{color_dict.get(color, None)}"
19
  f"{text}{color_dict[None]}")
20
 
 
 
 
 
 
 
 
 
21
  # ------------------ ENVIORNMENT ------------------- #
22
  os.environ["HF_ENDPOINT"] = "https://huggingface.co"
23
  device = ("cuda"
@@ -25,9 +33,8 @@ device = ("cuda"
25
  logging.info(C("[INFO] "f"device = {device}"))
26
 
27
  # ------------------ INITITALIZE ------------------- #
28
- @st.cache(
29
- suppress_st_warning=True
30
- )
31
  def model_init():
32
 
33
  logging.info(C("[INFO] "f"Model init start!"))
@@ -60,25 +67,72 @@ def model_init():
60
 
61
  tokenizer, model = model_init()
62
 
63
- try:
64
- # ===================== INPUT ====================== #
65
- prompt = st.text_input("Prompt: ")
66
 
67
- # =================== INFERENCE ==================== #
68
- if prompt:
69
- st.title(prompt)
70
- with torch.no_grad():
71
- [texts_out] = model.generate(
72
- **tokenizer(
73
- prompt, return_tensors="pt",
74
-
75
- ).to(device),
76
- max_new_tokens=100,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  )
78
- output_text = tokenizer.decode(texts_out)
79
- st.balloons()
80
- st.markdown(output_text)
 
 
 
 
 
81
 
82
- except Exception as err:
83
- st.write(str(err))
84
- st.snow()
 
 
 
 
 
 
 
 
 
18
  f"{color_dict.get(color, None)}"
19
  f"{text}{color_dict[None]}")
20
 
21
+ def stcache():
22
+ from packaging import version
23
+ if version.parse(st.__version__) < version.parse("1.18"):
24
+ return lambda f: st.cache(suppress_st_warning=True)(f)
25
+ return lambda f: st.cache_resource()(f)
26
+
27
+ st.title("`ckip-joint/bloom-1b1-zh` demo")
28
+
29
  # ------------------ ENVIORNMENT ------------------- #
30
  os.environ["HF_ENDPOINT"] = "https://huggingface.co"
31
  device = ("cuda"
 
33
  logging.info(C("[INFO] "f"device = {device}"))
34
 
35
  # ------------------ INITITALIZE ------------------- #
36
+ stdec = stcache()
37
+ @stdec
 
38
  def model_init():
39
 
40
  logging.info(C("[INFO] "f"Model init start!"))
 
67
 
68
  tokenizer, model = model_init()
69
 
 
 
 
70
 
71
+ if 1:
72
+ try:
73
+ # ===================== INPUT ====================== #
74
+ prompt = st.text_input("Prompt: ")
75
+
76
+ # =================== INFERENCE ==================== #
77
+ if prompt:
78
+ # placeholder = st.empty()
79
+ # st.title(prompt)
80
+ with st.container():
81
+ st.markdown(f""
82
+ f":violet[{prompt}]⋯⋯"
83
+ )
84
+ # st.empty()
85
+
86
+ with torch.no_grad():
87
+ [texts_out] = model.generate(
88
+ **tokenizer(
89
+ prompt, return_tensors="pt",
90
+
91
+ ).to(device),
92
+ min_new_tokens=0,
93
+ max_new_tokens=100,
94
+ )
95
+ output_text = tokenizer.decode(texts_out,
96
+ skip_special_tokens=True,
97
+ )
98
+ st.empty()
99
+ if output_text.startswith(prompt):
100
+ out_gens = output_text[len(prompt):]
101
+ assert prompt + out_gens == output_text
102
+ else:
103
+ out_gens = output_text
104
+ prompt = ""
105
+ st.balloons()
106
+
107
+ def multiline(string):
108
+ lines = string.split('\n')
109
+ return '\\\n'.join([f"**:red[{l}]**"
110
+ for l in lines])
111
+
112
+
113
+
114
+ # st.empty()
115
+ st.caption("Result: ")
116
+ st.markdown(f""
117
+ f":blue[{prompt}]**:red[{multiline(out_gens)}]**"
118
  )
119
+ # st.text(repr(out_gens0))
120
+
121
+ except Exception as err:
122
+ st.write(str(err))
123
+ st.snow()
124
+
125
+
126
+ # import streamlit as st
127
 
128
+ # st.markdown('Streamlit is **_really_ cool**.')
129
+ # st.markdown("This text is :red[colored red], and this is **:blue[colored]** and bold.")
130
+ # st.markdown(":green[$\sqrt{x^2+y^2}=1$] is a Pythagorean identity. :pencil:")
131
+ # def multiline(string):
132
+ # lines = string.split('\n')
133
+ # return '\\\n'.join([f"**:red[{l}]**"
134
+ # for l in lines])
135
+ # st.markdown(multiline("1234 \n5616"))
136
+ # st.markdown("1234\\\n5616")
137
+ # https://docs.streamlit.io/library/api-reference/status/st.spinner
138
+ # https://stackoverflow.com/questions/32402502/how-to-change-the-time-zone-in-python-logging
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  torch
2
  transformers
3
  streamlit==1.17.0
4
- gradio==3.19.1
 
 
1
  torch
2
  transformers
3
  streamlit==1.17.0
4
+ gradio==3.19.1
5
+ packaging