hra commited on
Commit
cf9e738
1 Parent(s): 35d6bb5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import requests
3
+ import gradio as gr
4
+ import random
5
+ import time
6
+ import os
7
+ import datetime
8
+ from datetime import datetime
9
+ from PIL import Image
10
+ from PIL import ImageOps
11
+ from PIL import Image, ImageDraw, ImageFont
12
+ from textwrap import wrap
13
+ import json
14
+ from io import BytesIO
15
+
16
+ print('for update')
17
+
18
+ API_TOKEN = os.getenv("API_TOKEN")
19
+ HRA_TOKEN=os.getenv("HRA_TOKEN")
20
+
21
+
22
+ from huggingface_hub import InferenceApi
23
+ inference = InferenceApi("bigscience/bloom",token=API_TOKEN)
24
+
25
+ headers = {'Content-type': 'application/json', 'Accept': 'text/plain'}
26
+ url_hraprompts='https://us-central1-createinsightsproject.cloudfunctions.net/gethrahfprompts'
27
+
28
+ data={"prompt_type":'stable_diffusion_tee_shirt_text',"hra_token":HRA_TOKEN}
29
+ try:
30
+ r = requests.post(url_hraprompts, data=json.dumps(data), headers=headers)
31
+ except requests.exceptions.ReadTimeout as e:
32
+ print(e)
33
+ #print(r.content)
34
+
35
+
36
+ prompt_text=str(r.content, 'UTF-8')
37
+ print(prompt_text)
38
+ data={"prompt_type":'stable_diffusion_tee_shirt_image',"hra_token":HRA_TOKEN}
39
+ try:
40
+ r = requests.post(url_hraprompts, data=json.dumps(data), headers=headers)
41
+ except requests.exceptions.ReadTimeout as e:
42
+ print(e)
43
+ #print(r.content)
44
+
45
+ prompt_image=str(r.content, 'UTF-8')
46
+ print(prompt_image)
47
+
48
+ ENDPOINT_URL="https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-2-1" # url of your endpoint
49
+ #ENDPOINT_URL="https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-1-5" # url of your endpoint
50
+ HF_TOKEN=API_TOKEN# token where you deployed your endpoint
51
+
52
+ def generate_image(prompt_SD:str):
53
+ payload = {"inputs": prompt_SD,}
54
+ headers = {
55
+ "Authorization": f"Bearer {HF_TOKEN}",
56
+ "Content-Type": "application/json",
57
+ "Accept": "image/png" # important to get an image back
58
+ }
59
+ response = requests.post(ENDPOINT_URL, headers=headers, json=payload)
60
+ #print(response.content)
61
+ img = Image.open(BytesIO(response.content))
62
+
63
+ return img
64
+
65
+ def infer(prompt,
66
+ max_length = 250,
67
+ top_k = 0,
68
+ num_beams = 0,
69
+ no_repeat_ngram_size = 2,
70
+ top_p = 0.9,
71
+ seed=42,
72
+ temperature=0.7,
73
+ greedy_decoding = False,
74
+ return_full_text = False):
75
+
76
+ print(seed)
77
+ top_k = None if top_k == 0 else top_k
78
+ do_sample = False if num_beams > 0 else not greedy_decoding
79
+ num_beams = None if (greedy_decoding or num_beams == 0) else num_beams
80
+ no_repeat_ngram_size = None if num_beams is None else no_repeat_ngram_size
81
+ top_p = None if num_beams else top_p
82
+ early_stopping = None if num_beams is None else num_beams > 0
83
+
84
+ params = {
85
+ "max_new_tokens": max_length,
86
+ "top_k": top_k,
87
+ "top_p": top_p,
88
+ "temperature": temperature,
89
+ "do_sample": do_sample,
90
+ "seed": seed,
91
+ "early_stopping":early_stopping,
92
+ "no_repeat_ngram_size":no_repeat_ngram_size,
93
+ "num_beams":num_beams,
94
+ "return_full_text":return_full_text
95
+ }
96
+
97
+ s = time.time()
98
+ response = inference(prompt, params=params)
99
+ #print(response)
100
+ proc_time = time.time()-s
101
+ #print(f"Processing time was {proc_time} seconds")
102
+ return response
103
+
104
+ def getadline(text_inp):
105
+ print(text_inp)
106
+ print(datetime.today().strftime("%d-%m-%Y"))
107
+
108
+ text = prompt_text+"\nInput:"+text_inp + "\nOutput:"
109
+ resp = infer(text,seed=random.randint(0,100))
110
+
111
+ generated_text=resp[0]['generated_text']
112
+ result = generated_text.replace(text,'').strip()
113
+ result = result.replace("Output:","")
114
+ parts = result.split("###")
115
+ topic = parts[0].strip()
116
+ topic="\n".join(topic.split('\n'))
117
+
118
+ response_nsfw = requests.get('https://github.com/coffee-and-fun/google-profanity-words/raw/main/data/list.txt')
119
+ data_nsfw = response_nsfw.text
120
+ nsfwlist=data_nsfw.split('\n')
121
+ nsfwlowerlist=[]
122
+ for each in nsfwlist:
123
+ if each!='':
124
+ nsfwlowerlist.append(each.lower())
125
+ nsfwlowerlist.extend(['bra','gay','lesbian',])
126
+ print(topic)
127
+ mainstring=text_inp
128
+ foundnsfw=0
129
+ for each_word in nsfwlowerlist:
130
+ raw_search_string = r"\b" + each_word + r"\b"
131
+ match_output = re.search(raw_search_string, mainstring)
132
+ no_match_was_found = ( match_output is None )
133
+ if no_match_was_found:
134
+ foundnsfw=0
135
+ else:
136
+ foundnsfw=1
137
+ print(each_word)
138
+ break
139
+ if foundnsfw==1:
140
+ topic="Unsafe content found. Please try again with different prompts."
141
+ print(topic)
142
+ return(topic)
143
+
144
+ def getadvertisement(topic):
145
+ if topic!='':
146
+ input_keyword=topic
147
+ else:
148
+ input_keyword=getadline(random.choice('abcdefghijklmnopqrstuvwxyz'))
149
+ if 'Unsafe content found' in input_keyword:
150
+ input_keyword='Abstarct art with splash of colors'
151
+ prompt_SD=input_keyword+','+prompt_image
152
+ # generate image
153
+ image = generate_image(prompt_SD)
154
+
155
+ # save to disk
156
+ image.save("finalimage.png")
157
+
158
+ return 'finalimage.png'
159
+
160
+
161
+ with gr.Blocks() as demo:
162
+ gr.Markdown("<h1><center>Tee Shirt Designs</center></h1>")
163
+ gr.Markdown(
164
+ """Enter a prompt and get the tee shirt design. Use examples as a guide. We use an equally powerful AI model bigscience/bloom."""
165
+ )
166
+ textbox = gr.Textbox(placeholder="Enter prompt...", lines=1,label='Your prompt')
167
+ btn = gr.Button("Generate")
168
+ #output1 = gr.Textbox(lines=2,label='Market Sizing Framework')
169
+ output_image = gr.components.Image(label="Your tee shirt")
170
+
171
+
172
+ btn.click(getadvertisement,inputs=[textbox], outputs=[output_image])
173
+ examples = gr.Examples(examples=['anime art of man fighting','intricate skull','heavy metal band cover','abstract art of plants',],
174
+ inputs=[textbox])
175
+
176
+
177
+ demo.launch()