Michael Ramos commited on
Commit
9522910
1 Parent(s): 64b8712

Adding Train and Validation set

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ spc
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoPipelineForImage2Image, AutoPipelineForText2Image
2
+ import torch
3
+ import os
4
+
5
+ try:
6
+ import intel_extension_for_pytorch as ipex
7
+ except:
8
+ pass
9
+
10
+ from PIL import Image
11
+ import numpy as np
12
+ import gradio as gr
13
+ import psutil
14
+ import time
15
+ import math
16
+
17
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
18
+ TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
19
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
20
+ # check if MPS is available OSX only M1/M2/M3 chips
21
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
22
+ xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
23
+ device = torch.device(
24
+ "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
25
+ )
26
+ torch_device = device
27
+ torch_dtype = torch.float16
28
+
29
+ print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
30
+ print(f"TORCH_COMPILE: {TORCH_COMPILE}")
31
+ print(f"device: {device}")
32
+
33
+ if mps_available:
34
+ device = torch.device("mps")
35
+ torch_device = "cpu"
36
+ torch_dtype = torch.float32
37
+
38
+ if SAFETY_CHECKER == "True":
39
+ i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
40
+ "stabilityai/sdxl-turbo",
41
+ torch_dtype=torch_dtype,
42
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
43
+ )
44
+ t2i_pipe = AutoPipelineForText2Image.from_pretrained(
45
+ "stabilityai/sdxl-turbo",
46
+ torch_dtype=torch_dtype,
47
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
48
+ )
49
+ else:
50
+ i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
51
+ "stabilityai/sdxl-turbo",
52
+ safety_checker=None,
53
+ torch_dtype=torch_dtype,
54
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
55
+ )
56
+ t2i_pipe = AutoPipelineForText2Image.from_pretrained(
57
+ "stabilityai/sdxl-turbo",
58
+ safety_checker=None,
59
+ torch_dtype=torch_dtype,
60
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
61
+ )
62
+
63
+ t2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
64
+ t2i_pipe.set_progress_bar_config(disable=True)
65
+ i2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
66
+ i2i_pipe.set_progress_bar_config(disable=True)
67
+
68
+ def resize_crop(image, size=512):
69
+ image = image.convert("RGB")
70
+ w, h = image.size
71
+ image = image.resize((size, int(size * (h / w))), Image.BICUBIC)
72
+ return image
73
+
74
+ # Global variable to store the selected image index
75
+ selected_image_index = None
76
+
77
+ # Load images from the 'images' folder
78
+ image_folder = 'images'
79
+ images = [Image.open(os.path.join(image_folder, img)) for img in sorted(os.listdir(image_folder)) if img.endswith(('.png', '.jpg', '.jpeg'))]
80
+
81
+ # Ensure that there are 34 images
82
+ assert len(images) == 34, "There should be exactly 34 images in the 'images' folder."
83
+
84
+ # Function to handle image selection
85
+ async def select_fn(data: gr.SelectData, prompt: str):
86
+ global selected_image_index
87
+ selected_image_index = data.index
88
+ return await predict(prompt)
89
+
90
+ async def predict(prompt):
91
+ global selected_image_index
92
+ strength = 0.49999999999999999
93
+ steps = 2
94
+ if selected_image_index is not None:
95
+ init_image = images[selected_image_index]
96
+ init_image = resize_crop(init_image)
97
+ generator = torch.manual_seed(123123)
98
+ last_time = time.time()
99
+
100
+ if int(steps * strength) < 1:
101
+ steps = math.ceil(1 / max(0.10, strength))
102
+
103
+ results = i2i_pipe(
104
+ prompt=prompt,
105
+ image=init_image,
106
+ generator=generator,
107
+ num_inference_steps=steps,
108
+ guidance_scale=0.0,
109
+ strength=strength,
110
+ width=512,
111
+ height=512,
112
+ output_type="pil",
113
+ )
114
+
115
+ print(f"Pipe took {time.time() - last_time} seconds")
116
+ nsfw_content_detected = (
117
+ results.nsfw_content_detected[0]
118
+ if "nsfw_content_detected" in results
119
+ else False
120
+ )
121
+ if nsfw_content_detected:
122
+ gr.Warning("NSFW content detected.")
123
+ return Image.new("RGB", (512, 512))
124
+ return results.images[0]
125
+
126
+ # Create the Gradio interface
127
+ with gr.Blocks() as app:
128
+ with gr.Row():
129
+ with gr.Column():
130
+ prompt = gr.Textbox(label="I see...")
131
+ image_gallery = gr.Gallery(value=images, columns=4) # Adjust number of columns as needed
132
+ with gr.Column():
133
+ output = gr.Image(label="Generation")
134
+
135
+ button = gr.Button("Rorschachify!")
136
+
137
+ image_gallery.select(select_fn, inputs=[prompt], outputs=output, show_progress=False)
138
+ button.click(fn=predict, inputs=[prompt], outputs=output, show_progress=False)
139
+ prompt.change(fn=predict, inputs=[prompt], outputs=output, show_progress=False)
140
+
141
+ # Run the app
142
+ app.queue()
143
+ app.launch()
images/002.png ADDED

