邬彦泽 commited on
Commit
aa8012e
1 Parent(s): dbb55f6
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +201 -0
  2. README.md +2 -2
  3. app.py +298 -0
  4. eva_clip/__init__.py +10 -0
  5. eva_clip/constants.py +2 -0
  6. eva_clip/eva_vit_model.py +633 -0
  7. eva_clip/factory.py +517 -0
  8. eva_clip/hf_configs.py +57 -0
  9. eva_clip/hf_model.py +248 -0
  10. eva_clip/loss.py +138 -0
  11. eva_clip/model.py +440 -0
  12. eva_clip/model_configs/EVA01-CLIP-B-16.json +19 -0
  13. eva_clip/model_configs/EVA01-CLIP-g-14-plus.json +24 -0
  14. eva_clip/model_configs/EVA01-CLIP-g-14.json +24 -0
  15. eva_clip/model_configs/EVA02-CLIP-B-16.json +29 -0
  16. eva_clip/model_configs/EVA02-CLIP-L-14-336.json +29 -0
  17. eva_clip/model_configs/EVA02-CLIP-L-14.json +29 -0
  18. eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json +25 -0
  19. eva_clip/model_configs/EVA02-CLIP-bigE-14.json +25 -0
  20. eva_clip/modified_resnet.py +181 -0
  21. eva_clip/openai.py +144 -0
  22. eva_clip/pretrained.py +332 -0
  23. eva_clip/rope.py +137 -0
  24. eva_clip/timm_model.py +123 -0
  25. eva_clip/tokenizer.py +201 -0
  26. eva_clip/transform.py +103 -0
  27. eva_clip/transformer.py +792 -0
  28. eva_clip/utils.py +326 -0
  29. example_inputs/hinton.jpeg +0 -0
  30. example_inputs/lecun.jpg +0 -0
  31. example_inputs/lifeifei.jpg +0 -0
  32. example_inputs/liuyifei.png +0 -0
  33. example_inputs/rihanna.webp +0 -0
  34. example_inputs/zcy.webp +0 -0
  35. flux/__init__.py +11 -0
  36. flux/__main__.py +4 -0
  37. flux/api.py +194 -0
  38. flux/cli.py +261 -0
  39. flux/math.py +31 -0
  40. flux/model.py +135 -0
  41. flux/modules/__init__.py +0 -0
  42. flux/modules/autoencoder.py +312 -0
  43. flux/modules/conditioner.py +37 -0
  44. flux/modules/layers.py +253 -0
  45. flux/sampling.py +161 -0
  46. flux/util.py +201 -0
  47. models/.gitkeep +0 -0
  48. pulid/attention_processor.py +422 -0
  49. pulid/encoders.py +64 -0
  50. pulid/encoders_flux.py +207 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: PuLID FLUX
3
- emoji: 👁
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
 
1
  ---
