BeveledCube commited on
Commit
3b3a783
1 Parent(s): 953f815

Trying sum

Browse files
Dockerfile.fastapi ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y git
6
+
7
+ COPY . /app
8
+
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+ RUN pip install --no-cache-dir uvicorn gunicorn fastapi pytest ruff pytest-asyncio httpx
11
+
12
+ EXPOSE 80
13
+
14
+ CMD ["uvicorn", "tld.app:app", "--host", "0.0.0.0", "--port", "80"]
Dockerfile.gradio ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y git
6
+
7
+ COPY . /app
8
+
9
+ RUN pip install --no-cache-dir gradio Pillow
10
+
11
+ EXPOSE 80
12
+
13
+ CMD ["python", "tld/gradio_app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Alexandru Papiu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
docker-compose.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+ services:
3
+ fastapi:
4
+ image: apapiu89/tld-app:latest
5
+ ports:
6
+ - "80:80"
7
+ environment:
8
+ - API_TOKEN=${API_TOKEN}
9
+
10
+ gradio:
11
+ image: apapiu89/gradio-app:latest
12
+ ports:
13
+ - "7860:7860"
14
+ environment:
15
+ - API_URL=http://fastapi:80
16
+ depends_on:
17
+ - fastapi
main.py DELETED
@@ -1,38 +0,0 @@
1
- import os
2
- from transformers import CLIPProcessor, CLIPModel
3
- import torch
4
- from PIL import Image
5
-
6
- # Get the directory of the script
7
- script_directory = os.path.dirname(os.path.realpath(__file__))
8
- # Specify the directory where the cache will be stored (same folder as the script)
9
- cache_directory = os.path.join(script_directory, "cache")
10
- # Create the cache directory if it doesn't exist
11
- os.makedirs(cache_directory, exist_ok=True)
12
-
13
- # Load the CLIP processor and model
14
- clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=cache_directory)
15
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=cache_directory)
16
-
17
- # Text description to generate image
18
- text = "a cat sitting on a table"
19
-
20
- # Tokenize text and get features
21
- inputs = clip_processor(text, return_tensors="pt", padding=True)
22
-
23
- # Generate image from text
24
- generated_image = clip_model.generate_images(
25
- input_ids=inputs.input_ids,
26
- attention_mask=inputs.attention_mask,
27
- visual_input=None, # We don't provide image input
28
- return_tensors="pt" # Return PyTorch tensor
29
- )
30
-
31
- # Convert the generated image tensor to a NumPy array
32
- generated_image_np = generated_image[0].cpu().numpy()
33
-
34
- # Save the generated image
35
- output_image_path = "generated_image.png"
36
- Image.fromarray(generated_image_np).save(output_image_path)
37
-
38
- print("Image generated and saved as:", output_image_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mainHistory.py DELETED
@@ -1,46 +0,0 @@
1
- from fastapi.staticfiles import StaticFiles
2
- from fastapi.responses import FileResponse
3
- from pydantic import BaseModel
4
- from fastapi import FastAPI
5
-
6
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
-
8
- model_name = "facebook/blenderbot-1B-distill"
9
-
10
- # https://huggingface.co/models?sort=trending&search=facebook%2Fblenderbot
11
- # facebook/blenderbot-3B
12
- # facebook/blenderbot-1B-distill
13
- # facebook/blenderbot-400M-distill
14
- # facebook/blenderbot-90M
15
- # facebook/blenderbot_small-90M
16
-
17
- # https://www.youtube.com/watch?v=irjYqV6EebU
18
-
19
- app = FastAPI()
20
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
21
- tokenizer = AutoTokenizer.from_pretrained(model_name)
22
-
23
- class req(BaseModel):
24
- prompt: str
25
-
26
- @app.get("/")
27
- def read_root():
28
- return FileResponse(path="templates/index.html", media_type="text/html")
29
-
30
- @app.post("/api")
31
- def read_root(data: req):
32
- print("Prompt:", data.prompt)
33
-
34
- input_text = data.prompt
35
-
36
- # Tokenize the input text
37
- input_ids = tokenizer.encode(input_text, return_tensors="pt")
38
-
39
- # Generate output using the model
40
- output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
41
- generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
42
-
43
- answer_data = { "answer": generated_text }
44
- print("Answer:", generated_text)
45
-
46
- return answer_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models.md DELETED
@@ -1,7 +0,0 @@
1
- # Model list
2
- * [microsoft/DialoGPT-small](https://huggingface.co/microsoft/DialoGPT-small)
3
- * [microsoft/DialoGPT-medium](https://huggingface.co/microsoft/DialoGPT-medium)
4
- * [microsoft/DialoGPT-large](https://huggingface.co/microsoft/DialoGPT-large)
5
-
6
- # Download locations
7
- * Github Codespaces: /home/codespace/.local/lib/python3.10/site-packages/transformers/models/
 
 
 
 
 
 
 
 
og readme.md ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Transformer Latent Diffusion
2
+ Text to Image Latent Diffusion using a Transformer core in PyTorch.
3
+
4
+ [Original Github](https://github.com/apapiu/transformer_latent_diffusion)
5
+
6
+ **Try with own inputs**: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1VaCe01YG9rnPwAfwVLBKdXEX7D_tk1U5?usp=sharing)
7
+
8
+ Below are some random examples (at 256 resolution) from a 100MM model trained from scratch for 260k iterations (about 32 hours on 1 A100):
9
+
10
+ <img width="760" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/e01e3094-2487-4c04-bc0f-d9b03eeaed00">
11
+
12
+ #### Clip interpolation Examples:
13
+
14
+ a photo of a cat → an anime drawing of a super saiyan cat, artstation:
15
+
16
+ <img width="1361" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/a079458b-9bd5-4557-aa7a-5a3e78f31b53">
17
+
18
+ a cute great gray owl → starry night by van gogh:
19
+
20
+ <img width="1399" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/8731d87a-89fa-43a2-847d-c7ff772de286">
21
+
22
+ Note that the model has not converged yet and could use more training.
23
+
24
+ #### High(er) Resolution:
25
+ By upsampling the positional encoding the model can also generate 512 or 1024 px images with minimal fine-tuning. See below for some examples of model fine-tuned on 100k extra 512 px images and 30k 1024 px images for about 2 hours on an A100. The images do sometimes lack global coherence at 1024 px - more to come here:
26
+
27
+ <img width="600" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/adba64f0-b43c-423e-9a7d-033a4afea207">
28
+ <img width="600" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/5a94515b-313e-420d-89d4-6bdc376d9a00">
29
+
30
+
31
+
32
+ ### Intro:
33
+
34
+ The main goal of this project is to build an accessible diffusion model in PyTorch that is:
35
+ - fast (close to real time generation)
36
+ - small (~100MM params)
37
+ - reasonably good (of course not SOTA)
38
+ - can be trained in a reasonable amount of time on a single GPU (under 50 hours on an A100 or equivalent).
39
+ - simple self-contained codebase (model + train loop is about ~400 lines of PyTorch with little dependencies)
40
+ - uses ~ 1 million images with a focus on data quality over quantity
41
+
42
+ This is part II of a previous [project](https://github.com/apapiu/guided-diffusion-keras) I did where I trained a pixel level diffusion model in Keras. Even though this model outputs 4x higher resolution images (256px vs 64px), it's actually faster to both train and sample from, which shows the power of training in the latent space and speed of transformer architectures.
43
+
44
+ ## Table of Contents:
45
+ - [Codebase](#codebase)
46
+ - [Usage](#usage)
47
+ - [Examples](#examples)
48
+ - [Data Processing](#data-processing)
49
+ - [Architecture](#architecture)
50
+ - [TO-DOs](#todos)
51
+
52
+
53
+ ## Codebase:
54
+ The code is written in pure PyTorch with as few dependencies as possible.
55
+
56
+ - [transformer_blocks.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/transformer_blocks.py) - basic transformer building blocks relevant to the transformer denoiser
57
+ - [denoiser.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/denoiser.py) - the architecture of the denoiser transformer
58
+ - [train.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/train.py). The train loop uses `accelerate` so its training can scale to multiple GPUs if needed.
59
+ - [diffusion.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/diffusion.py). Class to generate image from noise using reverse diffusion. Short (~60 lines) and self-contained.
60
+ - [data.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/data.py). Data utils to download images/text and process necessary features for the diffusion model.
61
+
62
+ ### Usage:
63
+ If you have your own dataset of URLs + captions, the process to train a model on the data consists of two steps:
64
+
65
+ 1. Use `train.download_and_process_data` to obtain the latent and text encodings as numpy files. See [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1BPDFDBdsP9SSKBNEFJysmlBjfoxKK13r?usp=sharing) for a notebook example downloading and processing 2000 images from this HuggingFace [dataset](https://huggingface.co/datasets/zzliang/GRIT).
66
+
67
+ 2. use the `train.main` function in an accelerate `notebook_launcher` - see [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1sKk0usxEF4bmdCDcNQJQNMt4l9qBOeAM?usp=sharing) for a colab notebook that trains a model on 100k images from scratch. Note that this downloads already pre-preprocessed latents and embeddings from [here](https://huggingface.co/apapiu/small_ldt/tree/main) but you could just use whatever `.npy` files you had saved from step 1.
68
+
69
+ #### Fine-Tuning - TODO but it is the same as step 2 above except you train on a pre-trained model.
70
+
71
+ ```python
72
+ !wandb login
73
+ import os
74
+ from tld.train import main, DataConfig, ModelConfig
75
+ from accelerate import notebook_launcher
76
+
77
+ data_config = DataConfig(latent_path='path/to/image_latents.npy',
78
+ text_emb_path='path/to/text_encodings.npy',
79
+ val_path='path/to/val_encodings.npy')
80
+
81
+ model_config = ModelConfig(embed_dim=512, n_layers=6) #see ModelConfig for more params
82
+
83
+ #run the training process on 2 GPUs:
84
+ notebook_launcher(main, (model_config, data_config), num_processes=2)
85
+ ```
86
+
87
+ ### Dependencies:
88
+ - `PyTorch` `numpy` `einops` for model building
89
+ - `wandb` `tqdm` for logging + progress bars
90
+ - `accelerate` for train loop and multi-GPU support
91
+ - `img2dataset` `webdataset` `torchvision` for data downloading and image processing
92
+ - `diffusers` `clip` for pretrained VAE and CLIP text model
93
+
94
+ ### Codebases used for inspiration:
95
+ - [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha)
96
+ - [k-diffusion](https://github.com/crowsonkb/k-diffusion)
97
+ - [nanoGPT](https://github.com/karpathy/nanoGPT/tree/master)
98
+ - [LocalViT](https://github.com/ofsoundof/LocalViT)
99
+
100
+ #### Speed:
101
+
102
+ I try to speed up training and inference as much as possible by:
103
+ - using mixed precision for training + [sdpa]
104
+ - precompute all latent and text embeddings
105
+ - using float16 precision for inference
106
+ - using [sdpa] for the attention natively + torch.compile() (compile doesn't always work).
107
+ - use a highly performant sampler (DPM-Solver++(2M)) that gets good results in ~ 15 steps.
108
+ - TODO: would distillation or something like LCM work here?
109
+ - TODO: use flash-attention2?
110
+ - TODO: use smaller vae?
111
+
112
+ The time to generate a batch of 36 images (15 iterations) on a:
113
+ - T4: ~ 3.5 seconds
114
+ - A100: ~ 0.6 seconds
115
+ In fact on an A100 the vae becomes the bottleneck even though it is only used once.
116
+
117
+
118
+ ## Examples:
119
+
120
+ More examples generated with the 100MM model - click the photo to see the prompt and other params like cfg and seed:
121
+ ![image](tld/img_examples/a%20cute%20grey%20great%20owl_cfg_8_seed_11.png)
122
+ ![image](tld/img_examples/watercolor%20of%20a%20cute%20cat%20riding%20a%20motorcycle_cfg_7_seed_11.png)
123
+ ![image](tld/img_examples/painting%20of%20a%20cyberpunk%20market_cfg_7_seed_11.png)
124
+ ![image](tld/img_examples/isometric%20view%20of%20small%20japanese%20village%20with%20blooming%20trees_cfg_7_seed_11.png)
125
+ ![image](tld/img_examples/a%20beautiful%20woman%20with%20blonde%20hair%20in%20her%2050s_cfg_7_seed_11.png)
126
+ ![image](tld/img_examples/painting%20of%20a%20cute%20fox%20in%20a%20suit%20in%20a%20field%20of%20poppies_cfg_8_seed_11.png)
127
+ ![image](tld/img_examples/an%20aerial%20view%20of%20manhattan%2C%20isometric%20view%2C%20as%20pantinted%20by%20mondrian_cfg_7_seed_11.png)
128
+
129
+ ## Outpainting model:
130
+
131
+ I also fine-tuned an outpaing model on top of the original 101MM model. I had to modify the original input conv2d patch to 8 channel and initialize the mask channels parameters to zero. The rest of the architecture remained the same.
132
+
133
+ Below I apply the outpainting model repatedly to generate a somewhat consistent scenery based on the prompt "a cyberpunk marketplace":
134
+
135
+ <img width="1440" alt="image" src="https://github.com/apapiu/transformer_latent_diffusion/assets/13619417/4451719f-d45a-4a86-a7bb-06c021b34996">
136
+
137
+ ## Data Processing:
138
+
139
+ In [data.py](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/data.py), I have some helper functions to process images and captions. The flow is as follows:
140
+ - Use `img2dataset` to download images from a dataframe containing URLs and captions.
141
+ - Use `CLIP` to encode the prompts and the `VAE` to encode images to latents on a web2dataset data generator.
142
+ - Save the latents and text embedding for future training.
143
+
144
+ There are two advantages to this approach. One is that the VAE encoding is somewhat expensive, so doing it every epoch would affect training times. The other is that we can discard the images after processing. For `3*256*256` images, the latent dimension is `4*32*32`, so every latent is around 4KB (when quantized in uint8; see [here](https://pub.towardsai.net/stable-diffusion-based-image-compresssion-6f1f0a399202?gi=1f45c6522d3b)). This means that 1 million latents will be "only" 4GB in size, which is easy to handle even in RAM. Storing the raw images would have been 48x larger in size.
145
+
146
+ ## Architecture:
147
+
148
+ See [here](https://github.com/apapiu/transformer_latent_diffusion/blob/main/tld/denoiser.py) for the denoiser class.
149
+
150
+ The denoiser model is a Transformer-based model based on the archirtecture in [DiT](https://arxiv.org/abs/2203.02378) and [Pixart-Alpha](https://pixart-alpha.github.io/), albeit with quite a few modifications and simplifications. Using a Transformer as the denoiser is different from most diffusion models in that most other models used a CNN-based U-NET as the denoising backbone. I decided to use a Transformer for a few reasons. One was I just wanted to experiment and learn how to build and train Transformers from the ground up. Secondly, Transformers are fast both to train and to do inference on, and they will benefit most from future advances (both in hardware and in software) in performance.
151
+
152
+ Transformers are not natively built for spatial data and at first I found a lot of the outputs to be very "patchy". To remediy that I added a depth-wise convolution in the FFN layer of the transformer (this was introduced in the [Local ViT](https://arxiv.org/abs/2104.05707) paper. This allows the model to mix pixels that are close to each other with very little added compute cost.
153
+
154
+
155
+ ### Img+Text+Noise Encoding:
156
+
157
+ The image latent inputs are `4*32*32` and we use a patch size of 2 to build 256 flattened `4*2*2=16` dimensional input "pixels". These are then projected into the embed dimensions are are fed through the transformer blocks.
158
+
159
+ The text and noise conditioning is very simple - we concatenate a pooled CLIP text embedding (`ViT/L14` - 768-dimensional) and the sinusoidal noise embedding and feed it as input in the cross-attention layer in each transformer block. No unpooled CLIP embeddings are used.
160
+
161
+ ### Training:
162
+ The base model is 101MM parameters and has 12 layers and embedding dimension = 768. I train it with a batch size of 256 on a A100 and learning rate of `3e-4`. I used 1000 steps for warmup. Due to computational contraints I did not do any ablations for this configuration.
163
+
164
+
165
+ ## Train and Diffusion Setup:
166
+
167
+ We train a denoising transformer that takes the following three inputs:
168
+ - `noise_level` (sampled from 0 to 1 with more values concentrated close to 0 - I use a beta distribution)
169
+ - Image latent (x) corrupted with a level of random noise
170
+ - For a given `noise_level` between 0 and 1, the corruption is as follows:
171
+ - `x_noisy = x*(1-noise_level) + eps*noise_level where eps ~ np.random.normal(0, 1)`
172
+ - CLIP embeddings of a text prompt
173
+ - You can think of this as a numerical representation of a text prompt.
174
+ - We use the pooled text embedding here (768 dimensional for `ViT/L14`)
175
+
176
+ The output is a prediction of the denoised image latent - call it `f(x_noisy)`.
177
+
178
+ The model is trained to minimize the mean squared error `|f(x_noisy) - x|` between the prediction and actual image
179
+ (you can also use absolute error here). Note that I don't reparameterize the loss in terms of the noise here to keep things simple.
180
+
181
+ Using this model, we then iteratively generate an image from random noise as follows:
182
+
183
+ for i in range(len(self.noise_levels) - 1):
184
+
185
+ curr_noise, next_noise = self.noise_levels[i], self.noise_levels[i + 1]
186
+
187
+ # Predict original denoised image:
188
+ x0_pred = predict_x_zero(new_img, label, curr_noise)
189
+
190
+ # New image at next_noise level is a weighted average of old image and predicted x0:
191
+ new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise
192
+
193
+ The `predict_x_zero` method uses classifier free guidance by combining the conditional and unconditional
194
+ prediction: `x0_pred = class_guidance * x0_pred_conditional + (1 - class_guidance) * x0_pred_unconditional`
195
+
196
+ A bit of math: The approach above falls within the VDM parametrization see 3.1 in [Kingma et al.](https://arxiv.org/pdf/2107.00630.pdf):
197
+
198
+ $$z_t = \alpha_t x + \sigma_t \epsilon, \epsilon \sim \mathcal{N}(0,1)$$
199
+
200
+ Where $z_t$ is the noisy version of $x$ at time $t$.
201
+
202
+ Generally, $\alpha_t$ is chosen to be $\sqrt{1-\sigma_t^2}$ so that the process is variance preserving. Here, I chose $\alpha_t=1-\sigma_t$ so that we linearly interpolate between the image and random noise. Why? For one, it simplifies the updating equation quite a bit, and it's easier to understand what the noise to signal ratio will look like. I also found that the model produces sharper images faster - more validation here is needed. The updating equation above is the DDIM model for this parametrization, which simplifies to a simple weighted average. Note that the DDIM model deterministically maps random normal noise to images - this has two benefits: we can interpolate in the random normal latent space, and it generally takes fewer steps to achieve decent image quality.
203
+
204
+ ## TODOS:
205
+ - better config in the train file
206
+ - how to speed up generation even more - LCMs or other sampling strategies?
207
+ - add script to compute FID
208
+
209
+
210
+
211
+
pyproject.toml DELETED
@@ -1,17 +0,0 @@
1
- [tool.poetry]
2
- name = "img gen"
3
- version = "0.0.1"
4
- description = "A project to test image generation with AI models"
5
- authors = ["CubeBeveled <[email protected]>"]
6
- readme = "README.md"
7
-
8
- [tool.poetry.dependencies]
9
- python = "^3.11"
10
- transformers = "^4.39.1"
11
- torch = "^2.2.1"
12
- pillow = "^10.2.0"
13
-
14
-
15
- [build-system]
16
- requires = ["poetry-core"]
17
- build-backend = "poetry.core.masonry.api"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,10 @@
1
- transformers
2
  torch
3
- poetry
4
- pillow
 
 
 
 
 
 
 
 
 
1
  torch
2
+ numpy
3
+ einops
4
+ torchvision
5
+ tqdm
6
+ diffusers
7
+ accelerate
8
+ transformers
9
+ Pillow
10
+ git+https://github.com/openai/CLIP.git
setup.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+
4
+ def load_requirements(filename="requirements.txt"):
5
+ with open(filename, "r") as file:
6
+ lines = [line.strip() for line in file.readlines() if line.strip() and not line.startswith("#")]
7
+ return lines
8
+
9
+
10
+ setup(
11
+ name="tld",
12
+ version="0.1.0",
13
+ author="Alexandru Papiu",
14
+ author_email="[email protected]",
15
+ description="Transformer Latent Diffusion",
16
+ url="https://github.com/apapiu/transformer_latent_diffusion",
17
+ packages=find_packages(exclude=["tests*"]),
18
+ classifiers=[
19
+ "Programming Language :: Python :: 3",
20
+ "License :: OSI Approved :: MIT License",
21
+ "Operating System :: OS Independent",
22
+ ],
23
+ python_requires=">=3.6",
24
+ install_requires=[
25
+ "torch",
26
+ "numpy",
27
+ "einops",
28
+ "torchvision",
29
+ "tqdm",
30
+ "diffusers",
31
+ "accelerate",
32
+ "transformers",
33
+ "Pillow",
34
+ "clip @ git+https://github.com/openai/CLIP.git",
35
+ ],
36
+ )
start.sh DELETED
@@ -1,5 +0,0 @@
1
- pip install --upgrade pip
2
- pip install -r requirements.txt
3
- poetry install --no-root
4
-
5
- python main.py
 
 
 
 
 
 
tests/__init__.py ADDED
File without changes
tests/test_api.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from fastapi.testclient import TestClient
4
+ from tld.app import app
5
+ import PIL
6
+ from PIL import Image
7
+ from io import BytesIO
8
+
9
+ client = TestClient(app)
10
+
11
+ def test_read_main():
12
+ response = client.get("/")
13
+ assert response.status_code == 200
14
+ assert response.json() == {"message": "Welcome to Image Generator"}
15
+
16
+
17
+ def test_generate_image_unauthorized():
18
+ response = client.post("/generate-image/", json={})
19
+ assert response.status_code == 401
20
+ assert response.json() == {"detail": "Not authenticated"}
21
+
22
+
23
+ def test_generate_image_authorized():
24
+ api_token = os.getenv("API_TOKEN")
25
+ response = client.post(
26
+ "/generate-image/", json={"prompt": "a cute cat"}, headers={"Authorization": f"Bearer {api_token}"}
27
+ )
28
+ assert response.status_code == 200
29
+
30
+ image = Image.open(BytesIO(response.content))
31
+ assert type(image) == PIL.JpegImagePlugin.JpegImageFile
tests/test_diffuser.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
5
+ import time
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torchvision.transforms as transforms
10
+ import torchvision.utils as vutils
11
+ from diffusers import AutoencoderKL
12
+
13
+ from tld.denoiser import Denoiser
14
+ from tld.diffusion import DiffusionGenerator, DiffusionTransformer, LTDConfig
15
+ from PIL.Image import Image
16
+
17
+ to_pil = transforms.ToPILImage()
18
+
19
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
+
21
+
22
+ def test_outputs(num_imgs=4):
23
+ model = Denoiser(
24
+ image_size=32, noise_embed_dims=128, patch_size=2, embed_dim=768, dropout=0.1, n_layers=12
25
+ )
26
+ x = torch.rand(num_imgs, 4, 32, 32)
27
+ noise_level = torch.rand(num_imgs, 1)
28
+ label = torch.rand(num_imgs, 768)
29
+
30
+ print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")
31
+
32
+ with torch.no_grad():
33
+ start_time = time.time()
34
+ output = model(x, noise_level, label)
35
+ end_time = time.time()
36
+
37
+ execution_time = end_time - start_time
38
+ print(f"Model execution took {execution_time:.4f} seconds.")
39
+
40
+ assert output.shape == torch.Size([num_imgs, 4, 32, 32])
41
+ print("Basic tests passed.")
42
+
43
+ # model = Denoiser(image_size=16, noise_embed_dims=128, patch_size=2, embed_dim=256, dropout=0.1, n_layers=6)
44
+ # x = torch.rand(8, 4, 32, 32)
45
+ # noise_level = torch.rand(8, 1)
46
+ # label = torch.rand(8, 768)
47
+
48
+ # with torch.no_grad():
49
+ # output = model(x, noise_level, label)
50
+
51
+ # assert output.shape == torch.Size([8, 4, 32, 32])
52
+ # print("Uspscale tests passed.")
53
+
54
+
55
+ def test_diffusion_generator():
56
+ model_dtype = torch.float32 ##float 16 will not work on cpu
57
+ num_imgs = 1
58
+ nrow = int(np.sqrt(num_imgs))
59
+
60
+ denoiser = Denoiser(
61
+ image_size=32, noise_embed_dims=128, patch_size=2, embed_dim=256, dropout=0.1, n_layers=3
62
+ )
63
+ print(f"Model has {sum(p.numel() for p in denoiser.parameters())} parameters")
64
+
65
+ denoiser.to(model_dtype)
66
+
67
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=model_dtype).to(device)
68
+
69
+ labels = torch.rand(num_imgs, 768)
70
+
71
+ diffuser = DiffusionGenerator(denoiser, vae, device, model_dtype)
72
+
73
+ out, _ = diffuser.generate(
74
+ labels=labels,
75
+ num_imgs=num_imgs,
76
+ class_guidance=3,
77
+ seed=1,
78
+ n_iter=5,
79
+ exponent=1,
80
+ scale_factor=8,
81
+ sharp_f=0,
82
+ bright_f=0,
83
+ )
84
+
85
+ out = to_pil((vutils.make_grid((out + 1) / 2, nrow=nrow, padding=4)).float().clip(0, 1))
86
+ out.save("test.png")
87
+ print("Images generated at test.png")
88
+
89
+
90
+ def test_full_generation_pipeline():
91
+ ltdconfig = LTDConfig()
92
+ diffusion_transformer = DiffusionTransformer(ltdconfig)
93
+
94
+ out = diffusion_transformer.generate_image_from_text(prompt="a cute cat")
95
+ print(out)
96
+ assert type(out) == Image
97
+
98
+
99
+ # TODO: should add tests for train loop and data processing
tld/__init__.py ADDED
File without changes
tld/app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ from fastapi import Depends, FastAPI, HTTPException, status
8
+ from fastapi.responses import StreamingResponse
9
+ from fastapi.security import OAuth2PasswordBearer
10
+ from pydantic import BaseModel
11
+
12
+ from tld.diffusion import DiffusionTransformer, LTDConfig
13
+
14
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15
+ to_pil = transforms.ToPILImage()
16
+
17
+ ltdconfig = LTDConfig()
18
+ diffusion_transformer = DiffusionTransformer(ltdconfig)
19
+
20
+ app = FastAPI()
21
+
22
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
23
+
24
+
25
+ def validate_token(token: str = Depends(oauth2_scheme)):
26
+ if token != os.getenv("API_TOKEN"):
27
+ raise HTTPException(
28
+ status_code=status.HTTP_401_UNAUTHORIZED,
29
+ detail="Invalid authentication credentials",
30
+ headers={"WWW-Authenticate": "Bearer"},
31
+ )
32
+
33
+
34
+ class ImageRequest(BaseModel):
35
+ prompt: str
36
+ class_guidance: Optional[int] = 6
37
+ seed: Optional[int] = 11
38
+ num_imgs: Optional[int] = 1
39
+ img_size: Optional[int] = 32
40
+
41
+
42
+ @app.get("/")
43
+ def read_root():
44
+ return {"message": "Welcome to Image Generator"}
45
+
46
+
47
+ @app.post("/generate-image/")
48
+ async def generate_image(request: ImageRequest, token: str = Depends(validate_token)):
49
+ try:
50
+ img = diffusion_transformer.generate_image_from_text(
51
+ prompt=request.prompt,
52
+ class_guidance=request.class_guidance,
53
+ seed=request.seed,
54
+ num_imgs=request.num_imgs,
55
+ img_size=request.img_size,
56
+ )
57
+ # Convert PIL image to byte stream suitable for HTTP response
58
+ img_byte_arr = io.BytesIO()
59
+ img.save(img_byte_arr, format="JPEG")
60
+ img_byte_arr.seek(0)
61
+
62
+ return StreamingResponse(img_byte_arr, media_type="image/jpeg")
63
+ except Exception as e:
64
+ raise HTTPException(status_code=500, detail=str(e))
65
+
66
+
67
+ # build job to test and deploy the API on a docker image (maybe in Azure?)
tld/data.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ####data util to get and preprocess data from a text and image pair to latents and text embeddings.
2
+ ### all that is required is a csv file with an image url and text caption:
3
+ #!pip install datasets img2dataset accelerate diffusers
4
+ #!pip install git+https://github.com/openai/CLIP.git
5
+
6
+ import json
7
+ import os
8
+ from dataclasses import dataclass
9
+ from typing import List, Union
10
+
11
+ import clip
12
+ import h5py
13
+ import numpy as np
14
+ import pandas as pd
15
+ import torch
16
+ import torchvision.transforms as transforms
17
+ import webdataset as wds
18
+ from diffusers import AutoencoderKL
19
+ from img2dataset import download
20
+ from torch import Tensor, nn
21
+ from torch.utils.data import DataLoader
22
+ from tqdm import tqdm
23
+
24
+
25
+ @torch.no_grad()
26
+ def encode_text(label: Union[str, List[str]], model: nn.Module, device: str) -> Tensor:
27
+ text_tokens = clip.tokenize(label, truncate=True).to(device)
28
+ text_encoding = model.encode_text(text_tokens)
29
+ return text_encoding.cpu()
30
+
31
+
32
+ @torch.no_grad()
33
+ def encode_image(img: Tensor, vae: AutoencoderKL) -> Tensor:
34
+ x = img.to("cuda").to(torch.float16)
35
+
36
+ x = x * 2 - 1 # to make it between -1 and 1.
37
+ encoded = vae.encode(x, return_dict=False)[0].sample()
38
+ return encoded.cpu()
39
+
40
+
41
+ @torch.no_grad()
42
+ def decode_latents(out_latents: torch.FloatTensor, vae: AutoencoderKL) -> Tensor:
43
+ # expected to be in the unscaled latent space
44
+ out = vae.decode(out_latents.cuda())[0].cpu()
45
+
46
+ return ((out + 1) / 2).clip(0, 1)
47
+
48
+
49
+ def quantize_latents(lat: Tensor, clip_val: float = 20) -> Tensor:
50
+ """scale and quantize latents to unit8"""
51
+ lat_norm = lat.clip(-clip_val, clip_val) / clip_val
52
+ return (((lat_norm + 1) / 2) * 255).to(torch.uint8)
53
+
54
+
55
+ def dequantize_latents(lat: Tensor, clip_val: float = 20) -> Tensor:
56
+ lat_norm = (lat.to(torch.float16) / 255) * 2 - 1
57
+ return lat_norm * clip_val
58
+
59
+
60
+ def append_to_dataset(dataset: h5py.File, new_data: Tensor) -> None:
61
+ """Appends new data to an HDF5 dataset."""
62
+ new_size = dataset.shape[0] + new_data.shape[0]
63
+ dataset.resize(new_size, axis=0)
64
+ dataset[-new_data.shape[0] :] = new_data
65
+
66
+
67
+ def get_text_and_latent_embeddings_hdf5(
68
+ dataloader: DataLoader, vae: AutoencoderKL, model: nn.Module, drive_save_path: str
69
+ ) -> None:
70
+ """Process img/text inptus that outputs an latent and text embeddings and text_prompts, saving encodings as float16."""
71
+
72
+ img_latent_path = os.path.join(drive_save_path, "image_latents.hdf5")
73
+ text_embed_path = os.path.join(drive_save_path, "text_encodings.hdf5")
74
+ metadata_csv_path = os.path.join(drive_save_path, "metadata.csv")
75
+
76
+ with h5py.File(img_latent_path, "a") as img_file, h5py.File(text_embed_path, "a") as text_file:
77
+ if "image_latents" not in img_file:
78
+ img_ds = img_file.create_dataset(
79
+ "image_latents",
80
+ shape=(0, 4, 32, 32),
81
+ maxshape=(None, 4, 32, 32),
82
+ dtype="float16",
83
+ chunks=True,
84
+ )
85
+ else:
86
+ img_ds = img_file["image_latents"]
87
+
88
+ if "text_encodings" not in text_file:
89
+ text_ds = text_file.create_dataset(
90
+ "text_encodings", shape=(0, 768), maxshape=(None, 768), dtype="float16", chunks=True
91
+ )
92
+ else:
93
+ text_ds = text_file["text_encodings"]
94
+
95
+ for img, (label, url) in tqdm(dataloader):
96
+ text_encoding = encode_text(label, model).cpu().numpy().astype(np.float16)
97
+ img_encoding = encode_image(img, vae).cpu().numpy().astype(np.float16)
98
+
99
+ append_to_dataset(img_ds, img_encoding)
100
+ append_to_dataset(text_ds, text_encoding)
101
+
102
+ metadata_df = pd.DataFrame({"text": label, "url": url})
103
+ if os.path.exists(metadata_csv_path):
104
+ metadata_df.to_csv(metadata_csv_path, mode="a", header=False, index=False)
105
+ else:
106
+ metadata_df.to_csv(metadata_csv_path, mode="w", header=True, index=False)
107
+
108
+
109
+ def download_and_process_data(
110
+ latent_save_path="latents",
111
+ raw_imgs_save_path="raw_imgs",
112
+ csv_path="imgs.csv",
113
+ image_size=256,
114
+ bs=64,
115
+ caption_col="captions",
116
+ url_col="url",
117
+ download_data=True,
118
+ number_sample_per_shard=10000,
119
+ ):
120
+ if not os.path.exists(raw_imgs_save_path):
121
+ os.mkdir(raw_imgs_save_path)
122
+
123
+ if not os.path.exists(latent_save_path):
124
+ os.mkdir(latent_save_path)
125
+
126
+ if download_data:
127
+ download(
128
+ processes_count=8,
129
+ thread_count=64,
130
+ url_list=csv_path,
131
+ image_size=image_size,
132
+ output_folder=raw_imgs_save_path,
133
+ output_format="webdataset",
134
+ input_format="csv",
135
+ url_col=url_col,
136
+ caption_col=caption_col,
137
+ enable_wandb=False,
138
+ number_sample_per_shard=number_sample_per_shard,
139
+ distributor="multiprocessing",
140
+ resize_mode="center_crop",
141
+ )
142
+
143
+ files = os.listdir(raw_imgs_save_path)
144
+ tar_files = [os.path.join(raw_imgs_save_path, file) for file in files if file.endswith(".tar")]
145
+ print(tar_files)
146
+ dataset = wds.WebDataset(tar_files)
147
+
148
+ transform = transforms.Compose(
149
+ [
150
+ transforms.ToTensor(),
151
+ ]
152
+ )
153
+
154
+ # output is (img_tensor, (caption , url_col)) per batch:
155
+ dataset = (
156
+ dataset.decode("pil")
157
+ .to_tuple("jpg;png", "json")
158
+ .map_tuple(transform, lambda x: (x["caption"], x[url_col]))
159
+ )
160
+
161
+ dataloader = DataLoader(dataset, batch_size=bs, shuffle=False)
162
+
163
+ model, _ = clip.load("ViT-L/14")
164
+
165
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
166
+ vae = vae.to("cuda")
167
+ model.to("cuda")
168
+
169
+ print("Starting to encode latents and text:")
170
+ get_text_and_latent_embeddings_hdf5(dataloader, vae, model, latent_save_path)
171
+ print("Finished encode latents and text:")
172
+
173
+
174
+ @dataclass
175
+ class DataConfiguration:
176
+ data_link: str
177
+ caption_col: str = "caption"
178
+ url_col: str = "url"
179
+ latent_save_path: str = "latents_folder"
180
+ raw_imgs_save_path: str = "raw_imgs_folder"
181
+ use_drive: bool = False
182
+ initial_csv_path: str = "imgs.csv"
183
+ number_sample_per_shard: int = 10000
184
+ image_size: int = 256
185
+ batch_size: int = 64
186
+ download_data: bool = True
187
+
188
+
189
+ if __name__ == "__main__":
190
+ use_wandb = False
191
+
192
+ if use_wandb:
193
+ import wandb
194
+
195
+ os.environ["WANDB_API_KEY"] = "key"
196
+ #!wandb login
197
+
198
+ data_link = "https://huggingface.co/datasets/zzliang/GRIT/resolve/main/grit-20m/coyo_0_snappy.parquet?download=true"
199
+
200
+ data_config = DataConfiguration(
201
+ data_link=data_link,
202
+ latent_save_path="latent_folder",
203
+ raw_imgs_save_path="raw_imgs_folder",
204
+ download_data=False,
205
+ number_sample_per_shard=1000,
206
+ )
207
+
208
+ if use_wandb:
209
+ wandb.init(project="image_vae_processing", entity="apapiu", config=data_config)
210
+
211
+ if not os.path.exists(data_config.latent_save_path):
212
+ os.mkdir(data_config.latent_save_path)
213
+
214
+ config_file_path = os.path.join(data_config.latent_save_path, "config.json")
215
+ with open(config_file_path, "w") as f:
216
+ json.dump(data_config.__dict__, f)
217
+
218
+ print("Config saved to:", config_file_path)
219
+
220
+ df = pd.read_parquet(data_link)
221
+ ###add additional data cleaning here...should I
222
+ df = df.iloc[:3000]
223
+ df[["key", "url", "caption"]].to_csv("imgs.csv", index=None)
224
+
225
+ if data_config.use_drive:
226
+ from google.colab import drive
227
+
228
+ drive.mount("/content/drive")
229
+
230
+ download_and_process_data(
231
+ latent_save_path=data_config.latent_save_path,
232
+ raw_imgs_save_path=data_config.raw_imgs_save_path,
233
+ csv_path=data_config.initial_csv_path,
234
+ image_size=data_config.image_size,
235
+ bs=data_config.batch_size,
236
+ caption_col=data_config.caption_col,
237
+ url_col=data_config.url_col,
238
+ download_data=data_config.download_data,
239
+ number_sample_per_shard=data_config.number_sample_per_shard,
240
+ )
241
+
242
+ if use_wandb:
243
+ wandb.finish()
tld/denoiser.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """transformer based denoiser"""
2
+
3
+ import torch
4
+ from einops.layers.torch import Rearrange
5
+ from torch import nn
6
+
7
+ from tld.transformer_blocks import DecoderBlock, MLPSepConv, SinusoidalEmbedding
8
+
9
+
10
+ class DenoiserTransBlock(nn.Module):
11
+ def __init__(
12
+ self,
13
+ patch_size: int,
14
+ img_size: int,
15
+ embed_dim: int,
16
+ dropout: float,
17
+ n_layers: int,
18
+ mlp_multiplier: int = 4,
19
+ n_channels: int = 4,
20
+ ):
21
+ super().__init__()
22
+
23
+ self.patch_size = patch_size
24
+ self.img_size = img_size
25
+ self.n_channels = n_channels
26
+ self.embed_dim = embed_dim
27
+ self.dropout = dropout
28
+ self.n_layers = n_layers
29
+ self.mlp_multiplier = mlp_multiplier
30
+
31
+ seq_len = int((self.img_size / self.patch_size) * (self.img_size / self.patch_size))
32
+ patch_dim = self.n_channels * self.patch_size * self.patch_size
33
+
34
+ self.patchify_and_embed = nn.Sequential(
35
+ nn.Conv2d(
36
+ self.n_channels,
37
+ patch_dim,
38
+ kernel_size=self.patch_size,
39
+ stride=self.patch_size,
40
+ ),
41
+ Rearrange("bs d h w -> bs (h w) d"),
42
+ nn.LayerNorm(patch_dim),
43
+ nn.Linear(patch_dim, self.embed_dim),
44
+ nn.LayerNorm(self.embed_dim),
45
+ )
46
+
47
+ self.rearrange2 = Rearrange(
48
+ "b (h w) (c p1 p2) -> b c (h p1) (w p2)",
49
+ h=int(self.img_size / self.patch_size),
50
+ p1=self.patch_size,
51
+ p2=self.patch_size,
52
+ )
53
+
54
+ self.pos_embed = nn.Embedding(seq_len, self.embed_dim)
55
+ self.register_buffer("precomputed_pos_enc", torch.arange(0, seq_len).long())
56
+
57
+ self.decoder_blocks = nn.ModuleList(
58
+ [
59
+ DecoderBlock(
60
+ embed_dim=self.embed_dim,
61
+ mlp_multiplier=self.mlp_multiplier,
62
+ # note that this is a non-causal block since we are
63
+ # denoising the entire image no need for masking
64
+ is_causal=False,
65
+ dropout_level=self.dropout,
66
+ mlp_class=MLPSepConv,
67
+ )
68
+ for _ in range(self.n_layers)
69
+ ]
70
+ )
71
+
72
+ self.out_proj = nn.Sequential(nn.Linear(self.embed_dim, patch_dim), self.rearrange2)
73
+
74
+ def forward(self, x, cond):
75
+ x = self.patchify_and_embed(x)
76
+ pos_enc = self.precomputed_pos_enc[: x.size(1)].expand(x.size(0), -1)
77
+ x = x + self.pos_embed(pos_enc)
78
+
79
+ for block in self.decoder_blocks:
80
+ x = block(x, cond)
81
+
82
+ return self.out_proj(x)
83
+
84
+
85
+ class Denoiser(nn.Module):
86
+ def __init__(
87
+ self,
88
+ image_size: int,
89
+ noise_embed_dims: int,
90
+ patch_size: int,
91
+ embed_dim: int,
92
+ dropout: float,
93
+ n_layers: int,
94
+ text_emb_size: int = 768,
95
+ ):
96
+ super().__init__()
97
+
98
+ self.image_size = image_size
99
+ self.noise_embed_dims = noise_embed_dims
100
+ self.embed_dim = embed_dim
101
+
102
+ self.fourier_feats = nn.Sequential(
103
+ SinusoidalEmbedding(embedding_dims=noise_embed_dims),
104
+ nn.Linear(noise_embed_dims, self.embed_dim),
105
+ nn.GELU(),
106
+ nn.Linear(self.embed_dim, self.embed_dim),
107
+ )
108
+
109
+ self.denoiser_trans_block = DenoiserTransBlock(patch_size, image_size, embed_dim, dropout, n_layers)
110
+ self.norm = nn.LayerNorm(self.embed_dim)
111
+ self.label_proj = nn.Linear(text_emb_size, self.embed_dim)
112
+
113
+ def forward(self, x, noise_level, label):
114
+ noise_level = self.fourier_feats(noise_level).unsqueeze(1)
115
+
116
+ label = self.label_proj(label).unsqueeze(1)
117
+
118
+ noise_label_emb = torch.cat([noise_level, label], dim=1) # bs, 2, d
119
+ noise_label_emb = self.norm(noise_label_emb)
120
+
121
+ x = self.denoiser_trans_block(x, noise_label_emb)
122
+
123
+ return x
tld/diffusion.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import clip
4
+ import numpy as np
5
+ import requests
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ import torchvision.utils as vutils
9
+ from diffusers import AutoencoderKL
10
+ from torch import Tensor
11
+ from tqdm import tqdm
12
+
13
+ from tld.denoiser import Denoiser
14
+
15
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
+ to_pil = transforms.ToPILImage()
17
+
18
+
19
+ @dataclass
20
+ class DiffusionGenerator:
21
+ model: Denoiser
22
+ vae: AutoencoderKL
23
+ device: torch.device
24
+ model_dtype: torch.dtype = torch.float32
25
+
26
+ @torch.no_grad()
27
+ def generate(
28
+ self,
29
+ labels: Tensor, # embeddings to condition on
30
+ n_iter: int = 30,
31
+ num_imgs: int = 16,
32
+ class_guidance: float = 3,
33
+ seed: int = 10,
34
+ scale_factor: int = 8, # latent scaling before decoding - should be ~ std of latent space
35
+ img_size: int = 32, # height, width of latent
36
+ sharp_f: float = 0.1,
37
+ bright_f: float = 0.1,
38
+ exponent: float = 1,
39
+ seeds: Tensor | None = None,
40
+ noise_levels=None,
41
+ use_ddpm_plus: bool = True,
42
+ ):
43
+ """Generate images via reverse diffusion.
44
+ if use_ddpm_plus=True uses Algorithm 2 DPM-Solver++(2M) here: https://arxiv.org/pdf/2211.01095.pdf
45
+ else use ddim with alpha = 1-sigma
46
+ """
47
+ if noise_levels is None:
48
+ noise_levels = (1 - torch.pow(torch.arange(0, 1, 1 / n_iter), exponent)).tolist()
49
+ noise_levels[0] = 0.99
50
+
51
+ if use_ddpm_plus:
52
+ lambdas = [np.log((1 - sigma) / sigma) for sigma in noise_levels] # log snr
53
+ hs = [lambdas[i] - lambdas[i - 1] for i in range(1, len(lambdas))]
54
+ rs = [hs[i - 1] / hs[i] for i in range(1, len(hs))]
55
+
56
+ x_t = self.initialize_image(seeds, num_imgs, img_size, seed)
57
+
58
+ labels = torch.cat([labels, torch.zeros_like(labels)])
59
+ self.model.eval()
60
+
61
+ x0_pred_prev = None
62
+
63
+ for i in tqdm(range(len(noise_levels) - 1)):
64
+ curr_noise, next_noise = noise_levels[i], noise_levels[i + 1]
65
+
66
+ x0_pred = self.pred_image(x_t, labels, curr_noise, class_guidance)
67
+
68
+ if x0_pred_prev is None:
69
+ x_t = ((curr_noise - next_noise) * x0_pred + next_noise * x_t) / curr_noise
70
+ else:
71
+ if use_ddpm_plus:
72
+ # x0_pred is a combination of the two previous x0_pred:
73
+ D = (1 + 1 / (2 * rs[i - 1])) * x0_pred - (1 / (2 * rs[i - 1])) * x0_pred_prev
74
+ else:
75
+ # ddim:
76
+ D = x0_pred
77
+
78
+ x_t = ((curr_noise - next_noise) * D + next_noise * x_t) / curr_noise
79
+
80
+ x0_pred_prev = x0_pred
81
+
82
+ x0_pred = self.pred_image(x_t, labels, next_noise, class_guidance)
83
+
84
+ # shifting latents works a bit like an image editor:
85
+ x0_pred[:, 3, :, :] += sharp_f
86
+ x0_pred[:, 0, :, :] += bright_f
87
+
88
+ x0_pred_img = self.vae.decode((x0_pred * scale_factor).to(self.model_dtype))[0].cpu()
89
+ return x0_pred_img, x0_pred
90
+
91
+ def pred_image(self, noisy_image, labels, noise_level, class_guidance):
92
+ num_imgs = noisy_image.size(0)
93
+ noises = torch.full((2 * num_imgs, 1), noise_level)
94
+ x0_pred = self.model(
95
+ torch.cat([noisy_image, noisy_image]),
96
+ noises.to(self.device, self.model_dtype),
97
+ labels.to(self.device, self.model_dtype),
98
+ )
99
+ x0_pred = self.apply_classifier_free_guidance(x0_pred, num_imgs, class_guidance)
100
+ return x0_pred
101
+
102
+ def initialize_image(self, seeds, num_imgs, img_size, seed):
103
+ """Initialize the seed tensor."""
104
+ if seeds is None:
105
+ generator = torch.Generator(device=self.device)
106
+ generator.manual_seed(seed)
107
+ return torch.randn(
108
+ num_imgs,
109
+ 4,
110
+ img_size,
111
+ img_size,
112
+ dtype=self.model_dtype,
113
+ device=self.device,
114
+ generator=generator,
115
+ )
116
+ else:
117
+ return seeds.to(self.device, self.model_dtype)
118
+
119
+ def apply_classifier_free_guidance(self, x0_pred, num_imgs, class_guidance):
120
+ """Apply classifier-free guidance to the predictions."""
121
+ x0_pred_label, x0_pred_no_label = x0_pred[:num_imgs], x0_pred[num_imgs:]
122
+ return class_guidance * x0_pred_label + (1 - class_guidance) * x0_pred_no_label
123
+
124
+
125
+ @dataclass
126
+ class LTDConfig:
127
+ vae_scale_factor: float = 8
128
+ img_size: int = 32
129
+ model_dtype: torch.dtype = torch.float32
130
+ file_url: str = None # = "https://huggingface.co/apapiu/small_ldt/resolve/main/state_dict_378000.pth"
131
+ local_filename: str = "state_dict_378000.pth"
132
+ vae_name: str = "madebyollin/sdxl-vae-fp16-fix"
133
+ clip_model_name: str = "ViT-L/14"
134
+ denoiser: Denoiser = Denoiser(
135
+ image_size=32,
136
+ noise_embed_dims=256,
137
+ patch_size=2,
138
+ embed_dim=256,
139
+ dropout=0,
140
+ n_layers=4,
141
+ )
142
+
143
+
144
+ def download_file(url, filename):
145
+ with requests.get(url, stream=True) as r:
146
+ r.raise_for_status()
147
+ with open(filename, "wb") as f:
148
+ for chunk in r.iter_content(chunk_size=8192):
149
+ f.write(chunk)
150
+
151
+
152
+ @torch.no_grad()
153
+ def encode_text(label, model):
154
+ text_tokens = clip.tokenize(label, truncate=True).to(device)
155
+ text_encoding = model.encode_text(text_tokens)
156
+ return text_encoding.cpu()
157
+
158
+
159
+ class DiffusionTransformer:
160
+ def __init__(self, config: LTDConfig):
161
+ denoiser = config.denoiser.to(config.model_dtype)
162
+
163
+ if config.file_url is not None:
164
+ print(f"Downloading model from {config.file_url}")
165
+ download_file(config.file_url, config.local_filename)
166
+ state_dict = torch.load(config.local_filename, map_location=torch.device("cpu"))
167
+ denoiser.load_state_dict(state_dict)
168
+
169
+ denoiser = denoiser.to(device)
170
+
171
+ vae = AutoencoderKL.from_pretrained(config.vae_name, torch_dtype=config.model_dtype).to(device)
172
+
173
+ self.clip_model, preprocess = clip.load(config.clip_model_name)
174
+ self.clip_model = self.clip_model.to(device)
175
+
176
+ self.diffuser = DiffusionGenerator(denoiser, vae, device, config.model_dtype)
177
+
178
+ def generate_image_from_text(
179
+ self, prompt: str, class_guidance=6, seed=11, num_imgs=1, img_size=32, n_iter=15
180
+ ):
181
+ nrow = int(np.sqrt(num_imgs))
182
+
183
+ cur_prompts = [prompt] * num_imgs
184
+ labels = encode_text(cur_prompts, self.clip_model)
185
+ out, out_latent = self.diffuser.generate(
186
+ labels=labels,
187
+ num_imgs=num_imgs,
188
+ class_guidance=class_guidance,
189
+ seed=seed,
190
+ n_iter=n_iter,
191
+ exponent=1,
192
+ scale_factor=8,
193
+ sharp_f=0,
194
+ bright_f=0,
195
+ )
196
+
197
+ out = to_pil((vutils.make_grid((out + 1) / 2, nrow=nrow, padding=4)).float().clip(0, 1))
198
+ return out
tld/gradio_app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import BytesIO
3
+
4
+ import gradio as gr
5
+ import requests
6
+ from PIL import Image
7
+
8
+ # runpod_id = os.environ['RUNPOD_ID']
9
+ # token_id = os.environ['AUTH_TOKEN']
10
+ # url = f'https://{runpod_id}-8000.proxy.runpod.net/generate-image/'
11
+
12
+ url = os.getenv("API_URL")
13
+ token_id = os.getenv("API_TOKEN")
14
+
15
+
16
+ def generate_image_from_text(prompt, class_guidance):
17
+ headers = {"Authorization": f"Bearer {token_id}"}
18
+
19
+ data = {"prompt": prompt, "class_guidance": class_guidance, "seed": 11, "num_imgs": 1, "img_size": 32}
20
+
21
+ response = requests.post(url, json=data, headers=headers)
22
+
23
+ if response.status_code == 200:
24
+ image = Image.open(BytesIO(response.content))
25
+ else:
26
+ print("Failed to fetch image:", response.status_code, response.text)
27
+
28
+ return image
29
+
30
+
31
+ iface = gr.Interface(
32
+ fn=generate_image_from_text,
33
+ inputs=["text", "slider"],
34
+ outputs="image",
35
+ title="Text-to-Image Generator",
36
+ description="Enter a text prompt to generate an image.",
37
+ )
38
+
39
+ # Launch the app
40
+ iface.launch()
tld/img_examples/a beautiful woman with blonde hair in her 50s_cfg_7_seed_11.png ADDED
tld/img_examples/a cute grey great owl_cfg_8_seed_11.png ADDED
tld/img_examples/a lake in mountains in the fall at sunset_cfg_7_seed_11.png ADDED
tld/img_examples/a woman cyborg with red curly hair, 8k_cfg_9.5_seed_11.png ADDED
tld/img_examples/an aerial view of manhattan, isometric view, as pantinted by mondrian_cfg_7_seed_11.png ADDED
tld/img_examples/isometric view of small japanese village with blooming trees_cfg_7_seed_11.png ADDED
tld/img_examples/painting of a cute fox in a suit in a field of poppies_cfg_8_seed_11.png ADDED
tld/img_examples/painting of a cyberpunk market_cfg_7_seed_11.png ADDED
tld/img_examples/watercolor of a cute cat riding a motorcycle_cfg_7_seed_11.png ADDED
tld/train.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import copy
4
+ from dataclasses import asdict, dataclass
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torchvision
9
+ import torchvision.utils as vutils
10
+ import wandb
11
+ from accelerate import Accelerator
12
+ from diffusers import AutoencoderKL
13
+ from PIL.Image import Image
14
+ from torch import Tensor, nn
15
+ from torch.utils.data import DataLoader, TensorDataset
16
+ from tqdm import tqdm
17
+
18
+ from tld.denoiser import Denoiser
19
+ from tld.diffusion import DiffusionGenerator
20
+
21
+
22
+ def eval_gen(diffuser: DiffusionGenerator, labels: Tensor) -> Image:
23
+ class_guidance = 4.5
24
+ seed = 10
25
+ out, _ = diffuser.generate(
26
+ labels=torch.repeat_interleave(labels, 8, dim=0),
27
+ num_imgs=64,
28
+ class_guidance=class_guidance,
29
+ seed=seed,
30
+ n_iter=40,
31
+ exponent=1,
32
+ sharp_f=0.1,
33
+ )
34
+
35
+ out = to_pil((vutils.make_grid((out + 1) / 2, nrow=8, padding=4)).float().clip(0, 1))
36
+ out.save(f"emb_val_cfg:{class_guidance}_seed:{seed}.png")
37
+
38
+ return out
39
+
40
+
41
+ def count_parameters(model: nn.Module):
42
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
43
+
44
+
45
+ def count_parameters_per_layer(model: nn.Module):
46
+ for name, param in model.named_parameters():
47
+ print(f"{name}: {param.numel()} parameters")
48
+
49
+
50
+ to_pil = torchvision.transforms.ToPILImage()
51
+
52
+
53
+ def update_ema(ema_model: nn.Module, model: nn.Module, alpha: float = 0.999):
54
+ with torch.no_grad():
55
+ for ema_param, model_param in zip(ema_model.parameters(), model.parameters()):
56
+ ema_param.data.mul_(alpha).add_(model_param.data, alpha=1 - alpha)
57
+
58
+
59
+ @dataclass
60
+ class ModelConfig:
61
+ embed_dim: int = 512
62
+ n_layers: int = 6
63
+ clip_embed_size: int = 768
64
+ scaling_factor: int = 8
65
+ patch_size: int = 2
66
+ image_size: int = 32
67
+ n_channels: int = 4
68
+ dropout: float = 0
69
+ mlp_multiplier: int = 4
70
+ batch_size: int = 128
71
+ class_guidance: int = 3
72
+ lr: float = 3e-4
73
+ n_epoch: int = 100
74
+ alpha: float = 0.999
75
+ noise_embed_dims: int = 128
76
+ diffusion_n_iter: int = 35
77
+ from_scratch: bool = True
78
+ run_id: str = ""
79
+ model_name: str = ""
80
+ beta_a: float = 0.75
81
+ beta_b: float = 0.75
82
+ save_and_eval_every_iters: int = 1000
83
+
84
+
85
+ @dataclass
86
+ class DataConfig:
87
+ latent_path: str # path to a numpy file containing latents
88
+ text_emb_path: str
89
+ val_path: str
90
+
91
+
92
+ def main(config: ModelConfig, dataconfig: DataConfig) -> None:
93
+ """main train loop to be used with accelerate"""
94
+
95
+ accelerator = Accelerator(mixed_precision="fp16", log_with="wandb")
96
+
97
+ accelerator.print("Loading Data:")
98
+ latent_train_data = torch.tensor(np.load(dataconfig.latent_path), dtype=torch.float32)
99
+ train_label_embeddings = torch.tensor(np.load(dataconfig.text_emb_path), dtype=torch.float32)
100
+ emb_val = torch.tensor(np.load(dataconfig.val_path), dtype=torch.float32)
101
+ dataset = TensorDataset(latent_train_data, train_label_embeddings)
102
+ train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
103
+
104
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
105
+
106
+ if accelerator.is_main_process:
107
+ vae = vae.to(accelerator.device)
108
+
109
+ model = Denoiser(
110
+ image_size=config.image_size,
111
+ noise_embed_dims=config.noise_embed_dims,
112
+ patch_size=config.patch_size,
113
+ embed_dim=config.embed_dim,
114
+ dropout=config.dropout,
115
+ n_layers=config.n_layers,
116
+ )
117
+
118
+ loss_fn = nn.MSELoss()
119
+ optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
120
+
121
+ accelerator.print("Compiling model:")
122
+ model = torch.compile(model)
123
+
124
+ if not config.from_scratch:
125
+ accelerator.print("Loading Model:")
126
+ wandb.restore(
127
+ config.model_name, run_path=f"apapiu/cifar_diffusion/runs/{config.run_id}", replace=True
128
+ )
129
+ full_state_dict = torch.load(config.model_name)
130
+ model.load_state_dict(full_state_dict["model_ema"])
131
+ optimizer.load_state_dict(full_state_dict["opt_state"])
132
+ global_step = full_state_dict["global_step"]
133
+ else:
134
+ global_step = 0
135
+
136
+ if accelerator.is_local_main_process:
137
+ ema_model = copy.deepcopy(model).to(accelerator.device)
138
+ diffuser = DiffusionGenerator(ema_model, vae, accelerator.device, torch.float32)
139
+
140
+ accelerator.print("model prep")
141
+ model, train_loader, optimizer = accelerator.prepare(model, train_loader, optimizer)
142
+
143
+ accelerator.init_trackers(project_name="cifar_diffusion", config=asdict(config))
144
+
145
+ accelerator.print(count_parameters(model))
146
+ accelerator.print(count_parameters_per_layer(model))
147
+
148
+ ### Train:
149
+ for i in range(1, config.n_epoch + 1):
150
+ accelerator.print(f"epoch: {i}")
151
+
152
+ for x, y in tqdm(train_loader):
153
+ x = x / config.scaling_factor
154
+
155
+ noise_level = torch.tensor(
156
+ np.random.beta(config.beta_a, config.beta_b, len(x)), device=accelerator.device
157
+ )
158
+ signal_level = 1 - noise_level
159
+ noise = torch.randn_like(x)
160
+
161
+ x_noisy = noise_level.view(-1, 1, 1, 1) * noise + signal_level.view(-1, 1, 1, 1) * x
162
+
163
+ x_noisy = x_noisy.float()
164
+ noise_level = noise_level.float()
165
+ label = y
166
+
167
+ prob = 0.15
168
+ mask = torch.rand(y.size(0), device=accelerator.device) < prob
169
+ label[mask] = 0 # OR replacement_vector
170
+
171
+ if global_step % config.save_and_eval_every_iters == 0:
172
+ accelerator.wait_for_everyone()
173
+ if accelerator.is_main_process:
174
+ ##eval and saving:
175
+ out = eval_gen(diffuser=diffuser, labels=emb_val)
176
+ out.save("img.jpg")
177
+ accelerator.log({f"step: {global_step}": wandb.Image("img.jpg")})
178
+
179
+ opt_unwrapped = accelerator.unwrap_model(optimizer)
180
+ full_state_dict = {
181
+ "model_ema": ema_model.state_dict(),
182
+ "opt_state": opt_unwrapped.state_dict(),
183
+ "global_step": global_step,
184
+ }
185
+ accelerator.save(full_state_dict, config.model_name)
186
+ wandb.save(config.model_name)
187
+
188
+ model.train()
189
+
190
+ with accelerator.accumulate():
191
+ ###train loop:
192
+ optimizer.zero_grad()
193
+
194
+ pred = model(x_noisy, noise_level.view(-1, 1), label)
195
+ loss = loss_fn(pred, x)
196
+ accelerator.log({"train_loss": loss.item()}, step=global_step)
197
+ accelerator.backward(loss)
198
+ optimizer.step()
199
+
200
+ if accelerator.is_main_process:
201
+ update_ema(ema_model, model, alpha=config.alpha)
202
+
203
+ global_step += 1
204
+ accelerator.end_training()
205
+
206
+
207
+ # args = (config, data_path, val_path)
208
+ # notebook_launcher(training_loop)
tld/transformer_blocks.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+
6
+
7
+ class SinusoidalEmbedding(nn.Module):
8
+ def __init__(self, emb_min_freq=1.0, emb_max_freq=1000.0, embedding_dims=32):
9
+ super(SinusoidalEmbedding, self).__init__()
10
+
11
+ frequencies = torch.exp(
12
+ torch.linspace(np.log(emb_min_freq), np.log(emb_max_freq), embedding_dims // 2)
13
+ )
14
+
15
+ self.register_buffer("angular_speeds", 2.0 * torch.pi * frequencies)
16
+
17
+ def forward(self, x):
18
+ embeddings = torch.cat(
19
+ [torch.sin(self.angular_speeds * x), torch.cos(self.angular_speeds * x)], dim=-1
20
+ )
21
+ return embeddings
22
+
23
+
24
+ class MHAttention(nn.Module):
25
+ def __init__(self, is_causal=False, dropout_level=0.0, n_heads=4):
26
+ super().__init__()
27
+ self.is_causal = is_causal
28
+ self.dropout_level = dropout_level
29
+ self.n_heads = n_heads
30
+
31
+ def forward(self, q, k, v, attn_mask=None):
32
+ assert q.size(-1) == k.size(-1)
33
+ assert k.size(-2) == v.size(-2)
34
+
35
+ q, k, v = [rearrange(x, "bs n (h d) -> bs h n d", h=self.n_heads) for x in [q, k, v]]
36
+
37
+ out = nn.functional.scaled_dot_product_attention(
38
+ q,
39
+ k,
40
+ v,
41
+ attn_mask=attn_mask,
42
+ is_causal=self.is_causal,
43
+ dropout_p=self.dropout_level if self.training else 0,
44
+ )
45
+
46
+ out = rearrange(out, "bs h n d -> bs n (h d)", h=self.n_heads)
47
+
48
+ return out
49
+
50
+
51
+ class SelfAttention(nn.Module):
52
+ def __init__(self, embed_dim, is_causal=False, dropout_level=0.0, n_heads=4):
53
+ super().__init__()
54
+ self.qkv_linear = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
55
+ self.mha = MHAttention(is_causal, dropout_level, n_heads)
56
+
57
+ def forward(self, x):
58
+ q, k, v = self.qkv_linear(x).chunk(3, dim=2)
59
+ return self.mha(q, k, v)
60
+
61
+
62
+ class CrossAttention(nn.Module):
63
+ def __init__(self, embed_dim, is_causal=False, dropout_level=0, n_heads=4):
64
+ super().__init__()
65
+ self.kv_linear = nn.Linear(embed_dim, 2 * embed_dim, bias=False)
66
+ self.q_linear = nn.Linear(embed_dim, embed_dim, bias=False)
67
+ self.mha = MHAttention(is_causal, dropout_level, n_heads)
68
+
69
+ def forward(self, x, y):
70
+ q = self.q_linear(x)
71
+ k, v = self.kv_linear(y).chunk(2, dim=2)
72
+ return self.mha(q, k, v)
73
+
74
+
75
+ class MLP(nn.Module):
76
+ def __init__(self, embed_dim, mlp_multiplier, dropout_level):
77
+ super().__init__()
78
+ self.mlp = nn.Sequential(
79
+ nn.Linear(embed_dim, mlp_multiplier * embed_dim),
80
+ nn.GELU(),
81
+ nn.Linear(mlp_multiplier * embed_dim, embed_dim),
82
+ nn.Dropout(dropout_level),
83
+ )
84
+
85
+ def forward(self, x):
86
+ return self.mlp(x)
87
+
88
+
89
+ class MLPSepConv(nn.Module):
90
+ def __init__(self, embed_dim, mlp_multiplier, dropout_level):
91
+ """see: https://github.com/ofsoundof/LocalViT"""
92
+ super().__init__()
93
+ self.mlp = nn.Sequential(
94
+ # this Conv with kernel size 1 is equivalent to the Linear layer in a "regular" transformer MLP
95
+ nn.Conv2d(embed_dim, mlp_multiplier * embed_dim, kernel_size=1, padding="same"),
96
+ nn.Conv2d(
97
+ mlp_multiplier * embed_dim,
98
+ mlp_multiplier * embed_dim,
99
+ kernel_size=3,
100
+ padding="same",
101
+ groups=mlp_multiplier * embed_dim,
102
+ ), # <- depthwise conv
103
+ nn.GELU(),
104
+ nn.Conv2d(mlp_multiplier * embed_dim, embed_dim, kernel_size=1, padding="same"),
105
+ nn.Dropout(dropout_level),
106
+ )
107
+
108
+ def forward(self, x):
109
+ w = h = int(np.sqrt(x.size(1))) # only square images for now
110
+ x = rearrange(x, "bs (h w) d -> bs d h w", h=h, w=w)
111
+ x = self.mlp(x)
112
+ x = rearrange(x, "bs d h w -> bs (h w) d")
113
+ return x
114
+
115
+
116
+ class DecoderBlock(nn.Module):
117
+ def __init__(
118
+ self,
119
+ embed_dim: int,
120
+ is_causal: bool,
121
+ mlp_multiplier: int,
122
+ dropout_level: float,
123
+ mlp_class: type[MLP] | type[MLPSepConv],
124
+ ):
125
+ super().__init__()
126
+ self.self_attention = SelfAttention(embed_dim, is_causal, dropout_level, n_heads=embed_dim // 64)
127
+ self.cross_attention = CrossAttention(
128
+ embed_dim, is_causal=False, dropout_level=0, n_heads=embed_dim // 64
129
+ )
130
+ self.mlp = mlp_class(embed_dim, mlp_multiplier, dropout_level)
131
+ self.norm1 = nn.LayerNorm(embed_dim)
132
+ self.norm2 = nn.LayerNorm(embed_dim)
133
+ self.norm3 = nn.LayerNorm(embed_dim)
134
+
135
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
136
+ x = self.self_attention(self.norm1(x)) + x
137
+ x = self.cross_attention(self.norm2(x), y) + x
138
+ x = self.mlp(self.norm3(x)) + x
139
+ return x