Git LFS Details

  • SHA256: e2a8e9e73529736c6d0a5702b5d9dc595894ec773b79aeffe49f3405499e9d8e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.36 MB
images/003-m.png ADDED

Git LFS Details

  • SHA256: d9dd5c8bca6b0b2bb75b906b2ffa8de8184458c64bd661164cb6af8e465483b1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.31 MB
images/005.png ADDED

Git LFS Details

  • SHA256: 47e2bd19c42ca4942b950e1f7950342626f471a625c7715de10163f80152e122
  • Pointer size: 131 Bytes
  • Size of remote file: 429 kB
images/006.png ADDED

Git LFS Details

  • SHA256: 59070b7ee75956318619206480029a61ae02992aa3f71ab3519d55a9bae3b235
  • Pointer size: 131 Bytes
  • Size of remote file: 560 kB
images/008.png ADDED

Git LFS Details

  • SHA256: c10de0ee0ef69310cf4966557786e0f1656cadbed34da67fd8adb64f7ef46af4
  • Pointer size: 131 Bytes
  • Size of remote file: 640 kB
images/009.png ADDED

Git LFS Details

  • SHA256: 32442c9f49216f392603cb40cd3d8a5de1c8f07944e86552cde0118ff1c2f70c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.34 MB
images/010.png ADDED

Git LFS Details

  • SHA256: 1fe27201622749f62a9ffce12b62523ba7da63e692392d1d19515208838fd41b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.29 MB
images/011.png ADDED

Git LFS Details

  • SHA256: d56182c12df975080429bac8e09eb5f67e19b20b39b92a4ca6641215d2f4a865
  • Pointer size: 131 Bytes
  • Size of remote file: 382 kB
images/012.png ADDED

Git LFS Details

  • SHA256: d85cd0c486489db08b766bf1868afc99f00bbe62ba6f71fbbb3af62d6540410a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.08 MB
images/014.png ADDED

Git LFS Details

  • SHA256: 825417e751da7c6fa3b1cedc74de42e2542393e7893aaf90e2c94114cb618713
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
images/015.png ADDED

Git LFS Details

  • SHA256: 6b1753f802c9b2c30c163f1599f9cacc8d7dd5b56c3ad629ecd8bea35ce93838
  • Pointer size: 132 Bytes
  • Size of remote file: 1.48 MB
images/016.png ADDED

Git LFS Details

  • SHA256: 176158d1708e4054da7235eb8f90d2731b41781720bbf997b35bc95bb1cca202
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
images/017.png ADDED

Git LFS Details

  • SHA256: 2a9ae3b30c341277ac42255adb0a49c3821ccb66fc030efe6539a10c53cf7629
  • Pointer size: 131 Bytes
  • Size of remote file: 637 kB
images/021.png ADDED