2
+ title: PuLID-FLUX
3
+ emoji: 🤗
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import time
3
+ import os
4
+
5
+ import gradio as gr
6
+ import torch
7
+ from einops import rearrange
8
+ from PIL import Image
9
+
10
+ from flux.cli import SamplingOptions
11
+ from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
12
+ from flux.util import load_ae, load_clip, load_flow_model, load_t5
13
+ from pulid.pipeline_flux import PuLIDPipeline
14
+ from pulid.utils import resize_numpy_image_long
15
+
16
+
17
+ def get_models(name: str, device: torch.device, offload: bool):
18
+ t5 = load_t5(device, max_length=128)
19
+ clip = load_clip(device)
20
+ model = load_flow_model(name, device="cpu" if offload else device)
21
+ model.eval()
22
+ ae = load_ae(name, device="cpu" if offload else device)
23
+ return model, ae, t5, clip
24
+
25
+
26
+ class FluxGenerator:
27
+ def __init__(self, model_name: str, device: str, offload: bool, args):
28
+ self.device = torch.device(device)
29
+ self.offload = offload
30
+ self.model_name = model_name
31
+ self.model, self.ae, self.t5, self.clip = get_models(
32
+ model_name,
33
+ device=self.device,
34
+ offload=self.offload,
35
+ )
36
+ self.pulid_model = PuLIDPipeline(self.model, device, weight_dtype=torch.bfloat16)
37
+ self.pulid_model.load_pretrain(args.pretrained_model)
38
+
39
+ @spaces.GPU
40
+ @torch.inference_mode()
41
+ def generate_image(
42
+ self,
43
+ width,
44
+ height,
45
+ num_steps,
46
+ start_step,
47
+ guidance,
48
+ seed,
49
+ prompt,
50
+ id_image=None,
51
+ id_weight=1.0,
52
+ neg_prompt="",
53
+ true_cfg=1.0,
54
+ timestep_to_start_cfg=1,
55
+ max_sequence_length=128,
56
+ ):
57
+ self.t5.max_length = max_sequence_length
58
+
59
+ seed = int(seed)
60
+ if seed == -1:
61
+ seed = None
62
+
63
+ opts = SamplingOptions(
64
+ prompt=prompt,
65
+ width=width,
66
+ height=height,
67
+ num_steps=num_steps,
68
+ guidance=guidance,
69
+ seed=seed,
70
+ )
71
+
72
+ if opts.seed is None:
73
+ opts.seed = torch.Generator(device="cpu").seed()
74
+ print(f"Generating '{opts.prompt}' with seed {opts.seed}")
75
+ t0 = time.perf_counter()
76
+
77
+ use_true_cfg = abs(true_cfg - 1.0) > 1e-2
78
+
79
+ if id_image is not None:
80
+ id_image = resize_numpy_image_long(id_image, 1024)
81
+ id_embeddings, uncond_id_embeddings = self.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
82
+ else:
83
+ id_embeddings = None
84
+ uncond_id_embeddings = None
85
+
86
+ # prepare input
87
+ x = get_noise(
88
+ 1,
89
+ opts.height,
90
+ opts.width,
91
+ device=self.device,
92
+ dtype=torch.bfloat16,
93
+ seed=opts.seed,
94
+ )
95
+ timesteps = get_schedule(
96
+ opts.num_steps,
97
+ x.shape[-1] * x.shape[-2] // 4,
98
+ shift=True,
99
+ )
100
+
101
+ if self.offload:
102
+ self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
103
+ inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=opts.prompt)
104
+ inp_neg = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) if use_true_cfg else None
105
+
106
+ # offload TEs to CPU, load model to gpu
107
+ if self.offload:
108
+ self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
109
+ torch.cuda.empty_cache()
110
+ self.model = self.model.to(self.device)
111
+
112
+ # denoise initial noise
113
+ x = denoise(
114
+ self.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight,
115
+ start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg,
116
+ timestep_to_start_cfg=timestep_to_start_cfg,
117
+ neg_txt=inp_neg["txt"] if use_true_cfg else None,
118
+ neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
119
+ neg_vec=inp_neg["vec"] if use_true_cfg else None,
120
+ )
121
+
122
+ # offload model, load autoencoder to gpu
123
+ if self.offload:
124
+ self.model.cpu()
125
+ torch.cuda.empty_cache()
126
+ self.ae.decoder.to(x.device)
127
+
128
+ # decode latents to pixel space
129
+ x = unpack(x.float(), opts.height, opts.width)
130
+ with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
131
+ x = self.ae.decode(x)
132
+
133
+ if self.offload:
134
+ self.ae.decoder.cpu()
135
+ torch.cuda.empty_cache()
136
+
137
+ t1 = time.perf_counter()
138
+
139
+ print(f"Done in {t1 - t0:.1f}s.")
140
+ # bring into PIL format
141
+ x = x.clamp(-1, 1)
142
+ # x = embed_watermark(x.float())
143
+ x = rearrange(x[0], "c h w -> h w c")
144
+
145
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
146
+ return img, str(opts.seed), self.pulid_model.debug_img_list
147
+
148
+ _HEADER_ = '''
149
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
150
+ <h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">PuLID for FLUX</h1>
151
+ <p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</a> | Codes: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a></p>
152
+ </div>
153
+
154
+ ❗️❗️❗️**Tips:**
155
+ - `timestep to start inserting ID:` The smaller the value, the higher the fidelity, but the lower the editability; the higher the value, the lower the fidelity, but the higher the editability. **The recommended range for this value is between 0 and 4**. For photorealistic scenes, we recommend using 4; for stylized scenes, we recommend using 0-1. If you are not satisfied with the similarity, you can lower this value; conversely, if you are not satisfied with the editability, you can increase this value.
156
+ - `true CFG scale:` In most scenarios, it is recommended to use a fake CFG, i.e., setting the true CFG scale to 1, and just adjusting the guidance scale. This is also more efficiency. However, in a few cases, utilizing a true CFG can yield better results. For more detaileds, please refer to XX.
157
+ - please refer to the <a href='URL_ADDRESS' target='_blank'>github doc</a> for more details and info about the model, we provide the detail explanation about the above two parameters in the doc.
158
+ - we provide some examples in the bottom, you can try these example prompts first
159
+
160
+ ''' # noqa E501
161
+
162
+ _CITE_ = r"""
163
+ If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'> Github Repo</a>. Thanks!
164
+ ---
165
+
166
+ 📧 **Contact**
167
+ If you have any questions or feedbacks, feel free to open a discussion or contact <b>[email protected]</b>.
168
+ """ # noqa E501
169
+
170
+
171
+ def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu",
172
+ offload: bool = False):
173
+ generator = FluxGenerator(model_name, device, offload, args)
174
+
175
+ with gr.Blocks() as demo:
176
+ gr.Markdown(_HEADER_)
177
+
178
+ with gr.Row():
179
+ with gr.Column():
180
+ prompt = gr.Textbox(label="Prompt", value="portrait, color, cinematic")
181
+ id_image = gr.Image(label="ID Image")
182
+ id_weight = gr.Slider(0.0, 3.0, 1, step=0.05, label="id weight")
183
+
184
+ width = gr.Slider(256, 1536, 896, step=16, label="Width")
185
+ height = gr.Slider(256, 1536, 1152, step=16, label="Height")
186
+ num_steps = gr.Slider(1, 20, 20, step=1, label="Number of steps")
187
+ start_step = gr.Slider(0, 10, 0, step=1, label="timestep to start inserting ID")
188
+ guidance = gr.Slider(1.0, 10.0, 4, step=0.1, label="Guidance")
189
+ seed = gr.Textbox(-1, label="Seed (-1 for random)")
190
+ max_sequence_length = gr.Slider(128, 512, 128, step=128,
191
+ label="max_sequence_length for prompt (T5), small will be faster")
192
+
193
+ with gr.Accordion("Advanced Options (True CFG, true_cfg_scale=1 means use fake CFG, >1 means use true CFG, if using true CFG, we recommend set the guidance scale to 1)", open=False): # noqa E501
194
+ neg_prompt = gr.Textbox(
195
+ label="Negative Prompt",
196
+ value="bad quality, worst quality, text, signature, watermark, extra limbs")
197
+ true_cfg = gr.Slider(1.0, 10.0, 1, step=0.1, label="true CFG scale")
198
+ timestep_to_start_cfg = gr.Slider(0, 20, 1, step=1, label="timestep to start cfg", visible=args.dev)
199
+
200
+ generate_btn = gr.Button("Generate")
201
+
202
+ with gr.Column():
203
+ output_image = gr.Image(label="Generated Image")
204
+ seed_output = gr.Textbox(label="Used Seed")
205
+ intermediate_output = gr.Gallery(label='Output', elem_id="gallery", visible=args.dev)
206
+ gr.Markdown(_CITE_)
207
+
208
+ with gr.Row(), gr.Column():
209
+ gr.Markdown("## Examples")
210
+ example_inps = [
211
+ [
212
+ 'a woman holding sign with glowing green text \"PuLID for FLUX\"',
213
+ 'example_inputs/liuyifei.png',
214
+ 4, 4, 2680261499100305976, 1
215
+ ],
216
+ [
217
+ 'portrait, side view',
218
+ 'example_inputs/liuyifei.png',
219
+ 4, 4, 1205240166692517553, 1
220
+ ],
221
+ [
222
+ 'white-haired woman with vr technology atmosphere, revolutionary exceptional magnum with remarkable details', # noqa E501
223
+ 'example_inputs/liuyifei.png',
224
+ 4, 4, 6349424134217931066, 1
225
+ ],
226
+ [
227
+ 'a young child is eating Icecream',
228
+ 'example_inputs/liuyifei.png',
229
+ 4, 4, 10606046113565776207, 1
230
+ ],
231
+ [
232
+ 'a man is holding a sign with text \"PuLID for FLUX\", winter, snowing, top of the mountain',
233
+ 'example_inputs/pengwei.jpg',
234
+ 4, 4, 2410129802683836089, 1
235
+ ],
236
+ [
237
+ 'portrait, candle light',
238
+ 'example_inputs/pengwei.jpg',
239
+ 4, 4, 17522759474323955700, 1
240
+ ],
241
+ [
242
+ 'profile shot dark photo of a 25-year-old male with smoke escaping from his mouth, the backlit smoke gives the image an ephemeral quality, natural face, natural eyebrows, natural skin texture, award winning photo, highly detailed face, atmospheric lighting, film grain, monochrome', # noqa E501
243
+ 'example_inputs/pengwei.jpg',
244
+ 4, 4, 17733156847328193625, 1
245
+ ],
246
+ [
247
+ 'American Comics, 1boy',
248
+ 'example_inputs/pengwei.jpg',
249
+ 1, 4, 13223174453874179686, 1
250
+ ],
251
+ [
252
+ 'portrait, pixar',
253
+ 'example_inputs/pengwei.jpg',
254
+ 1, 4, 9445036702517583939, 1
255
+ ],
256
+ ]
257
+ gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg],
258
+ label='fake CFG')
259
+
260
+ example_inps = [
261
+ [
262
+ 'portrait, made of ice sculpture',
263
+ 'example_inputs/lecun.jpg',
264
+ 1, 1, 3811899118709451814, 5
265
+ ],
266
+ ]
267
+ gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg],
268
+ label='true CFG')
269
+
270
+ generate_btn.click(
271
+ fn=generator.generate_image,
272
+ inputs=[width, height, num_steps, start_step, guidance, seed, prompt, id_image, id_weight, neg_prompt,
273
+ true_cfg, timestep_to_start_cfg, max_sequence_length],
274
+ outputs=[output_image, seed_output, intermediate_output],
275
+ )
276
+
277
+ return demo
278
+
279
+
280
+ if __name__ == "__main__":
281
+ import argparse
282
+
283
+ parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
284
+ parser.add_argument("--name", type=str, default="flux-dev", choices=list('flux-dev'),
285
+ help="currently only support flux-dev")
286
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
287
+ help="Device to use")
288
+ parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
289
+ parser.add_argument("--port", type=int, default=8080, help="Port to use")
290
+ parser.add_argument("--dev", action='store_true', help="Development mode")
291
+ parser.add_argument("--pretrained_model", type=str, help='for development')
292
+ args = parser.parse_args()
293
+
294
+ import huggingface_hub
295
+ huggingface_hub.login(os.getenv('HF_TOKEN'))
296
+
297
+ demo = create_demo(args, args.name, args.device, args.offload)
298
+ demo.launch()
eva_clip/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2
+ from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_transforms
3
+ from .factory import list_models, add_model_config, get_model_config, load_checkpoint
4
+ from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\
5
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
6
+ from .openai import load_openai_model, list_openai_models
7
+ from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
8
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
9
+ from .tokenizer import SimpleTokenizer, tokenize
10
+ from .transform import image_transform
eva_clip/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
eva_clip/eva_vit_model.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from https://github.com/microsoft/unilm/tree/master/beit
3
+ # --------------------------------------------------------
4
+ import math
5
+ import os
6
+ from functools import partial
7
+ from itertools import repeat
8
+ import collections.abc
9
+ import torch
10
+ import torch.nn as nn
11
+ import warnings
12
+ import torch.nn.functional as F
13
+
14
+ from .transformer import PatchDropout
15
+ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
16
+
17
+ if os.getenv('ENV_TYPE') == 'deepspeed':
18
+ try:
19
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
20
+ except:
21
+ from torch.utils.checkpoint import checkpoint
22
+ else:
23
+ from torch.utils.checkpoint import checkpoint
24
+
25
+ try:
26
+ import xformers
27
+ import xformers.ops as xops
28
+ XFORMERS_IS_AVAILBLE = True
29
+ except:
30
+ XFORMERS_IS_AVAILBLE = False
31
+
32
+
33
+ def _ntuple(n):
34
+ def parse(x):
35
+ if isinstance(x, collections.abc.Iterable):
36
+ return x
37
+ return tuple(repeat(x, n))
38
+ return parse
39
+
40
+ to_2tuple = _ntuple(2)
41
+
42
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
43
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
44
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
45
+ def norm_cdf(x):
46
+ # Computes standard normal cumulative distribution function
47
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
48
+
49
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
50
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
51
+ "The distribution of values may be incorrect.",
52
+ stacklevel=2)
53
+
54
+ with torch.no_grad():
55
+ # Values are generated by using a truncated uniform distribution and
56
+ # then using the inverse CDF for the normal distribution.
57
+ # Get upper and lower cdf values
58
+ l = norm_cdf((a - mean) / std)
59
+ u = norm_cdf((b - mean) / std)
60
+
61
+ # Uniformly fill tensor with values from [l, u], then translate to
62
+ # [2l-1, 2u-1].
63
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
64
+
65
+ # Use inverse cdf transform for normal distribution to get truncated
66
+ # standard normal
67
+ tensor.erfinv_()
68
+
69
+ # Transform to proper mean, std
70
+ tensor.mul_(std * math.sqrt(2.))
71
+ tensor.add_(mean)
72
+
73
+ # Clamp to ensure it's in the proper range
74
+ tensor.clamp_(min=a, max=b)
75
+ return tensor
76
+
77
+
78
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
79
+ # type: (Tensor, float, float, float, float) -> Tensor
80
+ r"""Fills the input Tensor with values drawn from a truncated
81
+ normal distribution. The values are effectively drawn from the
82
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
83
+ with values outside :math:`[a, b]` redrawn until they are within
84
+ the bounds. The method used for generating the random values works
85
+ best when :math:`a \leq \text{mean} \leq b`.
86
+ Args:
87
+ tensor: an n-dimensional `torch.Tensor`
88
+ mean: the mean of the normal distribution
89
+ std: the standard deviation of the normal distribution
90
+ a: the minimum cutoff value
91
+ b: the maximum cutoff value
92
+ Examples:
93
+ >>> w = torch.empty(3, 5)
94
+ >>> nn.init.trunc_normal_(w)
95
+ """
96
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
97
+
98
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
99
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
100
+
101
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
102
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
103
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
104
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
105
+ 'survival rate' as the argument.
106
+
107
+ """
108
+ if drop_prob == 0. or not training:
109
+ return x
110
+ keep_prob = 1 - drop_prob
111
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
112
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
113
+ if keep_prob > 0.0 and scale_by_keep:
114
+ random_tensor.div_(keep_prob)
115
+ return x * random_tensor
116
+
117
+
118
+ class DropPath(nn.Module):
119
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
120
+ """
121
+ def __init__(self, drop_prob=None):
122
+ super(DropPath, self).__init__()
123
+ self.drop_prob = drop_prob
124
+
125
+ def forward(self, x):
126
+ return drop_path(x, self.drop_prob, self.training)
127
+
128
+ def extra_repr(self) -> str:
129
+ return 'p={}'.format(self.drop_prob)
130
+
131
+
132
+ class Mlp(nn.Module):
133
+ def __init__(
134
+ self,
135
+ in_features,
136
+ hidden_features=None,
137
+ out_features=None,
138
+ act_layer=nn.GELU,
139
+ norm_layer=nn.LayerNorm,
140
+ drop=0.,
141
+ subln=False,
142
+
143
+ ):
144
+ super().__init__()
145
+ out_features = out_features or in_features
146
+ hidden_features = hidden_features or in_features
147
+ self.fc1 = nn.Linear(in_features, hidden_features)
148
+ self.act = act_layer()
149
+
150
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
151
+
152
+ self.fc2 = nn.Linear(hidden_features, out_features)
153
+ self.drop = nn.Dropout(drop)
154
+
155
+ def forward(self, x):
156
+ x = self.fc1(x)
157
+ x = self.act(x)
158
+ # x = self.drop(x)
159
+ # commit this for the orignal BERT implement
160
+ x = self.ffn_ln(x)
161
+
162
+ x = self.fc2(x)
163
+ x = self.drop(x)
164
+ return x
165
+
166
+ class SwiGLU(nn.Module):
167
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
168
+ norm_layer=nn.LayerNorm, subln=False):
169
+ super().__init__()
170
+ out_features = out_features or in_features
171
+ hidden_features = hidden_features or in_features
172
+
173
+ self.w1 = nn.Linear(in_features, hidden_features)
174
+ self.w2 = nn.Linear(in_features, hidden_features)
175
+
176
+ self.act = act_layer()
177
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
178
+ self.w3 = nn.Linear(hidden_features, out_features)
179
+
180
+ self.drop = nn.Dropout(drop)
181
+
182
+ def forward(self, x):
183
+ x1 = self.w1(x)
184
+ x2 = self.w2(x)
185
+ hidden = self.act(x1) * x2
186
+ x = self.ffn_ln(hidden)
187
+ x = self.w3(x)
188
+ x = self.drop(x)
189
+ return x
190
+
191
+ class Attention(nn.Module):
192
+ def __init__(
193
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
194
+ proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
195
+ super().__init__()
196
+ self.num_heads = num_heads
197
+ head_dim = dim // num_heads
198
+ if attn_head_dim is not None:
199
+ head_dim = attn_head_dim
200
+ all_head_dim = head_dim * self.num_heads
201
+ self.scale = qk_scale or head_dim ** -0.5
202
+
203
+ self.subln = subln
204
+ if self.subln:
205
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
206
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
207
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
208
+ else:
209
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
210
+
211
+ if qkv_bias:
212
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
213
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
214
+ else:
215
+ self.q_bias = None
216
+ self.v_bias = None
217
+
218
+ if window_size:
219
+ self.window_size = window_size
220
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
221
+ self.relative_position_bias_table = nn.Parameter(
222
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
223
+ # cls to token & token 2 cls & cls to cls
224
+
225
+ # get pair-wise relative position index for each token inside the window
226
+ coords_h = torch.arange(window_size[0])
227
+ coords_w = torch.arange(window_size[1])
228
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
229
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
230
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
231
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
232
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
233
+ relative_coords[:, :, 1] += window_size[1] - 1
234
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
235
+ relative_position_index = \
236
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
237
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
238
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
239
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
240
+ relative_position_index[0, 0] = self.num_relative_distance - 1
241
+
242
+ self.register_buffer("relative_position_index", relative_position_index)
243
+ else:
244
+ self.window_size = None
245
+ self.relative_position_bias_table = None
246
+ self.relative_position_index = None
247
+
248
+ self.attn_drop = nn.Dropout(attn_drop)
249
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
250
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
251
+ self.proj = nn.Linear(all_head_dim, dim)
252
+ self.proj_drop = nn.Dropout(proj_drop)
253
+ self.xattn = xattn
254
+ self.xattn_drop = attn_drop
255
+
256
+ self.rope = rope
257
+
258
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
259
+ B, N, C = x.shape
260
+ if self.subln:
261
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
262
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
263
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
264
+
265
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
266
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
267
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
268
+ else:
269
+
270
+ qkv_bias = None
271
+ if self.q_bias is not None:
272
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
273
+
274
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
275
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
276
+ q, k, v = qkv[0], qkv[1], qkv[2]
277
+
278
+ if self.rope:
279
+ # slightly fast impl
280
+ q_t = q[:, :, 1:, :]
281
+ ro_q_t = self.rope(q_t)
282
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
283
+
284
+ k_t = k[:, :, 1:, :]
285
+ ro_k_t = self.rope(k_t)
286
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
287
+
288
+ if self.xattn:
289
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
290
+ k = k.permute(0, 2, 1, 3)
291
+ v = v.permute(0, 2, 1, 3)
292
+
293
+ x = xops.memory_efficient_attention(
294
+ q, k, v,
295
+ p=self.xattn_drop,
296
+ scale=self.scale,
297
+ )
298
+ x = x.reshape(B, N, -1)
299
+ x = self.inner_attn_ln(x)
300
+ x = self.proj(x)
301
+ x = self.proj_drop(x)
302
+ else:
303
+ q = q * self.scale
304
+ attn = (q @ k.transpose(-2, -1))
305
+
306
+ if self.relative_position_bias_table is not None:
307
+ relative_position_bias = \
308
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
309
+ self.window_size[0] * self.window_size[1] + 1,
310
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
311
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
312
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
313
+
314
+ if rel_pos_bias is not None:
315
+ attn = attn + rel_pos_bias.type_as(attn)
316
+
317
+ if attn_mask is not None:
318
+ attn_mask = attn_mask.bool()
319
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
320
+
321
+ attn = attn.softmax(dim=-1)
322
+ attn = self.attn_drop(attn)
323
+
324
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
325
+ x = self.inner_attn_ln(x)
326
+ x = self.proj(x)
327
+ x = self.proj_drop(x)
328
+ return x
329
+
330
+
331
+ class Block(nn.Module):
332
+
333
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
334
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
335
+ window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
336
+ subln=False, naiveswiglu=False):
337
+ super().__init__()
338
+ self.norm1 = norm_layer(dim)
339
+ self.attn = Attention(
340
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
341
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
342
+ xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
343
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
344
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
345
+ self.norm2 = norm_layer(dim)
346
+ mlp_hidden_dim = int(dim * mlp_ratio)
347
+
348
+ if naiveswiglu:
349
+ self.mlp = SwiGLU(
350
+ in_features=dim,
351
+ hidden_features=mlp_hidden_dim,
352
+ subln=subln,
353
+ norm_layer=norm_layer,
354
+ )
355
+ else:
356
+ self.mlp = Mlp(
357
+ in_features=dim,
358
+ hidden_features=mlp_hidden_dim,
359
+ act_layer=act_layer,
360
+ subln=subln,
361
+ drop=drop
362
+ )
363
+
364
+ if init_values is not None and init_values > 0:
365
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
366
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
367
+ else:
368
+ self.gamma_1, self.gamma_2 = None, None
369
+
370
+ self.postnorm = postnorm
371
+
372
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
373
+ if self.gamma_1 is None:
374
+ if self.postnorm:
375
+ x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
376
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
377
+ else:
378
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
379
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
380
+ else:
381
+ if self.postnorm:
382
+ x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
383
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
384
+ else:
385
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
386
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
387
+ return x
388
+
389
+
390
+ class PatchEmbed(nn.Module):
391
+ """ Image to Patch Embedding
392
+ """
393
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
394
+ super().__init__()
395
+ img_size = to_2tuple(img_size)
396
+ patch_size = to_2tuple(patch_size)
397
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
398
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
399
+ self.img_size = img_size
400
+ self.patch_size = patch_size
401
+ self.num_patches = num_patches
402
+
403
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
404
+
405
+ def forward(self, x, **kwargs):
406
+ B, C, H, W = x.shape
407
+ # FIXME look at relaxing size constraints
408
+ assert H == self.img_size[0] and W == self.img_size[1], \
409
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
410
+ x = self.proj(x).flatten(2).transpose(1, 2)
411
+ return x
412
+
413
+
414
+ class RelativePositionBias(nn.Module):
415
+
416
+ def __init__(self, window_size, num_heads):
417
+ super().__init__()
418
+ self.window_size = window_size
419
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
420
+ self.relative_position_bias_table = nn.Parameter(
421
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
422
+ # cls to token & token 2 cls & cls to cls
423
+
424
+ # get pair-wise relative position index for each token inside the window
425
+ coords_h = torch.arange(window_size[0])
426
+ coords_w = torch.arange(window_size[1])
427
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
428
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
429
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
430
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
431
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
432
+ relative_coords[:, :, 1] += window_size[1] - 1
433
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
434
+ relative_position_index = \
435
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
436
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
437
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
438
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
439
+ relative_position_index[0, 0] = self.num_relative_distance - 1
440
+
441
+ self.register_buffer("relative_position_index", relative_position_index)
442
+
443
+ def forward(self):
444
+ relative_position_bias = \
445
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
446
+ self.window_size[0] * self.window_size[1] + 1,
447
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
448
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
449
+
450
+
451
+ class EVAVisionTransformer(nn.Module):
452
+ """ Vision Transformer with support for patch or hybrid CNN input stage
453
+ """
454
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
455
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
456
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
457
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
458
+ use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
459
+ pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):
460
+ super().__init__()
461
+
462
+ if not XFORMERS_IS_AVAILBLE:
463
+ xattn = False
464
+
465
+ self.image_size = img_size
466
+ self.num_classes = num_classes
467
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
468
+
469
+ self.patch_embed = PatchEmbed(
470
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
471
+ num_patches = self.patch_embed.num_patches
472
+
473
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
474
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
475
+ if use_abs_pos_emb:
476
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
477
+ else:
478
+ self.pos_embed = None
479
+ self.pos_drop = nn.Dropout(p=drop_rate)
480
+
481
+ if use_shared_rel_pos_bias:
482
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
483
+ else:
484
+ self.rel_pos_bias = None
485
+
486
+ if rope:
487
+ half_head_dim = embed_dim // num_heads // 2
488
+ hw_seq_len = img_size // patch_size
489
+ self.rope = VisionRotaryEmbeddingFast(
490
+ dim=half_head_dim,
491
+ pt_seq_len=pt_hw_seq_len,
492
+ ft_seq_len=hw_seq_len if intp_freq else None,
493
+ # patch_dropout=patch_dropout
494
+ )
495
+ else:
496
+ self.rope = None
497
+
498
+ self.naiveswiglu = naiveswiglu
499
+
500
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
501
+ self.use_rel_pos_bias = use_rel_pos_bias
502
+ self.blocks = nn.ModuleList([
503
+ Block(
504
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
505
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
506
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
507
+ xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
508
+ for i in range(depth)])
509
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
510
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
511
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
512
+
513
+ if self.pos_embed is not None:
514
+ trunc_normal_(self.pos_embed, std=.02)
515
+
516
+ trunc_normal_(self.cls_token, std=.02)
517
+ # trunc_normal_(self.mask_token, std=.02)
518
+
519
+ self.apply(self._init_weights)
520
+ self.fix_init_weight()
521
+
522
+ if isinstance(self.head, nn.Linear):
523
+ trunc_normal_(self.head.weight, std=.02)
524
+ self.head.weight.data.mul_(init_scale)
525
+ self.head.bias.data.mul_(init_scale)
526
+
527
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
528
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
529
+
530
+ self.grad_checkpointing = grad_checkpointing
531
+
532
+ def fix_init_weight(self):
533
+ def rescale(param, layer_id):
534
+ param.div_(math.sqrt(2.0 * layer_id))
535
+
536
+ for layer_id, layer in enumerate(self.blocks):
537
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
538
+ if self.naiveswiglu:
539
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
540
+ else:
541
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
542
+
543
+ def get_cast_dtype(self) -> torch.dtype:
544
+ return self.blocks[0].mlp.fc2.weight.dtype
545
+
546
+ def _init_weights(self, m):
547
+ if isinstance(m, nn.Linear):
548
+ trunc_normal_(m.weight, std=.02)
549
+ if m.bias is not None:
550
+ nn.init.constant_(m.bias, 0)
551
+ elif isinstance(m, nn.LayerNorm):
552
+ nn.init.constant_(m.bias, 0)
553
+ nn.init.constant_(m.weight, 1.0)
554
+
555
+ def get_num_layers(self):
556
+ return len(self.blocks)
557
+
558
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
559
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
560
+ for param in self.parameters():
561
+ param.requires_grad = False
562
+
563
+ @torch.jit.ignore
564
+ def set_grad_checkpointing(self, enable=True):
565
+ self.grad_checkpointing = enable
566
+
567
+ @torch.jit.ignore
568
+ def no_weight_decay(self):
569
+ return {'pos_embed', 'cls_token'}
570
+
571
+ def get_classifier(self):
572
+ return self.head
573
+
574
+ def reset_classifier(self, num_classes, global_pool=''):
575
+ self.num_classes = num_classes
576
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
577
+
578
+ def forward_features(self, x, return_all_features=False, return_hidden=False, shuffle=False):
579
+
580
+ x = self.patch_embed(x)
581
+ batch_size, seq_len, _ = x.size()
582
+
583
+ if shuffle:
584
+ idx = torch.randperm(x.shape[1]) + 1
585
+ zero = torch.LongTensor([0, ])
586
+ idx = torch.cat([zero, idx])
587
+ pos_embed = self.pos_embed[:, idx]
588
+
589
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
590
+ x = torch.cat((cls_tokens, x), dim=1)
591
+ if shuffle:
592
+ x = x + pos_embed
593
+ elif self.pos_embed is not None:
594
+ x = x + self.pos_embed
595
+ x = self.pos_drop(x)
596
+
597
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
598
+ if os.getenv('RoPE') == '1':
599
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
600
+ x, patch_indices_keep = self.patch_dropout(x)
601
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
602
+ else:
603
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
604
+ x = self.patch_dropout(x)
605
+ else:
606
+ x = self.patch_dropout(x)
607
+
608
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
609
+ hidden_states = []
610
+ for idx, blk in enumerate(self.blocks):
611
+ if (0 < idx <= 20) and (idx % 4 == 0) and return_hidden:
612
+ hidden_states.append(x)
613
+ if self.grad_checkpointing:
614
+ x = checkpoint(blk, x, (rel_pos_bias,))
615
+ else:
616
+ x = blk(x, rel_pos_bias=rel_pos_bias)
617
+
618
+ if not return_all_features:
619
+ x = self.norm(x)
620
+ if self.fc_norm is not None:
621
+ return self.fc_norm(x.mean(1)), hidden_states
622
+ else:
623
+ return x[:, 0], hidden_states
624
+ return x
625
+
626
+ def forward(self, x, return_all_features=False, return_hidden=False, shuffle=False):
627
+ if return_all_features:
628
+ return self.forward_features(x, return_all_features, return_hidden, shuffle)
629
+ x, hidden_states = self.forward_features(x, return_all_features, return_hidden, shuffle)
630
+ x = self.head(x)
631
+ if return_hidden:
632
+ return x, hidden_states
633
+ return x
eva_clip/factory.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from typing import Optional, Tuple, Union, Dict, Any
9
+ import torch
10
+
11
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
12
+ from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
13
+ get_cast_dtype
14
+ from .openai import load_openai_model
15
+ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
16
+ from .transform import image_transform
17
+ from .tokenizer import HFTokenizer, tokenize
18
+ from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed
19
+
20
+
21
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
22
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
23
+
24
+
25
+ def _natural_key(string_):
26
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
27
+
28
+
29
+ def _rescan_model_configs():
30
+ global _MODEL_CONFIGS
31
+
32
+ config_ext = ('.json',)
33
+ config_files = []
34
+ for config_path in _MODEL_CONFIG_PATHS:
35
+ if config_path.is_file() and config_path.suffix in config_ext:
36
+ config_files.append(config_path)
37
+ elif config_path.is_dir():
38
+ for ext in config_ext:
39
+ config_files.extend(config_path.glob(f'*{ext}'))
40
+
41
+ for cf in config_files:
42
+ with open(cf, "r", encoding="utf8") as f:
43
+ model_cfg = json.load(f)
44
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
45
+ _MODEL_CONFIGS[cf.stem] = model_cfg
46
+
47
+ _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
48
+
49
+
50
+ _rescan_model_configs() # initial populate of model config registry
51
+
52
+
53
+ def list_models():
54
+ """ enumerate available model architectures based on config files """
55
+ return list(_MODEL_CONFIGS.keys())
56
+
57
+
58
+ def add_model_config(path):
59
+ """ add model config path or file and update registry """
60
+ if not isinstance(path, Path):
61
+ path = Path(path)
62
+ _MODEL_CONFIG_PATHS.append(path)
63
+ _rescan_model_configs()
64
+
65
+
66
+ def get_model_config(model_name):
67
+ if model_name in _MODEL_CONFIGS:
68
+ return deepcopy(_MODEL_CONFIGS[model_name])
69
+ else:
70
+ return None
71
+
72
+
73
+ def get_tokenizer(model_name):
74
+ config = get_model_config(model_name)
75
+ tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
76
+ return tokenizer
77
+
78
+
79
+ # loading openai CLIP weights when is_openai=True for training
80
+ def load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: str='model|module|state_dict', is_openai: bool=False, skip_list: list=[]):
81
+ if is_openai:
82
+ model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
83
+ state_dict = model.state_dict()
84
+ for key in ["input_resolution", "context_length", "vocab_size"]:
85
+ state_dict.pop(key, None)
86
+ else:
87
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
88
+ for mk in model_key.split('|'):
89
+ if isinstance(checkpoint, dict) and mk in checkpoint:
90
+ state_dict = checkpoint[mk]
91
+ break
92
+ else:
93
+ state_dict = checkpoint
94
+ if next(iter(state_dict.items()))[0].startswith('module'):
95
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
96
+
97
+ for k in skip_list:
98
+ if k in list(state_dict.keys()):
99
+ logging.info(f"Removing key {k} from pretrained checkpoint")
100
+ del state_dict[k]
101
+
102
+ if os.getenv('RoPE') == '1':
103
+ for k in list(state_dict.keys()):
104
+ if 'freqs_cos' in k or 'freqs_sin' in k:
105
+ del state_dict[k]
106
+ return state_dict
107
+
108
+
109
+
110
+ def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
111
+ state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
112
+ # detect old format and make compatible with new format
113
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
114
+ state_dict = convert_to_custom_text_state_dict(state_dict)
115
+ if 'text.logit_scale' in state_dict and hasattr(model, 'logit_scale'):
116
+ state_dict['logit_scale'] = state_dict['text.logit_scale']
117
+ del state_dict['text.logit_scale']
118
+
119
+ # resize_clip_pos_embed for CLIP and open CLIP
120
+ if 'visual.positional_embedding' in state_dict:
121
+ resize_clip_pos_embed(state_dict, model)
122
+ # specified to eva_vit_model
123
+ elif 'visual.pos_embed' in state_dict:
124
+ resize_evaclip_pos_embed(state_dict, model)
125
+
126
+ # resize_clip_pos_embed(state_dict, model)
127
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
128
+ logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
129
+ return incompatible_keys
130
+
131
+ def load_clip_visual_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
132
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
133
+
134
+ for k in list(state_dict.keys()):
135
+ if not k.startswith('visual.'):
136
+ del state_dict[k]
137
+ for k in list(state_dict.keys()):
138
+ if k.startswith('visual.'):
139
+ new_k = k[7:]
140
+ state_dict[new_k] = state_dict[k]
141
+ del state_dict[k]
142
+ return state_dict
143
+
144
+ def load_clip_text_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
145
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
146
+
147
+ for k in list(state_dict.keys()):
148
+ if k.startswith('visual.'):
149
+ del state_dict[k]
150
+ return state_dict
151
+
152
+ def get_pretrained_tag(pretrained_model):
153
+ pretrained_model = pretrained_model.lower()
154
+ if "laion" in pretrained_model or "open_clip" in pretrained_model:
155
+ return "open_clip"
156
+ elif "openai" in pretrained_model:
157
+ return "clip"
158
+ elif "eva" in pretrained_model and "clip" in pretrained_model:
159
+ return "eva_clip"
160
+ else:
161
+ return "other"
162
+
163
+ def load_pretrained_checkpoint(
164
+ model,
165
+ visual_checkpoint_path,
166
+ text_checkpoint_path,
167
+ strict=True,
168
+ visual_model=None,
169
+ text_model=None,
170
+ model_key="model|module|state_dict",
171
+ skip_list=[]):
172
+ visual_tag = get_pretrained_tag(visual_model)
173
+ text_tag = get_pretrained_tag(text_model)
174
+
175
+ logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
176
+ visual_incompatible_keys, text_incompatible_keys = None, None
177
+ if visual_checkpoint_path:
178
+ if visual_tag == "eva_clip" or visual_tag == "open_clip":
179
+ visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list)
180
+ elif visual_tag == "clip":
181
+ visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
182
+ else:
183
+ visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
184
+
185
+ # resize_clip_pos_embed for CLIP and open CLIP
186
+ if 'positional_embedding' in visual_state_dict:
187
+ resize_visual_pos_embed(visual_state_dict, model)
188
+ # specified to EVA model
189
+ elif 'pos_embed' in visual_state_dict:
190
+ resize_eva_pos_embed(visual_state_dict, model)
191
+
192
+ visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
193
+ logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
194
+ logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
195
+
196
+ if text_checkpoint_path:
197
+ if text_tag == "eva_clip" or text_tag == "open_clip":
198
+ text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
199
+ elif text_tag == "clip":
200
+ text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
201
+ else:
202
+ text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
203
+
204
+ text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
205
+
206
+ logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
207
+ logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
208
+
209
+ return visual_incompatible_keys, text_incompatible_keys
210
+
211
+ def create_model(
212
+ model_name: str,
213
+ pretrained: Optional[str] = None,
214
+ precision: str = 'fp32',
215
+ device: Union[str, torch.device] = 'cpu',
216
+ jit: bool = False,
217
+ force_quick_gelu: bool = False,
218
+ force_custom_clip: bool = False,
219
+ force_patch_dropout: Optional[float] = None,
220
+ pretrained_image: str = '',
221
+ pretrained_text: str = '',
222
+ pretrained_hf: bool = True,
223
+ pretrained_visual_model: str = None,
224
+ pretrained_text_model: str = None,
225
+ cache_dir: Optional[str] = None,
226
+ skip_list: list = [],
227
+ ):
228
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
229
+ if isinstance(device, str):
230
+ device = torch.device(device)
231
+
232
+ if pretrained and pretrained.lower() == 'openai':
233
+ logging.info(f'Loading pretrained {model_name} from OpenAI.')
234
+ model = load_openai_model(
235
+ model_name,
236
+ precision=precision,
237
+ device=device,
238
+ jit=jit,
239
+ cache_dir=cache_dir,
240
+ )
241
+ else:
242
+ model_cfg = get_model_config(model_name)
243
+ if model_cfg is not None:
244
+ logging.info(f'Loaded {model_name} model config.')
245
+ else:
246
+ logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
247
+ raise RuntimeError(f'Model config for {model_name} not found.')
248
+
249
+ if 'rope' in model_cfg.get('vision_cfg', {}):
250
+ if model_cfg['vision_cfg']['rope']:
251
+ os.environ['RoPE'] = "1"
252
+ else:
253
+ os.environ['RoPE'] = "0"
254
+
255
+ if force_quick_gelu:
256
+ # override for use of QuickGELU on non-OpenAI transformer models
257
+ model_cfg["quick_gelu"] = True
258
+
259
+ if force_patch_dropout is not None:
260
+ # override the default patch dropout value
261
+ model_cfg['vision_cfg']["patch_dropout"] = force_patch_dropout
262
+
263
+ cast_dtype = get_cast_dtype(precision)
264
+ custom_clip = model_cfg.pop('custom_text', False) or force_custom_clip or ('hf_model_name' in model_cfg['text_cfg'])
265
+
266
+
267
+ if custom_clip:
268
+ if 'hf_model_name' in model_cfg.get('text_cfg', {}):
269
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
270
+ model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)
271
+ else:
272
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
273
+
274
+ pretrained_cfg = {}
275
+ if pretrained:
276
+ checkpoint_path = ''
277
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
278
+ if pretrained_cfg:
279
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
280
+ elif os.path.exists(pretrained):
281
+ checkpoint_path = pretrained
282
+
283
+ if checkpoint_path:
284
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
285
+ load_checkpoint(model,
286
+ checkpoint_path,
287
+ model_key="model|module|state_dict",
288
+ strict=False
289
+ )
290
+ else:
291
+ error_str = (
292
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
293
+ f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
294
+ logging.warning(error_str)
295
+ raise RuntimeError(error_str)
296
+ else:
297
+ visual_checkpoint_path = ''
298
+ text_checkpoint_path = ''
299
+
300
+ if pretrained_image:
301
+ pretrained_visual_model = pretrained_visual_model.replace('/', '-') # for callers using old naming with / in ViT names
302
+ pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
303
+ if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
304
+ # pretrained weight loading for timm models set via vision_cfg
305
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
306
+ elif pretrained_image_cfg:
307
+ visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
308
+ elif os.path.exists(pretrained_image):
309
+ visual_checkpoint_path = pretrained_image
310
+ else:
311
+ logging.warning(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
312
+ raise RuntimeError(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
313
+
314
+ if pretrained_text:
315
+ pretrained_text_model = pretrained_text_model.replace('/', '-') # for callers using old naming with / in ViT names
316
+ pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
317
+ if pretrained_image_cfg:
318
+ text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
319
+ elif os.path.exists(pretrained_text):
320
+ text_checkpoint_path = pretrained_text
321
+ else:
322
+ logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
323
+ raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
324
+
325
+ if visual_checkpoint_path:
326
+ logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).')
327
+ if text_checkpoint_path:
328
+ logging.info(f'Loading pretrained {model_name}.text weights ({text_checkpoint_path}).')
329
+
330
+ if visual_checkpoint_path or text_checkpoint_path:
331
+ load_pretrained_checkpoint(
332
+ model,
333
+ visual_checkpoint_path,
334
+ text_checkpoint_path,
335
+ strict=False,
336
+ visual_model=pretrained_visual_model,
337
+ text_model=pretrained_text_model,
338
+ model_key="model|module|state_dict",
339
+ skip_list=skip_list
340
+ )
341
+
342
+ if "fp16" in precision or "bf16" in precision:
343
+ logging.info(f'convert precision to {precision}')
344
+ model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16)
345
+
346
+ model.to(device=device)
347
+
348
+ # set image / mean metadata from pretrained_cfg if available, or use default
349
+ model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
350
+ model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
351
+
352
+ if jit:
353
+ model = torch.jit.script(model)
354
+
355
+ return model
356
+
357
+
358
+ def create_model_and_transforms(
359
+ model_name: str,
360
+ pretrained: Optional[str] = None,
361
+ precision: str = 'fp32',
362
+ device: Union[str, torch.device] = 'cpu',
363
+ jit: bool = False,
364
+ force_quick_gelu: bool = False,
365
+ force_custom_clip: bool = False,
366
+ force_patch_dropout: Optional[float] = None,
367
+ pretrained_image: str = '',
368
+ pretrained_text: str = '',
369
+ pretrained_hf: bool = True,
370
+ pretrained_visual_model: str = None,
371
+ pretrained_text_model: str = None,
372
+ image_mean: Optional[Tuple[float, ...]] = None,
373
+ image_std: Optional[Tuple[float, ...]] = None,
374
+ cache_dir: Optional[str] = None,
375
+ skip_list: list = [],
376
+ ):
377
+ model = create_model(
378
+ model_name,
379
+ pretrained,
380
+ precision=precision,
381
+ device=device,
382
+ jit=jit,
383
+ force_quick_gelu=force_quick_gelu,
384
+ force_custom_clip=force_custom_clip,
385
+ force_patch_dropout=force_patch_dropout,
386
+ pretrained_image=pretrained_image,
387
+ pretrained_text=pretrained_text,
388
+ pretrained_hf=pretrained_hf,
389
+ pretrained_visual_model=pretrained_visual_model,
390
+ pretrained_text_model=pretrained_text_model,
391
+ cache_dir=cache_dir,
392
+ skip_list=skip_list,
393
+ )
394
+
395
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
396
+ image_std = image_std or getattr(model.visual, 'image_std', None)
397
+ preprocess_train = image_transform(
398
+ model.visual.image_size,
399
+ is_train=True,
400
+ mean=image_mean,
401
+ std=image_std
402
+ )
403
+ preprocess_val = image_transform(
404
+ model.visual.image_size,
405
+ is_train=False,
406
+ mean=image_mean,
407
+ std=image_std
408
+ )
409
+
410
+ return model, preprocess_train, preprocess_val
411
+
412
+
413
+ def create_transforms(
414
+ model_name: str,
415
+ pretrained: Optional[str] = None,
416
+ precision: str = 'fp32',
417
+ device: Union[str, torch.device] = 'cpu',
418
+ jit: bool = False,
419
+ force_quick_gelu: bool = False,
420
+ force_custom_clip: bool = False,
421
+ force_patch_dropout: Optional[float] = None,
422
+ pretrained_image: str = '',
423
+ pretrained_text: str = '',
424
+ pretrained_hf: bool = True,
425
+ pretrained_visual_model: str = None,
426
+ pretrained_text_model: str = None,
427
+ image_mean: Optional[Tuple[float, ...]] = None,
428
+ image_std: Optional[Tuple[float, ...]] = None,
429
+ cache_dir: Optional[str] = None,
430
+ skip_list: list = [],
431
+ ):
432
+ model = create_model(
433
+ model_name,
434
+ pretrained,
435
+ precision=precision,
436
+ device=device,
437
+ jit=jit,
438
+ force_quick_gelu=force_quick_gelu,
439
+ force_custom_clip=force_custom_clip,
440
+ force_patch_dropout=force_patch_dropout,
441
+ pretrained_image=pretrained_image,
442
+ pretrained_text=pretrained_text,
443
+ pretrained_hf=pretrained_hf,
444
+ pretrained_visual_model=pretrained_visual_model,
445
+ pretrained_text_model=pretrained_text_model,
446
+ cache_dir=cache_dir,
447
+ skip_list=skip_list,
448
+ )
449
+
450
+
451
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
452
+ image_std = image_std or getattr(model.visual, 'image_std', None)
453
+ preprocess_train = image_transform(
454
+ model.visual.image_size,
455
+ is_train=True,
456
+ mean=image_mean,
457
+ std=image_std
458
+ )
459
+ preprocess_val = image_transform(
460
+ model.visual.image_size,
461
+ is_train=False,
462
+ mean=image_mean,
463
+ std=image_std
464
+ )
465
+ del model
466
+
467
+ return preprocess_train, preprocess_val
468
+
469
+ def create_model_from_pretrained(
470
+ model_name: str,
471
+ pretrained: str,
472
+ precision: str = 'fp32',
473
+ device: Union[str, torch.device] = 'cpu',
474
+ jit: bool = False,
475
+ force_quick_gelu: bool = False,
476
+ force_custom_clip: bool = False,
477
+ force_patch_dropout: Optional[float] = None,
478
+ return_transform: bool = True,
479
+ image_mean: Optional[Tuple[float, ...]] = None,
480
+ image_std: Optional[Tuple[float, ...]] = None,
481
+ cache_dir: Optional[str] = None,
482
+ is_frozen: bool = False,
483
+ ):
484
+ if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
485
+ raise RuntimeError(
486
+ f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
487
+ f' Use open_clip.list_pretrained() to find one.')
488
+
489
+ model = create_model(
490
+ model_name,
491
+ pretrained,
492
+ precision=precision,
493
+ device=device,
494
+ jit=jit,
495
+ force_quick_gelu=force_quick_gelu,
496
+ force_custom_clip=force_custom_clip,
497
+ force_patch_dropout=force_patch_dropout,
498
+ cache_dir=cache_dir,
499
+ )
500
+
501
+ if is_frozen:
502
+ for param in model.parameters():
503
+ param.requires_grad = False
504
+
505
+ if not return_transform:
506
+ return model
507
+
508
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
509
+ image_std = image_std or getattr(model.visual, 'image_std', None)
510
+ preprocess = image_transform(
511
+ model.visual.image_size,
512
+ is_train=False,
513
+ mean=image_mean,
514
+ std=image_std
515
+ )
516
+
517
+ return model, preprocess
eva_clip/hf_configs.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF architecture dict:
2
+ arch_dict = {
3
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
+ "roberta": {
5
+ "config_names": {
6
+ "context_length": "max_position_embeddings",
7
+ "vocab_size": "vocab_size",
8
+ "width": "hidden_size",
9
+ "heads": "num_attention_heads",
10
+ "layers": "num_hidden_layers",
11
+ "layer_attr": "layer",
12
+ "token_embeddings_attr": "embeddings"
13
+ },
14
+ "pooler": "mean_pooler",
15
+ },
16
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
+ "xlm-roberta": {
18
+ "config_names": {
19
+ "context_length": "max_position_embeddings",
20
+ "vocab_size": "vocab_size",
21
+ "width": "hidden_size",
22
+ "heads": "num_attention_heads",
23
+ "layers": "num_hidden_layers",
24
+ "layer_attr": "layer",
25
+ "token_embeddings_attr": "embeddings"
26
+ },
27
+ "pooler": "mean_pooler",
28
+ },
29
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
+ "mt5": {
31
+ "config_names": {
32
+ # unlimited seqlen
33
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
+ "context_length": "",
36
+ "vocab_size": "vocab_size",
37
+ "width": "d_model",
38
+ "heads": "num_heads",
39
+ "layers": "num_layers",
40
+ "layer_attr": "block",
41
+ "token_embeddings_attr": "embed_tokens"
42
+ },
43
+ "pooler": "mean_pooler",
44
+ },
45
+ "bert": {
46
+ "config_names": {
47
+ "context_length": "max_position_embeddings",
48
+ "vocab_size": "vocab_size",
49
+ "width": "hidden_size",
50
+ "heads": "num_attention_heads",
51
+ "layers": "num_hidden_layers",
52
+ "layer_attr": "layer",
53
+ "token_embeddings_attr": "embeddings"
54
+ },
55
+ "pooler": "mean_pooler",
56
+ }
57
+ }
eva_clip/hf_model.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ huggingface model adapter
2
+
3
+ Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
+ """
5
+
6
+ import re
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ from torch import TensorType
12
+ try:
13
+ import transformers
14
+ from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
15
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
16
+ BaseModelOutputWithPoolingAndCrossAttentions
17
+ except ImportError as e:
18
+ transformers = None
19
+
20
+
21
+ class BaseModelOutput:
22
+ pass
23
+
24
+
25
+ class PretrainedConfig:
26
+ pass
27
+
28
+ from .hf_configs import arch_dict
29
+
30
+ # utils
31
+ def _camel2snake(s):
32
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
33
+
34
+ # TODO: ?last - for gpt-like models
35
+ _POOLERS = {}
36
+
37
+ def register_pooler(cls):
38
+ """Decorator registering pooler class"""
39
+ _POOLERS[_camel2snake(cls.__name__)] = cls
40
+ return cls
41
+
42
+
43
+ @register_pooler
44
+ class MeanPooler(nn.Module):
45
+ """Mean pooling"""
46
+ def forward(self, x:BaseModelOutput, attention_mask:TensorType):
47
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
48
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
49
+
50
+ @register_pooler
51
+ class MaxPooler(nn.Module):
52
+ """Max pooling"""
53
+ def forward(self, x:BaseModelOutput, attention_mask:TensorType):
54
+ masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
55
+ return masked_output.max(1).values
56
+
57
+ @register_pooler
58
+ class ClsPooler(nn.Module):
59
+ """CLS token pooling"""
60
+ def __init__(self, use_pooler_output=True):
61
+ super().__init__()
62
+ self.cls_token_position = 0
63
+ self.use_pooler_output = use_pooler_output
64
+
65
+ def forward(self, x:BaseModelOutput, attention_mask:TensorType):
66
+
67
+ if (self.use_pooler_output and
68
+ isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
69
+ (x.pooler_output is not None)
70
+ ):
71
+ return x.pooler_output
72
+
73
+ return x.last_hidden_state[:, self.cls_token_position, :]
74
+
75
+ class HFTextEncoder(nn.Module):
76
+ """HuggingFace model adapter"""
77
+ def __init__(
78
+ self,
79
+ model_name_or_path: str,
80
+ output_dim: int,
81
+ tokenizer_name: str = None,
82
+ config: PretrainedConfig = None,
83
+ pooler_type: str = None,
84
+ proj: str = None,
85
+ pretrained: bool = True,
86
+ masked_language_modeling: bool = False):
87
+ super().__init__()
88
+
89
+ self.output_dim = output_dim
90
+
91
+ # TODO: find better way to get this information
92
+ uses_transformer_pooler = (pooler_type == "cls_pooler")
93
+
94
+ if transformers is None:
95
+ raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
96
+ if config is None:
97
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
98
+ if masked_language_modeling:
99
+ create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (
100
+ AutoModelForMaskedLM.from_config, self.config)
101
+ else:
102
+ create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
103
+ AutoModel.from_config, self.config)
104
+ # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
105
+ if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
106
+ self.transformer = create_func(model_args)
107
+ self.transformer = self.transformer.encoder
108
+ else:
109
+ self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
110
+ else:
111
+ self.config = config
112
+ if masked_language_modeling:
113
+ self.transformer = AutoModelForMaskedLM.from_config(config)
114
+ else:
115
+ self.transformer = AutoModel.from_config(config)
116
+
117
+ if pooler_type is None: # get default arch pooler
118
+ self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
119
+ else:
120
+ self.pooler = _POOLERS[pooler_type]()
121
+
122
+ d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
123
+ if (d_model == output_dim) and (proj is None): # do we always need a proj?
124
+ self.proj = nn.Identity()
125
+ elif proj == 'linear':
126
+ self.proj = nn.Linear(d_model, output_dim, bias=False)
127
+ elif proj == 'mlp':
128
+ hidden_size = (d_model + output_dim) // 2
129
+ self.proj = nn.Sequential(
130
+ nn.Linear(d_model, hidden_size, bias=False),
131
+ nn.GELU(),
132
+ nn.Linear(hidden_size, output_dim, bias=False),
133
+ )
134
+
135
+ # self.itm_proj = nn.Linear(d_model, 2, bias=False)
136
+ # self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
137
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
138
+
139
+ # def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
140
+ # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
141
+ # attn_mask = (x != self.config.pad_token_id).long()
142
+ # out = self.transformer(
143
+ # input_ids=x,
144
+ # attention_mask=attn_mask,
145
+ # encoder_hidden_states = image_embeds,
146
+ # encoder_attention_mask = image_atts,
147
+ # )
148
+ # pooled_out = self.pooler(out, attn_mask)
149
+
150
+ # return self.itm_proj(pooled_out)
151
+
152
+ def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
153
+ if masked_indices is None:
154
+ masked_indices = torch.bernoulli(probability_matrix).bool()
155
+
156
+ masked_indices[input_ids == self.tokenizer.pad_token_id] = False
157
+ masked_indices[input_ids == self.tokenizer.cls_token_id] = False
158
+
159
+ if targets is not None:
160
+ targets[~masked_indices] = -100 # We only compute loss on masked tokens
161
+
162
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
163
+ indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
164
+ input_ids[indices_replaced] = self.tokenizer.mask_token_id
165
+
166
+ # 10% of the time, we replace masked input tokens with random word
167
+ indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
168
+ random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
169
+ input_ids[indices_random] = random_words[indices_random]
170
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
171
+
172
+ if targets is not None:
173
+ return input_ids, targets
174
+ else:
175
+ return input_ids
176
+
177
+ def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
178
+ labels = input_ids.clone()
179
+ attn_mask = (input_ids != self.config.pad_token_id).long()
180
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device)
181
+ vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
182
+ probability_matrix = torch.full(labels.shape, mlm_probability)
183
+ input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels,
184
+ probability_matrix = probability_matrix)
185
+ mlm_output = self.transformer(input_ids,
186
+ attention_mask = attn_mask,
187
+ encoder_hidden_states = image_embeds,
188
+ encoder_attention_mask = image_atts,
189
+ return_dict = True,
190
+ labels = labels,
191
+ )
192
+ return mlm_output.loss
193
+ # mlm_output = self.transformer(input_ids,
194
+ # attention_mask = attn_mask,
195
+ # encoder_hidden_states = image_embeds,
196
+ # encoder_attention_mask = image_atts,
197
+ # return_dict = True,
198
+ # ).last_hidden_state
199
+ # logits = self.mlm_proj(mlm_output)
200
+
201
+ # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
202
+ # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
203
+ # labels = labels[:, 1:].contiguous().view(-1)
204
+
205
+ # mlm_loss = F.cross_entropy(
206
+ # logits,
207
+ # labels,
208
+ # # label_smoothing=0.1,
209
+ # )
210
+ # return mlm_loss
211
+
212
+
213
+ def forward(self, x:TensorType) -> TensorType:
214
+ attn_mask = (x != self.config.pad_token_id).long()
215
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
216
+ pooled_out = self.pooler(out, attn_mask)
217
+
218
+ return self.proj(pooled_out)
219
+
220
+ def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
221
+ if not unlocked_layers: # full freezing
222
+ for n, p in self.transformer.named_parameters():
223
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
224
+ return
225
+
226
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
227
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
228
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
229
+ embeddings = getattr(
230
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
231
+ modules = [embeddings, *layer_list][:-unlocked_layers]
232
+ # freeze layers
233
+ for module in modules:
234
+ for n, p in module.named_parameters():
235
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
236
+
237
+
238
+ @torch.jit.ignore
239
+ def set_grad_checkpointing(self, enable=True):
240
+ self.transformer.gradient_checkpointing_enable()
241
+
242
+ def get_num_layers(self):
243
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
244
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
245
+ return len(layer_list)
246
+
247
+ def init_parameters(self):
248
+ pass
eva_clip/loss.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ try:
7
+ import torch.distributed.nn
8
+ from torch import distributed as dist
9
+ has_distributed = True
10
+ except ImportError:
11
+ has_distributed = False
12
+
13
+ try:
14
+ import horovod.torch as hvd
15
+ except ImportError:
16
+ hvd = None
17
+
18
+ from timm.loss import LabelSmoothingCrossEntropy
19
+
20
+
21
+ def gather_features(
22
+ image_features,
23
+ text_features,
24
+ local_loss=False,
25
+ gather_with_grad=False,
26
+ rank=0,
27
+ world_size=1,
28
+ use_horovod=False
29
+ ):
30
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
31
+ if use_horovod:
32
+ assert hvd is not None, 'Please install horovod'
33
+ if gather_with_grad:
34
+ all_image_features = hvd.allgather(image_features)
35
+ all_text_features = hvd.allgather(text_features)
36
+ else:
37
+ with torch.no_grad():
38
+ all_image_features = hvd.allgather(image_features)
39
+ all_text_features = hvd.allgather(text_features)
40
+ if not local_loss:
41
+ # ensure grads for local rank when all_* features don't have a gradient
42
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
43
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
44
+ gathered_image_features[rank] = image_features
45
+ gathered_text_features[rank] = text_features
46
+ all_image_features = torch.cat(gathered_image_features, dim=0)
47
+ all_text_features = torch.cat(gathered_text_features, dim=0)
48
+ else:
49
+ # We gather tensors from all gpus
50
+ if gather_with_grad:
51
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
52
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
53
+ # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
54
+ # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
55
+ else:
56
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
57
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
58
+ dist.all_gather(gathered_image_features, image_features)
59
+ dist.all_gather(gathered_text_features, text_features)
60
+ if not local_loss:
61
+ # ensure grads for local rank when all_* features don't have a gradient
62
+ gathered_image_features[rank] = image_features
63
+ gathered_text_features[rank] = text_features
64
+ all_image_features = torch.cat(gathered_image_features, dim=0)
65
+ all_text_features = torch.cat(gathered_text_features, dim=0)
66
+
67
+ return all_image_features, all_text_features
68
+
69
+
70
+ class ClipLoss(nn.Module):
71
+
72
+ def __init__(
73
+ self,
74
+ local_loss=False,
75
+ gather_with_grad=False,
76
+ cache_labels=False,
77
+ rank=0,
78
+ world_size=1,
79
+ use_horovod=False,
80
+ smoothing=0.,
81
+ ):
82
+ super().__init__()
83
+ self.local_loss = local_loss
84
+ self.gather_with_grad = gather_with_grad
85
+ self.cache_labels = cache_labels
86
+ self.rank = rank
87
+ self.world_size = world_size
88
+ self.use_horovod = use_horovod
89
+ self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
90
+
91
+ # cache state
92
+ self.prev_num_logits = 0
93
+ self.labels = {}
94
+
95
+ def forward(self, image_features, text_features, logit_scale=1.):
96
+ device = image_features.device
97
+ if self.world_size > 1:
98
+ all_image_features, all_text_features = gather_features(
99
+ image_features, text_features,
100
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
101
+
102
+ if self.local_loss:
103
+ logits_per_image = logit_scale * image_features @ all_text_features.T
104
+ logits_per_text = logit_scale * text_features @ all_image_features.T
105
+ else:
106
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
107
+ logits_per_text = logits_per_image.T
108
+ else:
109
+ logits_per_image = logit_scale * image_features @ text_features.T
110
+ logits_per_text = logit_scale * text_features @ image_features.T
111
+ # calculated ground-truth and cache if enabled
112
+ num_logits = logits_per_image.shape[0]
113
+ if self.prev_num_logits != num_logits or device not in self.labels:
114
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
115
+ if self.world_size > 1 and self.local_loss:
116
+ labels = labels + num_logits * self.rank
117
+ if self.cache_labels:
118
+ self.labels[device] = labels
119
+ self.prev_num_logits = num_logits
120
+ else:
121
+ labels = self.labels[device]
122
+
123
+ if self.label_smoothing_cross_entropy:
124
+ total_loss = (
125
+ self.label_smoothing_cross_entropy(logits_per_image, labels) +
126
+ self.label_smoothing_cross_entropy(logits_per_text, labels)
127
+ ) / 2
128
+ else:
129
+ total_loss = (
130
+ F.cross_entropy(logits_per_image, labels) +
131
+ F.cross_entropy(logits_per_text, labels)
132
+ ) / 2
133
+
134
+ acc = None
135
+ i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
136
+ t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
137
+ acc = {"i2t": i2t_acc, "t2i": t2i_acc}
138
+ return total_loss, acc
eva_clip/model.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import os
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple, Union
8
+ from functools import partial
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+
15
+ try:
16
+ from .hf_model import HFTextEncoder
17
+ except:
18
+ HFTextEncoder = None
19
+ from .modified_resnet import ModifiedResNet
20
+ # from .timm_model import TimmModel
21
+ from .eva_vit_model import EVAVisionTransformer
22
+ from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
23
+
24
+ try:
25
+ from apex.normalization import FusedLayerNorm
26
+ except:
27
+ FusedLayerNorm = LayerNorm
28
+ print("Please 'pip install apex'")
29
+
30
+ try:
31
+ import xformers.ops as xops
32
+ except ImportError:
33
+ xops = None
34
+ print("Please 'pip install xformers'")
35
+
36
+ @dataclass
37
+ class CLIPVisionCfg:
38
+ layers: Union[Tuple[int, int, int, int], int] = 12
39
+ width: int = 768
40
+ head_width: int = 64
41
+ mlp_ratio: float = 4.0
42
+ patch_size: int = 16
43
+ image_size: Union[Tuple[int, int], int] = 224
44
+ ls_init_value: Optional[float] = None # layer scale initial value
45
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
46
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
47
+ drop_path_rate: Optional[float] = None # drop path rate
48
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
49
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
50
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
51
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
52
+ timm_proj_bias: bool = False # enable bias final projection
53
+ eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
54
+ qkv_bias: bool = True
55
+ fusedLN: bool = False
56
+ xattn: bool = False
57
+ postnorm: bool = False
58
+ rope: bool = False
59
+ pt_hw_seq_len: int = 16 # 224/14
60
+ intp_freq: bool = False
61
+ naiveswiglu: bool = False
62
+ subln: bool = False
63
+
64
+
65
+ @dataclass
66
+ class CLIPTextCfg:
67
+ context_length: int = 77
68
+ vocab_size: int = 49408
69
+ width: int = 512
70
+ heads: int = 8
71
+ layers: int = 12
72
+ ls_init_value: Optional[float] = None # layer scale initial value
73
+ hf_model_name: str = None
74
+ hf_tokenizer_name: str = None
75
+ hf_model_pretrained: bool = True
76
+ proj: str = 'mlp'
77
+ pooler_type: str = 'mean_pooler'
78
+ masked_language_modeling: bool = False
79
+ fusedLN: bool = False
80
+ xattn: bool = False
81
+ attn_mask: bool = True
82
+
83
+ def get_cast_dtype(precision: str):
84
+ cast_dtype = None
85
+ if precision == 'bf16':
86
+ cast_dtype = torch.bfloat16
87
+ elif precision == 'fp16':
88
+ cast_dtype = torch.float16
89
+ return cast_dtype
90
+
91
+
92
+ def _build_vision_tower(
93
+ embed_dim: int,
94
+ vision_cfg: CLIPVisionCfg,
95
+ quick_gelu: bool = False,
96
+ cast_dtype: Optional[torch.dtype] = None
97
+ ):
98
+ if isinstance(vision_cfg, dict):
99
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
100
+
101
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
102
+ # memory efficient in recent PyTorch releases (>= 1.10).
103
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
104
+ act_layer = QuickGELU if quick_gelu else nn.GELU
105
+
106
+ if vision_cfg.eva_model_name:
107
+ vision_heads = vision_cfg.width // vision_cfg.head_width
108
+ norm_layer = LayerNorm
109
+
110
+ visual = EVAVisionTransformer(
111
+ img_size=vision_cfg.image_size,
112
+ patch_size=vision_cfg.patch_size,
113
+ num_classes=embed_dim,
114
+ use_mean_pooling=vision_cfg.global_average_pool, #False
115
+ init_values=vision_cfg.ls_init_value,
116
+ patch_dropout=vision_cfg.patch_dropout,
117
+ embed_dim=vision_cfg.width,
118
+ depth=vision_cfg.layers,
119
+ num_heads=vision_heads,
120
+ mlp_ratio=vision_cfg.mlp_ratio,
121
+ qkv_bias=vision_cfg.qkv_bias,
122
+ drop_path_rate=vision_cfg.drop_path_rate,
123
+ norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
124
+ xattn=vision_cfg.xattn,
125
+ rope=vision_cfg.rope,
126
+ postnorm=vision_cfg.postnorm,
127
+ pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14
128
+ intp_freq= vision_cfg.intp_freq,
129
+ naiveswiglu= vision_cfg.naiveswiglu,
130
+ subln= vision_cfg.subln
131
+ )
132
+ elif vision_cfg.timm_model_name:
133
+ # visual = TimmModel(
134
+ # vision_cfg.timm_model_name,
135
+ # pretrained=vision_cfg.timm_model_pretrained,
136
+ # pool=vision_cfg.timm_pool,
137
+ # proj=vision_cfg.timm_proj,
138
+ # proj_bias=vision_cfg.timm_proj_bias,
139
+ # embed_dim=embed_dim,
140
+ # image_size=vision_cfg.image_size
141
+ # )
142
+ # act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
143
+ raise ValueError
144
+ elif isinstance(vision_cfg.layers, (tuple, list)):
145
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
146
+ visual = ModifiedResNet(
147
+ layers=vision_cfg.layers,
148
+ output_dim=embed_dim,
149
+ heads=vision_heads,
150
+ image_size=vision_cfg.image_size,
151
+ width=vision_cfg.width
152
+ )
153
+ else:
154
+ vision_heads = vision_cfg.width // vision_cfg.head_width
155
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
156
+ visual = VisionTransformer(
157
+ image_size=vision_cfg.image_size,
158
+ patch_size=vision_cfg.patch_size,
159
+ width=vision_cfg.width,
160
+ layers=vision_cfg.layers,
161
+ heads=vision_heads,
162
+ mlp_ratio=vision_cfg.mlp_ratio,
163
+ ls_init_value=vision_cfg.ls_init_value,
164
+ patch_dropout=vision_cfg.patch_dropout,
165
+ global_average_pool=vision_cfg.global_average_pool,
166
+ output_dim=embed_dim,
167
+ act_layer=act_layer,
168
+ norm_layer=norm_layer,
169
+ )
170
+
171
+ return visual
172
+
173
+
174
+ def _build_text_tower(
175
+ embed_dim: int,
176
+ text_cfg: CLIPTextCfg,
177
+ quick_gelu: bool = False,
178
+ cast_dtype: Optional[torch.dtype] = None,
179
+ ):
180
+ if isinstance(text_cfg, dict):
181
+ text_cfg = CLIPTextCfg(**text_cfg)
182
+
183
+ if text_cfg.hf_model_name:
184
+ text = HFTextEncoder(
185
+ text_cfg.hf_model_name,
186
+ output_dim=embed_dim,
187
+ tokenizer_name=text_cfg.hf_tokenizer_name,
188
+ proj=text_cfg.proj,
189
+ pooler_type=text_cfg.pooler_type,
190
+ masked_language_modeling=text_cfg.masked_language_modeling
191
+ )
192
+ else:
193
+ act_layer = QuickGELU if quick_gelu else nn.GELU
194
+ norm_layer = LayerNorm
195
+
196
+ text = TextTransformer(
197
+ context_length=text_cfg.context_length,
198
+ vocab_size=text_cfg.vocab_size,
199
+ width=text_cfg.width,
200
+ heads=text_cfg.heads,
201
+ layers=text_cfg.layers,
202
+ ls_init_value=text_cfg.ls_init_value,
203
+ output_dim=embed_dim,
204
+ act_layer=act_layer,
205
+ norm_layer= FusedLayerNorm if text_cfg.fusedLN else norm_layer,
206
+ xattn=text_cfg.xattn,
207
+ attn_mask=text_cfg.attn_mask,
208
+ )
209
+ return text
210
+
211
+ class CLIP(nn.Module):
212
+ def __init__(
213
+ self,
214
+ embed_dim: int,
215
+ vision_cfg: CLIPVisionCfg,
216
+ text_cfg: CLIPTextCfg,
217
+ quick_gelu: bool = False,
218
+ cast_dtype: Optional[torch.dtype] = None,
219
+ ):
220
+ super().__init__()
221
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
222
+
223
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
224
+ self.transformer = text.transformer
225
+ self.vocab_size = text.vocab_size
226
+ self.token_embedding = text.token_embedding
227
+ self.positional_embedding = text.positional_embedding
228
+ self.ln_final = text.ln_final
229
+ self.text_projection = text.text_projection
230
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
231
+
232
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
233
+
234
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
235
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
236
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
237
+
238
+ @torch.jit.ignore
239
+ def set_grad_checkpointing(self, enable=True):
240
+ self.visual.set_grad_checkpointing(enable)
241
+ self.transformer.grad_checkpointing = enable
242
+
243
+ @torch.jit.ignore
244
+ def no_weight_decay(self):
245
+ return {'logit_scale'}
246
+
247
+ def encode_image(self, image, normalize: bool = False):
248
+ features = self.visual(image)
249
+ return F.normalize(features, dim=-1) if normalize else features
250
+
251
+ def encode_text(self, text, normalize: bool = False):
252
+ cast_dtype = self.transformer.get_cast_dtype()
253
+
254
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
255
+
256
+ x = x + self.positional_embedding.to(cast_dtype)
257
+ x = x.permute(1, 0, 2) # NLD -> LND
258
+ x = self.transformer(x, attn_mask=self.attn_mask)
259
+ x = x.permute(1, 0, 2) # LND -> NLD
260
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
261
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
262
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
263
+ return F.normalize(x, dim=-1) if normalize else x
264
+
265
+ def forward(self, image, text):
266
+ image_features = self.encode_image(image, normalize=True)
267
+ text_features = self.encode_text(text, normalize=True)
268
+ return image_features, text_features, self.logit_scale.exp()
269
+
270
+
271
+ class CustomCLIP(nn.Module):
272
+ def __init__(
273
+ self,
274
+ embed_dim: int,
275
+ vision_cfg: CLIPVisionCfg,
276
+ text_cfg: CLIPTextCfg,
277
+ quick_gelu: bool = False,
278
+ cast_dtype: Optional[torch.dtype] = None,
279
+ itm_task: bool = False,
280
+ ):
281
+ super().__init__()
282
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
283
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
284
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
285
+
286
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
287
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
288
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
289
+
290
+ def lock_text_tower(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
291
+ self.text.lock(unlocked_layers, freeze_layer_norm)
292
+
293
+ @torch.jit.ignore
294
+ def set_grad_checkpointing(self, enable=True):
295
+ self.visual.set_grad_checkpointing(enable)
296
+ self.text.set_grad_checkpointing(enable)
297
+
298
+ @torch.jit.ignore
299
+ def no_weight_decay(self):
300
+ return {'logit_scale'}
301
+
302
+ def encode_image(self, image, normalize: bool = False):
303
+ features = self.visual(image)
304
+ return F.normalize(features, dim=-1) if normalize else features
305
+
306
+ def encode_text(self, text, normalize: bool = False):
307
+ features = self.text(text)
308
+ return F.normalize(features, dim=-1) if normalize else features
309
+
310
+ def forward(self, image, text):
311
+ image_features = self.encode_image(image, normalize=True)
312
+ text_features = self.encode_text(text, normalize=True)
313
+ return image_features, text_features, self.logit_scale.exp()
314
+
315
+
316
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
317
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
318
+
319
+ def _convert_weights(l):
320
+
321
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
322
+ l.weight.data = l.weight.data.to(dtype)
323
+ if l.bias is not None:
324
+ l.bias.data = l.bias.data.to(dtype)
325
+
326
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
327
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
328
+ tensor = getattr(l, attr, None)
329
+ if tensor is not None:
330
+ tensor.data = tensor.data.to(dtype)
331
+
332
+ if isinstance(l, nn.Parameter):
333
+ l.data = l.data.to(dtype)
334
+
335
+ for name in ["text_projection", "proj"]:
336
+ if hasattr(l, name) and isinstance(l, nn.Parameter):
337
+ attr = getattr(l, name, None)
338
+ if attr is not None:
339
+ attr.data = attr.data.to(dtype)
340
+
341
+ model.apply(_convert_weights)
342
+
343
+
344
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
345
+
346
+
347
+ # used to maintain checkpoint compatibility
348
+ def convert_to_custom_text_state_dict(state_dict: dict):
349
+ if 'text_projection' in state_dict:
350
+ # old format state_dict, move text tower -> .text
351
+ new_state_dict = {}
352
+ for k, v in state_dict.items():
353
+ if any(k.startswith(p) for p in (
354
+ 'text_projection',
355
+ 'positional_embedding',
356
+ 'token_embedding',
357
+ 'transformer',
358
+ 'ln_final',
359
+ 'logit_scale'
360
+ )):
361
+ k = 'text.' + k
362
+ new_state_dict[k] = v
363
+ return new_state_dict
364
+ return state_dict
365
+
366
+
367
+ def build_model_from_openai_state_dict(
368
+ state_dict: dict,
369
+ quick_gelu=True,
370
+ cast_dtype=torch.float16,
371
+ ):
372
+ vit = "visual.proj" in state_dict
373
+
374
+ if vit:
375
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
376
+ vision_layers = len(
377
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
378
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
379
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
380
+ image_size = vision_patch_size * grid_size
381
+ else:
382
+ counts: list = [
383
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
384
+ vision_layers = tuple(counts)
385
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
386
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
387
+ vision_patch_size = None
388
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
389
+ image_size = output_width * 32
390
+
391
+ embed_dim = state_dict["text_projection"].shape[1]
392
+ context_length = state_dict["positional_embedding"].shape[0]
393
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
394
+ transformer_width = state_dict["ln_final.weight"].shape[0]
395
+ transformer_heads = transformer_width // 64
396
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
397
+
398
+ vision_cfg = CLIPVisionCfg(
399
+ layers=vision_layers,
400
+ width=vision_width,
401
+ patch_size=vision_patch_size,
402
+ image_size=image_size,
403
+ )
404
+ text_cfg = CLIPTextCfg(
405
+ context_length=context_length,
406
+ vocab_size=vocab_size,
407
+ width=transformer_width,
408
+ heads=transformer_heads,
409
+ layers=transformer_layers
410
+ )
411
+ model = CLIP(
412
+ embed_dim,
413
+ vision_cfg=vision_cfg,
414
+ text_cfg=text_cfg,
415
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
416
+ cast_dtype=cast_dtype,
417
+ )
418
+
419
+ for key in ["input_resolution", "context_length", "vocab_size"]:
420
+ state_dict.pop(key, None)
421
+
422
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
423
+ model.load_state_dict(state_dict)
424
+ return model.eval()
425
+
426
+
427
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
428
+ model.eval()
429
+ image_size = model.visual.image_size
430
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
431
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
432
+ model = torch.jit.trace_module(
433
+ model,
434
+ inputs=dict(
435
+ forward=(example_images, example_text),
436
+ encode_text=(example_text,),
437
+ encode_image=(example_images,)
438
+ ))
439
+ model.visual.image_size = image_size
440
+ return model
eva_clip/model_configs/EVA01-CLIP-B-16.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 16,
8
+ "eva_model_name": "eva-clip-b-16",
9
+ "ls_init_value": 0.1,
10
+ "drop_path_rate": 0.0
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 77,
14
+ "vocab_size": 49408,
15
+ "width": 512,
16
+ "heads": 8,
17
+ "layers": 12
18
+ }
19
+ }
eva_clip/model_configs/EVA01-CLIP-g-14-plus.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 40,
6
+ "width": 1408,
7
+ "head_width": 88,
8
+ "mlp_ratio": 4.3637,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-g-14-x",
11
+ "drop_path_rate": 0,
12
+ "xattn": true,
13
+ "fusedLN": true
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 1024,
19
+ "heads": 16,
20
+ "layers": 24,
21
+ "xattn": false,
22
+ "fusedLN": true
23
+ }
24
+ }
eva_clip/model_configs/EVA01-CLIP-g-14.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 40,
6
+ "width": 1408,
7
+ "head_width": 88,
8
+ "mlp_ratio": 4.3637,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-g-14-x",
11
+ "drop_path_rate": 0.4,
12
+ "xattn": true,
13
+ "fusedLN": true
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 768,
19
+ "heads": 12,
20
+ "layers": 12,
21
+ "xattn": false,
22
+ "fusedLN": true
23
+ }
24
+ }
eva_clip/model_configs/EVA02-CLIP-B-16.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "head_width": 64,
8
+ "patch_size": 16,
9
+ "mlp_ratio": 2.6667,
10
+ "eva_model_name": "eva-clip-b-16-X",
11
+ "drop_path_rate": 0.0,
12
+ "xattn": true,
13
+ "fusedLN": true,
14
+ "rope": true,
15
+ "pt_hw_seq_len": 16,
16
+ "intp_freq": true,
17
+ "naiveswiglu": true,
18
+ "subln": true
19
+ },
20
+ "text_cfg": {
21
+ "context_length": 77,
22
+ "vocab_size": 49408,
23
+ "width": 512,
24
+ "heads": 8,
25
+ "layers": 12,
26
+ "xattn": true,
27
+ "fusedLN": true
28
+ }
29
+ }
eva_clip/model_configs/EVA02-CLIP-L-14-336.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 336,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "drop_path_rate": 0,
8
+ "head_width": 64,
9
+ "mlp_ratio": 2.6667,
10
+ "patch_size": 14,
11
+ "eva_model_name": "eva-clip-l-14-336",
12
+ "xattn": true,
13
+ "fusedLN": true,
14
+ "rope": true,
15
+ "pt_hw_seq_len": 16,
16
+ "intp_freq": true,
17
+ "naiveswiglu": true,
18
+ "subln": true
19
+ },
20
+ "text_cfg": {
21
+ "context_length": 77,
22
+ "vocab_size": 49408,
23
+ "width": 768,
24
+ "heads": 12,
25
+ "layers": 12,
26
+ "xattn": false,
27
+ "fusedLN": true
28
+ }
29
+ }
eva_clip/model_configs/EVA02-CLIP-L-14.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "drop_path_rate": 0,
8
+ "head_width": 64,
9
+ "mlp_ratio": 2.6667,
10
+ "patch_size": 14,
11
+ "eva_model_name": "eva-clip-l-14",
12
+ "xattn": true,
13
+ "fusedLN": true,
14
+ "rope": true,
15
+ "pt_hw_seq_len": 16,
16
+ "intp_freq": true,
17
+ "naiveswiglu": true,
18
+ "subln": true
19
+ },
20
+ "text_cfg": {
21
+ "context_length": 77,
22
+ "vocab_size": 49408,
23
+ "width": 768,
24
+ "heads": 12,
25
+ "layers": 12,
26
+ "xattn": false,
27
+ "fusedLN": true
28
+ }
29
+ }
eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 64,
6
+ "width": 1792,
7
+ "head_width": 112,
8
+ "mlp_ratio": 8.571428571428571,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-4b-14-x",
11
+ "drop_path_rate": 0,
12
+ "xattn": true,
13
+ "postnorm": true,
14
+ "fusedLN": true
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 1280,
20
+ "heads": 20,
21
+ "layers": 32,
22
+ "xattn": false,
23
+ "fusedLN": true
24
+ }
25
+ }
eva_clip/model_configs/EVA02-CLIP-bigE-14.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 64,
6
+ "width": 1792,
7
+ "head_width": 112,
8
+ "mlp_ratio": 8.571428571428571,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-4b-14-x",
11
+ "drop_path_rate": 0,
12
+ "xattn": true,
13
+ "postnorm": true,
14
+ "fusedLN": true
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 1024,
20
+ "heads": 16,
21
+ "layers": 24,
22
+ "xattn": false,
23
+ "fusedLN": true
24
+ }
25
+ }
eva_clip/modified_resnet.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from eva_clip.utils import freeze_batch_norm_2d
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.act1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.act2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.act3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.act1(self.bn1(self.conv1(x)))
46
+ out = self.act2(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.act3(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
+ x, _ = F.multi_head_attention_forward(
73
+ query=x, key=x, value=x,
74
+ embed_dim_to_check=x.shape[-1],
75
+ num_heads=self.num_heads,
76
+ q_proj_weight=self.q_proj.weight,
77
+ k_proj_weight=self.k_proj.weight,
78
+ v_proj_weight=self.v_proj.weight,
79
+ in_proj_weight=None,
80
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
+ bias_k=None,
82
+ bias_v=None,
83
+ add_zero_attn=False,
84
+ dropout_p=0.,
85
+ out_proj_weight=self.c_proj.weight,
86
+ out_proj_bias=self.c_proj.bias,
87
+ use_separate_proj_weight=True,
88
+ training=self.training,
89
+ need_weights=False
90
+ )
91
+
92
+ return x[0]
93
+
94
+
95
+ class ModifiedResNet(nn.Module):
96
+ """
97
+ A ResNet class that is similar to torchvision's but contains the following changes:
98
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100
+ - The final pooling layer is a QKV attention instead of an average pool
101
+ """
102
+
103
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
104
+ super().__init__()
105
+ self.output_dim = output_dim
106
+ self.image_size = image_size
107
+
108
+ # the 3-layer stem
109
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110
+ self.bn1 = nn.BatchNorm2d(width // 2)
111
+ self.act1 = nn.ReLU(inplace=True)
112
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113
+ self.bn2 = nn.BatchNorm2d(width // 2)
114
+ self.act2 = nn.ReLU(inplace=True)
115
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
116
+ self.bn3 = nn.BatchNorm2d(width)
117
+ self.act3 = nn.ReLU(inplace=True)
118
+ self.avgpool = nn.AvgPool2d(2)
119
+
120
+ # residual layers
121
+ self._inplanes = width # this is a *mutable* variable used during construction
122
+ self.layer1 = self._make_layer(width, layers[0])
123
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
124
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
125
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
126
+
127
+ embed_dim = width * 32 # the ResNet feature dimension
128
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
129
+
130
+ self.init_parameters()
131
+
132
+ def _make_layer(self, planes, blocks, stride=1):
133
+ layers = [Bottleneck(self._inplanes, planes, stride)]
134
+
135
+ self._inplanes = planes * Bottleneck.expansion
136
+ for _ in range(1, blocks):
137
+ layers.append(Bottleneck(self._inplanes, planes))
138
+
139
+ return nn.Sequential(*layers)
140
+
141
+ def init_parameters(self):
142
+ if self.attnpool is not None:
143
+ std = self.attnpool.c_proj.in_features ** -0.5
144
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
145
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
146
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
147
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
148
+
149
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
150
+ for name, param in resnet_block.named_parameters():
151
+ if name.endswith("bn3.weight"):
152
+ nn.init.zeros_(param)
153
+
154
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
155
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
156
+ for param in self.parameters():
157
+ param.requires_grad = False
158
+ if freeze_bn_stats:
159
+ freeze_batch_norm_2d(self)
160
+
161
+ @torch.jit.ignore
162
+ def set_grad_checkpointing(self, enable=True):
163
+ # FIXME support for non-transformer
164
+ pass
165
+
166
+ def stem(self, x):
167
+ x = self.act1(self.bn1(self.conv1(x)))
168
+ x = self.act2(self.bn2(self.conv2(x)))
169
+ x = self.act3(self.bn3(self.conv3(x)))
170
+ x = self.avgpool(x)
171
+ return x
172
+
173
+ def forward(self, x):
174
+ x = self.stem(x)
175
+ x = self.layer1(x)
176
+ x = self.layer2(x)
177
+ x = self.layer3(x)
178
+ x = self.layer4(x)
179
+ x = self.attnpool(x)
180
+
181
+ return x
eva_clip/openai.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ OpenAI pretrained model functions
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+
6
+ import os
7
+ import warnings
8
+ from typing import List, Optional, Union
9
+
10
+ import torch
11
+
12
+ from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
13
+ from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
14
+
15
+ __all__ = ["list_openai_models", "load_openai_model"]
16
+
17
+
18
+ def list_openai_models() -> List[str]:
19
+ """Returns the names of available CLIP models"""
20
+ return list_pretrained_models_by_tag('openai')
21
+
22
+
23
+ def load_openai_model(
24
+ name: str,
25
+ precision: Optional[str] = None,
26
+ device: Optional[Union[str, torch.device]] = None,
27
+ jit: bool = True,
28
+ cache_dir: Optional[str] = None,
29
+ ):
30
+ """Load a CLIP model
31
+
32
+ Parameters
33
+ ----------
34
+ name : str
35
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36
+ precision: str
37
+ Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38
+ device : Union[str, torch.device]
39
+ The device to put the loaded model
40
+ jit : bool
41
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
42
+ cache_dir : Optional[str]
43
+ The directory to cache the downloaded model weights
44
+
45
+ Returns
46
+ -------
47
+ model : torch.nn.Module
48
+ The CLIP model
49
+ preprocess : Callable[[PIL.Image], torch.Tensor]
50
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
51
+ """
52
+ if device is None:
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+ if precision is None:
55
+ precision = 'fp32' if device == 'cpu' else 'fp16'
56
+
57
+ if get_pretrained_url(name, 'openai'):
58
+ model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
59
+ elif os.path.isfile(name):
60
+ model_path = name
61
+ else:
62
+ raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
63
+
64
+ try:
65
+ # loading JIT archive
66
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
67
+ state_dict = None
68
+ except RuntimeError:
69
+ # loading saved state dict
70
+ if jit:
71
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
72
+ jit = False
73
+ state_dict = torch.load(model_path, map_location="cpu")
74
+
75
+ if not jit:
76
+ # Build a non-jit model from the OpenAI jitted model state dict
77
+ cast_dtype = get_cast_dtype(precision)
78
+ try:
79
+ model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
80
+ except KeyError:
81
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
82
+ model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
83
+
84
+ # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
85
+ model = model.to(device)
86
+ if precision.startswith('amp') or precision == 'fp32':
87
+ model.float()
88
+ elif precision == 'bf16':
89
+ convert_weights_to_lp(model, dtype=torch.bfloat16)
90
+
91
+ return model
92
+
93
+ # patch the device names
94
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
95
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
96
+
97
+ def patch_device(module):
98
+ try:
99
+ graphs = [module.graph] if hasattr(module, "graph") else []
100
+ except RuntimeError:
101
+ graphs = []
102
+
103
+ if hasattr(module, "forward1"):
104
+ graphs.append(module.forward1.graph)
105
+
106
+ for graph in graphs:
107
+ for node in graph.findAllNodes("prim::Constant"):
108
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
109
+ node.copyAttributes(device_node)
110
+
111
+ model.apply(patch_device)
112
+ patch_device(model.encode_image)
113
+ patch_device(model.encode_text)
114
+
115
+ # patch dtype to float32 (typically for CPU)
116
+ if precision == 'fp32':
117
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
118
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
119
+ float_node = float_input.node()
120
+
121
+ def patch_float(module):
122
+ try:
123
+ graphs = [module.graph] if hasattr(module, "graph") else []
124
+ except RuntimeError:
125
+ graphs = []
126
+
127
+ if hasattr(module, "forward1"):
128
+ graphs.append(module.forward1.graph)
129
+
130
+ for graph in graphs:
131
+ for node in graph.findAllNodes("aten::to"):
132
+ inputs = list(node.inputs())
133
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
134
+ if inputs[i].node()["value"] == 5:
135
+ inputs[i].node().copyAttributes(float_node)
136
+
137
+ model.apply(patch_float)
138
+ patch_float(model.encode_image)
139
+ patch_float(model.encode_text)
140
+ model.float()
141
+
142
+ # ensure image_size attr available at consistent location for both jit and non-jit
143
+ model.visual.image_size = model.input_resolution.item()
144
+ return model
eva_clip/pretrained.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from functools import partial
6
+ from typing import Dict, Union
7
+
8
+ from tqdm import tqdm
9
+
10
+ try:
11
+ from huggingface_hub import hf_hub_download
12
+ _has_hf_hub = True
13
+ except ImportError:
14
+ hf_hub_download = None
15
+ _has_hf_hub = False
16
+
17
+
18
+ def _pcfg(url='', hf_hub='', filename='', mean=None, std=None):
19
+ return dict(
20
+ url=url,
21
+ hf_hub=hf_hub,
22
+ mean=mean,
23
+ std=std,
24
+ )
25
+
26
+ _VITB32 = dict(
27
+ openai=_pcfg(
28
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
29
+ laion400m_e31=_pcfg(
30
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
31
+ laion400m_e32=_pcfg(
32
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
33
+ laion2b_e16=_pcfg(
34
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
35
+ laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
36
+ )
37
+
38
+ _VITB32_quickgelu = dict(
39
+ openai=_pcfg(
40
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
41
+ laion400m_e31=_pcfg(
42
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
43
+ laion400m_e32=_pcfg(
44
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
45
+ )
46
+
47
+ _VITB16 = dict(
48
+ openai=_pcfg(
49
+ "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
50
+ laion400m_e31=_pcfg(
51
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
52
+ laion400m_e32=_pcfg(
53
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
54
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
55
+ )
56
+
57
+ _EVAB16 = dict(
58
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
59
+ eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
60
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
61
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
62
+ )
63
+
64
+ _VITB16_PLUS_240 = dict(
65
+ laion400m_e31=_pcfg(
66
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
67
+ laion400m_e32=_pcfg(
68
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
69
+ )
70
+
71
+ _VITL14 = dict(
72
+ openai=_pcfg(
73
+ "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
74
+ laion400m_e31=_pcfg(
75
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
76
+ laion400m_e32=_pcfg(
77
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
78
+ laion2b_s32b_b82k=_pcfg(
79
+ hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
80
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
81
+ )
82
+
83
+ _EVAL14 = dict(
84
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
85
+ eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
86
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
87
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
88
+ )
89
+
90
+ _VITL14_336 = dict(
91
+ openai=_pcfg(
92
+ "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
93
+ )
94
+
95
+ _EVAL14_336 = dict(
96
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
97
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
98
+ eva_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
99
+ eva02_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
100
+ )
101
+
102
+ _VITH14 = dict(
103
+ laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
104
+ )
105
+
106
+ _VITg14 = dict(
107
+ laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
108
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
109
+ )
110
+
111
+ _EVAg14 = dict(
112
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
113
+ eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
114
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
115
+ eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
116
+ )
117
+
118
+ _EVAg14_PLUS = dict(
119
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
120
+ eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
121
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
122
+ eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
123
+ )
124
+
125
+ _VITbigG14 = dict(
126
+ laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
127
+ )
128
+
129
+ _EVAbigE14 = dict(
130
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
131
+ eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
132
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
133
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
134
+ )
135
+
136
+ _EVAbigE14_PLUS = dict(
137
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
138
+ eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
139
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
140
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
141
+ )
142
+
143
+
144
+ _PRETRAINED = {
145
+ # "ViT-B-32": _VITB32,
146
+ "OpenaiCLIP-B-32": _VITB32,
147
+ "OpenCLIP-B-32": _VITB32,
148
+
149
+ # "ViT-B-32-quickgelu": _VITB32_quickgelu,
150
+ "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu,
151
+ "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu,
152
+
153
+ # "ViT-B-16": _VITB16,
154
+ "OpenaiCLIP-B-16": _VITB16,
155
+ "OpenCLIP-B-16": _VITB16,
156
+
157
+ "EVA02-B-16": _EVAB16,
158
+ "EVA02-CLIP-B-16": _EVAB16,
159
+
160
+ # "ViT-B-16-plus-240": _VITB16_PLUS_240,
161
+ "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240,
162
+
163
+ # "ViT-L-14": _VITL14,
164
+ "OpenaiCLIP-L-14": _VITL14,
165
+ "OpenCLIP-L-14": _VITL14,
166
+
167
+ "EVA02-L-14": _EVAL14,
168
+ "EVA02-CLIP-L-14": _EVAL14,
169
+
170
+ # "ViT-L-14-336": _VITL14_336,
171
+ "OpenaiCLIP-L-14-336": _VITL14_336,
172
+
173
+ "EVA02-CLIP-L-14-336": _EVAL14_336,
174
+
175
+ # "ViT-H-14": _VITH14,
176
+ # "ViT-g-14": _VITg14,
177
+ "OpenCLIP-H-14": _VITH14,
178
+ "OpenCLIP-g-14": _VITg14,
179
+
180
+ "EVA01-CLIP-g-14": _EVAg14,
181
+ "EVA01-CLIP-g-14-plus": _EVAg14_PLUS,
182
+
183
+ # "ViT-bigG-14": _VITbigG14,
184
+ "OpenCLIP-bigG-14": _VITbigG14,
185
+
186
+ "EVA02-CLIP-bigE-14": _EVAbigE14,
187
+ "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS,
188
+ }
189
+
190
+
191
+ def _clean_tag(tag: str):
192
+ # normalize pretrained tags
193
+ return tag.lower().replace('-', '_')
194
+
195
+
196
+ def list_pretrained(as_str: bool = False):
197
+ """ returns list of pretrained models
198
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
199
+ """
200
+ return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
201
+
202
+
203
+ def list_pretrained_models_by_tag(tag: str):
204
+ """ return all models having the specified pretrain tag """
205
+ models = []
206
+ tag = _clean_tag(tag)
207
+ for k in _PRETRAINED.keys():
208
+ if tag in _PRETRAINED[k]:
209
+ models.append(k)
210
+ return models
211
+
212
+
213
+ def list_pretrained_tags_by_model(model: str):
214
+ """ return all pretrain tags for the specified model architecture """
215
+ tags = []
216
+ if model in _PRETRAINED:
217
+ tags.extend(_PRETRAINED[model].keys())
218
+ return tags
219
+
220
+
221
+ def is_pretrained_cfg(model: str, tag: str):
222
+ if model not in _PRETRAINED:
223
+ return False
224
+ return _clean_tag(tag) in _PRETRAINED[model]
225
+
226
+
227
+ def get_pretrained_cfg(model: str, tag: str):
228
+ if model not in _PRETRAINED:
229
+ return {}
230
+ model_pretrained = _PRETRAINED[model]
231
+ return model_pretrained.get(_clean_tag(tag), {})
232
+
233
+
234
+ def get_pretrained_url(model: str, tag: str):
235
+ cfg = get_pretrained_cfg(model, _clean_tag(tag))
236
+ return cfg.get('url', '')
237
+
238
+
239
+ def download_pretrained_from_url(
240
+ url: str,
241
+ cache_dir: Union[str, None] = None,
242
+ ):
243
+ if not cache_dir:
244
+ cache_dir = os.path.expanduser("~/.cache/clip")
245
+ os.makedirs(cache_dir, exist_ok=True)
246
+ filename = os.path.basename(url)
247
+
248
+ if 'openaipublic' in url:
249
+ expected_sha256 = url.split("/")[-2]
250
+ elif 'mlfoundations' in url:
251
+ expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
252
+ else:
253
+ expected_sha256 = ''
254
+
255
+ download_target = os.path.join(cache_dir, filename)
256
+
257
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
258
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
259
+
260
+ if os.path.isfile(download_target):
261
+ if expected_sha256:
262
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
263
+ return download_target
264
+ else:
265
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
266
+ else:
267
+ return download_target
268
+
269
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
270
+ with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
271
+ while True:
272
+ buffer = source.read(8192)
273
+ if not buffer:
274
+ break
275
+
276
+ output.write(buffer)
277
+ loop.update(len(buffer))
278
+
279
+ if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
280
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
281
+
282
+ return download_target
283
+
284
+
285
+ def has_hf_hub(necessary=False):
286
+ if not _has_hf_hub and necessary:
287
+ # if no HF Hub module installed, and it is necessary to continue, raise error
288
+ raise RuntimeError(
289
+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
290
+ return _has_hf_hub
291
+
292
+
293
+ def download_pretrained_from_hf(
294
+ model_id: str,
295
+ filename: str = 'open_clip_pytorch_model.bin',
296
+ revision=None,
297
+ cache_dir: Union[str, None] = None,
298
+ ):
299
+ has_hf_hub(True)
300
+ cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
301
+ return cached_file
302
+
303
+
304
+ def download_pretrained(
305
+ cfg: Dict,
306
+ force_hf_hub: bool = False,
307
+ cache_dir: Union[str, None] = None,
308
+ ):
309
+ target = ''
310
+ if not cfg:
311
+ return target
312
+
313
+ download_url = cfg.get('url', '')
314
+ download_hf_hub = cfg.get('hf_hub', '')
315
+ if download_hf_hub and force_hf_hub:
316
+ # use HF hub even if url exists
317
+ download_url = ''
318
+
319
+ if download_url:
320
+ target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
321
+ elif download_hf_hub:
322
+ has_hf_hub(True)
323
+ # we assume the hf_hub entries in pretrained config combine model_id + filename in
324
+ # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
325
+ # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
326
+ model_id, filename = os.path.split(download_hf_hub)
327
+ if filename:
328
+ target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
329
+ else:
330
+ target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
331
+
332
+ return target
eva_clip/rope.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ import torch
3
+ from torch import nn
4
+ from einops import rearrange, repeat
5
+ import logging
6
+
7
+ def broadcat(tensors, dim = -1):
8
+ num_tensors = len(tensors)
9
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
10
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
11
+ shape_len = list(shape_lens)[0]
12
+ dim = (dim + shape_len) if dim < 0 else dim
13
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
14
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
15
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
16
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
17
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
18
+ expanded_dims.insert(dim, (dim, dims[dim]))
19
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
20
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
21
+ return torch.cat(tensors, dim = dim)
22
+
23
+ def rotate_half(x):
24
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
25
+ x1, x2 = x.unbind(dim = -1)
26
+ x = torch.stack((-x2, x1), dim = -1)
27
+ return rearrange(x, '... d r -> ... (d r)')
28
+
29
+
30
+ class VisionRotaryEmbedding(nn.Module):
31
+ def __init__(
32
+ self,
33
+ dim,
34
+ pt_seq_len,
35
+ ft_seq_len=None,
36
+ custom_freqs = None,
37
+ freqs_for = 'lang',
38
+ theta = 10000,
39
+ max_freq = 10,
40
+ num_freqs = 1,
41
+ ):
42
+ super().__init__()
43
+ if custom_freqs:
44
+ freqs = custom_freqs
45
+ elif freqs_for == 'lang':
46
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
47
+ elif freqs_for == 'pixel':
48
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
49
+ elif freqs_for == 'constant':
50
+ freqs = torch.ones(num_freqs).float()
51
+ else:
52
+ raise ValueError(f'unknown modality {freqs_for}')
53
+
54
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
55
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
56
+
57
+ freqs_h = torch.einsum('..., f -> ... f', t, freqs)
58
+ freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
59
+
60
+ freqs_w = torch.einsum('..., f -> ... f', t, freqs)
61
+ freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
62
+
63
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
64
+
65
+ self.register_buffer("freqs_cos", freqs.cos())
66
+ self.register_buffer("freqs_sin", freqs.sin())
67
+
68
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
69
+
70
+ def forward(self, t, start_index = 0):
71
+ rot_dim = self.freqs_cos.shape[-1]
72
+ end_index = start_index + rot_dim
73
+ assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
74
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
75
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
76
+
77
+ return torch.cat((t_left, t, t_right), dim = -1)
78
+
79
+ class VisionRotaryEmbeddingFast(nn.Module):
80
+ def __init__(
81
+ self,
82
+ dim,
83
+ pt_seq_len,
84
+ ft_seq_len=None,
85
+ custom_freqs = None,
86
+ freqs_for = 'lang',
87
+ theta = 10000,
88
+ max_freq = 10,
89
+ num_freqs = 1,
90
+ patch_dropout = 0.
91
+ ):
92
+ super().__init__()
93
+ if custom_freqs:
94
+ freqs = custom_freqs
95
+ elif freqs_for == 'lang':
96
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
97
+ elif freqs_for == 'pixel':
98
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
99
+ elif freqs_for == 'constant':
100
+ freqs = torch.ones(num_freqs).float()
101
+ else:
102
+ raise ValueError(f'unknown modality {freqs_for}')
103
+
104
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
105
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
106
+
107
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
108
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
109
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
110
+
111
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
112
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
113
+
114
+ self.patch_dropout = patch_dropout
115
+
116
+ self.register_buffer("freqs_cos", freqs_cos)
117
+ self.register_buffer("freqs_sin", freqs_sin)
118
+
119
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
120
+
121
+ def forward(self, t, patch_indices_keep=None):
122
+ if patch_indices_keep is not None:
123
+ batch = t.size()[0]
124
+ batch_indices = torch.arange(batch)
125
+ batch_indices = batch_indices[..., None]
126
+
127
+ freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
128
+ freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
129
+
130
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
131
+ freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
132
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
133
+ freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
134
+
135
+ return t * freqs_cos + rotate_half(t) * freqs_sin
136
+
137
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
eva_clip/timm_model.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ timm model adapter
2
+
3
+ Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4
+ """
5
+ import logging
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ try:
12
+ import timm
13
+ from timm.models.layers import Mlp, to_2tuple
14
+ try:
15
+ # old timm imports < 0.8.1
16
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
17
+ from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
18
+ except ImportError:
19
+ # new timm imports >= 0.8.1
20
+ from timm.layers import RotAttentionPool2d
21
+ from timm.layers import AttentionPool2d as AbsAttentionPool2d
22
+ except ImportError:
23
+ timm = None
24
+
25
+ from .utils import freeze_batch_norm_2d
26
+
27
+
28
+ class TimmModel(nn.Module):
29
+ """ timm model adapter
30
+ # FIXME this adapter is a work in progress, may change in ways that break weight compat
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ model_name,
36
+ embed_dim,
37
+ image_size=224,
38
+ pool='avg',
39
+ proj='linear',
40
+ proj_bias=False,
41
+ drop=0.,
42
+ pretrained=False):
43
+ super().__init__()
44
+ if timm is None:
45
+ # raise RuntimeError("Please `pip install timm` to use timm models.")
46
+ return
47
+
48
+ self.image_size = to_2tuple(image_size)
49
+ self.trunk = timm.create_model(model_name, pretrained=pretrained)
50
+ feat_size = self.trunk.default_cfg.get('pool_size', None)
51
+ feature_ndim = 1 if not feat_size else 2
52
+ if pool in ('abs_attn', 'rot_attn'):
53
+ assert feature_ndim == 2
54
+ # if attn pooling used, remove both classifier and default pool
55
+ self.trunk.reset_classifier(0, global_pool='')
56
+ else:
57
+ # reset global pool if pool config set, otherwise leave as network default
58
+ reset_kwargs = dict(global_pool=pool) if pool else {}
59
+ self.trunk.reset_classifier(0, **reset_kwargs)
60
+ prev_chs = self.trunk.num_features
61
+
62
+ head_layers = OrderedDict()
63
+ if pool == 'abs_attn':
64
+ head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
65
+ prev_chs = embed_dim
66
+ elif pool == 'rot_attn':
67
+ head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
68
+ prev_chs = embed_dim
69
+ else:
70
+ assert proj, 'projection layer needed if non-attention pooling is used.'
71
+
72
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
73
+ if proj == 'linear':
74
+ head_layers['drop'] = nn.Dropout(drop)
75
+ head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
76
+ elif proj == 'mlp':
77
+ head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
78
+
79
+ self.head = nn.Sequential(head_layers)
80
+
81
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
82
+ """ lock modules
83
+ Args:
84
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
85
+ """
86
+ if not unlocked_groups:
87
+ # lock full model
88
+ for param in self.trunk.parameters():
89
+ param.requires_grad = False
90
+ if freeze_bn_stats:
91
+ freeze_batch_norm_2d(self.trunk)
92
+ else:
93
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
94
+ try:
95
+ # FIXME import here until API stable and in an official release
96
+ from timm.models.helpers import group_parameters, group_modules
97
+ except ImportError:
98
+ raise RuntimeError(
99
+ 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
100
+ matcher = self.trunk.group_matcher()
101
+ gparams = group_parameters(self.trunk, matcher)
102
+ max_layer_id = max(gparams.keys())
103
+ max_layer_id = max_layer_id - unlocked_groups
104
+ for group_idx in range(max_layer_id + 1):
105
+ group = gparams[group_idx]
106
+ for param in group:
107
+ self.trunk.get_parameter(param).requires_grad = False
108
+ if freeze_bn_stats:
109
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
110
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
111
+ freeze_batch_norm_2d(self.trunk, gmodules)
112
+
113
+ @torch.jit.ignore
114
+ def set_grad_checkpointing(self, enable=True):
115
+ try:
116
+ self.trunk.set_grad_checkpointing(enable)
117
+ except Exception as e:
118
+ logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
119
+
120
+ def forward(self, x):
121
+ x = self.trunk(x)
122
+ x = self.head(x)
123
+ return x
eva_clip/tokenizer.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP tokenizer
2
+
3
+ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import gzip
6
+ import html
7
+ import os
8
+ from functools import lru_cache
9
+ from typing import Union, List
10
+
11
+ import ftfy
12
+ import regex as re
13
+ import torch
14
+
15
+ # https://stackoverflow.com/q/62691279
16
+ import os
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+
19
+
20
+ @lru_cache()
21
+ def default_bpe():
22
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
23
+
24
+
25
+ @lru_cache()
26
+ def bytes_to_unicode():
27
+ """
28
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
29
+ The reversible bpe codes work on unicode strings.
30
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
31
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
32
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
33
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
34
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
35
+ """
36
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
37
+ cs = bs[:]
38
+ n = 0
39
+ for b in range(2**8):
40
+ if b not in bs:
41
+ bs.append(b)
42
+ cs.append(2**8+n)
43
+ n += 1
44
+ cs = [chr(n) for n in cs]
45
+ return dict(zip(bs, cs))
46
+
47
+
48
+ def get_pairs(word):
49
+ """Return set of symbol pairs in a word.
50
+ Word is represented as tuple of symbols (symbols being variable-length strings).
51
+ """
52
+ pairs = set()
53
+ prev_char = word[0]
54
+ for char in word[1:]:
55
+ pairs.add((prev_char, char))
56
+ prev_char = char
57
+ return pairs
58
+
59
+
60
+ def basic_clean(text):
61
+ text = ftfy.fix_text(text)
62
+ text = html.unescape(html.unescape(text))
63
+ return text.strip()
64
+
65
+
66
+ def whitespace_clean(text):
67
+ text = re.sub(r'\s+', ' ', text)
68
+ text = text.strip()
69
+ return text
70
+
71
+
72
+ class SimpleTokenizer(object):
73
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
74
+ self.byte_encoder = bytes_to_unicode()
75
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
76
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
77
+ merges = merges[1:49152-256-2+1]
78
+ merges = [tuple(merge.split()) for merge in merges]
79
+ vocab = list(bytes_to_unicode().values())
80
+ vocab = vocab + [v+'</w>' for v in vocab]
81
+ for merge in merges:
82
+ vocab.append(''.join(merge))
83
+ if not special_tokens:
84
+ special_tokens = ['<start_of_text>', '<end_of_text>']
85
+ else:
86
+ special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
87
+ vocab.extend(special_tokens)
88
+ self.encoder = dict(zip(vocab, range(len(vocab))))
89
+ self.decoder = {v: k for k, v in self.encoder.items()}
90
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
91
+ self.cache = {t:t for t in special_tokens}
92
+ special = "|".join(special_tokens)
93
+ self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
94
+
95
+ self.vocab_size = len(self.encoder)
96
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
97
+
98
+ def bpe(self, token):
99
+ if token in self.cache:
100
+ return self.cache[token]
101
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
102
+ pairs = get_pairs(word)
103
+
104
+ if not pairs:
105
+ return token+'</w>'
106
+
107
+ while True:
108
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
109
+ if bigram not in self.bpe_ranks:
110
+ break
111
+ first, second = bigram
112
+ new_word = []
113
+ i = 0
114
+ while i < len(word):
115
+ try:
116
+ j = word.index(first, i)
117
+ new_word.extend(word[i:j])
118
+ i = j
119
+ except:
120
+ new_word.extend(word[i:])
121
+ break
122
+
123
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
124
+ new_word.append(first+second)
125
+ i += 2
126
+ else:
127
+ new_word.append(word[i])
128
+ i += 1
129
+ new_word = tuple(new_word)
130
+ word = new_word
131
+ if len(word) == 1:
132
+ break
133
+ else:
134
+ pairs = get_pairs(word)
135
+ word = ' '.join(word)
136
+ self.cache[token] = word
137
+ return word
138
+
139
+ def encode(self, text):
140
+ bpe_tokens = []
141
+ text = whitespace_clean(basic_clean(text)).lower()
142
+ for token in re.findall(self.pat, text):
143
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
144
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
145
+ return bpe_tokens
146
+
147
+ def decode(self, tokens):
148
+ text = ''.join([self.decoder[token] for token in tokens])
149
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
150
+ return text
151
+
152
+
153
+ _tokenizer = SimpleTokenizer()
154
+
155
+
156
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
157
+ """
158
+ Returns the tokenized representation of given input string(s)
159
+
160
+ Parameters
161
+ ----------
162
+ texts : Union[str, List[str]]
163
+ An input string or a list of input strings to tokenize
164
+ context_length : int
165
+ The context length to use; all CLIP models use 77 as the context length
166
+
167
+ Returns
168
+ -------
169
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
170
+ """
171
+ if isinstance(texts, str):
172
+ texts = [texts]
173
+
174
+ sot_token = _tokenizer.encoder["<start_of_text>"]
175
+ eot_token = _tokenizer.encoder["<end_of_text>"]
176
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
177
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
178
+
179
+ for i, tokens in enumerate(all_tokens):
180
+ if len(tokens) > context_length:
181
+ tokens = tokens[:context_length] # Truncate
182
+ tokens[-1] = eot_token
183
+ result[i, :len(tokens)] = torch.tensor(tokens)
184
+
185
+ return result
186
+
187
+
188
+ class HFTokenizer:
189
+ "HuggingFace tokenizer wrapper"
190
+ def __init__(self, tokenizer_name:str):
191
+ from transformers import AutoTokenizer
192
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
193
+
194
+ def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor:
195
+ # same cleaning as for default tokenizer, except lowercasing
196
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
197
+ if isinstance(texts, str):
198
+ texts = [texts]
199
+ texts = [whitespace_clean(basic_clean(text)) for text in texts]
200
+ input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids
201
+ return input_ids
eva_clip/transform.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.transforms.functional as F
6
+
7
+ from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
8
+ CenterCrop
9
+
10
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
11
+
12
+
13
+ class ResizeMaxSize(nn.Module):
14
+
15
+ def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
16
+ super().__init__()
17
+ if not isinstance(max_size, int):
18
+ raise TypeError(f"Size should be int. Got {type(max_size)}")
19
+ self.max_size = max_size
20
+ self.interpolation = interpolation
21
+ self.fn = min if fn == 'min' else min
22
+ self.fill = fill
23
+
24
+ def forward(self, img):
25
+ if isinstance(img, torch.Tensor):
26
+ height, width = img.shape[:2]
27
+ else:
28
+ width, height = img.size
29
+ scale = self.max_size / float(max(height, width))
30
+ if scale != 1.0:
31
+ new_size = tuple(round(dim * scale) for dim in (height, width))
32
+ img = F.resize(img, new_size, self.interpolation)
33
+ pad_h = self.max_size - new_size[0]
34
+ pad_w = self.max_size - new_size[1]
35
+ img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
36
+ return img
37
+
38
+
39
+ def _convert_to_rgb(image):
40
+ return image.convert('RGB')
41
+
42
+
43
+ # class CatGen(nn.Module):
44
+ # def __init__(self, num=4):
45
+ # self.num = num
46
+ # def mixgen_batch(image, text):
47
+ # batch_size = image.shape[0]
48
+ # index = np.random.permutation(batch_size)
49
+
50
+ # cat_images = []
51
+ # for i in range(batch_size):
52
+ # # image mixup
53
+ # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
54
+ # # text concat
55
+ # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
56
+ # text = torch.stack(text)
57
+ # return image, text
58
+
59
+
60
+ def image_transform(
61
+ image_size: int,
62
+ is_train: bool,
63
+ mean: Optional[Tuple[float, ...]] = None,
64
+ std: Optional[Tuple[float, ...]] = None,
65
+ resize_longest_max: bool = False,
66
+ fill_color: int = 0,
67
+ ):
68
+ mean = mean or OPENAI_DATASET_MEAN
69
+ if not isinstance(mean, (list, tuple)):
70
+ mean = (mean,) * 3
71
+
72
+ std = std or OPENAI_DATASET_STD
73
+ if not isinstance(std, (list, tuple)):
74
+ std = (std,) * 3
75
+
76
+ if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
77
+ # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
78
+ image_size = image_size[0]
79
+
80
+ normalize = Normalize(mean=mean, std=std)
81
+ if is_train:
82
+ return Compose([
83
+ RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
84
+ _convert_to_rgb,
85
+ ToTensor(),
86
+ normalize,
87
+ ])
88
+ else:
89
+ if resize_longest_max:
90
+ transforms = [
91
+ ResizeMaxSize(image_size, fill=fill_color)
92
+ ]
93
+ else:
94
+ transforms = [
95
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
96
+ CenterCrop(image_size),
97
+ ]
98
+ transforms.extend([
99
+ _convert_to_rgb,
100
+ ToTensor(),
101
+ normalize,
102
+ ])
103
+ return Compose(transforms)
eva_clip/transformer.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from collections import OrderedDict
4
+ import math
5
+ import warnings
6
+ from typing import Callable, Optional, Sequence
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
13
+ from .utils import to_2tuple
14
+
15
+ if os.getenv('ENV_TYPE') == 'deepspeed':
16
+ try:
17
+ import deepspeed
18
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
19
+ except:
20
+ print("Please 'pip install deepspeed'")
21
+ deepspeed = None
22
+ from torch.utils.checkpoint import checkpoint
23
+ else:
24
+ from torch.utils.checkpoint import checkpoint
25
+
26
+ try:
27
+ import xformers.ops as xops
28
+ except ImportError:
29
+ xops = None
30
+ print("Please 'pip install xformers'")
31
+
32
+
33
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
34
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
35
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
36
+ def norm_cdf(x):
37
+ # Computes standard normal cumulative distribution function
38
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
39
+
40
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
41
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
42
+ "The distribution of values may be incorrect.",
43
+ stacklevel=2)
44
+
45
+ with torch.no_grad():
46
+ # Values are generated by using a truncated uniform distribution and
47
+ # then using the inverse CDF for the normal distribution.
48
+ # Get upper and lower cdf values
49
+ l = norm_cdf((a - mean) / std)
50
+ u = norm_cdf((b - mean) / std)
51
+
52
+ # Uniformly fill tensor with values from [l, u], then translate to
53
+ # [2l-1, 2u-1].
54
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
55
+
56
+ # Use inverse cdf transform for normal distribution to get truncated
57
+ # standard normal
58
+ tensor.erfinv_()
59
+
60
+ # Transform to proper mean, std
61
+ tensor.mul_(std * math.sqrt(2.))
62
+ tensor.add_(mean)
63
+
64
+ # Clamp to ensure it's in the proper range
65
+ tensor.clamp_(min=a, max=b)
66
+ return tensor
67
+
68
+
69
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
70
+ # type: (Tensor, float, float, float, float) -> Tensor
71
+ r"""Fills the input Tensor with values drawn from a truncated
72
+ normal distribution. The values are effectively drawn from the
73
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
74
+ with values outside :math:`[a, b]` redrawn until they are within
75
+ the bounds. The method used for generating the random values works
76
+ best when :math:`a \leq \text{mean} \leq b`.
77
+ Args:
78
+ tensor: an n-dimensional `torch.Tensor`
79
+ mean: the mean of the normal distribution
80
+ std: the standard deviation of the normal distribution
81
+ a: the minimum cutoff value
82
+ b: the maximum cutoff value
83
+ Examples:
84
+ >>> w = torch.empty(3, 5)
85
+ >>> nn.init.trunc_normal_(w)
86
+ """
87
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
88
+
89
+
90
+
91
+ class LayerNormFp32(nn.LayerNorm):
92
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
93
+ def __init__(self, *args, **kwargs):
94
+ super().__init__(*args, **kwargs)
95
+
96
+ def forward(self, x: torch.Tensor):
97
+ output = F.layer_norm(
98
+ x.float(),
99
+ self.normalized_shape,
100
+ self.weight.float() if self.weight is not None else None,
101
+ self.bias.float() if self.bias is not None else None,
102
+ self.eps,
103
+ )
104
+ return output.type_as(x)
105
+
106
+
107
+ class LayerNorm(nn.LayerNorm):
108
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
109
+
110
+ def forward(self, x: torch.Tensor):
111
+ orig_type = x.dtype
112
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
113
+ return x.to(orig_type)
114
+
115
+ class QuickGELU(nn.Module):
116
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
117
+ def forward(self, x: torch.Tensor):
118
+ return x * torch.sigmoid(1.702 * x)
119
+
120
+
121
+ class LayerScale(nn.Module):
122
+ def __init__(self, dim, init_values=1e-5, inplace=False):
123
+ super().__init__()
124
+ self.inplace = inplace
125
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
126
+
127
+ def forward(self, x):
128
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
129
+
130
+ class PatchDropout(nn.Module):
131
+ """
132
+ https://arxiv.org/abs/2212.00794
133
+ """
134
+
135
+ def __init__(self, prob, exclude_first_token=True):
136
+ super().__init__()
137
+ assert 0 <= prob < 1.
138
+ self.prob = prob
139
+ self.exclude_first_token = exclude_first_token # exclude CLS token
140
+ logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
141
+
142
+ def forward(self, x):
143
+ if not self.training or self.prob == 0.:
144
+ return x
145
+
146
+ if self.exclude_first_token:
147
+ cls_tokens, x = x[:, :1], x[:, 1:]
148
+ else:
149
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
150
+
151
+ batch = x.size()[0]
152
+ num_tokens = x.size()[1]
153
+
154
+ batch_indices = torch.arange(batch)
155
+ batch_indices = batch_indices[..., None]
156
+
157
+ keep_prob = 1 - self.prob
158
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
159
+
160
+ rand = torch.randn(batch, num_tokens)
161
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
162
+
163
+ x = x[batch_indices, patch_indices_keep]
164
+
165
+ if self.exclude_first_token:
166
+ x = torch.cat((cls_tokens, x), dim=1)
167
+
168
+ if self.training and os.getenv('RoPE') == '1':
169
+ return x, patch_indices_keep
170
+
171
+ return x
172
+
173
+
174
+ def _in_projection_packed(
175
+ q: torch.Tensor,
176
+ k: torch.Tensor,
177
+ v: torch.Tensor,
178
+ w: torch.Tensor,
179
+ b: Optional[torch.Tensor] = None,
180
+ ):
181
+ """
182
+ https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726
183
+ """
184
+ E = q.size(-1)
185
+ if k is v:
186
+ if q is k:
187
+ # self-attention
188
+ return F.linear(q, w, b).chunk(3, dim=-1)
189
+ else:
190
+ # encoder-decoder attention
191
+ w_q, w_kv = w.split([E, E * 2])
192
+ if b is None:
193
+ b_q = b_kv = None
194
+ else:
195
+ b_q, b_kv = b.split([E, E * 2])
196
+ return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
197
+ else:
198
+ w_q, w_k, w_v = w.chunk(3)
199
+ if b is None:
200
+ b_q = b_k = b_v = None
201
+ else:
202
+ b_q, b_k, b_v = b.chunk(3)
203
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
204
+
205
+ class Attention(nn.Module):
206
+ def __init__(
207
+ self,
208
+ dim,
209
+ num_heads=8,
210
+ qkv_bias=True,
211
+ scaled_cosine=False,
212
+ scale_heads=False,
213
+ logit_scale_max=math.log(1. / 0.01),
214
+ attn_drop=0.,
215
+ proj_drop=0.,
216
+ xattn=False,
217
+ rope=False
218
+ ):
219
+ super().__init__()
220
+ self.scaled_cosine = scaled_cosine
221
+ self.scale_heads = scale_heads
222
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
223
+ self.num_heads = num_heads
224
+ self.head_dim = dim // num_heads
225
+ self.scale = self.head_dim ** -0.5
226
+ self.logit_scale_max = logit_scale_max
227
+
228
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
229
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
230
+ if qkv_bias:
231
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
232
+ else:
233
+ self.in_proj_bias = None
234
+
235
+ if self.scaled_cosine:
236
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
237
+ else:
238
+ self.logit_scale = None
239
+ self.attn_drop = nn.Dropout(attn_drop)
240
+ if self.scale_heads:
241
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
242
+ else:
243
+ self.head_scale = None
244
+ self.out_proj = nn.Linear(dim, dim)
245
+ self.out_drop = nn.Dropout(proj_drop)
246
+ self.xattn = xattn
247
+ self.xattn_drop = attn_drop
248
+ self.rope = rope
249
+
250
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
251
+ L, N, C = x.shape
252
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
253
+ if self.xattn:
254
+ q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
255
+ k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
256
+ v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
257
+
258
+ x = xops.memory_efficient_attention(
259
+ q, k, v,
260
+ p=self.xattn_drop,
261
+ scale=self.scale if self.logit_scale is None else None,
262
+ attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
263
+ )
264
+ else:
265
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
266
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
267
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
268
+
269
+ if self.logit_scale is not None:
270
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
271
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
272
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
273
+ attn = attn.view(-1, L, L)
274
+ else:
275
+ q = q * self.scale
276
+ attn = torch.bmm(q, k.transpose(-1, -2))
277
+
278
+ if attn_mask is not None:
279
+ if attn_mask.dtype == torch.bool:
280
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
281
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
282
+ attn_mask = new_attn_mask
283
+ attn += attn_mask
284
+
285
+ attn = attn.softmax(dim=-1)
286
+ attn = self.attn_drop(attn)
287
+
288
+ x = torch.bmm(attn, v)
289
+
290
+ if self.head_scale is not None:
291
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
292
+ x = x.view(-1, L, C)
293
+ x = x.transpose(0, 1).reshape(L, N, C)
294
+ x = self.out_proj(x)
295
+ x = self.out_drop(x)
296
+ return x
297
+
298
+ class CustomAttention(nn.Module):
299
+ def __init__(
300
+ self,
301
+ dim,
302
+ num_heads=8,
303
+ qkv_bias=True,
304
+ scaled_cosine=True,
305
+ scale_heads=False,
306
+ logit_scale_max=math.log(1. / 0.01),
307
+ attn_drop=0.,
308
+ proj_drop=0.,
309
+ xattn=False
310
+ ):
311
+ super().__init__()
312
+ self.scaled_cosine = scaled_cosine
313
+ self.scale_heads = scale_heads
314
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
315
+ self.num_heads = num_heads
316
+ self.head_dim = dim // num_heads
317
+ self.scale = self.head_dim ** -0.5
318
+ self.logit_scale_max = logit_scale_max
319
+
320
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
321
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
322
+ if qkv_bias:
323
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
324
+ else:
325
+ self.in_proj_bias = None
326
+
327
+ if self.scaled_cosine:
328
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
329
+ else:
330
+ self.logit_scale = None
331
+ self.attn_drop = nn.Dropout(attn_drop)
332
+ if self.scale_heads:
333
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
334
+ else:
335
+ self.head_scale = None
336
+ self.out_proj = nn.Linear(dim, dim)
337
+ self.out_drop = nn.Dropout(proj_drop)
338
+ self.xattn = xattn
339
+ self.xattn_drop = attn_drop
340
+
341
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
342
+ q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
343
+ N_q, B_q, C_q = q.shape
344
+ N_k, B_k, C_k = k.shape
345
+ N_v, B_v, C_v = v.shape
346
+ if self.xattn:
347
+ # B, N, C -> B, N, num_heads, C
348
+ q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)
349
+ k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)
350
+ v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)
351
+
352
+ x = xops.memory_efficient_attention(
353
+ q, k, v,
354
+ p=self.xattn_drop,
355
+ scale=self.scale if self.logit_scale is None else None,
356
+ attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None
357
+ )
358
+ else:
359
+ # B*H, L, C
360
+ q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)
361
+ k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)
362
+ v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)
363
+
364
+ if self.logit_scale is not None:
365
+ # B*H, N_q, N_k
366
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
367
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
368
+ attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale
369
+ attn = attn.view(-1, N_q, N_k)
370
+ else:
371
+ q = q * self.scale
372
+ attn = torch.bmm(q, k.transpose(-1, -2))
373
+
374
+ if attn_mask is not None:
375
+ if attn_mask.dtype == torch.bool:
376
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
377
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
378
+ attn_mask = new_attn_mask
379
+ attn += attn_mask
380
+
381
+ attn = attn.softmax(dim=-1)
382
+ attn = self.attn_drop(attn)
383
+
384
+ x = torch.bmm(attn, v)
385
+
386
+ if self.head_scale is not None:
387
+ x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale
388
+ x = x.view(-1, N_q, C_q)
389
+ x = x.transpose(0, 1).reshape(N_q, B_q, C_q)
390
+ x = self.out_proj(x)
391
+ x = self.out_drop(x)
392
+ return x
393
+
394
+ class CustomResidualAttentionBlock(nn.Module):
395
+ def __init__(
396
+ self,
397
+ d_model: int,
398
+ n_head: int,
399
+ mlp_ratio: float = 4.0,
400
+ ls_init_value: float = None,
401
+ act_layer: Callable = nn.GELU,
402
+ norm_layer: Callable = LayerNorm,
403
+ scale_cosine_attn: bool = False,
404
+ scale_heads: bool = False,
405
+ scale_attn: bool = False,
406
+ scale_fc: bool = False,
407
+ cross_attn: bool = False,
408
+ xattn: bool = False,
409
+ ):
410
+ super().__init__()
411
+
412
+ self.ln_1 = norm_layer(d_model)
413
+ self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1
414
+ self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1
415
+ self.attn = CustomAttention(
416
+ d_model, n_head,
417
+ qkv_bias=True,
418
+ attn_drop=0.,
419
+ proj_drop=0.,
420
+ scaled_cosine=scale_cosine_attn,
421
+ scale_heads=scale_heads,
422
+ xattn=xattn
423
+ )
424
+
425
+ self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
426
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
427
+
428
+ self.ln_2 = norm_layer(d_model)
429
+ mlp_width = int(d_model * mlp_ratio)
430
+ self.mlp = nn.Sequential(OrderedDict([
431
+ ("c_fc", nn.Linear(d_model, mlp_width)),
432
+ ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
433
+ ("gelu", act_layer()),
434
+ ("c_proj", nn.Linear(mlp_width, d_model))
435
+ ]))
436
+
437
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
438
+
439
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
440
+ q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))
441
+ q = q + self.ls_2(self.mlp(self.ln_2(q)))
442
+ return q
443
+
444
+ class CustomTransformer(nn.Module):
445
+ def __init__(
446
+ self,
447
+ width: int,
448
+ layers: int,
449
+ heads: int,
450
+ mlp_ratio: float = 4.0,
451
+ ls_init_value: float = None,
452
+ act_layer: Callable = nn.GELU,
453
+ norm_layer: Callable = LayerNorm,
454
+ scale_cosine_attn: bool = True,
455
+ scale_heads: bool = False,
456
+ scale_attn: bool = False,
457
+ scale_fc: bool = False,
458
+ cross_attn: bool = False,
459
+ xattn: bool = False,
460
+ ):
461
+ super().__init__()
462
+ self.width = width
463
+ self.layers = layers
464
+ self.grad_checkpointing = False
465
+ self.xattn = xattn
466
+
467
+ self.resblocks = nn.ModuleList([
468
+ CustomResidualAttentionBlock(
469
+ width,
470
+ heads,
471
+ mlp_ratio,
472
+ ls_init_value=ls_init_value,
473
+ act_layer=act_layer,
474
+ norm_layer=norm_layer,
475
+ scale_cosine_attn=scale_cosine_attn,
476
+ scale_heads=scale_heads,
477
+ scale_attn=scale_attn,
478
+ scale_fc=scale_fc,
479
+ cross_attn=cross_attn,
480
+ xattn=xattn)
481
+ for _ in range(layers)
482
+ ])
483
+
484
+ def get_cast_dtype(self) -> torch.dtype:
485
+ return self.resblocks[0].mlp.c_fc.weight.dtype
486
+
487
+ def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):
488
+ if k is None and v is None:
489
+ k = v = q
490
+ for r in self.resblocks:
491
+ if self.grad_checkpointing and not torch.jit.is_scripting():
492
+ q = checkpoint(r, q, k, v, attn_mask)
493
+ else:
494
+ q = r(q, k, v, attn_mask=attn_mask)
495
+ return q
496
+
497
+
498
+ class ResidualAttentionBlock(nn.Module):
499
+ def __init__(
500
+ self,
501
+ d_model: int,
502
+ n_head: int,
503
+ mlp_ratio: float = 4.0,
504
+ ls_init_value: float = None,
505
+ act_layer: Callable = nn.GELU,
506
+ norm_layer: Callable = LayerNorm,
507
+ xattn: bool = False,
508
+ ):
509
+ super().__init__()
510
+
511
+ self.ln_1 = norm_layer(d_model)
512
+ if xattn:
513
+ self.attn = Attention(d_model, n_head, xattn=True)
514
+ else:
515
+ self.attn = nn.MultiheadAttention(d_model, n_head)
516
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
517
+
518
+ self.ln_2 = norm_layer(d_model)
519
+ mlp_width = int(d_model * mlp_ratio)
520
+ self.mlp = nn.Sequential(OrderedDict([
521
+ ("c_fc", nn.Linear(d_model, mlp_width)),
522
+ ("gelu", act_layer()),
523
+ ("c_proj", nn.Linear(mlp_width, d_model))
524
+ ]))
525
+
526
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
527
+ self.xattn = xattn
528
+
529
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
530
+ attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
531
+ if self.xattn:
532
+ return self.attn(x, attn_mask=attn_mask)
533
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
534
+
535
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
536
+ x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
537
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
538
+ return x
539
+
540
+ class Transformer(nn.Module):
541
+ def __init__(
542
+ self,
543
+ width: int,
544
+ layers: int,
545
+ heads: int,
546
+ mlp_ratio: float = 4.0,
547
+ ls_init_value: float = None,
548
+ act_layer: Callable = nn.GELU,
549
+ norm_layer: Callable = LayerNorm,
550
+ xattn: bool = False,
551
+ ):
552
+ super().__init__()
553
+ self.width = width
554
+ self.layers = layers
555
+ self.grad_checkpointing = False
556
+
557
+ self.resblocks = nn.ModuleList([
558
+ ResidualAttentionBlock(
559
+ width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
560
+ for _ in range(layers)
561
+ ])
562
+
563
+ def get_cast_dtype(self) -> torch.dtype:
564
+ return self.resblocks[0].mlp.c_fc.weight.dtype
565
+
566
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
567
+ for r in self.resblocks:
568
+ if self.grad_checkpointing and not torch.jit.is_scripting():
569
+ x = checkpoint(r, x, attn_mask)
570
+ else:
571
+ x = r(x, attn_mask=attn_mask)
572
+ return x
573
+
574
+
575
+ class VisionTransformer(nn.Module):
576
+ def __init__(
577
+ self,
578
+ image_size: int,
579
+ patch_size: int,
580
+ width: int,
581
+ layers: int,
582
+ heads: int,
583
+ mlp_ratio: float,
584
+ ls_init_value: float = None,
585
+ patch_dropout: float = 0.,
586
+ global_average_pool: bool = False,
587
+ output_dim: int = 512,
588
+ act_layer: Callable = nn.GELU,
589
+ norm_layer: Callable = LayerNorm,
590
+ xattn: bool = False,
591
+ ):
592
+ super().__init__()
593
+ self.image_size = to_2tuple(image_size)
594
+ self.patch_size = to_2tuple(patch_size)
595
+ self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
596
+ self.output_dim = output_dim
597
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
598
+
599
+ scale = width ** -0.5
600
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
601
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
602
+
603
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
604
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
605
+ self.ln_pre = norm_layer(width)
606
+
607
+ self.transformer = Transformer(
608
+ width,
609
+ layers,
610
+ heads,
611
+ mlp_ratio,
612
+ ls_init_value=ls_init_value,
613
+ act_layer=act_layer,
614
+ norm_layer=norm_layer,
615
+ xattn=xattn
616
+ )
617
+
618
+ self.global_average_pool = global_average_pool
619
+ self.ln_post = norm_layer(width)
620
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
621
+
622
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
623
+ for param in self.parameters():
624
+ param.requires_grad = False
625
+
626
+ if unlocked_groups != 0:
627
+ groups = [
628
+ [
629
+ self.conv1,
630
+ self.class_embedding,
631
+ self.positional_embedding,
632
+ self.ln_pre,
633
+ ],
634
+ *self.transformer.resblocks[:-1],
635
+ [
636
+ self.transformer.resblocks[-1],
637
+ self.ln_post,
638
+ ],
639
+ self.proj,
640
+ ]
641
+
642
+ def _unlock(x):
643
+ if isinstance(x, Sequence):
644
+ for g in x:
645
+ _unlock(g)
646
+ else:
647
+ if isinstance(x, torch.nn.Parameter):
648
+ x.requires_grad = True
649
+ else:
650
+ for p in x.parameters():
651
+ p.requires_grad = True
652
+
653
+ _unlock(groups[-unlocked_groups:])
654
+
655
+ def get_num_layers(self):
656
+ return self.transformer.layers
657
+
658
+ @torch.jit.ignore
659
+ def set_grad_checkpointing(self, enable=True):
660
+ self.transformer.grad_checkpointing = enable
661
+
662
+ @torch.jit.ignore
663
+ def no_weight_decay(self):
664
+ return {'positional_embedding', 'class_embedding'}
665
+
666
+ def forward(self, x: torch.Tensor, return_all_features: bool=False):
667
+ x = self.conv1(x) # shape = [*, width, grid, grid]
668
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
669
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
670
+ x = torch.cat(
671
+ [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
672
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
673
+ x = x + self.positional_embedding.to(x.dtype)
674
+
675
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
676
+ x = self.patch_dropout(x)
677
+ x = self.ln_pre(x)
678
+
679
+ x = x.permute(1, 0, 2) # NLD -> LND
680
+ x = self.transformer(x)
681
+ x = x.permute(1, 0, 2) # LND -> NLD
682
+
683
+ if not return_all_features:
684
+ if self.global_average_pool:
685
+ x = x.mean(dim=1) #x = x[:,1:,:].mean(dim=1)
686
+ else:
687
+ x = x[:, 0]
688
+
689
+ x = self.ln_post(x)
690
+
691
+ if self.proj is not None:
692
+ x = x @ self.proj
693
+
694
+ return x
695
+
696
+
697
+ class TextTransformer(nn.Module):
698
+ def __init__(
699
+ self,
700
+ context_length: int = 77,
701
+ vocab_size: int = 49408,
702
+ width: int = 512,
703
+ heads: int = 8,
704
+ layers: int = 12,
705
+ ls_init_value: float = None,
706
+ output_dim: int = 512,
707
+ act_layer: Callable = nn.GELU,
708
+ norm_layer: Callable = LayerNorm,
709
+ xattn: bool= False,
710
+ attn_mask: bool = True
711
+ ):
712
+ super().__init__()
713
+ self.context_length = context_length
714
+ self.vocab_size = vocab_size
715
+ self.width = width
716
+ self.output_dim = output_dim
717
+
718
+ self.token_embedding = nn.Embedding(vocab_size, width)
719
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
720
+ self.transformer = Transformer(
721
+ width=width,
722
+ layers=layers,
723
+ heads=heads,
724
+ ls_init_value=ls_init_value,
725
+ act_layer=act_layer,
726
+ norm_layer=norm_layer,
727
+ xattn=xattn
728
+ )
729
+
730
+ self.xattn = xattn
731
+ self.ln_final = norm_layer(width)
732
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
733
+
734
+ if attn_mask:
735
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
736
+ else:
737
+ self.attn_mask = None
738
+
739
+ self.init_parameters()
740
+
741
+ def init_parameters(self):
742
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
743
+ nn.init.normal_(self.positional_embedding, std=0.01)
744
+
745
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
746
+ attn_std = self.transformer.width ** -0.5
747
+ fc_std = (2 * self.transformer.width) ** -0.5
748
+ for block in self.transformer.resblocks:
749
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
750
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
751
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
752
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
753
+
754
+ if self.text_projection is not None:
755
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
756
+
757
+ @torch.jit.ignore
758
+ def set_grad_checkpointing(self, enable=True):
759
+ self.transformer.grad_checkpointing = enable
760
+
761
+ @torch.jit.ignore
762
+ def no_weight_decay(self):
763
+ # return {'positional_embedding', 'token_embedding'}
764
+ return {'positional_embedding'}
765
+
766
+ def get_num_layers(self):
767
+ return self.transformer.layers
768
+
769
+ def build_attention_mask(self):
770
+ # lazily create causal attention mask, with full attention between the vision tokens
771
+ # pytorch uses additive attention mask; fill with -inf
772
+ mask = torch.empty(self.context_length, self.context_length)
773
+ mask.fill_(float("-inf"))
774
+ mask.triu_(1) # zero out the lower diagonal
775
+ return mask
776
+
777
+ def forward(self, text, return_all_features: bool=False):
778
+ cast_dtype = self.transformer.get_cast_dtype()
779
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
780
+
781
+ x = x + self.positional_embedding.to(cast_dtype)
782
+ x = x.permute(1, 0, 2) # NLD -> LND
783
+ x = self.transformer(x, attn_mask=self.attn_mask)
784
+ # x = self.transformer(x) # no attention mask is applied
785
+ x = x.permute(1, 0, 2) # LND -> NLD
786
+ x = self.ln_final(x)
787
+
788
+ if not return_all_features:
789
+ # x.shape = [batch_size, n_ctx, transformer.width]
790
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
791
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
792
+ return x
eva_clip/utils.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import repeat
2
+ import collections.abc
3
+ import logging
4
+ import math
5
+ import numpy as np
6
+
7
+ import torch
8
+ from torch import nn as nn
9
+ from torchvision.ops.misc import FrozenBatchNorm2d
10
+ import torch.nn.functional as F
11
+
12
+ # open CLIP
13
+ def resize_clip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
14
+ # Rescale the grid of position embeddings when loading from state_dict
15
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
16
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
17
+ return
18
+ grid_size = to_2tuple(model.visual.grid_size)
19
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
20
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
21
+ if new_seq_len == old_pos_embed.shape[0]:
22
+ return
23
+
24
+ if extra_tokens:
25
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
26
+ else:
27
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
28
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
29
+
30
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
31
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
32
+ pos_emb_img = F.interpolate(
33
+ pos_emb_img,
34
+ size=grid_size,
35
+ mode=interpolation,
36
+ align_corners=True,
37
+ )
38
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
39
+ if pos_emb_tok is not None:
40
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
41
+ else:
42
+ new_pos_embed = pos_emb_img
43
+ state_dict['visual.positional_embedding'] = new_pos_embed
44
+
45
+
46
+ def resize_visual_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
47
+ # Rescale the grid of position embeddings when loading from state_dict
48
+ old_pos_embed = state_dict.get('positional_embedding', None)
49
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
50
+ return
51
+ grid_size = to_2tuple(model.visual.grid_size)
52
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
53
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
54
+ if new_seq_len == old_pos_embed.shape[0]:
55
+ return
56
+
57
+ if extra_tokens:
58
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
59
+ else:
60
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
61
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
62
+
63
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
64
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
65
+ pos_emb_img = F.interpolate(
66
+ pos_emb_img,
67
+ size=grid_size,
68
+ mode=interpolation,
69
+ align_corners=True,
70
+ )
71
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
72
+ if pos_emb_tok is not None:
73
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
74
+ else:
75
+ new_pos_embed = pos_emb_img
76
+ state_dict['positional_embedding'] = new_pos_embed
77
+
78
+ def resize_evaclip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
79
+ all_keys = list(state_dict.keys())
80
+ # interpolate position embedding
81
+ if 'visual.pos_embed' in state_dict:
82
+ pos_embed_checkpoint = state_dict['visual.pos_embed']
83
+ embedding_size = pos_embed_checkpoint.shape[-1]
84
+ num_patches = model.visual.patch_embed.num_patches
85
+ num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
86
+ # height (== width) for the checkpoint position embedding
87
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
88
+ # height (== width) for the new position embedding
89
+ new_size = int(num_patches ** 0.5)
90
+ # class_token and dist_token are kept unchanged
91
+ if orig_size != new_size:
92
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
93
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
94
+ # only the position tokens are interpolated
95
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
96
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
97
+ pos_tokens = torch.nn.functional.interpolate(
98
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
99
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
100
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
101
+ state_dict['visual.pos_embed'] = new_pos_embed
102
+
103
+ patch_embed_proj = state_dict['visual.patch_embed.proj.weight']
104
+ patch_size = model.visual.patch_embed.patch_size
105
+ state_dict['visual.patch_embed.proj.weight'] = torch.nn.functional.interpolate(
106
+ patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
107
+
108
+
109
+ def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
110
+ all_keys = list(state_dict.keys())
111
+ # interpolate position embedding
112
+ if 'pos_embed' in state_dict:
113
+ pos_embed_checkpoint = state_dict['pos_embed']
114
+ embedding_size = pos_embed_checkpoint.shape[-1]
115
+ num_patches = model.visual.patch_embed.num_patches
116
+ num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
117
+ # height (== width) for the checkpoint position embedding
118
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
119
+ # height (== width) for the new position embedding
120
+ new_size = int(num_patches ** 0.5)
121
+ # class_token and dist_token are kept unchanged
122
+ if orig_size != new_size:
123
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
124
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
125
+ # only the position tokens are interpolated
126
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
127
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
128
+ pos_tokens = torch.nn.functional.interpolate(
129
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
130
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
131
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
132
+ state_dict['pos_embed'] = new_pos_embed
133
+
134
+ patch_embed_proj = state_dict['patch_embed.proj.weight']
135
+ patch_size = model.visual.patch_embed.patch_size
136
+ state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
137
+ patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
138
+
139
+
140
+ def resize_rel_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
141
+ all_keys = list(state_dict.keys())
142
+ for key in all_keys:
143
+ if "relative_position_index" in key:
144
+ state_dict.pop(key)
145
+
146
+ if "relative_position_bias_table" in key:
147
+ rel_pos_bias = state_dict[key]
148
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
149
+ dst_num_pos, _ = model.visual.state_dict()[key].size()
150
+ dst_patch_shape = model.visual.patch_embed.patch_shape
151
+ if dst_patch_shape[0] != dst_patch_shape[1]:
152
+ raise NotImplementedError()
153
+ num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
154
+ src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
155
+ dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
156
+ if src_size != dst_size:
157
+ print("Position interpolate for %s from %dx%d to %dx%d" % (
158
+ key, src_size, src_size, dst_size, dst_size))
159
+ extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
160
+ rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
161
+
162
+ def geometric_progression(a, r, n):
163
+ return a * (1.0 - r ** n) / (1.0 - r)
164
+
165
+ left, right = 1.01, 1.5
166
+ while right - left > 1e-6:
167
+ q = (left + right) / 2.0
168
+ gp = geometric_progression(1, q, src_size // 2)
169
+ if gp > dst_size // 2:
170
+ right = q
171
+ else:
172
+ left = q
173
+
174
+ # if q > 1.090307:
175
+ # q = 1.090307
176
+
177
+ dis = []
178
+ cur = 1
179
+ for i in range(src_size // 2):
180
+ dis.append(cur)
181
+ cur += q ** (i + 1)
182
+
183
+ r_ids = [-_ for _ in reversed(dis)]
184
+
185
+ x = r_ids + [0] + dis
186
+ y = r_ids + [0] + dis
187
+
188
+ t = dst_size // 2.0
189
+ dx = np.arange(-t, t + 0.1, 1.0)
190
+ dy = np.arange(-t, t + 0.1, 1.0)
191
+
192
+ print("Original positions = %s" % str(x))
193
+ print("Target positions = %s" % str(dx))
194
+
195
+ all_rel_pos_bias = []
196
+
197
+ for i in range(num_attn_heads):
198
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
199
+ f = F.interpolate.interp2d(x, y, z, kind='cubic')
200
+ all_rel_pos_bias.append(
201
+ torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
202
+
203
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
204
+
205
+ new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
206
+ state_dict[key] = new_rel_pos_bias
207
+
208
+ # interpolate position embedding
209
+ if 'pos_embed' in state_dict:
210
+ pos_embed_checkpoint = state_dict['pos_embed']
211
+ embedding_size = pos_embed_checkpoint.shape[-1]
212
+ num_patches = model.visual.patch_embed.num_patches
213
+ num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
214
+ # height (== width) for the checkpoint position embedding
215
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
216
+ # height (== width) for the new position embedding
217
+ new_size = int(num_patches ** 0.5)
218
+ # class_token and dist_token are kept unchanged
219
+ if orig_size != new_size:
220
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
221
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
222
+ # only the position tokens are interpolated
223
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
224
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
225
+ pos_tokens = torch.nn.functional.interpolate(
226
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
227
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
228
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
229
+ state_dict['pos_embed'] = new_pos_embed
230
+
231
+ patch_embed_proj = state_dict['patch_embed.proj.weight']
232
+ patch_size = model.visual.patch_embed.patch_size
233
+ state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
234
+ patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
235
+
236
+
237
+ def freeze_batch_norm_2d(module, module_match={}, name=''):
238
+ """
239
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
240
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
241
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
242
+
243
+ Args:
244
+ module (torch.nn.Module): Any PyTorch module.
245
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
246
+ name (str): Full module name (prefix)
247
+
248
+ Returns:
249
+ torch.nn.Module: Resulting module
250
+
251
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
252
+ """
253
+ res = module
254
+ is_match = True
255
+ if module_match:
256
+ is_match = name in module_match
257
+ if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
258
+ res = FrozenBatchNorm2d(module.num_features)
259
+ res.num_features = module.num_features
260
+ res.affine = module.affine
261
+ if module.affine:
262
+ res.weight.data = module.weight.data.clone().detach()
263
+ res.bias.data = module.bias.data.clone().detach()
264
+ res.running_mean.data = module.running_mean.data
265
+ res.running_var.data = module.running_var.data
266
+ res.eps = module.eps
267
+ else:
268
+ for child_name, child in module.named_children():
269
+ full_child_name = '.'.join([name, child_name]) if name else child_name
270
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
271
+ if new_child is not child:
272
+ res.add_module(child_name, new_child)
273
+ return res
274
+
275
+
276
+ # From PyTorch internals
277
+ def _ntuple(n):
278
+ def parse(x):
279
+ if isinstance(x, collections.abc.Iterable):
280
+ return x
281
+ return tuple(repeat(x, n))
282
+ return parse
283
+
284
+
285
+ to_1tuple = _ntuple(1)
286
+ to_2tuple = _ntuple(2)
287
+ to_3tuple = _ntuple(3)
288
+ to_4tuple = _ntuple(4)
289
+ to_ntuple = lambda n, x: _ntuple(n)(x)
290
+
291
+
292
+ def is_logging(args):
293
+ def is_global_master(args):
294
+ return args.rank == 0
295
+
296
+ def is_local_master(args):
297
+ return args.local_rank == 0
298
+
299
+ def is_master(args, local=False):
300
+ return is_local_master(args) if local else is_global_master(args)
301
+ return is_master
302
+
303
+
304
+ class AllGather(torch.autograd.Function):
305
+ """An autograd function that performs allgather on a tensor.
306
+ Performs all_gather operation on the provided tensors.
307
+ *** Warning ***: torch.distributed.all_gather has no gradient.
308
+ """
309
+
310
+ @staticmethod
311
+ def forward(ctx, tensor, rank, world_size):
312
+ tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)]
313
+ torch.distributed.all_gather(tensors_gather, tensor)
314
+ ctx.rank = rank
315
+ ctx.batch_size = tensor.shape[0]
316
+ return torch.cat(tensors_gather, 0)
317
+
318
+ @staticmethod
319
+ def backward(ctx, grad_output):
320
+ return (
321
+ grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)],
322
+ None,
323
+ None
324
+ )
325
+
326
+ allgather = AllGather.apply
example_inputs/hinton.jpeg ADDED
example_inputs/lecun.jpg ADDED
example_inputs/lifeifei.jpg ADDED
example_inputs/liuyifei.png ADDED
example_inputs/rihanna.webp ADDED
example_inputs/zcy.webp ADDED
flux/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from ._version import version as __version__ # type: ignore
3
+ from ._version import version_tuple
4
+ except ImportError:
5
+ __version__ = "unknown (no version information available)"
6
+ version_tuple = (0, 0, "unknown", "noinfo")
7
+
8
+ from pathlib import Path
9
+
10
+ PACKAGE = __package__.replace("_", "-")
11
+ PACKAGE_ROOT = Path(__file__).parent
flux/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .cli import app
2
+
3
+ if __name__ == "__main__":
4
+ app()
flux/api.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import requests
7
+ from PIL import Image
8
+
9
+ API_ENDPOINT = "https://api.bfl.ml"
10
+
11
+
12
+ class ApiException(Exception):
13
+ def __init__(self, status_code: int, detail: str = None):
14
+ super().__init__()
15
+ self.detail = detail
16
+ self.status_code = status_code
17
+
18
+ def __str__(self) -> str:
19
+ return self.__repr__()
20
+
21
+ def __repr__(self) -> str:
22
+ if self.detail is None:
23
+ message = None
24
+ elif isinstance(self.detail, str):
25
+ message = self.detail
26
+ else:
27
+ message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
28
+ return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
29
+
30
+
31
+ class ImageRequest:
32
+ def __init__(
33
+ self,
34
+ prompt: str,
35
+ width: int = 1024,
36
+ height: int = 1024,
37
+ name: str = "flux.1-pro",
38
+ num_steps: int = 50,
39
+ prompt_upsampling: bool = False,
40
+ seed: int = None,
41
+ validate: bool = True,
42
+ launch: bool = True,
43
+ api_key: str = None,
44
+ ):
45
+ """
46
+ Manages an image generation request to the API.
47
+
48
+ Args:
49
+ prompt: Prompt to sample
50
+ width: Width of the image in pixel
51
+ height: Height of the image in pixel
52
+ name: Name of the model
53
+ num_steps: Number of network evaluations
54
+ prompt_upsampling: Use prompt upsampling
55
+ seed: Fix the generation seed
56
+ validate: Run input validation
57
+ launch: Directly launches request
58
+ api_key: Your API key if not provided by the environment
59
+
60
+ Raises:
61
+ ValueError: For invalid input
62
+ ApiException: For errors raised from the API
63
+ """
64
+ if validate:
65
+ if name not in ["flux.1-pro"]:
66
+ raise ValueError(f"Invalid model {name}")
67
+ elif width % 32 != 0:
68
+ raise ValueError(f"width must be divisible by 32, got {width}")
69
+ elif not (256 <= width <= 1440):
70
+ raise ValueError(f"width must be between 256 and 1440, got {width}")
71
+ elif height % 32 != 0:
72
+ raise ValueError(f"height must be divisible by 32, got {height}")
73
+ elif not (256 <= height <= 1440):
74
+ raise ValueError(f"height must be between 256 and 1440, got {height}")
75
+ elif not (1 <= num_steps <= 50):
76
+ raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
77
+
78
+ self.request_json = {
79
+ "prompt": prompt,
80
+ "width": width,
81
+ "height": height,
82
+ "variant": name,
83
+ "steps": num_steps,
84
+ "prompt_upsampling": prompt_upsampling,
85
+ }
86
+ if seed is not None:
87
+ self.request_json["seed"] = seed
88
+
89
+ self.request_id: str = None
90
+ self.result: dict = None
91
+ self._image_bytes: bytes = None
92
+ self._url: str = None
93
+ if api_key is None:
94
+ self.api_key = os.environ.get("BFL_API_KEY")
95
+ else:
96
+ self.api_key = api_key
97
+
98
+ if launch:
99
+ self.request()
100
+
101
+ def request(self):
102
+ """
103
+ Request to generate the image.
104
+ """
105
+ if self.request_id is not None:
106
+ return
107
+ response = requests.post(
108
+ f"{API_ENDPOINT}/v1/image",
109
+ headers={
110
+ "accept": "application/json",
111
+ "x-key": self.api_key,
112
+ "Content-Type": "application/json",
113
+ },
114
+ json=self.request_json,
115
+ )
116
+ result = response.json()
117
+ if response.status_code != 200:
118
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
119
+ self.request_id = response.json()["id"]
120
+
121
+ def retrieve(self) -> dict:
122
+ """
123
+ Wait for the generation to finish and retrieve response.
124
+ """
125
+ if self.request_id is None:
126
+ self.request()
127
+ while self.result is None:
128
+ response = requests.get(
129
+ f"{API_ENDPOINT}/v1/get_result",
130
+ headers={
131
+ "accept": "application/json",
132
+ "x-key": self.api_key,
133
+ },
134
+ params={
135
+ "id": self.request_id,
136
+ },
137
+ )
138
+ result = response.json()
139
+ if "status" not in result:
140
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
141
+ elif result["status"] == "Ready":
142
+ self.result = result["result"]
143
+ elif result["status"] == "Pending":
144
+ time.sleep(0.5)
145
+ else:
146
+ raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
147
+ return self.result
148
+
149
+ @property
150
+ def bytes(self) -> bytes:
151
+ """
152
+ Generated image as bytes.
153
+ """
154
+ if self._image_bytes is None:
155
+ response = requests.get(self.url)
156
+ if response.status_code == 200:
157
+ self._image_bytes = response.content
158
+ else:
159
+ raise ApiException(status_code=response.status_code)
160
+ return self._image_bytes
161
+
162
+ @property
163
+ def url(self) -> str:
164
+ """
165
+ Public url to retrieve the image from
166
+ """
167
+ if self._url is None:
168
+ result = self.retrieve()
169
+ self._url = result["sample"]
170
+ return self._url
171
+
172
+ @property
173
+ def image(self) -> Image.Image:
174
+ """
175
+ Load the image as a PIL Image
176
+ """
177
+ return Image.open(io.BytesIO(self.bytes))
178
+
179
+ def save(self, path: str):
180
+ """
181
+ Save the generated image to a local path
182
+ """
183
+ suffix = Path(self.url).suffix
184
+ if not path.endswith(suffix):
185
+ path = path + suffix
186
+ Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
187
+ with open(path, "wb") as file:
188
+ file.write(self.bytes)
189
+
190
+
191
+ if __name__ == "__main__":
192
+ from fire import Fire
193
+
194
+ Fire(ImageRequest)
flux/cli.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from dataclasses import dataclass
5
+ from glob import iglob
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from fire import Fire
10
+ from PIL import ExifTags, Image
11
+ from transformers import pipeline
12
+
13
+ from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
14
+ from flux.util import (
15
+ configs,
16
+ embed_watermark,
17
+ load_ae,
18
+ load_clip,
19
+ load_flow_model,
20
+ load_t5,
21
+ )
22
+
23
+ NSFW_THRESHOLD = 0.85
24
+
25
+
26
+ @dataclass
27
+ class SamplingOptions:
28
+ prompt: str
29
+ width: int
30
+ height: int
31
+ num_steps: int
32
+ guidance: float
33
+ seed: int
34
+
35
+
36
+ def parse_prompt(options: SamplingOptions) -> SamplingOptions:
37
+ user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
38
+ usage = (
39
+ "Usage: Either write your prompt directly, leave this field empty "
40
+ "to repeat the prompt or write a command starting with a slash:\n"
41
+ "- '/w <width>' will set the width of the generated image\n"
42
+ "- '/h <height>' will set the height of the generated image\n"
43
+ "- '/s <seed>' sets the next seed\n"
44
+ "- '/g <guidance>' sets the guidance (flux-dev only)\n"
45
+ "- '/n <steps>' sets the number of steps\n"
46
+ "- '/q' to quit"
47
+ )
48
+
49
+ while (prompt := input(user_question)).startswith("/"):
50
+ if prompt.startswith("/w"):
51
+ if prompt.count(" ") != 1:
52
+ print(f"Got invalid command '{prompt}'\n{usage}")
53
+ continue
54
+ _, width = prompt.split()
55
+ options.width = 16 * (int(width) // 16)
56
+ print(
57
+ f"Setting resolution to {options.width} x {options.height} "
58
+ f"({options.height * options.width / 1e6:.2f}MP)"
59
+ )
60
+ elif prompt.startswith("/h"):
61
+ if prompt.count(" ") != 1:
62
+ print(f"Got invalid command '{prompt}'\n{usage}")
63
+ continue
64
+ _, height = prompt.split()
65
+ options.height = 16 * (int(height) // 16)
66
+ print(
67
+ f"Setting resolution to {options.width} x {options.height} "
68
+ f"({options.height * options.width / 1e6:.2f}MP)"
69
+ )
70
+ elif prompt.startswith("/g"):
71
+ if prompt.count(" ") != 1:
72
+ print(f"Got invalid command '{prompt}'\n{usage}")
73
+ continue
74
+ _, guidance = prompt.split()
75
+ options.guidance = float(guidance)
76
+ print(f"Setting guidance to {options.guidance}")
77
+ elif prompt.startswith("/s"):
78
+ if prompt.count(" ") != 1:
79
+ print(f"Got invalid command '{prompt}'\n{usage}")
80
+ continue
81
+ _, seed = prompt.split()
82
+ options.seed = int(seed)
83
+ print(f"Setting seed to {options.seed}")
84
+ elif prompt.startswith("/n"):
85
+ if prompt.count(" ") != 1:
86
+ print(f"Got invalid command '{prompt}'\n{usage}")
87
+ continue
88
+ _, steps = prompt.split()
89
+ options.num_steps = int(steps)
90
+ print(f"Setting seed to {options.num_steps}")
91
+ elif prompt.startswith("/q"):
92
+ print("Quitting")
93
+ return None
94
+ else:
95
+ if not prompt.startswith("/h"):
96
+ print(f"Got invalid command '{prompt}'\n{usage}")
97
+ print(usage)
98
+ if prompt != "":
99
+ options.prompt = prompt
100
+ return options
101
+
102
+
103
+ @torch.inference_mode()
104
+ def main(
105
+ name: str = "flux-schnell",
106
+ width: int = 1360,
107
+ height: int = 768,
108
+ seed: int = None,
109
+ prompt: str = (
110
+ "a photo of a forest with mist swirling around the tree trunks. The word "
111
+ '"FLUX" is painted over it in big, red brush strokes with visible texture'
112
+ ),
113
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
114
+ num_steps: int = None,
115
+ loop: bool = False,
116
+ guidance: float = 3.5,
117
+ offload: bool = False,
118
+ output_dir: str = "output",
119
+ add_sampling_metadata: bool = True,
120
+ ):
121
+ """
122
+ Sample the flux model. Either interactively (set `--loop`) or run for a
123
+ single image.
124
+
125
+ Args:
126
+ name: Name of the model to load
127
+ height: height of the sample in pixels (should be a multiple of 16)
128
+ width: width of the sample in pixels (should be a multiple of 16)
129
+ seed: Set a seed for sampling
130
+ output_name: where to save the output image, `{idx}` will be replaced
131
+ by the index of the sample
132
+ prompt: Prompt used for sampling
133
+ device: Pytorch device
134
+ num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
135
+ loop: start an interactive session and sample multiple times
136
+ guidance: guidance value used for guidance distillation
137
+ add_sampling_metadata: Add the prompt to the image Exif metadata
138
+ """
139
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
140
+
141
+ if name not in configs:
142
+ available = ", ".join(configs.keys())
143
+ raise ValueError(f"Got unknown model name: {name}, chose from {available}")
144
+
145
+ torch_device = torch.device(device)
146
+ if num_steps is None:
147
+ num_steps = 4 if name == "flux-schnell" else 50
148
+
149
+ # allow for packing and conversion to latent space
150
+ height = 16 * (height // 16)
151
+ width = 16 * (width // 16)
152
+
153
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
154
+ if not os.path.exists(output_dir):
155
+ os.makedirs(output_dir)
156
+ idx = 0
157
+ else:
158
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]\.jpg$", fn)]
159
+ if len(fns) > 0:
160
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
161
+ else:
162
+ idx = 0
163
+
164
+ # init all components
165
+ t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
166
+ clip = load_clip(torch_device)
167
+ model = load_flow_model(name, device="cpu" if offload else torch_device)
168
+ ae = load_ae(name, device="cpu" if offload else torch_device)
169
+
170
+ rng = torch.Generator(device="cpu")
171
+ opts = SamplingOptions(
172
+ prompt=prompt,
173
+ width=width,
174
+ height=height,
175
+ num_steps=num_steps,
176
+ guidance=guidance,
177
+ seed=seed,
178
+ )
179
+
180
+ if loop:
181
+ opts = parse_prompt(opts)
182
+
183
+ while opts is not None:
184
+ if opts.seed is None:
185
+ opts.seed = rng.seed()
186
+ print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
187
+ t0 = time.perf_counter()
188
+
189
+ # prepare input
190
+ x = get_noise(
191
+ 1,
192
+ opts.height,
193
+ opts.width,
194
+ device=torch_device,
195
+ dtype=torch.bfloat16,
196
+ seed=opts.seed,
197
+ )
198
+ opts.seed = None
199
+ if offload:
200
+ ae = ae.cpu()
201
+ torch.cuda.empty_cache()
202
+ t5, clip = t5.to(torch_device), clip.to(torch_device)
203
+ inp = prepare(t5, clip, x, prompt=opts.prompt)
204
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
205
+
206
+ # offload TEs to CPU, load model to gpu
207
+ if offload:
208
+ t5, clip = t5.cpu(), clip.cpu()
209
+ torch.cuda.empty_cache()
210
+ model = model.to(torch_device)
211
+
212
+ # denoise initial noise
213
+ x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
214
+
215
+ # offload model, load autoencoder to gpu
216
+ if offload:
217
+ model.cpu()
218
+ torch.cuda.empty_cache()
219
+ ae.decoder.to(x.device)
220
+
221
+ # decode latents to pixel space
222
+ x = unpack(x.float(), opts.height, opts.width)
223
+ with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
224
+ x = ae.decode(x)
225
+ t1 = time.perf_counter()
226
+
227
+ fn = output_name.format(idx=idx)
228
+ print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
229
+ # bring into PIL format and save
230
+ x = x.clamp(-1, 1)
231
+ x = embed_watermark(x.float())
232
+ x = rearrange(x[0], "c h w -> h w c")
233
+
234
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
235
+ nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
236
+
237
+ if nsfw_score < NSFW_THRESHOLD:
238
+ exif_data = Image.Exif()
239
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
240
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
241
+ exif_data[ExifTags.Base.Model] = name
242
+ if add_sampling_metadata:
243
+ exif_data[ExifTags.Base.ImageDescription] = prompt
244
+ img.save(fn, exif=exif_data, quality=95, subsampling=0)
245
+ idx += 1
246
+ else:
247
+ print("Your generated image may contain NSFW content.")
248
+
249
+ if loop:
250
+ print("-" * 80)
251
+ opts = parse_prompt(opts)
252
+ else:
253
+ opts = None
254
+
255
+
256
+ def app():
257
+ Fire(main)
258
+
259
+
260
+ if __name__ == "__main__":
261
+ app()
flux/math.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor
4
+
5
+
6
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
7
+ if pe is not None:
8
+ q, k = apply_rope(q, k, pe)
9
+
10
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
11
+ x = rearrange(x, "B H L D -> B L (H D)")
12
+
13
+ return x
14
+
15
+
16
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
17
+ assert dim % 2 == 0
18
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
19
+ omega = 1.0 / (theta**scale)
20
+ out = torch.einsum("...n,d->...nd", pos, omega)
21
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
22
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
23
+ return out.float()
24
+
25
+
26
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
27
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
28
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
29
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
30
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
31
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
flux/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from flux.modules.layers import (
7
+ DoubleStreamBlock,
8
+ EmbedND,
9
+ LastLayer,
10
+ MLPEmbedder,
11
+ SingleStreamBlock,
12
+ timestep_embedding,
13
+ )
14
+
15
+
16
+ @dataclass
17
+ class FluxParams:
18
+ in_channels: int
19
+ vec_in_dim: int
20
+ context_in_dim: int
21
+ hidden_size: int
22
+ mlp_ratio: float
23
+ num_heads: int
24
+ depth: int
25
+ depth_single_blocks: int
26
+ axes_dim: list[int]
27
+ theta: int
28
+ qkv_bias: bool
29
+ guidance_embed: bool
30
+
31
+
32
+ class Flux(nn.Module):
33
+ """
34
+ Transformer model for flow matching on sequences.
35
+ """
36
+
37
+ def __init__(self, params: FluxParams):
38
+ super().__init__()
39
+
40
+ self.params = params
41
+ self.in_channels = params.in_channels
42
+ self.out_channels = self.in_channels
43
+ if params.hidden_size % params.num_heads != 0:
44
+ raise ValueError(
45
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
46
+ )
47
+ pe_dim = params.hidden_size // params.num_heads
48
+ if sum(params.axes_dim) != pe_dim:
49
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
50
+ self.hidden_size = params.hidden_size
51
+ self.num_heads = params.num_heads
52
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
53
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
54
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
55
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
56
+ self.guidance_in = (
57
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
58
+ )
59
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
60
+
61
+ self.double_blocks = nn.ModuleList(
62
+ [
63
+ DoubleStreamBlock(
64
+ self.hidden_size,
65
+ self.num_heads,
66
+ mlp_ratio=params.mlp_ratio,
67
+ qkv_bias=params.qkv_bias,
68
+ )
69
+ for _ in range(params.depth)
70
+ ]
71
+ )
72
+
73
+ self.single_blocks = nn.ModuleList(
74
+ [
75
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
76
+ for _ in range(params.depth_single_blocks)
77
+ ]
78
+ )
79
+
80
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
81
+
82
+ self.pulid_ca = None
83
+ self.pulid_double_interval = 2
84
+ self.pulid_single_interval = 4
85
+
86
+ def forward(
87
+ self,
88
+ img: Tensor,
89
+ img_ids: Tensor,
90
+ txt: Tensor,
91
+ txt_ids: Tensor,
92
+ timesteps: Tensor,
93
+ y: Tensor,
94
+ guidance: Tensor = None,
95
+ id: Tensor = None,
96
+ id_weight: float = 1.0,
97
+ ) -> Tensor:
98
+ if img.ndim != 3 or txt.ndim != 3:
99
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
100
+
101
+ # running on sequences img
102
+ img = self.img_in(img)
103
+ vec = self.time_in(timestep_embedding(timesteps, 256))
104
+ if self.params.guidance_embed:
105
+ if guidance is None:
106
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
107
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
108
+ vec = vec + self.vector_in(y)
109
+ txt = self.txt_in(txt)
110
+
111
+ ids = torch.cat((txt_ids, img_ids), dim=1)
112
+ pe = self.pe_embedder(ids)
113
+
114
+ ca_idx = 0
115
+ for i, block in enumerate(self.double_blocks):
116
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
117
+
118
+ if i % self.pulid_double_interval == 0 and id is not None:
119
+ img = img + id_weight * self.pulid_ca[ca_idx](id, img)
120
+ ca_idx += 1
121
+
122
+ img = torch.cat((txt, img), 1)
123
+ for i, block in enumerate(self.single_blocks):
124
+ x = block(img, vec=vec, pe=pe)
125
+ real_img, txt = x[:, txt.shape[1]:, ...], x[:, :txt.shape[1], ...]
126
+
127
+ if i % self.pulid_single_interval == 0 and id is not None:
128
+ real_img = real_img + id_weight * self.pulid_ca[ca_idx](id, real_img)
129
+ ca_idx += 1
130
+
131
+ img = torch.cat((txt, real_img), 1)
132
+ img = img[:, txt.shape[1] :, ...]
133
+
134
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
135
+ return img
flux/modules/__init__.py ADDED
File without changes
flux/modules/autoencoder.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor, nn
6
+
7
+
8
+ @dataclass
9
+ class AutoEncoderParams:
10
+ resolution: int
11
+ in_channels: int
12
+ ch: int
13
+ out_ch: int
14
+ ch_mult: list[int]
15
+ num_res_blocks: int
16
+ z_channels: int
17
+ scale_factor: float
18
+ shift_factor: float
19
+
20
+
21
+ def swish(x: Tensor) -> Tensor:
22
+ return x * torch.sigmoid(x)
23
+
24
+
25
+ class AttnBlock(nn.Module):
26
+ def __init__(self, in_channels: int):
27
+ super().__init__()
28
+ self.in_channels = in_channels
29
+
30
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
31
+
32
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+
37
+ def attention(self, h_: Tensor) -> Tensor:
38
+ h_ = self.norm(h_)
39
+ q = self.q(h_)
40
+ k = self.k(h_)
41
+ v = self.v(h_)
42
+
43
+ b, c, h, w = q.shape
44
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
+
49
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return x + self.proj_out(self.attention(x))
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
63
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
64
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
65
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
66
+ if self.in_channels != self.out_channels:
67
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
68
+
69
+ def forward(self, x):
70
+ h = x
71
+ h = self.norm1(h)
72
+ h = swish(h)
73
+ h = self.conv1(h)
74
+
75
+ h = self.norm2(h)
76
+ h = swish(h)
77
+ h = self.conv2(h)
78
+
79
+ if self.in_channels != self.out_channels:
80
+ x = self.nin_shortcut(x)
81
+
82
+ return x + h
83
+
84
+
85
+ class Downsample(nn.Module):
86
+ def __init__(self, in_channels: int):
87
+ super().__init__()
88
+ # no asymmetric padding in torch conv, must do it ourselves
89
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
90
+
91
+ def forward(self, x: Tensor):
92
+ pad = (0, 1, 0, 1)
93
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
94
+ x = self.conv(x)
95
+ return x
96
+
97
+
98
+ class Upsample(nn.Module):
99
+ def __init__(self, in_channels: int):
100
+ super().__init__()
101
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
102
+
103
+ def forward(self, x: Tensor):
104
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
105
+ x = self.conv(x)
106
+ return x
107
+
108
+
109
+ class Encoder(nn.Module):
110
+ def __init__(
111
+ self,
112
+ resolution: int,
113
+ in_channels: int,
114
+ ch: int,
115
+ ch_mult: list[int],
116
+ num_res_blocks: int,
117
+ z_channels: int,
118
+ ):
119
+ super().__init__()
120
+ self.ch = ch
121
+ self.num_resolutions = len(ch_mult)
122
+ self.num_res_blocks = num_res_blocks
123
+ self.resolution = resolution
124
+ self.in_channels = in_channels
125
+ # downsampling
126
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
127
+
128
+ curr_res = resolution
129
+ in_ch_mult = (1,) + tuple(ch_mult)
130
+ self.in_ch_mult = in_ch_mult
131
+ self.down = nn.ModuleList()
132
+ block_in = self.ch
133
+ for i_level in range(self.num_resolutions):
134
+ block = nn.ModuleList()
135
+ attn = nn.ModuleList()
136
+ block_in = ch * in_ch_mult[i_level]
137
+ block_out = ch * ch_mult[i_level]
138
+ for _ in range(self.num_res_blocks):
139
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
140
+ block_in = block_out
141
+ down = nn.Module()
142
+ down.block = block
143
+ down.attn = attn
144
+ if i_level != self.num_resolutions - 1:
145
+ down.downsample = Downsample(block_in)
146
+ curr_res = curr_res // 2
147
+ self.down.append(down)
148
+
149
+ # middle
150
+ self.mid = nn.Module()
151
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
152
+ self.mid.attn_1 = AttnBlock(block_in)
153
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
154
+
155
+ # end
156
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
157
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
158
+
159
+ def forward(self, x: Tensor) -> Tensor:
160
+ # downsampling
161
+ hs = [self.conv_in(x)]
162
+ for i_level in range(self.num_resolutions):
163
+ for i_block in range(self.num_res_blocks):
164
+ h = self.down[i_level].block[i_block](hs[-1])
165
+ if len(self.down[i_level].attn) > 0:
166
+ h = self.down[i_level].attn[i_block](h)
167
+ hs.append(h)
168
+ if i_level != self.num_resolutions - 1:
169
+ hs.append(self.down[i_level].downsample(hs[-1]))
170
+
171
+ # middle
172
+ h = hs[-1]
173
+ h = self.mid.block_1(h)
174
+ h = self.mid.attn_1(h)
175
+ h = self.mid.block_2(h)
176
+ # end
177
+ h = self.norm_out(h)
178
+ h = swish(h)
179
+ h = self.conv_out(h)
180
+ return h
181
+
182
+
183
+ class Decoder(nn.Module):
184
+ def __init__(
185
+ self,
186
+ ch: int,
187
+ out_ch: int,
188
+ ch_mult: list[int],
189
+ num_res_blocks: int,
190
+ in_channels: int,
191
+ resolution: int,
192
+ z_channels: int,
193
+ ):
194
+ super().__init__()
195
+ self.ch = ch
196
+ self.num_resolutions = len(ch_mult)
197
+ self.num_res_blocks = num_res_blocks
198
+ self.resolution = resolution
199
+ self.in_channels = in_channels
200
+ self.ffactor = 2 ** (self.num_resolutions - 1)
201
+
202
+ # compute in_ch_mult, block_in and curr_res at lowest res
203
+ block_in = ch * ch_mult[self.num_resolutions - 1]
204
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
205
+ self.z_shape = (1, z_channels, curr_res, curr_res)
206
+
207
+ # z to block_in
208
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
209
+
210
+ # middle
211
+ self.mid = nn.Module()
212
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213
+ self.mid.attn_1 = AttnBlock(block_in)
214
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215
+
216
+ # upsampling
217
+ self.up = nn.ModuleList()
218
+ for i_level in reversed(range(self.num_resolutions)):
219
+ block = nn.ModuleList()
220
+ attn = nn.ModuleList()
221
+ block_out = ch * ch_mult[i_level]
222
+ for _ in range(self.num_res_blocks + 1):
223
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
224
+ block_in = block_out
225
+ up = nn.Module()
226
+ up.block = block
227
+ up.attn = attn
228
+ if i_level != 0:
229
+ up.upsample = Upsample(block_in)
230
+ curr_res = curr_res * 2
231
+ self.up.insert(0, up) # prepend to get consistent order
232
+
233
+ # end
234
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
235
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
236
+
237
+ def forward(self, z: Tensor) -> Tensor:
238
+ # z to block_in
239
+ h = self.conv_in(z)
240
+
241
+ # middle
242
+ h = self.mid.block_1(h)
243
+ h = self.mid.attn_1(h)
244
+ h = self.mid.block_2(h)
245
+
246
+ # upsampling
247
+ for i_level in reversed(range(self.num_resolutions)):
248
+ for i_block in range(self.num_res_blocks + 1):
249
+ h = self.up[i_level].block[i_block](h)
250
+ if len(self.up[i_level].attn) > 0:
251
+ h = self.up[i_level].attn[i_block](h)
252
+ if i_level != 0:
253
+ h = self.up[i_level].upsample(h)
254
+
255
+ # end
256
+ h = self.norm_out(h)
257
+ h = swish(h)
258
+ h = self.conv_out(h)
259
+ return h
260
+
261
+
262
+ class DiagonalGaussian(nn.Module):
263
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
264
+ super().__init__()
265
+ self.sample = sample
266
+ self.chunk_dim = chunk_dim
267
+
268
+ def forward(self, z: Tensor) -> Tensor:
269
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
270
+ if self.sample:
271
+ std = torch.exp(0.5 * logvar)
272
+ return mean + std * torch.randn_like(mean)
273
+ else:
274
+ return mean
275
+
276
+
277
+ class AutoEncoder(nn.Module):
278
+ def __init__(self, params: AutoEncoderParams):
279
+ super().__init__()
280
+ self.encoder = Encoder(
281
+ resolution=params.resolution,
282
+ in_channels=params.in_channels,
283
+ ch=params.ch,
284
+ ch_mult=params.ch_mult,
285
+ num_res_blocks=params.num_res_blocks,
286
+ z_channels=params.z_channels,
287
+ )
288
+ self.decoder = Decoder(
289
+ resolution=params.resolution,
290
+ in_channels=params.in_channels,
291
+ ch=params.ch,
292
+ out_ch=params.out_ch,
293
+ ch_mult=params.ch_mult,
294
+ num_res_blocks=params.num_res_blocks,
295
+ z_channels=params.z_channels,
296
+ )
297
+ self.reg = DiagonalGaussian()
298
+
299
+ self.scale_factor = params.scale_factor
300
+ self.shift_factor = params.shift_factor
301
+
302
+ def encode(self, x: Tensor) -> Tensor:
303
+ z = self.reg(self.encoder(x))
304
+ z = self.scale_factor * (z - self.shift_factor)
305
+ return z
306
+
307
+ def decode(self, z: Tensor) -> Tensor:
308
+ z = z / self.scale_factor + self.shift_factor
309
+ return self.decoder(z)
310
+
311
+ def forward(self, x: Tensor) -> Tensor:
312
+ return self.decode(self.encode(x))
flux/modules/conditioner.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn
2
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
3
+
4
+
5
+ class HFEmbedder(nn.Module):
6
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
7
+ super().__init__()
8
+ self.is_clip = version.startswith("openai")
9
+ self.max_length = max_length
10
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
11
+
12
+ if self.is_clip:
13
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
14
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
15
+ else:
16
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
17
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
18
+
19
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
20
+
21
+ def forward(self, text: list[str]) -> Tensor:
22
+ batch_encoding = self.tokenizer(
23
+ text,
24
+ truncation=True,
25
+ max_length=self.max_length,
26
+ return_length=False,
27
+ return_overflowing_tokens=False,
28
+ padding="max_length",
29
+ return_tensors="pt",
30
+ )
31
+
32
+ outputs = self.hf_module(
33
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
34
+ attention_mask=None,
35
+ output_hidden_states=False,
36
+ )
37
+ return outputs[self.output_key]
flux/modules/layers.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import Tensor, nn
7
+
8
+ from flux.math import attention, rope
9
+
10
+
11
+ class EmbedND(nn.Module):
12
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
13
+ super().__init__()
14
+ self.dim = dim
15
+ self.theta = theta
16
+ self.axes_dim = axes_dim
17
+
18
+ def forward(self, ids: Tensor) -> Tensor:
19
+ n_axes = ids.shape[-1]
20
+ emb = torch.cat(
21
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
22
+ dim=-3,
23
+ )
24
+
25
+ return emb.unsqueeze(1)
26
+
27
+
28
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
29
+ """
30
+ Create sinusoidal timestep embeddings.
31
+ :param t: a 1-D Tensor of N indices, one per batch element.
32
+ These may be fractional.
33
+ :param dim: the dimension of the output.
34
+ :param max_period: controls the minimum frequency of the embeddings.
35
+ :return: an (N, D) Tensor of positional embeddings.
36
+ """
37
+ t = time_factor * t
38
+ half = dim // 2
39
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
40
+ t.device
41
+ )
42
+
43
+ args = t[:, None].float() * freqs[None]
44
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
+ if dim % 2:
46
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
47
+ if torch.is_floating_point(t):
48
+ embedding = embedding.to(t)
49
+ return embedding
50
+
51
+
52
+ class MLPEmbedder(nn.Module):
53
+ def __init__(self, in_dim: int, hidden_dim: int):
54
+ super().__init__()
55
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
56
+ self.silu = nn.SiLU()
57
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
58
+
59
+ def forward(self, x: Tensor) -> Tensor:
60
+ return self.out_layer(self.silu(self.in_layer(x)))
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ def __init__(self, dim: int):
65
+ super().__init__()
66
+ self.scale = nn.Parameter(torch.ones(dim))
67
+
68
+ def forward(self, x: Tensor):
69
+ x_dtype = x.dtype
70
+ x = x.float()
71
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
72
+ return (x * rrms).to(dtype=x_dtype) * self.scale
73
+
74
+
75
+ class QKNorm(torch.nn.Module):
76
+ def __init__(self, dim: int):
77
+ super().__init__()
78
+ self.query_norm = RMSNorm(dim)
79
+ self.key_norm = RMSNorm(dim)
80
+
81
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
82
+ q = self.query_norm(q)
83
+ k = self.key_norm(k)
84
+ return q.to(v), k.to(v)
85
+
86
+
87
+ class SelfAttention(nn.Module):
88
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
89
+ super().__init__()
90
+ self.num_heads = num_heads
91
+ head_dim = dim // num_heads
92
+
93
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
94
+ self.norm = QKNorm(head_dim)
95
+ self.proj = nn.Linear(dim, dim)
96
+
97
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
98
+ qkv = self.qkv(x)
99
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
100
+ q, k = self.norm(q, k, v)
101
+ x = attention(q, k, v, pe=pe)
102
+ x = self.proj(x)
103
+ return x
104
+
105
+
106
+ @dataclass
107
+ class ModulationOut:
108
+ shift: Tensor
109
+ scale: Tensor
110
+ gate: Tensor
111
+
112
+
113
+ class Modulation(nn.Module):
114
+ def __init__(self, dim: int, double: bool):
115
+ super().__init__()
116
+ self.is_double = double
117
+ self.multiplier = 6 if double else 3
118
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
119
+
120
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]:
121
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
122
+
123
+ return (
124
+ ModulationOut(*out[:3]),
125
+ ModulationOut(*out[3:]) if self.is_double else None,
126
+ )
127
+
128
+
129
+ class DoubleStreamBlock(nn.Module):
130
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
131
+ super().__init__()
132
+
133
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
134
+ self.num_heads = num_heads
135
+ self.hidden_size = hidden_size
136
+ self.img_mod = Modulation(hidden_size, double=True)
137
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
138
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
139
+
140
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
141
+ self.img_mlp = nn.Sequential(
142
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
143
+ nn.GELU(approximate="tanh"),
144
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
145
+ )
146
+
147
+ self.txt_mod = Modulation(hidden_size, double=True)
148
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
149
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
150
+
151
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
152
+ self.txt_mlp = nn.Sequential(
153
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
154
+ nn.GELU(approximate="tanh"),
155
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
156
+ )
157
+
158
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
159
+ img_mod1, img_mod2 = self.img_mod(vec)
160
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
161
+
162
+ # prepare image for attention
163
+ img_modulated = self.img_norm1(img)
164
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
165
+ img_qkv = self.img_attn.qkv(img_modulated)
166
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
167
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
168
+
169
+ # prepare txt for attention
170
+ txt_modulated = self.txt_norm1(txt)
171
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
172
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
173
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
174
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
175
+
176
+ # run actual attention
177
+ q = torch.cat((txt_q, img_q), dim=2)
178
+ k = torch.cat((txt_k, img_k), dim=2)
179
+ v = torch.cat((txt_v, img_v), dim=2)
180
+
181
+ attn = attention(q, k, v, pe=pe)
182
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
183
+
184
+ # calculate the img bloks
185
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
186
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
187
+
188
+ # calculate the txt bloks
189
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
190
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
191
+ return img, txt
192
+
193
+
194
+ class SingleStreamBlock(nn.Module):
195
+ """
196
+ A DiT block with parallel linear layers as described in
197
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
198
+ """
199
+
200
+ def __init__(
201
+ self,
202
+ hidden_size: int,
203
+ num_heads: int,
204
+ mlp_ratio: float = 4.0,
205
+ qk_scale: float = None,
206
+ ):
207
+ super().__init__()
208
+ self.hidden_dim = hidden_size
209
+ self.num_heads = num_heads
210
+ head_dim = hidden_size // num_heads
211
+ self.scale = qk_scale or head_dim**-0.5
212
+
213
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
214
+ # qkv and mlp_in
215
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
216
+ # proj and mlp_out
217
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
218
+
219
+ self.norm = QKNorm(head_dim)
220
+
221
+ self.hidden_size = hidden_size
222
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
223
+
224
+ self.mlp_act = nn.GELU(approximate="tanh")
225
+ self.modulation = Modulation(hidden_size, double=False)
226
+
227
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
228
+ mod, _ = self.modulation(vec)
229
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
230
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
231
+
232
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
233
+ q, k = self.norm(q, k, v)
234
+
235
+ # compute attention
236
+ attn = attention(q, k, v, pe=pe)
237
+ # compute activation in mlp stream, cat again and run second linear layer
238
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
239
+ return x + mod.gate * output
240
+
241
+
242
+ class LastLayer(nn.Module):
243
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
244
+ super().__init__()
245
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
246
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
247
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
248
+
249
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
250
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
251
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
252
+ x = self.linear(x)
253
+ return x
flux/sampling.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable
3
+
4
+ import torch
5
+ from einops import rearrange, repeat
6
+ from torch import Tensor
7
+
8
+ from .model import Flux
9
+ from .modules.conditioner import HFEmbedder
10
+
11
+
12
+ def get_noise(
13
+ num_samples: int,
14
+ height: int,
15
+ width: int,
16
+ device: torch.device,
17
+ dtype: torch.dtype,
18
+ seed: int,
19
+ ):
20
+ return torch.randn(
21
+ num_samples,
22
+ 16,
23
+ # allow for packing
24
+ 2 * math.ceil(height / 16),
25
+ 2 * math.ceil(width / 16),
26
+ device=device,
27
+ dtype=dtype,
28
+ generator=torch.Generator(device=device).manual_seed(seed),
29
+ )
30
+
31
+
32
+ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str) -> dict[str, Tensor]:
33
+ bs, c, h, w = img.shape
34
+ if bs == 1 and not isinstance(prompt, str):
35
+ bs = len(prompt)
36
+
37
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
38
+ if img.shape[0] == 1 and bs > 1:
39
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
40
+
41
+ img_ids = torch.zeros(h // 2, w // 2, 3)
42
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
43
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
44
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
45
+
46
+ if isinstance(prompt, str):
47
+ prompt = [prompt]
48
+ txt = t5(prompt)
49
+ if txt.shape[0] == 1 and bs > 1:
50
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
51
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
52
+
53
+ vec = clip(prompt)
54
+ if vec.shape[0] == 1 and bs > 1:
55
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
56
+
57
+ return {
58
+ "img": img,
59
+ "img_ids": img_ids.to(img.device),
60
+ "txt": txt.to(img.device),
61
+ "txt_ids": txt_ids.to(img.device),
62
+ "vec": vec.to(img.device),
63
+ }
64
+
65
+
66
+ def time_shift(mu: float, sigma: float, t: Tensor):
67
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
68
+
69
+
70
+ def get_lin_function(
71
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
72
+ ) -> Callable[[float], float]:
73
+ m = (y2 - y1) / (x2 - x1)
74
+ b = y1 - m * x1
75
+ return lambda x: m * x + b
76
+
77
+
78
+ def get_schedule(
79
+ num_steps: int,
80
+ image_seq_len: int,
81
+ base_shift: float = 0.5,
82
+ max_shift: float = 1.15,
83
+ shift: bool = True,
84
+ ) -> list[float]:
85
+ # extra step for zero
86
+ timesteps = torch.linspace(1, 0, num_steps + 1)
87
+
88
+ # shifting the schedule to favor high timesteps for higher signal images
89
+ if shift:
90
+ # eastimate mu based on linear estimation between two points
91
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
92
+ timesteps = time_shift(mu, 1.0, timesteps)
93
+
94
+ return timesteps.tolist()
95
+
96
+
97
+ def denoise(
98
+ model: Flux,
99
+ # model input
100
+ img: Tensor,
101
+ img_ids: Tensor,
102
+ txt: Tensor,
103
+ txt_ids: Tensor,
104
+ vec: Tensor,
105
+ timesteps: list[float],
106
+ guidance: float = 4.0,
107
+ id_weight=1.0,
108
+ id=None,
109
+ start_step=0,
110
+ uncond_id=None,
111
+ true_cfg=1.0,
112
+ timestep_to_start_cfg=1,
113
+ neg_txt=None,
114
+ neg_txt_ids=None,
115
+ neg_vec=None,
116
+ ):
117
+ # this is ignored for schnell
118
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
119
+ use_true_cfg = abs(true_cfg - 1.0) > 1e-2
120
+ for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
121
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
122
+ pred = model(
123
+ img=img,
124
+ img_ids=img_ids,
125
+ txt=txt,
126
+ txt_ids=txt_ids,
127
+ y=vec,
128
+ timesteps=t_vec,
129
+ guidance=guidance_vec,
130
+ id=id if i >= start_step else None,
131
+ id_weight=id_weight,
132
+ )
133
+
134
+ if use_true_cfg and i >= timestep_to_start_cfg:
135
+ neg_pred = model(
136
+ img=img,
137
+ img_ids=img_ids,
138
+ txt=neg_txt,
139
+ txt_ids=neg_txt_ids,
140
+ y=neg_vec,
141
+ timesteps=t_vec,
142
+ guidance=guidance_vec,
143
+ id=uncond_id if i >= start_step else None,
144
+ id_weight=id_weight,
145
+ )
146
+ pred = neg_pred + true_cfg * (pred - neg_pred)
147
+
148
+ img = img + (t_prev - t_curr) * pred
149
+
150
+ return img
151
+
152
+
153
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
154
+ return rearrange(
155
+ x,
156
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
157
+ h=math.ceil(height / 16),
158
+ w=math.ceil(width / 16),
159
+ ph=2,
160
+ pw=2,
161
+ )
flux/util.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from huggingface_hub import hf_hub_download
7
+ from imwatermark import WatermarkEncoder
8
+ from safetensors.torch import load_file as load_sft
9
+
10
+ from flux.model import Flux, FluxParams
11
+ from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
12
+ from flux.modules.conditioner import HFEmbedder
13
+
14
+
15
+ @dataclass
16
+ class ModelSpec:
17
+ params: FluxParams
18
+ ae_params: AutoEncoderParams
19
+ ckpt_path: str
20
+ ae_path: str
21
+ repo_id: str
22
+ repo_flow: str
23
+ repo_ae: str
24
+
25
+
26
+ configs = {
27
+ "flux-dev": ModelSpec(
28
+ repo_id="black-forest-labs/FLUX.1-dev",
29
+ repo_flow="flux1-dev.safetensors",
30
+ repo_ae="ae.safetensors",
31
+ ckpt_path='models/flux1-dev.safetensors',
32
+ params=FluxParams(
33
+ in_channels=64,
34
+ vec_in_dim=768,
35
+ context_in_dim=4096,
36
+ hidden_size=3072,
37
+ mlp_ratio=4.0,
38
+ num_heads=24,
39
+ depth=19,
40
+ depth_single_blocks=38,
41
+ axes_dim=[16, 56, 56],
42
+ theta=10_000,
43
+ qkv_bias=True,
44
+ guidance_embed=True,
45
+ ),
46
+ ae_path='models/ae.safetensors',
47
+ ae_params=AutoEncoderParams(
48
+ resolution=256,
49
+ in_channels=3,
50
+ ch=128,
51
+ out_ch=3,
52
+ ch_mult=[1, 2, 4, 4],
53
+ num_res_blocks=2,
54
+ z_channels=16,
55
+ scale_factor=0.3611,
56
+ shift_factor=0.1159,
57
+ ),
58
+ ),
59
+ "flux-schnell": ModelSpec(
60
+ repo_id="black-forest-labs/FLUX.1-schnell",
61
+ repo_flow="flux1-schnell.safetensors",
62
+ repo_ae="ae.safetensors",
63
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
64
+ params=FluxParams(
65
+ in_channels=64,
66
+ vec_in_dim=768,
67
+ context_in_dim=4096,
68
+ hidden_size=3072,
69
+ mlp_ratio=4.0,
70
+ num_heads=24,
71
+ depth=19,
72
+ depth_single_blocks=38,
73
+ axes_dim=[16, 56, 56],
74
+ theta=10_000,
75
+ qkv_bias=True,
76
+ guidance_embed=False,
77
+ ),
78
+ ae_path=os.getenv("AE"),
79
+ ae_params=AutoEncoderParams(
80
+ resolution=256,
81
+ in_channels=3,
82
+ ch=128,
83
+ out_ch=3,
84
+ ch_mult=[1, 2, 4, 4],
85
+ num_res_blocks=2,
86
+ z_channels=16,
87
+ scale_factor=0.3611,
88
+ shift_factor=0.1159,
89
+ ),
90
+ ),
91
+ }
92
+
93
+
94
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
95
+ if len(missing) > 0 and len(unexpected) > 0:
96
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
97
+ print("\n" + "-" * 79 + "\n")
98
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
99
+ elif len(missing) > 0:
100
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
101
+ elif len(unexpected) > 0:
102
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
103
+
104
+
105
+ def load_flow_model(name: str, device: str = "cuda", hf_download: bool = True):
106
+ # Loading Flux
107
+ print("Init model")
108
+ ckpt_path = configs[name].ckpt_path
109
+ if (
110
+ ckpt_path is None
111
+ and configs[name].repo_id is not None
112
+ and configs[name].repo_flow is not None
113
+ and hf_download
114
+ ):
115
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow, local_dir='models')
116
+
117
+ with torch.device(device):
118
+ model = Flux(configs[name].params).to(torch.bfloat16)
119
+
120
+ if ckpt_path is not None:
121
+ print("Loading checkpoint")
122
+ # load_sft doesn't support torch.device
123
+ sd = load_sft(ckpt_path, device=str(device))
124
+ missing, unexpected = model.load_state_dict(sd, strict=False)
125
+ print_load_warning(missing, unexpected)
126
+ return model
127
+
128
+
129
+ def load_t5(device: str = "cuda", max_length: int = 512) -> HFEmbedder:
130
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
131
+ return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
132
+
133
+
134
+ def load_clip(device: str = "cuda") -> HFEmbedder:
135
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
136
+
137
+
138
+ def load_ae(name: str, device: str = "cuda", hf_download: bool = True) -> AutoEncoder:
139
+ ckpt_path = configs[name].ae_path
140
+ if (
141
+ ckpt_path is None
142
+ and configs[name].repo_id is not None
143
+ and configs[name].repo_ae is not None
144
+ and hf_download
145
+ ):
146
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae, local_dir='models')
147
+
148
+ # Loading the autoencoder
149
+ print("Init AE")
150
+ with torch.device(device):
151
+ ae = AutoEncoder(configs[name].ae_params)
152
+
153
+ if ckpt_path is not None:
154
+ sd = load_sft(ckpt_path, device=str(device))
155
+ missing, unexpected = ae.load_state_dict(sd, strict=False)
156
+ print_load_warning(missing, unexpected)
157
+ return ae
158
+
159
+
160
+ class WatermarkEmbedder:
161
+ def __init__(self, watermark):
162
+ self.watermark = watermark
163
+ self.num_bits = len(WATERMARK_BITS)
164
+ self.encoder = WatermarkEncoder()
165
+ self.encoder.set_watermark("bits", self.watermark)
166
+
167
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
168
+ """
169
+ Adds a predefined watermark to the input image
170
+
171
+ Args:
172
+ image: ([N,] B, RGB, H, W) in range [-1, 1]
173
+
174
+ Returns:
175
+ same as input but watermarked
176
+ """
177
+ image = 0.5 * image + 0.5
178
+ squeeze = len(image.shape) == 4
179
+ if squeeze:
180
+ image = image[None, ...]
181
+ n = image.shape[0]
182
+ image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
183
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
184
+ # watermarking libary expects input as cv2 BGR format
185
+ for k in range(image_np.shape[0]):
186
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
187
+ image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
188
+ image.device
189
+ )
190
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
191
+ if squeeze:
192
+ image = image[0]
193
+ image = 2 * image - 1
194
+ return image
195
+
196
+
197
+ # A fixed 48-bit message that was choosen at random
198
+ WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
199
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
200
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
201
+ embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
models/.gitkeep ADDED
File without changes
pulid/attention_processor.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ NUM_ZERO = 0
7
+ ORTHO = False
8
+ ORTHO_v2 = False
9
+
10
+
11
+ class AttnProcessor(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def __call__(
16
+ self,
17
+ attn,
18
+ hidden_states,
19
+ encoder_hidden_states=None,
20
+ attention_mask=None,
21
+ temb=None,
22
+ id_embedding=None,
23
+ id_scale=1.0,
24
+ ):
25
+ residual = hidden_states
26
+
27
+ if attn.spatial_norm is not None:
28
+ hidden_states = attn.spatial_norm(hidden_states, temb)
29
+
30
+ input_ndim = hidden_states.ndim
31
+
32
+ if input_ndim == 4:
33
+ batch_size, channel, height, width = hidden_states.shape
34
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
35
+
36
+ batch_size, sequence_length, _ = (
37
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
38
+ )
39
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
40
+
41
+ if attn.group_norm is not None:
42
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
43
+
44
+ query = attn.to_q(hidden_states)
45
+
46
+ if encoder_hidden_states is None:
47
+ encoder_hidden_states = hidden_states
48
+ elif attn.norm_cross:
49
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
50
+
51
+ key = attn.to_k(encoder_hidden_states)
52
+ value = attn.to_v(encoder_hidden_states)
53
+
54
+ query = attn.head_to_batch_dim(query)
55
+ key = attn.head_to_batch_dim(key)
56
+ value = attn.head_to_batch_dim(value)
57
+
58
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
59
+ hidden_states = torch.bmm(attention_probs, value)
60
+ hidden_states = attn.batch_to_head_dim(hidden_states)
61
+
62
+ # linear proj
63
+ hidden_states = attn.to_out[0](hidden_states)
64
+ # dropout
65
+ hidden_states = attn.to_out[1](hidden_states)
66
+
67
+ if input_ndim == 4:
68
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
69
+
70
+ if attn.residual_connection:
71
+ hidden_states = hidden_states + residual
72
+
73
+ hidden_states = hidden_states / attn.rescale_output_factor
74
+
75
+ return hidden_states
76
+
77
+
78
+ class IDAttnProcessor(nn.Module):
79
+ r"""
80
+ Attention processor for ID-Adapater.
81
+ Args:
82
+ hidden_size (`int`):
83
+ The hidden size of the attention layer.
84
+ cross_attention_dim (`int`):
85
+ The number of channels in the `encoder_hidden_states`.
86
+ scale (`float`, defaults to 1.0):
87
+ the weight scale of image prompt.
88
+ """
89
+
90
+ def __init__(self, hidden_size, cross_attention_dim=None):
91
+ super().__init__()
92
+ self.id_to_k = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
93
+ self.id_to_v = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
94
+
95
+ def __call__(
96
+ self,
97
+ attn,
98
+ hidden_states,
99
+ encoder_hidden_states=None,
100
+ attention_mask=None,
101
+ temb=None,
102
+ id_embedding=None,
103
+ id_scale=1.0,
104
+ ):
105
+ residual = hidden_states
106
+
107
+ if attn.spatial_norm is not None:
108
+ hidden_states = attn.spatial_norm(hidden_states, temb)
109
+
110
+ input_ndim = hidden_states.ndim
111
+
112
+ if input_ndim == 4:
113
+ batch_size, channel, height, width = hidden_states.shape
114
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
115
+
116
+ batch_size, sequence_length, _ = (
117
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
118
+ )
119
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
120
+
121
+ if attn.group_norm is not None:
122
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
123
+
124
+ query = attn.to_q(hidden_states)
125
+
126
+ if encoder_hidden_states is None:
127
+ encoder_hidden_states = hidden_states
128
+ elif attn.norm_cross:
129
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
130
+
131
+ key = attn.to_k(encoder_hidden_states)
132
+ value = attn.to_v(encoder_hidden_states)
133
+
134
+ query = attn.head_to_batch_dim(query)
135
+ key = attn.head_to_batch_dim(key)
136
+ value = attn.head_to_batch_dim(value)
137
+
138
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
139
+ hidden_states = torch.bmm(attention_probs, value)
140
+ hidden_states = attn.batch_to_head_dim(hidden_states)
141
+
142
+ # for id-adapter
143
+ if id_embedding is not None:
144
+ if NUM_ZERO == 0:
145
+ id_key = self.id_to_k(id_embedding)
146
+ id_value = self.id_to_v(id_embedding)
147
+ else:
148
+ zero_tensor = torch.zeros(
149
+ (id_embedding.size(0), NUM_ZERO, id_embedding.size(-1)),
150
+ dtype=id_embedding.dtype,
151
+ device=id_embedding.device,
152
+ )
153
+ id_key = self.id_to_k(torch.cat((id_embedding, zero_tensor), dim=1))
154
+ id_value = self.id_to_v(torch.cat((id_embedding, zero_tensor), dim=1))
155
+
156
+ id_key = attn.head_to_batch_dim(id_key).to(query.dtype)
157
+ id_value = attn.head_to_batch_dim(id_value).to(query.dtype)
158
+
159
+ id_attention_probs = attn.get_attention_scores(query, id_key, None)
160
+ id_hidden_states = torch.bmm(id_attention_probs, id_value)
161
+ id_hidden_states = attn.batch_to_head_dim(id_hidden_states)
162
+
163
+ if not ORTHO:
164
+ hidden_states = hidden_states + id_scale * id_hidden_states
165
+ else:
166
+ projection = (
167
+ torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
168
+ / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
169
+ * hidden_states
170
+ )
171
+ orthogonal = id_hidden_states - projection
172
+ hidden_states = hidden_states + id_scale * orthogonal
173
+
174
+ # linear proj
175
+ hidden_states = attn.to_out[0](hidden_states)
176
+ # dropout
177
+ hidden_states = attn.to_out[1](hidden_states)
178
+
179
+ if input_ndim == 4:
180
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
181
+
182
+ if attn.residual_connection:
183
+ hidden_states = hidden_states + residual
184
+
185
+ hidden_states = hidden_states / attn.rescale_output_factor
186
+
187
+ return hidden_states
188
+
189
+
190
+ class AttnProcessor2_0(nn.Module):
191
+ r"""
192
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
193
+ """
194
+
195
+ def __init__(self):
196
+ super().__init__()
197
+ if not hasattr(F, "scaled_dot_product_attention"):
198
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
199
+
200
+ def __call__(
201
+ self,
202
+ attn,
203
+ hidden_states,
204
+ encoder_hidden_states=None,
205
+ attention_mask=None,
206
+ temb=None,
207
+ id_embedding=None,
208
+ id_scale=1.0,
209
+ ):
210
+ residual = hidden_states
211
+
212
+ if attn.spatial_norm is not None:
213
+ hidden_states = attn.spatial_norm(hidden_states, temb)
214
+
215
+ input_ndim = hidden_states.ndim
216
+
217
+ if input_ndim == 4:
218
+ batch_size, channel, height, width = hidden_states.shape
219
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
220
+
221
+ batch_size, sequence_length, _ = (
222
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
223
+ )
224
+
225
+ if attention_mask is not None:
226
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
227
+ # scaled_dot_product_attention expects attention_mask shape to be
228
+ # (batch, heads, source_length, target_length)
229
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
230
+
231
+ if attn.group_norm is not None:
232
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
233
+
234
+ query = attn.to_q(hidden_states)
235
+
236
+ if encoder_hidden_states is None:
237
+ encoder_hidden_states = hidden_states
238
+ elif attn.norm_cross:
239
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
240
+
241
+ key = attn.to_k(encoder_hidden_states)
242
+ value = attn.to_v(encoder_hidden_states)
243
+
244
+ inner_dim = key.shape[-1]
245
+ head_dim = inner_dim // attn.heads
246
+
247
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
248
+
249
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
250
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
251
+
252
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
253
+ hidden_states = F.scaled_dot_product_attention(
254
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
255
+ )
256
+
257
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
258
+ hidden_states = hidden_states.to(query.dtype)
259
+
260
+ # linear proj
261
+ hidden_states = attn.to_out[0](hidden_states)
262
+ # dropout
263
+ hidden_states = attn.to_out[1](hidden_states)
264
+
265
+ if input_ndim == 4:
266
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
267
+
268
+ if attn.residual_connection:
269
+ hidden_states = hidden_states + residual
270
+
271
+ hidden_states = hidden_states / attn.rescale_output_factor
272
+
273
+ return hidden_states
274
+
275
+
276
+ class IDAttnProcessor2_0(torch.nn.Module):
277
+ r"""
278
+ Attention processor for ID-Adapater for PyTorch 2.0.
279
+ Args:
280
+ hidden_size (`int`):
281
+ The hidden size of the attention layer.
282
+ cross_attention_dim (`int`):
283
+ The number of channels in the `encoder_hidden_states`.
284
+ """
285
+
286
+ def __init__(self, hidden_size, cross_attention_dim=None):
287
+ super().__init__()
288
+ if not hasattr(F, "scaled_dot_product_attention"):
289
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
290
+
291
+ self.id_to_k = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
292
+ self.id_to_v = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
293
+
294
+ def __call__(
295
+ self,
296
+ attn,
297
+ hidden_states,
298
+ encoder_hidden_states=None,
299
+ attention_mask=None,
300
+ temb=None,
301
+ id_embedding=None,
302
+ id_scale=1.0,
303
+ ):
304
+ residual = hidden_states
305
+
306
+ if attn.spatial_norm is not None:
307
+ hidden_states = attn.spatial_norm(hidden_states, temb)
308
+
309
+ input_ndim = hidden_states.ndim
310
+
311
+ if input_ndim == 4:
312
+ batch_size, channel, height, width = hidden_states.shape
313
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
314
+
315
+ batch_size, sequence_length, _ = (
316
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
317
+ )
318
+
319
+ if attention_mask is not None:
320
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
321
+ # scaled_dot_product_attention expects attention_mask shape to be
322
+ # (batch, heads, source_length, target_length)
323
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
324
+
325
+ if attn.group_norm is not None:
326
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
327
+
328
+ query = attn.to_q(hidden_states)
329
+
330
+ if encoder_hidden_states is None:
331
+ encoder_hidden_states = hidden_states
332
+ elif attn.norm_cross:
333
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
334
+
335
+ key = attn.to_k(encoder_hidden_states)
336
+ value = attn.to_v(encoder_hidden_states)
337
+
338
+ inner_dim = key.shape[-1]
339
+ head_dim = inner_dim // attn.heads
340
+
341
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
342
+
343
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
344
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
345
+
346
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
347
+ hidden_states = F.scaled_dot_product_attention(
348
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
349
+ )
350
+
351
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
352
+ hidden_states = hidden_states.to(query.dtype)
353
+
354
+ # for id embedding
355
+ if id_embedding is not None:
356
+ if NUM_ZERO == 0:
357
+ id_key = self.id_to_k(id_embedding).to(query.dtype)
358
+ id_value = self.id_to_v(id_embedding).to(query.dtype)
359
+ else:
360
+ zero_tensor = torch.zeros(
361
+ (id_embedding.size(0), NUM_ZERO, id_embedding.size(-1)),
362
+ dtype=id_embedding.dtype,
363
+ device=id_embedding.device,
364
+ )
365
+ id_key = self.id_to_k(torch.cat((id_embedding, zero_tensor), dim=1)).to(query.dtype)
366
+ id_value = self.id_to_v(torch.cat((id_embedding, zero_tensor), dim=1)).to(query.dtype)
367
+
368
+ id_key = id_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
369
+ id_value = id_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
370
+
371
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
372
+ id_hidden_states = F.scaled_dot_product_attention(
373
+ query, id_key, id_value, attn_mask=None, dropout_p=0.0, is_causal=False
374
+ )
375
+
376
+ id_hidden_states = id_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
377
+ id_hidden_states = id_hidden_states.to(query.dtype)
378
+
379
+ if not ORTHO and not ORTHO_v2:
380
+ hidden_states = hidden_states + id_scale * id_hidden_states
381
+ elif ORTHO_v2:
382
+ orig_dtype = hidden_states.dtype
383
+ hidden_states = hidden_states.to(torch.float32)
384
+ id_hidden_states = id_hidden_states.to(torch.float32)
385
+ attn_map = query @ id_key.transpose(-2, -1)
386
+ attn_mean = attn_map.softmax(dim=-1).mean(dim=1)
387
+ attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True)
388
+ projection = (
389
+ torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
390
+ / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
391
+ * hidden_states
392
+ )
393
+ orthogonal = id_hidden_states + (attn_mean - 1) * projection
394
+ hidden_states = hidden_states + id_scale * orthogonal
395
+ hidden_states = hidden_states.to(orig_dtype)
396
+ else:
397
+ orig_dtype = hidden_states.dtype
398
+ hidden_states = hidden_states.to(torch.float32)
399
+ id_hidden_states = id_hidden_states.to(torch.float32)
400
+ projection = (
401
+ torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
402
+ / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
403
+ * hidden_states
404
+ )
405
+ orthogonal = id_hidden_states - projection
406
+ hidden_states = hidden_states + id_scale * orthogonal
407
+ hidden_states = hidden_states.to(orig_dtype)
408
+
409
+ # linear proj
410
+ hidden_states = attn.to_out[0](hidden_states)
411
+ # dropout
412
+ hidden_states = attn.to_out[1](hidden_states)
413
+
414
+ if input_ndim == 4:
415
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
416
+
417
+ if attn.residual_connection:
418
+ hidden_states = hidden_states + residual
419
+
420
+ hidden_states = hidden_states / attn.rescale_output_factor
421
+
422
+ return hidden_states
pulid/encoders.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class IDEncoder(nn.Module):
6
+ def __init__(self, width=1280, context_dim=2048, num_token=5):
7
+ super().__init__()
8
+ self.num_token = num_token
9
+ self.context_dim = context_dim
10
+ h1 = min((context_dim * num_token) // 4, 1024)
11
+ h2 = min((context_dim * num_token) // 2, 1024)
12
+ self.body = nn.Sequential(
13
+ nn.Linear(width, h1),
14
+ nn.LayerNorm(h1),
15
+ nn.LeakyReLU(),
16
+ nn.Linear(h1, h2),
17
+ nn.LayerNorm(h2),
18
+ nn.LeakyReLU(),
19
+ nn.Linear(h2, context_dim * num_token),
20
+ )
21
+
22
+ for i in range(5):
23
+ setattr(
24
+ self,
25
+ f'mapping_{i}',
26
+ nn.Sequential(
27
+ nn.Linear(1024, 1024),
28
+ nn.LayerNorm(1024),
29
+ nn.LeakyReLU(),
30
+ nn.Linear(1024, 1024),
31
+ nn.LayerNorm(1024),
32
+ nn.LeakyReLU(),
33
+ nn.Linear(1024, context_dim),
34
+ ),
35
+ )
36
+
37
+ setattr(
38
+ self,
39
+ f'mapping_patch_{i}',
40
+ nn.Sequential(
41
+ nn.Linear(1024, 1024),
42
+ nn.LayerNorm(1024),
43
+ nn.LeakyReLU(),
44
+ nn.Linear(1024, 1024),
45
+ nn.LayerNorm(1024),
46
+ nn.LeakyReLU(),
47
+ nn.Linear(1024, context_dim),
48
+ ),
49
+ )
50
+
51
+ def forward(self, x, y):
52
+ # x shape [N, C]
53
+ x = self.body(x)
54
+ x = x.reshape(-1, self.num_token, self.context_dim)
55
+
56
+ hidden_states = ()
57
+ for i, emb in enumerate(y):
58
+ hidden_state = getattr(self, f'mapping_{i}')(emb[:, :1]) + getattr(self, f'mapping_patch_{i}')(
59
+ emb[:, 1:]
60
+ ).mean(dim=1, keepdim=True)
61
+ hidden_states += (hidden_state,)
62
+ hidden_states = torch.cat(hidden_states, dim=1)
63
+
64
+ return torch.cat([x, hidden_states], dim=1)
pulid/encoders_flux.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ # FFN
8
+ def FeedForward(dim, mult=4):
9
+ inner_dim = int(dim * mult)
10
+ return nn.Sequential(
11
+ nn.LayerNorm(dim),
12
+ nn.Linear(dim, inner_dim, bias=False),
13
+ nn.GELU(),
14
+ nn.Linear(inner_dim, dim, bias=False),
15
+ )
16
+
17
+
18
+ def reshape_tensor(x, heads):
19
+ bs, length, width = x.shape
20
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
21
+ x = x.view(bs, length, heads, -1)
22
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
23
+ x = x.transpose(1, 2)
24
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
25
+ x = x.reshape(bs, heads, length, -1)
26
+ return x
27
+
28
+
29
+ class PerceiverAttentionCA(nn.Module):
30
+ def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):
31
+ super().__init__()
32
+ self.scale = dim_head ** -0.5
33
+ self.dim_head = dim_head
34
+ self.heads = heads
35
+ inner_dim = dim_head * heads
36
+
37
+ self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
38
+ self.norm2 = nn.LayerNorm(dim)
39
+
40
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
41
+ self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
42
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
43
+
44
+ def forward(self, x, latents):
45
+ """
46
+ Args:
47
+ x (torch.Tensor): image features
48
+ shape (b, n1, D)
49
+ latent (torch.Tensor): latent features
50
+ shape (b, n2, D)
51
+ """
52
+ x = self.norm1(x)
53
+ latents = self.norm2(latents)
54
+
55
+ b, seq_len, _ = latents.shape
56
+
57
+ q = self.to_q(latents)
58
+ k, v = self.to_kv(x).chunk(2, dim=-1)
59
+
60
+ q = reshape_tensor(q, self.heads)
61
+ k = reshape_tensor(k, self.heads)
62
+ v = reshape_tensor(v, self.heads)
63
+
64
+ # attention
65
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
66
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
67
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
68
+ out = weight @ v
69
+
70
+ out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
71
+
72
+ return self.to_out(out)
73
+
74
+
75
+ class PerceiverAttention(nn.Module):
76
+ def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None):
77
+ super().__init__()
78
+ self.scale = dim_head ** -0.5
79
+ self.dim_head = dim_head
80
+ self.heads = heads
81
+ inner_dim = dim_head * heads
82
+
83
+ self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
84
+ self.norm2 = nn.LayerNorm(dim)
85
+
86
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
87
+ self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
88
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
89
+
90
+ def forward(self, x, latents):
91
+ """
92
+ Args:
93
+ x (torch.Tensor): image features
94
+ shape (b, n1, D)
95
+ latent (torch.Tensor): latent features
96
+ shape (b, n2, D)
97
+ """
98
+ x = self.norm1(x)
99
+ latents = self.norm2(latents)
100
+
101
+ b, seq_len, _ = latents.shape
102
+
103
+ q = self.to_q(latents)
104
+ kv_input = torch.cat((x, latents), dim=-2)
105
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
106
+
107
+ q = reshape_tensor(q, self.heads)
108
+ k = reshape_tensor(k, self.heads)
109
+ v = reshape_tensor(v, self.heads)
110
+
111
+ # attention
112
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
113
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
114
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
115
+ out = weight @ v
116
+
117
+ out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
118
+
119
+ return self.to_out(out)
120
+
121
+
122
+ class IDFormer(nn.Module):
123
+ """
124
+ - perceiver resampler like arch (compared with previous MLP-like arch)
125
+ - we concat id embedding (generated by arcface) and query tokens as latents
126
+ - latents will attend each other and interact with vit features through cross-attention
127
+ - vit features are multi-scaled and inserted into IDFormer in order, currently, each scale corresponds to two
128
+ IDFormer layers
129
+ """
130
+ def __init__(
131
+ self,
132
+ dim=1024,
133
+ depth=10,
134
+ dim_head=64,
135
+ heads=16,
136
+ num_id_token=5,
137
+ num_queries=32,
138
+ output_dim=2048,
139
+ ff_mult=4,
140
+ ):
141
+ super().__init__()
142
+
143
+ self.num_id_token = num_id_token
144
+ self.dim = dim
145
+ self.num_queries = num_queries
146
+ assert depth % 5 == 0
147
+ self.depth = depth // 5
148
+ scale = dim ** -0.5
149
+
150
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) * scale)
151
+ self.proj_out = nn.Parameter(scale * torch.randn(dim, output_dim))
152
+
153
+ self.layers = nn.ModuleList([])
154
+ for _ in range(depth):
155
+ self.layers.append(
156
+ nn.ModuleList(
157
+ [
158
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
159
+ FeedForward(dim=dim, mult=ff_mult),
160
+ ]
161
+ )
162
+ )
163
+
164
+ for i in range(5):
165
+ setattr(
166
+ self,
167
+ f'mapping_{i}',
168
+ nn.Sequential(
169
+ nn.Linear(1024, 1024),
170
+ nn.LayerNorm(1024),
171
+ nn.LeakyReLU(),
172
+ nn.Linear(1024, 1024),
173
+ nn.LayerNorm(1024),
174
+ nn.LeakyReLU(),
175
+ nn.Linear(1024, dim),
176
+ ),
177
+ )
178
+
179
+ self.id_embedding_mapping = nn.Sequential(
180
+ nn.Linear(1280, 1024),
181
+ nn.LayerNorm(1024),
182
+ nn.LeakyReLU(),
183
+ nn.Linear(1024, 1024),
184
+ nn.LayerNorm(1024),
185
+ nn.LeakyReLU(),
186
+ nn.Linear(1024, dim * num_id_token),
187
+ )
188
+
189
+ def forward(self, x, y):
190
+
191
+ latents = self.latents.repeat(x.size(0), 1, 1)
192
+
193
+ x = self.id_embedding_mapping(x)
194
+ x = x.reshape(-1, self.num_id_token, self.dim)
195
+
196
+ latents = torch.cat((latents, x), dim=1)
197
+
198
+ for i in range(5):
199
+ vit_feature = getattr(self, f'mapping_{i}')(y[i])
200
+ ctx_feature = torch.cat((x, vit_feature), dim=1)
201
+ for attn, ff in self.layers[i * self.depth: (i + 1) * self.depth]:
202
+ latents = attn(ctx_feature, latents) + latents
203
+ latents = ff(latents) + latents
204
+
205
+ latents = latents[:, :self.num_queries]
206
+ latents = latents @ self.proj_out
207
+ return latents