Files changed (1) hide show
  1. app.py +30 -7
app.py CHANGED
@@ -1,9 +1,25 @@
1
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import io
 
3
  import random
 
4
 
5
  import requests
6
  from PIL import Image
 
7
  from dataset_viber import AnnotatorInterFace
8
 
9
  HF_TOKEN = os.environ["HF_TOKEN"]
@@ -32,10 +48,17 @@ def get_rows():
32
 
33
 
34
  def generate_response(prompt):
35
- payload = {
36
- "inputs": prompt,
37
- }
38
- response = requests.post(MODEL_URL, headers=HEADERS, json=payload)
 
 
 
 
 
 
 
39
  image = Image.open(io.BytesIO(response.content))
40
  return image
41
 
@@ -46,9 +69,9 @@ def next_input(_prompt, _completion_a, _completion_b):
46
  generated_image = generate_response(prompt)
47
  return (prompt, img_url, generated_image)
48
 
 
49
  if __name__ == "__main__":
50
  interface = AnnotatorInterFace.for_image_generation_preference(
51
- fn=next_input,
52
- dataset_name=None,
53
  )
54
  interface.launch()
 
1
+ # Copyright 2024-present, David Berenstein, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
  import io
16
+ import os
17
  import random
18
+ import time
19
 
20
  import requests
21
  from PIL import Image
22
+
23
  from dataset_viber import AnnotatorInterFace
24
 
25
  HF_TOKEN = os.environ["HF_TOKEN"]
 
48
 
49
 
50
  def generate_response(prompt):
51
+ def _get_response(prompt):
52
+ payload = {
53
+ "inputs": prompt,
54
+ }
55
+ response = requests.post(MODEL_URL, headers=HEADERS, json=payload)
56
+ if response.status_code != 200:
57
+ time.sleep(10)
58
+ return _get_response(prompt)
59
+ return response
60
+
61
+ response = _get_response(prompt)
62
  image = Image.open(io.BytesIO(response.content))
63
  return image
64
 
 
69
  generated_image = generate_response(prompt)
70
  return (prompt, img_url, generated_image)
71
 
72
+
73
  if __name__ == "__main__":
74
  interface = AnnotatorInterFace.for_image_generation_preference(
75
+ interactive=False, fn_next_input=next_input
 
76
  )
77
  interface.launch()