Git LFS Details

  • SHA256: bba8198fbe5c8d00076db4fc9ddc980f297b8e2b2478ce87c1ae4ea05d5cf2f2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
images/024.png ADDED

Git LFS Details

  • SHA256: f14ae4cbc0756d6511ab297e9ab28514b2a7c3d84d7f8bf155f887fac917da7c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB
images/025.png ADDED

Git LFS Details

  • SHA256: bef5dcba5fabb9311c3831338059e5e096d0148f21f87f6b6e02a7852382f318
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB
images/3D0040.png ADDED

Git LFS Details

  • SHA256: f7b08cc9da3330b6397076840395bbb2b33b2292a551b2c0d9b99df251b66dfd
  • Pointer size: 131 Bytes
  • Size of remote file: 891 kB
images/Fig02_C.png ADDED

Git LFS Details

  • SHA256: 6cb677b319c073dad48a0ea7255bad9814d7d74474048c4ee721c2caf72df0ce
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
images/Fig02_F.png ADDED

Git LFS Details

  • SHA256: 81598338d891ba8d65ed449451ec53510654c2780c495a7591e5a93cbf591cd2
  • Pointer size: 131 Bytes
  • Size of remote file: 919 kB
images/Fig02_I.png ADDED

Git LFS Details

  • SHA256: f0f53a2a140c25ab3fce2e9e8601d959ce06cb19e01076f206ef6a518103e587
  • Pointer size: 132 Bytes
  • Size of remote file: 1.55 MB
images/Fig06_C.png ADDED

Git LFS Details

  • SHA256: 782d1ba6dbdff590b036cf7650d722b52b4d6d0f4d06e7f5bee688bb587e8012
  • Pointer size: 131 Bytes
  • Size of remote file: 435 kB
images/Oct6_Page_04.png ADDED

Git LFS Details

  • SHA256: 9cf0a38d39cbdeefaf499f96047ceb16dd34fa1727272e8272ded739792a87ec
  • Pointer size: 131 Bytes
  • Size of remote file: 638 kB
images/Oct6_Page_05.png ADDED

Git LFS Details

  • SHA256: 6a74a87058efd21c2228a0687cd20043a00b3fce024461d4ac97c9cfdb59161d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.32 MB
images/Oct6_Page_06.png ADDED

Git LFS Details

  • SHA256: 38d5f5c49eacb3e7e01a4bf6dea968ffc3ca13a1d0d0762e72c25ddb1c8f5ab8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
images/Oct6_Page_12-1.png ADDED

Git LFS Details

  • SHA256: 9f6d6e45d5f468434a0fdcdeaee4e6364230e3d52e8938a85f8261d69455a247
  • Pointer size: 131 Bytes
  • Size of remote file: 770 kB
images/Oct6_Page_12-2.png ADDED

Git LFS Details

  • SHA256: 830e57afc60c72fc50cf85591f0bc5c86c7ed7b24bfc63660b8b3173500221d7
  • Pointer size: 131 Bytes
  • Size of remote file: 925 kB
images/Oct6_Page_13.png ADDED

Git LFS Details

  • SHA256: 76ee57300031b0cd2f5337ca60a7835eb1b5ae2e162fce125d08acf8d9dc180f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.51 MB
images/Oct6_Page_18.png ADDED

Git LFS Details

  • SHA256: 597e29576393ed653d2e977cfc3a3331bcb28d8000cf558135828fba25291054
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
images/Oct6_Page_23.png ADDED

Git LFS Details

  • SHA256: ab3de4d6e08d7ac2ee6d5d6e959821786bc26f6584e6eea75726b56d45af88dc
  • Pointer size: 131 Bytes
  • Size of remote file: 984 kB
images/Oct6_Page_25.png ADDED

Git LFS Details

  • SHA256: e916e69e243448ed683065db88dc4908d979e07def55aa007de1aea65c2c8869
  • Pointer size: 131 Bytes
  • Size of remote file: 524 kB
images/Oct6_Page_29.png ADDED

Git LFS Details

  • SHA256: 14c4190d07c63c3ae1eed23c682d6594398ad40f2315922bffce981848638c4b
  • Pointer size: 131 Bytes
  • Size of remote file: 902 kB
images/Oct6_Page_32-1.png ADDED

Git LFS Details

  • SHA256: cf9e92ceb0790d0f57856dd7757a4e615fb8d99c16fb238d9381cf111f86f112
  • Pointer size: 131 Bytes
  • Size of remote file: 798 kB
images/Oct6_Page_32-2.png ADDED

Git LFS Details

  • SHA256: 28876eae00ca0c6a7097f7abd1f31545b6432dbfe23b389cccc41687f5e74d6f
  • Pointer size: 131 Bytes
  • Size of remote file: 871 kB
images/Oct6_Page_33.png ADDED

Git LFS Details

  • SHA256: a9751bc0be929ae10497cdc5f56120c5247a23f64f4f9cfbc269848225a07b59
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
requirements.txt ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.26.1
2
+ aiofiles==23.2.1
3
+ altair==5.2.0
4
+ annotated-types==0.6.0
5
+ anyio==4.2.0
6
+ attrs==23.2.0
7
+ certifi==2023.11.17
8
+ charset-normalizer==3.3.2
9
+ click==8.1.7
10
+ colorama==0.4.6
11
+ contourpy==1.2.0
12
+ cycler==0.12.1
13
+ diffusers==0.25.1
14
+ fastapi==0.109.0
15
+ ffmpy==0.3.1
16
+ filelock==3.13.1
17
+ fonttools==4.47.2
18
+ fsspec==2023.12.2
19
+ gradio==4.15.0
20
+ gradio_client==0.8.1
21
+ h11==0.14.0
22
+ httpcore==1.0.2
23
+ httpx==0.26.0
24
+ huggingface-hub==0.20.3
25
+ idna==3.6
26
+ importlib-metadata==7.0.1
27
+ importlib-resources==6.1.1
28
+ Jinja2==3.1.3
29
+ jsonschema==4.21.1
30
+ jsonschema-specifications==2023.12.1
31
+ kiwisolver==1.4.5
32
+ markdown-it-py==3.0.0
33
+ MarkupSafe==2.1.4
34
+ matplotlib==3.8.2
35
+ mdurl==0.1.2
36
+ mpmath==1.3.0
37
+ networkx==3.0
38
+ numpy==1.26.3
39
+ orjson==3.9.12
40
+ packaging==23.2
41
+ pandas==2.2.0
42
+ pillow==10.2.0
43
+ psutil==5.9.8
44
+ pydantic==2.5.3
45
+ pydantic_core==2.14.6
46
+ pydub==0.25.1
47
+ Pygments==2.17.2
48
+ pyparsing==3.1.1
49
+ python-dateutil==2.8.2
50
+ python-multipart==0.0.6
51
+ pytz==2023.3.post1
52
+ PyYAML==6.0.1
53
+ referencing==0.32.1
54
+ regex==2023.12.25
55
+ requests==2.31.0
56
+ rich==13.7.0
57
+ rpds-py==0.17.1
58
+ ruff==0.1.14
59
+ safetensors==0.4.2
60
+ semantic-version==2.10.0
61
+ shellingham==1.5.4
62
+ six==1.16.0
63
+ sniffio==1.3.0
64
+ starlette==0.35.1
65
+ sympy==1.12
66
+ tokenizers==0.15.1
67
+ tomlkit==0.12.0
68
+ toolz==0.12.1
69
+ torch==2.1.2+cu118
70
+ torchaudio==2.1.2+cu118
71
+ torchvision==0.16.2+cu118
72
+ tqdm==4.66.1
73
+ transformers==4.37.0
74
+ typer==0.9.0
75
+ typing_extensions==4.9.0
76
+ tzdata==2023.4
77
+ urllib3==2.1.0
78
+ uvicorn==0.27.0
79
+ websockets==11.0.3
80
+ zipp==3.17.0