Yiwen-ntu commited on
Commit
7262fda
1 Parent(s): a3ef4ce

Upload 47 files

Browse files
Files changed (47) hide show
  1. MeshAnything/miche/LICENSE +674 -0
  2. MeshAnything/miche/encode.py +73 -0
  3. MeshAnything/miche/michelangelo/__init__.py +1 -0
  4. MeshAnything/miche/michelangelo/data/__init__.py +1 -0
  5. MeshAnything/miche/michelangelo/data/templates.json +69 -0
  6. MeshAnything/miche/michelangelo/data/transforms.py +407 -0
  7. MeshAnything/miche/michelangelo/data/utils.py +59 -0
  8. MeshAnything/miche/michelangelo/graphics/__init__.py +1 -0
  9. MeshAnything/miche/michelangelo/graphics/primitives/__init__.py +9 -0
  10. MeshAnything/miche/michelangelo/graphics/primitives/mesh.py +114 -0
  11. MeshAnything/miche/michelangelo/graphics/primitives/volume.py +21 -0
  12. MeshAnything/miche/michelangelo/models/__init__.py +1 -0
  13. MeshAnything/miche/michelangelo/models/asl_diffusion/__init__.py +1 -0
  14. MeshAnything/miche/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py +483 -0
  15. MeshAnything/miche/michelangelo/models/asl_diffusion/asl_udt.py +104 -0
  16. MeshAnything/miche/michelangelo/models/asl_diffusion/base.py +13 -0
  17. MeshAnything/miche/michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py +393 -0
  18. MeshAnything/miche/michelangelo/models/asl_diffusion/inference_utils.py +80 -0
  19. MeshAnything/miche/michelangelo/models/conditional_encoders/__init__.py +3 -0
  20. MeshAnything/miche/michelangelo/models/conditional_encoders/clip.py +89 -0
  21. MeshAnything/miche/michelangelo/models/conditional_encoders/encoder_factory.py +562 -0
  22. MeshAnything/miche/michelangelo/models/modules/__init__.py +3 -0
  23. MeshAnything/miche/michelangelo/models/modules/checkpoint.py +69 -0
  24. MeshAnything/miche/michelangelo/models/modules/diffusion_transformer.py +218 -0
  25. MeshAnything/miche/michelangelo/models/modules/distributions.py +100 -0
  26. MeshAnything/miche/michelangelo/models/modules/embedder.py +213 -0
  27. MeshAnything/miche/michelangelo/models/modules/transformer_blocks.py +286 -0
  28. MeshAnything/miche/michelangelo/models/modules/transformer_vit.py +308 -0
  29. MeshAnything/miche/michelangelo/models/tsal/__init__.py +1 -0
  30. MeshAnything/miche/michelangelo/models/tsal/asl_pl_module.py +395 -0
  31. MeshAnything/miche/michelangelo/models/tsal/clip_asl_module.py +118 -0
  32. MeshAnything/miche/michelangelo/models/tsal/inference_utils.py +80 -0
  33. MeshAnything/miche/michelangelo/models/tsal/loss.py +303 -0
  34. MeshAnything/miche/michelangelo/models/tsal/sal_perceiver.py +423 -0
  35. MeshAnything/miche/michelangelo/models/tsal/sal_pl_module.py +290 -0
  36. MeshAnything/miche/michelangelo/models/tsal/tsal_base.py +120 -0
  37. MeshAnything/miche/michelangelo/utils/__init__.py +3 -0
  38. MeshAnything/miche/michelangelo/utils/eval.py +12 -0
  39. MeshAnything/miche/michelangelo/utils/io.py +47 -0
  40. MeshAnything/miche/michelangelo/utils/misc.py +83 -0
  41. MeshAnything/miche/michelangelo/utils/visualizers/__init__.py +1 -0
  42. MeshAnything/miche/michelangelo/utils/visualizers/color_util.py +43 -0
  43. MeshAnything/miche/michelangelo/utils/visualizers/html_util.py +49 -0
  44. MeshAnything/miche/michelangelo/utils/visualizers/pythreejs_viewer.py +534 -0
  45. MeshAnything/miche/shapevae-256.yaml +46 -0
  46. MeshAnything/models/meshanything.py +223 -0
  47. MeshAnything/models/shape_opt.py +464 -0
MeshAnything/miche/LICENSE ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
MeshAnything/miche/encode.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import argparse
3
+ from omegaconf import OmegaConf
4
+ import numpy as np
5
+ import torch
6
+ from .michelangelo.utils.misc import instantiate_from_config
7
+
8
+ def load_surface(fp):
9
+
10
+ with np.load(fp) as input_pc:
11
+ surface = input_pc['points']
12
+ normal = input_pc['normals']
13
+
14
+ rng = np.random.default_rng()
15
+ ind = rng.choice(surface.shape[0], 4096, replace=False)
16
+ surface = torch.FloatTensor(surface[ind])
17
+ normal = torch.FloatTensor(normal[ind])
18
+
19
+ surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
20
+
21
+ return surface
22
+
23
+ def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000):
24
+
25
+ surface = load_surface(args.pointcloud_path)
26
+ # old_surface = surface.clone()
27
+
28
+ # surface[0,:,0]*=-1
29
+ # surface[0,:,1]*=-1
30
+ surface[0,:,2]*=-1
31
+
32
+ # encoding
33
+ shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True)
34
+ shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents)
35
+
36
+ # decoding
37
+ latents = model.model.shape_model.decode(shape_zq)
38
+ # geometric_func = partial(model.model.shape_model.query_geometry, latents=latents)
39
+
40
+ return 0
41
+
42
+ def load_model(ckpt_path="MeshAnything/miche/shapevae-256.ckpt"):
43
+ model_config = OmegaConf.load("MeshAnything/miche/shapevae-256.yaml")
44
+ # print(model_config)
45
+ if hasattr(model_config, "model"):
46
+ model_config = model_config.model
47
+
48
+ model = instantiate_from_config(model_config, ckpt_path=ckpt_path)
49
+ model = model.cuda()
50
+ model = model.eval()
51
+
52
+ return model
53
+ if __name__ == "__main__":
54
+ '''
55
+ 1. Reconstruct point cloud
56
+ 2. Image-conditioned generation
57
+ 3. Text-conditioned generation
58
+ '''
59
+ parser = argparse.ArgumentParser()
60
+ parser.add_argument("--config_path", type=str, required=True)
61
+ parser.add_argument("--ckpt_path", type=str, required=True)
62
+ parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud')
63
+ parser.add_argument("--image_path", type=str, help='Path to the input image')
64
+ parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.')
65
+ parser.add_argument("--output_dir", type=str, default='./output')
66
+ parser.add_argument("-s", "--seed", type=int, default=0)
67
+ args = parser.parse_args()
68
+
69
+ print(f'-----------------------------------------------------------------------------')
70
+ print(f'>>> Output directory: {args.output_dir}')
71
+ print(f'-----------------------------------------------------------------------------')
72
+
73
+ reconstruction(args, load_model(args))
MeshAnything/miche/michelangelo/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
MeshAnything/miche/michelangelo/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
MeshAnything/miche/michelangelo/data/templates.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "shape": [
3
+ "a point cloud model of {}.",
4
+ "There is a {} in the scene.",
5
+ "There is the {} in the scene.",
6
+ "a photo of a {} in the scene.",
7
+ "a photo of the {} in the scene.",
8
+ "a photo of one {} in the scene.",
9
+ "itap of a {}.",
10
+ "itap of my {}.",
11
+ "itap of the {}.",
12
+ "a photo of a {}.",
13
+ "a photo of my {}.",
14
+ "a photo of the {}.",
15
+ "a photo of one {}.",
16
+ "a photo of many {}.",
17
+ "a good photo of a {}.",
18
+ "a good photo of the {}.",
19
+ "a bad photo of a {}.",
20
+ "a bad photo of the {}.",
21
+ "a photo of a nice {}.",
22
+ "a photo of the nice {}.",
23
+ "a photo of a cool {}.",
24
+ "a photo of the cool {}.",
25
+ "a photo of a weird {}.",
26
+ "a photo of the weird {}.",
27
+ "a photo of a small {}.",
28
+ "a photo of the small {}.",
29
+ "a photo of a large {}.",
30
+ "a photo of the large {}.",
31
+ "a photo of a clean {}.",
32
+ "a photo of the clean {}.",
33
+ "a photo of a dirty {}.",
34
+ "a photo of the dirty {}.",
35
+ "a bright photo of a {}.",
36
+ "a bright photo of the {}.",
37
+ "a dark photo of a {}.",
38
+ "a dark photo of the {}.",
39
+ "a photo of a hard to see {}.",
40
+ "a photo of the hard to see {}.",
41
+ "a low resolution photo of a {}.",
42
+ "a low resolution photo of the {}.",
43
+ "a cropped photo of a {}.",
44
+ "a cropped photo of the {}.",
45
+ "a close-up photo of a {}.",
46
+ "a close-up photo of the {}.",
47
+ "a jpeg corrupted photo of a {}.",
48
+ "a jpeg corrupted photo of the {}.",
49
+ "a blurry photo of a {}.",
50
+ "a blurry photo of the {}.",
51
+ "a pixelated photo of a {}.",
52
+ "a pixelated photo of the {}.",
53
+ "a black and white photo of the {}.",
54
+ "a black and white photo of a {}",
55
+ "a plastic {}.",
56
+ "the plastic {}.",
57
+ "a toy {}.",
58
+ "the toy {}.",
59
+ "a plushie {}.",
60
+ "the plushie {}.",
61
+ "a cartoon {}.",
62
+ "the cartoon {}.",
63
+ "an embroidered {}.",
64
+ "the embroidered {}.",
65
+ "a painting of the {}.",
66
+ "a painting of a {}."
67
+ ]
68
+
69
+ }
MeshAnything/miche/michelangelo/data/transforms.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import time
4
+ import numpy as np
5
+ import warnings
6
+ import random
7
+ from omegaconf.listconfig import ListConfig
8
+ from webdataset import pipelinefilter
9
+ import torch
10
+ import torchvision.transforms.functional as TVF
11
+ from torchvision.transforms import InterpolationMode
12
+ from torchvision.transforms.transforms import _interpolation_modes_from_int
13
+ from typing import Sequence
14
+
15
+ from MeshAnything.miche.michelangelo.utils import instantiate_from_config
16
+
17
+
18
+ def _uid_buffer_pick(buf_dict, rng):
19
+ uid_keys = list(buf_dict.keys())
20
+ selected_uid = rng.choice(uid_keys)
21
+ buf = buf_dict[selected_uid]
22
+
23
+ k = rng.randint(0, len(buf) - 1)
24
+ sample = buf[k]
25
+ buf[k] = buf[-1]
26
+ buf.pop()
27
+
28
+ if len(buf) == 0:
29
+ del buf_dict[selected_uid]
30
+
31
+ return sample
32
+
33
+
34
+ def _add_to_buf_dict(buf_dict, sample):
35
+ key = sample["__key__"]
36
+ uid, uid_sample_id = key.split("_")
37
+ if uid not in buf_dict:
38
+ buf_dict[uid] = []
39
+ buf_dict[uid].append(sample)
40
+
41
+ return buf_dict
42
+
43
+
44
+ def _uid_shuffle(data, bufsize=1000, initial=100, rng=None, handler=None):
45
+ """Shuffle the data in the stream.
46
+
47
+ This uses a buffer of size `bufsize`. Shuffling at
48
+ startup is less random; this is traded off against
49
+ yielding samples quickly.
50
+
51
+ data: iterator
52
+ bufsize: buffer size for shuffling
53
+ returns: iterator
54
+ rng: either random module or random.Random instance
55
+
56
+ """
57
+ if rng is None:
58
+ rng = random.Random(int((os.getpid() + time.time()) * 1e9))
59
+ initial = min(initial, bufsize)
60
+ buf_dict = dict()
61
+ current_samples = 0
62
+ for sample in data:
63
+ _add_to_buf_dict(buf_dict, sample)
64
+ current_samples += 1
65
+
66
+ if current_samples < bufsize:
67
+ try:
68
+ _add_to_buf_dict(buf_dict, next(data)) # skipcq: PYL-R1708
69
+ current_samples += 1
70
+ except StopIteration:
71
+ pass
72
+
73
+ if current_samples >= initial:
74
+ current_samples -= 1
75
+ yield _uid_buffer_pick(buf_dict, rng)
76
+
77
+ while current_samples > 0:
78
+ current_samples -= 1
79
+ yield _uid_buffer_pick(buf_dict, rng)
80
+
81
+
82
+ uid_shuffle = pipelinefilter(_uid_shuffle)
83
+
84
+
85
+ class RandomSample(object):
86
+ def __init__(self,
87
+ num_volume_samples: int = 1024,
88
+ num_near_samples: int = 1024):
89
+
90
+ super().__init__()
91
+
92
+ self.num_volume_samples = num_volume_samples
93
+ self.num_near_samples = num_near_samples
94
+
95
+ def __call__(self, sample):
96
+ rng = np.random.default_rng()
97
+
98
+ # 1. sample surface input
99
+ total_surface = sample["surface"]
100
+ ind = rng.choice(total_surface.shape[0], replace=False)
101
+ surface = total_surface[ind]
102
+
103
+ # 2. sample volume/near geometric points
104
+ vol_points = sample["vol_points"]
105
+ vol_label = sample["vol_label"]
106
+ near_points = sample["near_points"]
107
+ near_label = sample["near_label"]
108
+
109
+ ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False)
110
+ vol_points = vol_points[ind]
111
+ vol_label = vol_label[ind]
112
+ vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1)
113
+
114
+ ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False)
115
+ near_points = near_points[ind]
116
+ near_label = near_label[ind]
117
+ near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1)
118
+
119
+ # concat sampled volume and near points
120
+ geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0)
121
+
122
+ sample = {
123
+ "surface": surface,
124
+ "geo_points": geo_points
125
+ }
126
+
127
+ return sample
128
+
129
+
130
+ class SplitRandomSample(object):
131
+ def __init__(self,
132
+ use_surface_sample: bool = False,
133
+ num_surface_samples: int = 4096,
134
+ num_volume_samples: int = 1024,
135
+ num_near_samples: int = 1024):
136
+
137
+ super().__init__()
138
+
139
+ self.use_surface_sample = use_surface_sample
140
+ self.num_surface_samples = num_surface_samples
141
+ self.num_volume_samples = num_volume_samples
142
+ self.num_near_samples = num_near_samples
143
+
144
+ def __call__(self, sample):
145
+
146
+ rng = np.random.default_rng()
147
+
148
+ # 1. sample surface input
149
+ surface = sample["surface"]
150
+
151
+ if self.use_surface_sample:
152
+ replace = surface.shape[0] < self.num_surface_samples
153
+ ind = rng.choice(surface.shape[0], self.num_surface_samples, replace=replace)
154
+ surface = surface[ind]
155
+
156
+ # 2. sample volume/near geometric points
157
+ vol_points = sample["vol_points"]
158
+ vol_label = sample["vol_label"]
159
+ near_points = sample["near_points"]
160
+ near_label = sample["near_label"]
161
+
162
+ ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False)
163
+ vol_points = vol_points[ind]
164
+ vol_label = vol_label[ind]
165
+ vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1)
166
+
167
+ ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False)
168
+ near_points = near_points[ind]
169
+ near_label = near_label[ind]
170
+ near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1)
171
+
172
+ # concat sampled volume and near points
173
+ geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0)
174
+
175
+ sample = {
176
+ "surface": surface,
177
+ "geo_points": geo_points
178
+ }
179
+
180
+ return sample
181
+
182
+
183
+ class FeatureSelection(object):
184
+
185
+ VALID_SURFACE_FEATURE_DIMS = {
186
+ "none": [0, 1, 2], # xyz
187
+ "watertight_normal": [0, 1, 2, 3, 4, 5], # xyz, normal
188
+ "normal": [0, 1, 2, 6, 7, 8]
189
+ }
190
+
191
+ def __init__(self, surface_feature_type: str):
192
+
193
+ self.surface_feature_type = surface_feature_type
194
+ self.surface_dims = self.VALID_SURFACE_FEATURE_DIMS[surface_feature_type]
195
+
196
+ def __call__(self, sample):
197
+ sample["surface"] = sample["surface"][:, self.surface_dims]
198
+ return sample
199
+
200
+
201
+ class AxisScaleTransform(object):
202
+ def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005):
203
+ assert isinstance(interval, (tuple, list, ListConfig))
204
+ self.interval = interval
205
+ self.min_val = interval[0]
206
+ self.max_val = interval[1]
207
+ self.inter_size = interval[1] - interval[0]
208
+ self.jitter = jitter
209
+ self.jitter_scale = jitter_scale
210
+
211
+ def __call__(self, sample):
212
+
213
+ surface = sample["surface"][..., 0:3]
214
+ geo_points = sample["geo_points"][..., 0:3]
215
+
216
+ scaling = torch.rand(1, 3) * self.inter_size + self.min_val
217
+ # print(scaling)
218
+ surface = surface * scaling
219
+ geo_points = geo_points * scaling
220
+
221
+ scale = (1 / torch.abs(surface).max().item()) * 0.999999
222
+ surface *= scale
223
+ geo_points *= scale
224
+
225
+ if self.jitter:
226
+ surface += self.jitter_scale * torch.randn_like(surface)
227
+ surface.clamp_(min=-1.015, max=1.015)
228
+
229
+ sample["surface"][..., 0:3] = surface
230
+ sample["geo_points"][..., 0:3] = geo_points
231
+
232
+ return sample
233
+
234
+
235
+ class ToTensor(object):
236
+
237
+ def __init__(self, tensor_keys=("surface", "geo_points", "tex_points")):
238
+ self.tensor_keys = tensor_keys
239
+
240
+ def __call__(self, sample):
241
+ for key in self.tensor_keys:
242
+ if key not in sample:
243
+ continue
244
+
245
+ sample[key] = torch.tensor(sample[key], dtype=torch.float32)
246
+
247
+ return sample
248
+
249
+
250
+ class AxisScale(object):
251
+ def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005):
252
+ assert isinstance(interval, (tuple, list, ListConfig))
253
+ self.interval = interval
254
+ self.jitter = jitter
255
+ self.jitter_scale = jitter_scale
256
+
257
+ def __call__(self, surface, *args):
258
+ scaling = torch.rand(1, 3) * 0.5 + 0.75
259
+ # print(scaling)
260
+ surface = surface * scaling
261
+ scale = (1 / torch.abs(surface).max().item()) * 0.999999
262
+ surface *= scale
263
+
264
+ args_outputs = []
265
+ for _arg in args:
266
+ _arg = _arg * scaling * scale
267
+ args_outputs.append(_arg)
268
+
269
+ if self.jitter:
270
+ surface += self.jitter_scale * torch.randn_like(surface)
271
+ surface.clamp_(min=-1, max=1)
272
+
273
+ if len(args) == 0:
274
+ return surface
275
+ else:
276
+ return surface, *args_outputs
277
+
278
+
279
+ class RandomResize(torch.nn.Module):
280
+ """Apply randomly Resize with a given probability."""
281
+
282
+ def __init__(
283
+ self,
284
+ size,
285
+ resize_radio=(0.5, 1),
286
+ allow_resize_interpolations=(InterpolationMode.BICUBIC, InterpolationMode.BILINEAR, InterpolationMode.BILINEAR),
287
+ interpolation=InterpolationMode.BICUBIC,
288
+ max_size=None,
289
+ antialias=None,
290
+ ):
291
+ super().__init__()
292
+ if not isinstance(size, (int, Sequence)):
293
+ raise TypeError(f"Size should be int or sequence. Got {type(size)}")
294
+ if isinstance(size, Sequence) and len(size) not in (1, 2):
295
+ raise ValueError("If size is a sequence, it should have 1 or 2 values")
296
+
297
+ self.size = size
298
+ self.max_size = max_size
299
+ # Backward compatibility with integer value
300
+ if isinstance(interpolation, int):
301
+ warnings.warn(
302
+ "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
303
+ "Please use InterpolationMode enum."
304
+ )
305
+ interpolation = _interpolation_modes_from_int(interpolation)
306
+
307
+ self.interpolation = interpolation
308
+ self.antialias = antialias
309
+
310
+ self.resize_radio = resize_radio
311
+ self.allow_resize_interpolations = allow_resize_interpolations
312
+
313
+ def random_resize_params(self):
314
+ radio = torch.rand(1) * (self.resize_radio[1] - self.resize_radio[0]) + self.resize_radio[0]
315
+
316
+ if isinstance(self.size, int):
317
+ size = int(self.size * radio)
318
+ elif isinstance(self.size, Sequence):
319
+ size = list(self.size)
320
+ size = (int(size[0] * radio), int(size[1] * radio))
321
+ else:
322
+ raise RuntimeError()
323
+
324
+ interpolation = self.allow_resize_interpolations[
325
+ torch.randint(low=0, high=len(self.allow_resize_interpolations), size=(1,))
326
+ ]
327
+ return size, interpolation
328
+
329
+ def forward(self, img):
330
+ size, interpolation = self.random_resize_params()
331
+ img = TVF.resize(img, size, interpolation, self.max_size, self.antialias)
332
+ img = TVF.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
333
+ return img
334
+
335
+ def __repr__(self) -> str:
336
+ detail = f"(size={self.size}, interpolation={self.interpolation.value},"
337
+ detail += f"max_size={self.max_size}, antialias={self.antialias}), resize_radio={self.resize_radio}"
338
+ return f"{self.__class__.__name__}{detail}"
339
+
340
+
341
+ class Compose(object):
342
+ """Composes several transforms together. This transform does not support torchscript.
343
+ Please, see the note below.
344
+
345
+ Args:
346
+ transforms (list of ``Transform`` objects): list of transforms to compose.
347
+
348
+ Example:
349
+ >>> transforms.Compose([
350
+ >>> transforms.CenterCrop(10),
351
+ >>> transforms.ToTensor(),
352
+ >>> ])
353
+
354
+ .. note::
355
+ In order to script the transformations, please use ``torch.nn.Sequential`` as below.
356
+
357
+ >>> transforms = torch.nn.Sequential(
358
+ >>> transforms.CenterCrop(10),
359
+ >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
360
+ >>> )
361
+ >>> scripted_transforms = torch.jit.script(transforms)
362
+
363
+ Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
364
+ `lambda` functions or ``PIL.Image``.
365
+
366
+ """
367
+
368
+ def __init__(self, transforms):
369
+ self.transforms = transforms
370
+
371
+ def __call__(self, *args):
372
+ for t in self.transforms:
373
+ args = t(*args)
374
+ return args
375
+
376
+ def __repr__(self):
377
+ format_string = self.__class__.__name__ + '('
378
+ for t in self.transforms:
379
+ format_string += '\n'
380
+ format_string += ' {0}'.format(t)
381
+ format_string += '\n)'
382
+ return format_string
383
+
384
+
385
+ def identity(*args, **kwargs):
386
+ if len(args) == 1:
387
+ return args[0]
388
+ else:
389
+ return args
390
+
391
+
392
+ def build_transforms(cfg):
393
+
394
+ if cfg is None:
395
+ return identity
396
+
397
+ transforms = []
398
+
399
+ for transform_name, cfg_instance in cfg.items():
400
+ transform_instance = instantiate_from_config(cfg_instance)
401
+ transforms.append(transform_instance)
402
+ print(f"Build transform: {transform_instance}")
403
+
404
+ transforms = Compose(transforms)
405
+
406
+ return transforms
407
+
MeshAnything/miche/michelangelo/data/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+
7
+ def worker_init_fn(_):
8
+ worker_info = torch.utils.data.get_worker_info()
9
+ worker_id = worker_info.id
10
+
11
+ # dataset = worker_info.dataset
12
+ # split_size = dataset.num_records // worker_info.num_workers
13
+ # # reset num_records to the true number to retain reliable length information
14
+ # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
15
+ # current_id = np.random.choice(len(np.random.get_state()[1]), 1)
16
+ # return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
17
+
18
+ return np.random.seed(np.random.get_state()[1][0] + worker_id)
19
+
20
+
21
+ def collation_fn(samples, combine_tensors=True, combine_scalars=True):
22
+ """
23
+
24
+ Args:
25
+ samples (list[dict]):
26
+ combine_tensors:
27
+ combine_scalars:
28
+
29
+ Returns:
30
+
31
+ """
32
+
33
+ result = {}
34
+
35
+ keys = samples[0].keys()
36
+
37
+ for key in keys:
38
+ result[key] = []
39
+
40
+ for sample in samples:
41
+ for key in keys:
42
+ val = sample[key]
43
+ result[key].append(val)
44
+
45
+ for key in keys:
46
+ val_list = result[key]
47
+ if isinstance(val_list[0], (int, float)):
48
+ if combine_scalars:
49
+ result[key] = np.array(result[key])
50
+
51
+ elif isinstance(val_list[0], torch.Tensor):
52
+ if combine_tensors:
53
+ result[key] = torch.stack(val_list)
54
+
55
+ elif isinstance(val_list[0], np.ndarray):
56
+ if combine_tensors:
57
+ result[key] = np.stack(val_list)
58
+
59
+ return result
MeshAnything/miche/michelangelo/graphics/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
MeshAnything/miche/michelangelo/graphics/primitives/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .volume import generate_dense_grid_points
4
+
5
+ from .mesh import (
6
+ MeshOutput,
7
+ save_obj,
8
+ savemeshtes2
9
+ )
MeshAnything/miche/michelangelo/graphics/primitives/mesh.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ import PIL.Image
7
+ from typing import Optional
8
+
9
+ import trimesh
10
+
11
+
12
+ def save_obj(pointnp_px3, facenp_fx3, fname):
13
+ fid = open(fname, "w")
14
+ write_str = ""
15
+ for pidx, p in enumerate(pointnp_px3):
16
+ pp = p
17
+ write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2])
18
+
19
+ for i, f in enumerate(facenp_fx3):
20
+ f1 = f + 1
21
+ write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2])
22
+ fid.write(write_str)
23
+ fid.close()
24
+ return
25
+
26
+
27
+ def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname):
28
+ fol, na = os.path.split(fname)
29
+ na, _ = os.path.splitext(na)
30
+
31
+ matname = "%s/%s.mtl" % (fol, na)
32
+ fid = open(matname, "w")
33
+ fid.write("newmtl material_0\n")
34
+ fid.write("Kd 1 1 1\n")
35
+ fid.write("Ka 0 0 0\n")
36
+ fid.write("Ks 0.4 0.4 0.4\n")
37
+ fid.write("Ns 10\n")
38
+ fid.write("illum 2\n")
39
+ fid.write("map_Kd %s.png\n" % na)
40
+ fid.close()
41
+ ####
42
+
43
+ fid = open(fname, "w")
44
+ fid.write("mtllib %s.mtl\n" % na)
45
+
46
+ for pidx, p in enumerate(pointnp_px3):
47
+ pp = p
48
+ fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
49
+
50
+ for pidx, p in enumerate(tcoords_px2):
51
+ pp = p
52
+ fid.write("vt %f %f\n" % (pp[0], pp[1]))
53
+
54
+ fid.write("usemtl material_0\n")
55
+ for i, f in enumerate(facenp_fx3):
56
+ f1 = f + 1
57
+ f2 = facetex_fx3[i] + 1
58
+ fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
59
+ fid.close()
60
+
61
+ PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save(
62
+ os.path.join(fol, "%s.png" % na))
63
+
64
+ return
65
+
66
+
67
+ class MeshOutput(object):
68
+
69
+ def __init__(self,
70
+ mesh_v: np.ndarray,
71
+ mesh_f: np.ndarray,
72
+ vertex_colors: Optional[np.ndarray] = None,
73
+ uvs: Optional[np.ndarray] = None,
74
+ mesh_tex_idx: Optional[np.ndarray] = None,
75
+ tex_map: Optional[np.ndarray] = None):
76
+
77
+ self.mesh_v = mesh_v
78
+ self.mesh_f = mesh_f
79
+ self.vertex_colors = vertex_colors
80
+ self.uvs = uvs
81
+ self.mesh_tex_idx = mesh_tex_idx
82
+ self.tex_map = tex_map
83
+
84
+ def contain_uv_texture(self):
85
+ return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None)
86
+
87
+ def contain_vertex_colors(self):
88
+ return self.vertex_colors is not None
89
+
90
+ def export(self, fname):
91
+
92
+ if self.contain_uv_texture():
93
+ savemeshtes2(
94
+ self.mesh_v,
95
+ self.uvs,
96
+ self.mesh_f,
97
+ self.mesh_tex_idx,
98
+ self.tex_map,
99
+ fname
100
+ )
101
+
102
+ elif self.contain_vertex_colors():
103
+ mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors)
104
+ mesh_obj.export(fname)
105
+
106
+ else:
107
+ save_obj(
108
+ self.mesh_v,
109
+ self.mesh_f,
110
+ fname
111
+ )
112
+
113
+
114
+
MeshAnything/miche/michelangelo/graphics/primitives/volume.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import numpy as np
4
+
5
+
6
+ def generate_dense_grid_points(bbox_min: np.ndarray,
7
+ bbox_max: np.ndarray,
8
+ octree_depth: int,
9
+ indexing: str = "ij"):
10
+ length = bbox_max - bbox_min
11
+ num_cells = np.exp2(octree_depth)
12
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
13
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
14
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
15
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
16
+ xyz = np.stack((xs, ys, zs), axis=-1)
17
+ xyz = xyz.reshape(-1, 3)
18
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
19
+
20
+ return xyz, grid_size, length
21
+
MeshAnything/miche/michelangelo/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
MeshAnything/miche/michelangelo/models/asl_diffusion/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
MeshAnything/miche/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from omegaconf import DictConfig
4
+ from typing import List, Tuple, Dict, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.optim import lr_scheduler
10
+ import pytorch_lightning as pl
11
+ from pytorch_lightning.utilities import rank_zero_only
12
+
13
+ from einops import rearrange
14
+
15
+ from diffusers.schedulers import (
16
+ DDPMScheduler,
17
+ DDIMScheduler,
18
+ KarrasVeScheduler,
19
+ DPMSolverMultistepScheduler
20
+ )
21
+
22
+ from MeshAnything.miche.michelangelo.utils import instantiate_from_config
23
+ # from MeshAnything.miche.michelangelo.models.tsal.tsal_base import ShapeAsLatentPLModule
24
+ from MeshAnything.miche.michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule
25
+ from MeshAnything.miche.michelangelo.models.asl_diffusion.inference_utils import ddim_sample
26
+
27
+ SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
28
+
29
+
30
+ def disabled_train(self, mode=True):
31
+ """Overwrite model.train with this function to make sure train/eval mode
32
+ does not change anymore."""
33
+ return self
34
+
35
+
36
+ class ASLDiffuser(pl.LightningModule):
37
+ first_stage_model: Optional[AlignedShapeAsLatentPLModule]
38
+ # cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
39
+ model: nn.Module
40
+
41
+ def __init__(self, *,
42
+ first_stage_config,
43
+ denoiser_cfg,
44
+ scheduler_cfg,
45
+ optimizer_cfg,
46
+ loss_cfg,
47
+ first_stage_key: str = "surface",
48
+ cond_stage_key: str = "image",
49
+ cond_stage_trainable: bool = True,
50
+ scale_by_std: bool = False,
51
+ z_scale_factor: float = 1.0,
52
+ ckpt_path: Optional[str] = None,
53
+ ignore_keys: Union[Tuple[str], List[str]] = ()):
54
+
55
+ super().__init__()
56
+
57
+ self.first_stage_key = first_stage_key
58
+ self.cond_stage_key = cond_stage_key
59
+ self.cond_stage_trainable = cond_stage_trainable
60
+
61
+ # 1. initialize first stage.
62
+ # Note: the condition model contained in the first stage model.
63
+ self.first_stage_config = first_stage_config
64
+ self.first_stage_model = None
65
+ # self.instantiate_first_stage(first_stage_config)
66
+
67
+ # 2. initialize conditional stage
68
+ # self.instantiate_cond_stage(cond_stage_config)
69
+ self.cond_stage_model = {
70
+ "image": self.encode_image,
71
+ "image_unconditional_embedding": self.empty_img_cond,
72
+ "text": self.encode_text,
73
+ "text_unconditional_embedding": self.empty_text_cond,
74
+ "surface": self.encode_surface,
75
+ "surface_unconditional_embedding": self.empty_surface_cond,
76
+ }
77
+
78
+ # 3. diffusion model
79
+ self.model = instantiate_from_config(
80
+ denoiser_cfg, device=None, dtype=None
81
+ )
82
+
83
+ self.optimizer_cfg = optimizer_cfg
84
+
85
+ # 4. scheduling strategy
86
+ self.scheduler_cfg = scheduler_cfg
87
+
88
+ self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
89
+ self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
90
+
91
+ # 5. loss configures
92
+ self.loss_cfg = loss_cfg
93
+
94
+ self.scale_by_std = scale_by_std
95
+ if scale_by_std:
96
+ self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
97
+ else:
98
+ self.z_scale_factor = z_scale_factor
99
+
100
+ self.ckpt_path = ckpt_path
101
+ if ckpt_path is not None:
102
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
103
+
104
+ def instantiate_first_stage(self, config):
105
+ model = instantiate_from_config(config)
106
+ self.first_stage_model = model.eval()
107
+ self.first_stage_model.train = disabled_train
108
+ for param in self.first_stage_model.parameters():
109
+ param.requires_grad = False
110
+
111
+ self.first_stage_model = self.first_stage_model.to(self.device)
112
+
113
+ # def instantiate_cond_stage(self, config):
114
+ # if not self.cond_stage_trainable:
115
+ # if config == "__is_first_stage__":
116
+ # print("Using first stage also as cond stage.")
117
+ # self.cond_stage_model = self.first_stage_model
118
+ # elif config == "__is_unconditional__":
119
+ # print(f"Training {self.__class__.__name__} as an unconditional model.")
120
+ # self.cond_stage_model = None
121
+ # # self.be_unconditional = True
122
+ # else:
123
+ # model = instantiate_from_config(config)
124
+ # self.cond_stage_model = model.eval()
125
+ # self.cond_stage_model.train = disabled_train
126
+ # for param in self.cond_stage_model.parameters():
127
+ # param.requires_grad = False
128
+ # else:
129
+ # assert config != "__is_first_stage__"
130
+ # assert config != "__is_unconditional__"
131
+ # model = instantiate_from_config(config)
132
+ # self.cond_stage_model = model
133
+
134
+ def init_from_ckpt(self, path, ignore_keys=()):
135
+ state_dict = torch.load(path, map_location="cpu")["state_dict"]
136
+
137
+ keys = list(state_dict.keys())
138
+ for k in keys:
139
+ for ik in ignore_keys:
140
+ if k.startswith(ik):
141
+ print("Deleting key {} from state_dict.".format(k))
142
+ del state_dict[k]
143
+
144
+ missing, unexpected = self.load_state_dict(state_dict, strict=False)
145
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
146
+ if len(missing) > 0:
147
+ print(f"Missing Keys: {missing}")
148
+ print(f"Unexpected Keys: {unexpected}")
149
+
150
+ @property
151
+ def zero_rank(self):
152
+ if self._trainer:
153
+ zero_rank = self.trainer.local_rank == 0
154
+ else:
155
+ zero_rank = True
156
+
157
+ return zero_rank
158
+
159
+ def configure_optimizers(self) -> Tuple[List, List]:
160
+
161
+ lr = self.learning_rate
162
+
163
+ trainable_parameters = list(self.model.parameters())
164
+ # if the conditional encoder is trainable
165
+
166
+ # if self.cond_stage_trainable:
167
+ # conditioner_params = [p for p in self.cond_stage_model.parameters() if p.requires_grad]
168
+ # trainable_parameters += conditioner_params
169
+ # print(f"number of trainable conditional parameters: {len(conditioner_params)}.")
170
+
171
+ if self.optimizer_cfg is None:
172
+ optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
173
+ schedulers = []
174
+ else:
175
+ optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
176
+ scheduler_func = instantiate_from_config(
177
+ self.optimizer_cfg.scheduler,
178
+ max_decay_steps=self.trainer.max_steps,
179
+ lr_max=lr
180
+ )
181
+ scheduler = {
182
+ "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
183
+ "interval": "step",
184
+ "frequency": 1
185
+ }
186
+ optimizers = [optimizer]
187
+ schedulers = [scheduler]
188
+
189
+ return optimizers, schedulers
190
+
191
+ @torch.no_grad()
192
+ def encode_text(self, text):
193
+
194
+ b = text.shape[0]
195
+ text_tokens = rearrange(text, "b t l -> (b t) l")
196
+ text_embed = self.first_stage_model.model.encode_text_embed(text_tokens)
197
+ text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
198
+ text_embed = text_embed.mean(dim=1)
199
+ text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
200
+
201
+ return text_embed
202
+
203
+ @torch.no_grad()
204
+ def encode_image(self, img):
205
+
206
+ return self.first_stage_model.model.encode_image_embed(img)
207
+
208
+ @torch.no_grad()
209
+ def encode_surface(self, surface):
210
+
211
+ return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False)
212
+
213
+ @torch.no_grad()
214
+ def empty_text_cond(self, cond):
215
+
216
+ return torch.zeros_like(cond, device=cond.device)
217
+
218
+ @torch.no_grad()
219
+ def empty_img_cond(self, cond):
220
+
221
+ return torch.zeros_like(cond, device=cond.device)
222
+
223
+ @torch.no_grad()
224
+ def empty_surface_cond(self, cond):
225
+
226
+ return torch.zeros_like(cond, device=cond.device)
227
+
228
+ @torch.no_grad()
229
+ def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
230
+
231
+ z_q = self.first_stage_model.encode(surface, sample_posterior)
232
+ z_q = self.z_scale_factor * z_q
233
+
234
+ return z_q
235
+
236
+ @torch.no_grad()
237
+ def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
238
+
239
+ z_q = 1. / self.z_scale_factor * z_q
240
+ latents = self.first_stage_model.decode(z_q, **kwargs)
241
+ return latents
242
+
243
+ @rank_zero_only
244
+ @torch.no_grad()
245
+ def on_train_batch_start(self, batch, batch_idx):
246
+ # only for very first batch
247
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
248
+ and batch_idx == 0 and self.ckpt_path is None:
249
+ # set rescale weight to 1./std of encodings
250
+ print("### USING STD-RESCALING ###")
251
+
252
+ z_q = self.encode_first_stage(batch[self.first_stage_key])
253
+ z = z_q.detach()
254
+
255
+ del self.z_scale_factor
256
+ self.register_buffer("z_scale_factor", 1. / z.flatten().std())
257
+ print(f"setting self.z_scale_factor to {self.z_scale_factor}")
258
+
259
+ print("### USING STD-RESCALING ###")
260
+
261
+ def compute_loss(self, model_outputs, split):
262
+ """
263
+
264
+ Args:
265
+ model_outputs (dict):
266
+ - x_0:
267
+ - noise:
268
+ - noise_prior:
269
+ - noise_pred:
270
+ - noise_pred_prior:
271
+
272
+ split (str):
273
+
274
+ Returns:
275
+
276
+ """
277
+
278
+ pred = model_outputs["pred"]
279
+
280
+ if self.noise_scheduler.prediction_type == "epsilon":
281
+ target = model_outputs["noise"]
282
+ elif self.noise_scheduler.prediction_type == "sample":
283
+ target = model_outputs["x_0"]
284
+ else:
285
+ raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
286
+
287
+ if self.loss_cfg.loss_type == "l1":
288
+ simple = F.l1_loss(pred, target, reduction="mean")
289
+ elif self.loss_cfg.loss_type in ["mse", "l2"]:
290
+ simple = F.mse_loss(pred, target, reduction="mean")
291
+ else:
292
+ raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
293
+
294
+ total_loss = simple
295
+
296
+ loss_dict = {
297
+ f"{split}/total_loss": total_loss.clone().detach(),
298
+ f"{split}/simple": simple.detach(),
299
+ }
300
+
301
+ return total_loss, loss_dict
302
+
303
+ def forward(self, batch):
304
+ """
305
+
306
+ Args:
307
+ batch:
308
+
309
+ Returns:
310
+
311
+ """
312
+
313
+ if self.first_stage_model is None:
314
+ self.instantiate_first_stage(self.first_stage_config)
315
+
316
+ latents = self.encode_first_stage(batch[self.first_stage_key])
317
+
318
+ # conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
319
+
320
+ conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1)
321
+
322
+ mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1
323
+ conditions = conditions * mask.to(conditions)
324
+
325
+ # Sample noise that we"ll add to the latents
326
+ # [batch_size, n_token, latent_dim]
327
+ noise = torch.randn_like(latents)
328
+ bs = latents.shape[0]
329
+ # Sample a random timestep for each motion
330
+ timesteps = torch.randint(
331
+ 0,
332
+ self.noise_scheduler.config.num_train_timesteps,
333
+ (bs,),
334
+ device=latents.device,
335
+ )
336
+ timesteps = timesteps.long()
337
+ # Add noise to the latents according to the noise magnitude at each timestep
338
+ noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
339
+
340
+ # diffusion model forward
341
+ noise_pred = self.model(noisy_z, timesteps, conditions)
342
+
343
+ diffusion_outputs = {
344
+ "x_0": noisy_z,
345
+ "noise": noise,
346
+ "pred": noise_pred
347
+ }
348
+
349
+ return diffusion_outputs
350
+
351
+ def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
352
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
353
+ """
354
+
355
+ Args:
356
+ batch (dict): the batch sample, and it contains:
357
+ - surface (torch.FloatTensor):
358
+ - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
359
+ - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
360
+ - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
361
+ - text (list of str):
362
+
363
+ batch_idx (int):
364
+
365
+ optimizer_idx (int):
366
+
367
+ Returns:
368
+ loss (torch.FloatTensor):
369
+
370
+ """
371
+
372
+ diffusion_outputs = self(batch)
373
+
374
+ loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
375
+ self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
376
+
377
+ return loss
378
+
379
+ def validation_step(self, batch: Dict[str, torch.FloatTensor],
380
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
381
+ """
382
+
383
+ Args:
384
+ batch (dict): the batch sample, and it contains:
385
+ - surface_pc (torch.FloatTensor): [n_pts, 4]
386
+ - surface_feats (torch.FloatTensor): [n_pts, c]
387
+ - text (list of str):
388
+
389
+ batch_idx (int):
390
+
391
+ optimizer_idx (int):
392
+
393
+ Returns:
394
+ loss (torch.FloatTensor):
395
+
396
+ """
397
+
398
+ diffusion_outputs = self(batch)
399
+
400
+ loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
401
+ self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
402
+
403
+ return loss
404
+
405
+ @torch.no_grad()
406
+ def sample(self,
407
+ batch: Dict[str, Union[torch.FloatTensor, List[str]]],
408
+ sample_times: int = 1,
409
+ steps: Optional[int] = None,
410
+ guidance_scale: Optional[float] = None,
411
+ eta: float = 0.0,
412
+ return_intermediates: bool = False, **kwargs):
413
+
414
+ if self.first_stage_model is None:
415
+ self.instantiate_first_stage(self.first_stage_config)
416
+
417
+ if steps is None:
418
+ steps = self.scheduler_cfg.num_inference_steps
419
+
420
+ if guidance_scale is None:
421
+ guidance_scale = self.scheduler_cfg.guidance_scale
422
+ do_classifier_free_guidance = guidance_scale > 0
423
+
424
+ # conditional encode
425
+ xc = batch[self.cond_stage_key]
426
+ # cond = self.cond_stage_model[self.cond_stage_key](xc)
427
+ cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1)
428
+
429
+ if do_classifier_free_guidance:
430
+ """
431
+ Note: There are two kinds of uncond for text.
432
+ 1: using "" as uncond text; (in SAL diffusion)
433
+ 2: zeros_like(cond) as uncond text; (in MDM)
434
+ """
435
+ # un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
436
+ un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond)
437
+ # un_cond = torch.zeros_like(cond, device=cond.device)
438
+ cond = torch.cat([un_cond, cond], dim=0)
439
+
440
+ outputs = []
441
+ latents = None
442
+
443
+ if not return_intermediates:
444
+ for _ in range(sample_times):
445
+ sample_loop = ddim_sample(
446
+ self.denoise_scheduler,
447
+ self.model,
448
+ shape=self.first_stage_model.latent_shape,
449
+ cond=cond,
450
+ steps=steps,
451
+ guidance_scale=guidance_scale,
452
+ do_classifier_free_guidance=do_classifier_free_guidance,
453
+ device=self.device,
454
+ eta=eta,
455
+ disable_prog=not self.zero_rank
456
+ )
457
+ for sample, t in sample_loop:
458
+ latents = sample
459
+ outputs.append(self.decode_first_stage(latents, **kwargs))
460
+ else:
461
+
462
+ sample_loop = ddim_sample(
463
+ self.denoise_scheduler,
464
+ self.model,
465
+ shape=self.first_stage_model.latent_shape,
466
+ cond=cond,
467
+ steps=steps,
468
+ guidance_scale=guidance_scale,
469
+ do_classifier_free_guidance=do_classifier_free_guidance,
470
+ device=self.device,
471
+ eta=eta,
472
+ disable_prog=not self.zero_rank
473
+ )
474
+
475
+ iter_size = steps // sample_times
476
+ i = 0
477
+ for sample, t in sample_loop:
478
+ latents = sample
479
+ if i % iter_size == 0 or i == steps - 1:
480
+ outputs.append(self.decode_first_stage(latents, **kwargs))
481
+ i += 1
482
+
483
+ return outputs
MeshAnything/miche/michelangelo/models/asl_diffusion/asl_udt.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Optional
6
+ from diffusers.models.embeddings import Timesteps
7
+ import math
8
+
9
+ from MeshAnything.miche.michelangelo.models.modules.transformer_blocks import MLP
10
+ from MeshAnything.miche.michelangelo.models.modules.diffusion_transformer import UNetDiffusionTransformer
11
+
12
+
13
+ class ConditionalASLUDTDenoiser(nn.Module):
14
+
15
+ def __init__(self, *,
16
+ device: Optional[torch.device],
17
+ dtype: Optional[torch.dtype],
18
+ input_channels: int,
19
+ output_channels: int,
20
+ n_ctx: int,
21
+ width: int,
22
+ layers: int,
23
+ heads: int,
24
+ context_dim: int,
25
+ context_ln: bool = True,
26
+ skip_ln: bool = False,
27
+ init_scale: float = 0.25,
28
+ flip_sin_to_cos: bool = False,
29
+ use_checkpoint: bool = False):
30
+ super().__init__()
31
+
32
+ self.use_checkpoint = use_checkpoint
33
+
34
+ init_scale = init_scale * math.sqrt(1.0 / width)
35
+
36
+ self.backbone = UNetDiffusionTransformer(
37
+ device=device,
38
+ dtype=dtype,
39
+ n_ctx=n_ctx,
40
+ width=width,
41
+ layers=layers,
42
+ heads=heads,
43
+ skip_ln=skip_ln,
44
+ init_scale=init_scale,
45
+ use_checkpoint=use_checkpoint
46
+ )
47
+ self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
48
+ self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
49
+ self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
50
+
51
+ # timestep embedding
52
+ self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0)
53
+ self.time_proj = MLP(
54
+ device=device, dtype=dtype, width=width, init_scale=init_scale
55
+ )
56
+
57
+ self.context_embed = nn.Sequential(
58
+ nn.LayerNorm(context_dim, device=device, dtype=dtype),
59
+ nn.Linear(context_dim, width, device=device, dtype=dtype),
60
+ )
61
+
62
+ if context_ln:
63
+ self.context_embed = nn.Sequential(
64
+ nn.LayerNorm(context_dim, device=device, dtype=dtype),
65
+ nn.Linear(context_dim, width, device=device, dtype=dtype),
66
+ )
67
+ else:
68
+ self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype)
69
+
70
+ def forward(self,
71
+ model_input: torch.FloatTensor,
72
+ timestep: torch.LongTensor,
73
+ context: torch.FloatTensor):
74
+
75
+ r"""
76
+ Args:
77
+ model_input (torch.FloatTensor): [bs, n_data, c]
78
+ timestep (torch.LongTensor): [bs,]
79
+ context (torch.FloatTensor): [bs, context_tokens, c]
80
+
81
+ Returns:
82
+ sample (torch.FloatTensor): [bs, n_data, c]
83
+
84
+ """
85
+
86
+ _, n_data, _ = model_input.shape
87
+
88
+ # 1. time
89
+ t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1)
90
+
91
+ # 2. conditions projector
92
+ context = self.context_embed(context)
93
+
94
+ # 3. denoiser
95
+ x = self.input_proj(model_input)
96
+ x = torch.cat([t_emb, context, x], dim=1)
97
+ x = self.backbone(x)
98
+ x = self.ln_post(x)
99
+ x = x[:, -n_data:]
100
+ sample = self.output_proj(x)
101
+
102
+ return sample
103
+
104
+
MeshAnything/miche/michelangelo/models/asl_diffusion/base.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class BaseDenoiser(nn.Module):
8
+
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def forward(self, x, t, context):
13
+ raise NotImplementedError
MeshAnything/miche/michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from omegaconf import DictConfig
4
+ from typing import List, Tuple, Dict, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.optim import lr_scheduler
10
+ import pytorch_lightning as pl
11
+ from pytorch_lightning.utilities import rank_zero_only
12
+
13
+ from diffusers.schedulers import (
14
+ DDPMScheduler,
15
+ DDIMScheduler,
16
+ KarrasVeScheduler,
17
+ DPMSolverMultistepScheduler
18
+ )
19
+
20
+ from MeshAnything.miche.michelangelo.utils import instantiate_from_config
21
+ from MeshAnything.miche.michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule
22
+ from MeshAnything.miche.michelangelo.models.asl_diffusion.inference_utils import ddim_sample
23
+
24
+ SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
25
+
26
+
27
+ def disabled_train(self, mode=True):
28
+ """Overwrite model.train with this function to make sure train/eval mode
29
+ does not change anymore."""
30
+ return self
31
+
32
+
33
+ class ClipASLDiffuser(pl.LightningModule):
34
+ first_stage_model: Optional[AlignedShapeAsLatentPLModule]
35
+ cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
36
+ model: nn.Module
37
+
38
+ def __init__(self, *,
39
+ first_stage_config,
40
+ cond_stage_config,
41
+ denoiser_cfg,
42
+ scheduler_cfg,
43
+ optimizer_cfg,
44
+ loss_cfg,
45
+ first_stage_key: str = "surface",
46
+ cond_stage_key: str = "image",
47
+ scale_by_std: bool = False,
48
+ z_scale_factor: float = 1.0,
49
+ ckpt_path: Optional[str] = None,
50
+ ignore_keys: Union[Tuple[str], List[str]] = ()):
51
+
52
+ super().__init__()
53
+
54
+ self.first_stage_key = first_stage_key
55
+ self.cond_stage_key = cond_stage_key
56
+
57
+ # 1. lazy initialize first stage
58
+ self.instantiate_first_stage(first_stage_config)
59
+
60
+ # 2. initialize conditional stage
61
+ self.instantiate_cond_stage(cond_stage_config)
62
+
63
+ # 3. diffusion model
64
+ self.model = instantiate_from_config(
65
+ denoiser_cfg, device=None, dtype=None
66
+ )
67
+
68
+ self.optimizer_cfg = optimizer_cfg
69
+
70
+ # 4. scheduling strategy
71
+ self.scheduler_cfg = scheduler_cfg
72
+
73
+ self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
74
+ self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
75
+
76
+ # 5. loss configures
77
+ self.loss_cfg = loss_cfg
78
+
79
+ self.scale_by_std = scale_by_std
80
+ if scale_by_std:
81
+ self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
82
+ else:
83
+ self.z_scale_factor = z_scale_factor
84
+
85
+ self.ckpt_path = ckpt_path
86
+ if ckpt_path is not None:
87
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
88
+
89
+ def instantiate_non_trainable_model(self, config):
90
+ model = instantiate_from_config(config)
91
+ model = model.eval()
92
+ model.train = disabled_train
93
+ for param in model.parameters():
94
+ param.requires_grad = False
95
+
96
+ return model
97
+
98
+ def instantiate_first_stage(self, first_stage_config):
99
+ self.first_stage_model = self.instantiate_non_trainable_model(first_stage_config)
100
+ self.first_stage_model.set_shape_model_only()
101
+
102
+ def instantiate_cond_stage(self, cond_stage_config):
103
+ self.cond_stage_model = self.instantiate_non_trainable_model(cond_stage_config)
104
+
105
+ def init_from_ckpt(self, path, ignore_keys=()):
106
+ state_dict = torch.load(path, map_location="cpu")["state_dict"]
107
+
108
+ keys = list(state_dict.keys())
109
+ for k in keys:
110
+ for ik in ignore_keys:
111
+ if k.startswith(ik):
112
+ print("Deleting key {} from state_dict.".format(k))
113
+ del state_dict[k]
114
+
115
+ missing, unexpected = self.load_state_dict(state_dict, strict=False)
116
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
117
+ if len(missing) > 0:
118
+ print(f"Missing Keys: {missing}")
119
+ print(f"Unexpected Keys: {unexpected}")
120
+
121
+ @property
122
+ def zero_rank(self):
123
+ if self._trainer:
124
+ zero_rank = self.trainer.local_rank == 0
125
+ else:
126
+ zero_rank = True
127
+
128
+ return zero_rank
129
+
130
+ def configure_optimizers(self) -> Tuple[List, List]:
131
+
132
+ lr = self.learning_rate
133
+
134
+ trainable_parameters = list(self.model.parameters())
135
+ if self.optimizer_cfg is None:
136
+ optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
137
+ schedulers = []
138
+ else:
139
+ optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
140
+ scheduler_func = instantiate_from_config(
141
+ self.optimizer_cfg.scheduler,
142
+ max_decay_steps=self.trainer.max_steps,
143
+ lr_max=lr
144
+ )
145
+ scheduler = {
146
+ "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
147
+ "interval": "step",
148
+ "frequency": 1
149
+ }
150
+ optimizers = [optimizer]
151
+ schedulers = [scheduler]
152
+
153
+ return optimizers, schedulers
154
+
155
+ @torch.no_grad()
156
+ def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
157
+
158
+ z_q = self.first_stage_model.encode(surface, sample_posterior)
159
+ z_q = self.z_scale_factor * z_q
160
+
161
+ return z_q
162
+
163
+ @torch.no_grad()
164
+ def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
165
+
166
+ z_q = 1. / self.z_scale_factor * z_q
167
+ latents = self.first_stage_model.decode(z_q, **kwargs)
168
+ return latents
169
+
170
+ @rank_zero_only
171
+ @torch.no_grad()
172
+ def on_train_batch_start(self, batch, batch_idx):
173
+ # only for very first batch
174
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
175
+ and batch_idx == 0 and self.ckpt_path is None:
176
+ # set rescale weight to 1./std of encodings
177
+ print("### USING STD-RESCALING ###")
178
+
179
+ z_q = self.encode_first_stage(batch[self.first_stage_key])
180
+ z = z_q.detach()
181
+
182
+ del self.z_scale_factor
183
+ self.register_buffer("z_scale_factor", 1. / z.flatten().std())
184
+ print(f"setting self.z_scale_factor to {self.z_scale_factor}")
185
+
186
+ print("### USING STD-RESCALING ###")
187
+
188
+ def compute_loss(self, model_outputs, split):
189
+ """
190
+
191
+ Args:
192
+ model_outputs (dict):
193
+ - x_0:
194
+ - noise:
195
+ - noise_prior:
196
+ - noise_pred:
197
+ - noise_pred_prior:
198
+
199
+ split (str):
200
+
201
+ Returns:
202
+
203
+ """
204
+
205
+ pred = model_outputs["pred"]
206
+
207
+ if self.noise_scheduler.prediction_type == "epsilon":
208
+ target = model_outputs["noise"]
209
+ elif self.noise_scheduler.prediction_type == "sample":
210
+ target = model_outputs["x_0"]
211
+ else:
212
+ raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
213
+
214
+ if self.loss_cfg.loss_type == "l1":
215
+ simple = F.l1_loss(pred, target, reduction="mean")
216
+ elif self.loss_cfg.loss_type in ["mse", "l2"]:
217
+ simple = F.mse_loss(pred, target, reduction="mean")
218
+ else:
219
+ raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
220
+
221
+ total_loss = simple
222
+
223
+ loss_dict = {
224
+ f"{split}/total_loss": total_loss.clone().detach(),
225
+ f"{split}/simple": simple.detach(),
226
+ }
227
+
228
+ return total_loss, loss_dict
229
+
230
+ def forward(self, batch):
231
+ """
232
+
233
+ Args:
234
+ batch:
235
+
236
+ Returns:
237
+
238
+ """
239
+
240
+ latents = self.encode_first_stage(batch[self.first_stage_key])
241
+ conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
242
+
243
+ # Sample noise that we"ll add to the latents
244
+ # [batch_size, n_token, latent_dim]
245
+ noise = torch.randn_like(latents)
246
+ bs = latents.shape[0]
247
+ # Sample a random timestep for each motion
248
+ timesteps = torch.randint(
249
+ 0,
250
+ self.noise_scheduler.config.num_train_timesteps,
251
+ (bs,),
252
+ device=latents.device,
253
+ )
254
+ timesteps = timesteps.long()
255
+ # Add noise to the latents according to the noise magnitude at each timestep
256
+ noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
257
+
258
+ # diffusion model forward
259
+ noise_pred = self.model(noisy_z, timesteps, conditions)
260
+
261
+ diffusion_outputs = {
262
+ "x_0": noisy_z,
263
+ "noise": noise,
264
+ "pred": noise_pred
265
+ }
266
+
267
+ return diffusion_outputs
268
+
269
+ def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
270
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
271
+ """
272
+
273
+ Args:
274
+ batch (dict): the batch sample, and it contains:
275
+ - surface (torch.FloatTensor):
276
+ - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
277
+ - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
278
+ - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
279
+ - text (list of str):
280
+
281
+ batch_idx (int):
282
+
283
+ optimizer_idx (int):
284
+
285
+ Returns:
286
+ loss (torch.FloatTensor):
287
+
288
+ """
289
+
290
+ diffusion_outputs = self(batch)
291
+
292
+ loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
293
+ self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
294
+
295
+ return loss
296
+
297
+ def validation_step(self, batch: Dict[str, torch.FloatTensor],
298
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
299
+ """
300
+
301
+ Args:
302
+ batch (dict): the batch sample, and it contains:
303
+ - surface_pc (torch.FloatTensor): [n_pts, 4]
304
+ - surface_feats (torch.FloatTensor): [n_pts, c]
305
+ - text (list of str):
306
+
307
+ batch_idx (int):
308
+
309
+ optimizer_idx (int):
310
+
311
+ Returns:
312
+ loss (torch.FloatTensor):
313
+
314
+ """
315
+
316
+ diffusion_outputs = self(batch)
317
+
318
+ loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
319
+ self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
320
+
321
+ return loss
322
+
323
+ @torch.no_grad()
324
+ def sample(self,
325
+ batch: Dict[str, Union[torch.FloatTensor, List[str]]],
326
+ sample_times: int = 1,
327
+ steps: Optional[int] = None,
328
+ guidance_scale: Optional[float] = None,
329
+ eta: float = 0.0,
330
+ return_intermediates: bool = False, **kwargs):
331
+
332
+ if steps is None:
333
+ steps = self.scheduler_cfg.num_inference_steps
334
+
335
+ if guidance_scale is None:
336
+ guidance_scale = self.scheduler_cfg.guidance_scale
337
+ do_classifier_free_guidance = guidance_scale > 0
338
+
339
+ # conditional encode
340
+ xc = batch[self.cond_stage_key]
341
+
342
+ # print(self.first_stage_model.device, self.cond_stage_model.device, self.device)
343
+
344
+ cond = self.cond_stage_model(xc)
345
+
346
+ if do_classifier_free_guidance:
347
+ un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
348
+ cond = torch.cat([un_cond, cond], dim=0)
349
+
350
+ outputs = []
351
+ latents = None
352
+
353
+ if not return_intermediates:
354
+ for _ in range(sample_times):
355
+ sample_loop = ddim_sample(
356
+ self.denoise_scheduler,
357
+ self.model,
358
+ shape=self.first_stage_model.latent_shape,
359
+ cond=cond,
360
+ steps=steps,
361
+ guidance_scale=guidance_scale,
362
+ do_classifier_free_guidance=do_classifier_free_guidance,
363
+ device=self.device,
364
+ eta=eta,
365
+ disable_prog=not self.zero_rank
366
+ )
367
+ for sample, t in sample_loop:
368
+ latents = sample
369
+ outputs.append(self.decode_first_stage(latents, **kwargs))
370
+ else:
371
+
372
+ sample_loop = ddim_sample(
373
+ self.denoise_scheduler,
374
+ self.model,
375
+ shape=self.first_stage_model.latent_shape,
376
+ cond=cond,
377
+ steps=steps,
378
+ guidance_scale=guidance_scale,
379
+ do_classifier_free_guidance=do_classifier_free_guidance,
380
+ device=self.device,
381
+ eta=eta,
382
+ disable_prog=not self.zero_rank
383
+ )
384
+
385
+ iter_size = steps // sample_times
386
+ i = 0
387
+ for sample, t in sample_loop:
388
+ latents = sample
389
+ if i % iter_size == 0 or i == steps - 1:
390
+ outputs.append(self.decode_first_stage(latents, **kwargs))
391
+ i += 1
392
+
393
+ return outputs
MeshAnything/miche/michelangelo/models/asl_diffusion/inference_utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from tqdm import tqdm
5
+ from typing import Tuple, List, Union, Optional
6
+ from diffusers.schedulers import DDIMScheduler
7
+
8
+
9
+ __all__ = ["ddim_sample"]
10
+
11
+
12
+ def ddim_sample(ddim_scheduler: DDIMScheduler,
13
+ diffusion_model: torch.nn.Module,
14
+ shape: Union[List[int], Tuple[int]],
15
+ cond: torch.FloatTensor,
16
+ steps: int,
17
+ eta: float = 0.0,
18
+ guidance_scale: float = 3.0,
19
+ do_classifier_free_guidance: bool = True,
20
+ generator: Optional[torch.Generator] = None,
21
+ device: torch.device = "cuda:0",
22
+ disable_prog: bool = True):
23
+
24
+ assert steps > 0, f"{steps} must > 0."
25
+
26
+ # init latents
27
+ bsz = cond.shape[0]
28
+ if do_classifier_free_guidance:
29
+ bsz = bsz // 2
30
+
31
+ latents = torch.randn(
32
+ (bsz, *shape),
33
+ generator=generator,
34
+ device=cond.device,
35
+ dtype=cond.dtype,
36
+ )
37
+ # scale the initial noise by the standard deviation required by the scheduler
38
+ latents = latents * ddim_scheduler.init_noise_sigma
39
+ # set timesteps
40
+ ddim_scheduler.set_timesteps(steps)
41
+ timesteps = ddim_scheduler.timesteps.to(device)
42
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
43
+ # eta (η) is only used with the DDIMScheduler, and between [0, 1]
44
+ extra_step_kwargs = {
45
+ "eta": eta,
46
+ "generator": generator
47
+ }
48
+
49
+ # reverse
50
+ for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)):
51
+ # expand the latents if we are doing classifier free guidance
52
+ latent_model_input = (
53
+ torch.cat([latents] * 2)
54
+ if do_classifier_free_guidance
55
+ else latents
56
+ )
57
+ # latent_model_input = scheduler.scale_model_input(latent_model_input, t)
58
+ # predict the noise residual
59
+ timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
60
+ timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
61
+ noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond)
62
+
63
+ # perform guidance
64
+ if do_classifier_free_guidance:
65
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
66
+ noise_pred = noise_pred_uncond + guidance_scale * (
67
+ noise_pred_text - noise_pred_uncond
68
+ )
69
+ # text_embeddings_for_guidance = encoder_hidden_states.chunk(
70
+ # 2)[1] if do_classifier_free_guidance else encoder_hidden_states
71
+ # compute the previous noisy sample x_t -> x_t-1
72
+ latents = ddim_scheduler.step(
73
+ noise_pred, t, latents, **extra_step_kwargs
74
+ ).prev_sample
75
+
76
+ yield latents, t
77
+
78
+
79
+ def karra_sample():
80
+ pass
MeshAnything/miche/michelangelo/models/conditional_encoders/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .clip import CLIPEncoder
MeshAnything/miche/michelangelo/models/conditional_encoders/clip.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from dataclasses import dataclass
7
+ from torchvision.transforms import Normalize
8
+ from transformers import CLIPModel, CLIPTokenizer
9
+ from transformers.utils import ModelOutput
10
+ from typing import Iterable, Optional, Union, List
11
+
12
+
13
+ ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
14
+
15
+
16
+ @dataclass
17
+ class CLIPEmbedOutput(ModelOutput):
18
+ last_hidden_state: torch.FloatTensor = None
19
+ pooler_output: torch.FloatTensor = None
20
+ embeds: torch.FloatTensor = None
21
+
22
+
23
+ class CLIPEncoder(torch.nn.Module):
24
+
25
+ def __init__(self, model_path="openai/clip-vit-base-patch32"):
26
+
27
+ super().__init__()
28
+
29
+ # Load the CLIP model and processor
30
+ self.model: CLIPModel = CLIPModel.from_pretrained(model_path)
31
+ self.tokenizer = CLIPTokenizer.from_pretrained(model_path)
32
+ self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
33
+
34
+ self.model.training = False
35
+ for p in self.model.parameters():
36
+ p.requires_grad = False
37
+
38
+ @torch.no_grad()
39
+ def encode_image(self, images: Iterable[Optional[ImageType]]):
40
+ pixel_values = self.image_preprocess(images)
41
+
42
+ vision_outputs = self.model.vision_model(pixel_values=pixel_values)
43
+
44
+ pooler_output = vision_outputs[1] # pooled_output
45
+ image_features = self.model.visual_projection(pooler_output)
46
+
47
+ visual_embeds = CLIPEmbedOutput(
48
+ last_hidden_state=vision_outputs.last_hidden_state,
49
+ pooler_output=pooler_output,
50
+ embeds=image_features
51
+ )
52
+
53
+ return visual_embeds
54
+
55
+ @torch.no_grad()
56
+ def encode_text(self, texts: List[str]):
57
+ text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt")
58
+
59
+ text_outputs = self.model.text_model(input_ids=text_inputs)
60
+
61
+ pooler_output = text_outputs[1] # pooled_output
62
+ text_features = self.model.text_projection(pooler_output)
63
+
64
+ text_embeds = CLIPEmbedOutput(
65
+ last_hidden_state=text_outputs.last_hidden_state,
66
+ pooler_output=pooler_output,
67
+ embeds=text_features
68
+ )
69
+
70
+ return text_embeds
71
+
72
+ def forward(self,
73
+ images: Iterable[Optional[ImageType]],
74
+ texts: List[str]):
75
+
76
+ visual_embeds = self.encode_image(images)
77
+ text_embeds = self.encode_text(texts)
78
+
79
+ return visual_embeds, text_embeds
80
+
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
MeshAnything/miche/michelangelo/models/conditional_encoders/encoder_factory.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import transforms
7
+ from transformers import CLIPModel, CLIPTokenizer
8
+ from collections import OrderedDict
9
+
10
+ from MeshAnything.miche.michelangelo.data.transforms import RandomResize
11
+
12
+
13
+ class AbstractEncoder(nn.Module):
14
+ embedding_dim: int
15
+
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ def encode(self, *args, **kwargs):
20
+ raise NotImplementedError
21
+
22
+
23
+ class ClassEmbedder(nn.Module):
24
+ def __init__(self, embed_dim, n_classes=1000, key="class"):
25
+ super().__init__()
26
+ self.key = key
27
+ self.embedding = nn.Embedding(n_classes, embed_dim)
28
+
29
+ def forward(self, batch, key=None):
30
+ if key is None:
31
+ key = self.key
32
+ # this is for use in crossattn
33
+ c = batch[key][:, None]
34
+ c = self.embedding(c)
35
+ return c
36
+
37
+
38
+ class FrozenCLIPTextEmbedder(AbstractEncoder):
39
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
40
+
41
+ def __init__(
42
+ self,
43
+ version="openai/clip-vit-large-patch14",
44
+ tokenizer_version=None,
45
+ device="cuda",
46
+ max_length=77,
47
+ zero_embedding_radio: float = 0.1,
48
+ ):
49
+ super().__init__()
50
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
51
+
52
+ self.device = device
53
+ self.max_length = max_length
54
+ self.zero_embedding_radio = zero_embedding_radio
55
+
56
+ self.clip_dict = OrderedDict()
57
+ self.clip_name = os.path.split(version)[-1]
58
+
59
+ transformer = CLIPModel.from_pretrained(version).text_model
60
+
61
+ for param in transformer.parameters():
62
+ param.requires_grad = False
63
+ self.clip_dict[self.clip_name] = transformer
64
+
65
+ self._move_flag = False
66
+
67
+ @property
68
+ def clip(self):
69
+ return self.clip_dict[self.clip_name]
70
+
71
+ def move(self):
72
+ if self._move_flag:
73
+ return
74
+
75
+ self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
76
+ self._move_flag = True
77
+
78
+ def unconditional_embedding(self, batch_size):
79
+ empty_text = [""] * batch_size
80
+ empty_z = self.forward(empty_text)
81
+ return empty_z
82
+
83
+ def forward(self, text):
84
+ self.move()
85
+
86
+ batch_encoding = self.tokenizer(
87
+ text,
88
+ truncation=True,
89
+ max_length=self.max_length,
90
+ return_length=True,
91
+ return_overflowing_tokens=False,
92
+ padding="max_length",
93
+ return_tensors="pt",
94
+ )
95
+
96
+ tokens = batch_encoding["input_ids"].to(self.device)
97
+ outputs = self.clip(input_ids=tokens)
98
+
99
+ z = outputs.last_hidden_state
100
+ return z
101
+
102
+ def encode(self, text):
103
+ batch_size = len(text)
104
+ batch_mask = torch.rand((batch_size,))
105
+ for i in range(batch_size):
106
+ if batch_mask[i] < self.zero_embedding_radio:
107
+ text[i] = ""
108
+
109
+ return self(text)
110
+
111
+ class FrozenAlignedCLIPTextEmbedder(AbstractEncoder):
112
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
113
+
114
+ def __init__(
115
+ self,
116
+ version="openai/clip-vit-large-patch14",
117
+ tokenizer_version=None,
118
+ device="cuda",
119
+ max_length=77,
120
+ zero_embedding_radio: float = 0.1,
121
+ ):
122
+ super().__init__()
123
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
124
+
125
+ self.device = device
126
+ self.max_length = max_length
127
+ self.zero_embedding_radio = zero_embedding_radio
128
+
129
+ self.clip_dict = OrderedDict()
130
+ self.clip_name = os.path.split(version)[-1]
131
+
132
+ transformer = CLIPModel.from_pretrained(version).text_model
133
+
134
+ for param in transformer.parameters():
135
+ param.requires_grad = False
136
+ self.clip_dict[self.clip_name] = transformer
137
+
138
+ self._move_flag = False
139
+
140
+ @property
141
+ def clip(self):
142
+ return self.clip_dict[self.clip_name]
143
+
144
+ def move(self):
145
+ if self._move_flag:
146
+ return
147
+
148
+ self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
149
+ self._move_flag = True
150
+
151
+ def unconditional_embedding(self, batch_size):
152
+ empty_text = [""] * batch_size
153
+ empty_z = self.forward(empty_text)
154
+ return empty_z
155
+
156
+ def forward(self, text):
157
+ self.move()
158
+
159
+ batch_encoding = self.tokenizer(
160
+ text,
161
+ truncation=True,
162
+ max_length=self.max_length,
163
+ return_length=True,
164
+ return_overflowing_tokens=False,
165
+ padding="max_length",
166
+ return_tensors="pt",
167
+ )
168
+
169
+ tokens = batch_encoding["input_ids"].to(self.device)
170
+ outputs = self.clip(input_ids=tokens)
171
+
172
+ z = outputs.last_hidden_state
173
+ return z
174
+
175
+ def encode(self, text):
176
+ batch_size = len(text)
177
+ batch_mask = torch.rand((batch_size,))
178
+ for i in range(batch_size):
179
+ if batch_mask[i] < self.zero_embedding_radio:
180
+ text[i] = ""
181
+
182
+ return self(text)
183
+
184
+
185
+ class FrozenCLIPImageEmbedder(AbstractEncoder):
186
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
187
+
188
+ def __init__(
189
+ self,
190
+ version="openai/clip-vit-large-patch14",
191
+ device="cuda",
192
+ zero_embedding_radio=0.1,
193
+ normalize_embedding=True,
194
+ num_projection_vector=0,
195
+ linear_mapping_bias=True,
196
+ reverse_visual_projection=False,
197
+ ):
198
+ super().__init__()
199
+
200
+ self.device = device
201
+
202
+ self.clip_dict = OrderedDict()
203
+ self.clip_name = os.path.split(version)[-1]
204
+
205
+ clip_model = CLIPModel.from_pretrained(version)
206
+ clip_model.text_model = None
207
+ clip_model.text_projection = None
208
+ clip_model = clip_model.eval()
209
+ for param in self.parameters():
210
+ param.requires_grad = False
211
+ self.clip_dict[self.clip_name] = clip_model
212
+
213
+ self.transform = transforms.Compose(
214
+ [
215
+ transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
216
+ transforms.CenterCrop(224), # crop a (224, 224) square
217
+ transforms.Normalize(
218
+ mean=[0.48145466, 0.4578275, 0.40821073],
219
+ std=[0.26862954, 0.26130258, 0.27577711],
220
+ ),
221
+ ]
222
+ )
223
+ self.zero_embedding_radio = zero_embedding_radio
224
+
225
+ self.num_projection_vector = num_projection_vector
226
+ self.reverse_visual_projection = reverse_visual_projection
227
+ self.normalize_embedding = normalize_embedding
228
+
229
+ embedding_dim = (
230
+ clip_model.visual_projection.in_features
231
+ if reverse_visual_projection
232
+ else clip_model.visual_projection.out_features
233
+ )
234
+ self.embedding_dim = embedding_dim
235
+ if self.num_projection_vector > 0:
236
+ self.projection = nn.Linear(
237
+ embedding_dim,
238
+ clip_model.visual_projection.out_features * num_projection_vector,
239
+ bias=linear_mapping_bias,
240
+ )
241
+ nn.init.normal_(self.projection.weight, std=embedding_dim ** -0.5)
242
+
243
+ self._move_flag = False
244
+
245
+ @property
246
+ def clip(self):
247
+ return self.clip_dict[self.clip_name]
248
+
249
+ def unconditional_embedding(self, batch_size):
250
+ zero = torch.zeros(
251
+ batch_size,
252
+ 1,
253
+ self.embedding_dim,
254
+ device=self.device,
255
+ dtype=self.clip.visual_projection.weight.dtype,
256
+ )
257
+ if self.num_projection_vector > 0:
258
+ zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
259
+ return zero
260
+
261
+ def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
262
+ if value_range is not None:
263
+ low, high = value_range
264
+ image = (image - low) / (high - low)
265
+
266
+ image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
267
+
268
+ if self.reverse_visual_projection:
269
+ z = self.clip.vision_model(self.transform(image))[1]
270
+ else:
271
+ z = self.clip.get_image_features(self.transform(image))
272
+
273
+ if self.normalize_embedding:
274
+ z = z / z.norm(dim=-1, keepdim=True)
275
+ if z.ndim == 2:
276
+ z = z.unsqueeze(dim=-2)
277
+
278
+ if zero_embedding_radio > 0:
279
+ mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) < zero_embedding_radio
280
+ z = z * mask.to(z)
281
+
282
+ if self.num_projection_vector > 0:
283
+ z = self.projection(z).view(len(image), self.num_projection_vector, -1)
284
+
285
+ return z
286
+
287
+ def move(self):
288
+ if self._move_flag:
289
+ return
290
+
291
+ self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
292
+ self._move_flag = True
293
+
294
+ def encode(self, image):
295
+ self.move()
296
+ return self(image, zero_embedding_radio=self.zero_embedding_radio)
297
+
298
+
299
+ class FrozenCLIPImageGridEmbedder(AbstractEncoder):
300
+
301
+ def __init__(
302
+ self,
303
+ version="openai/clip-vit-large-patch14",
304
+ device="cuda",
305
+ zero_embedding_radio=0.1,
306
+ ):
307
+ super().__init__()
308
+
309
+ self.device = device
310
+
311
+ self.clip_dict = OrderedDict()
312
+ self.clip_name = os.path.split(version)[-1]
313
+
314
+ clip_model: CLIPModel = CLIPModel.from_pretrained(version)
315
+ clip_model.text_model = None
316
+ clip_model.text_projection = None
317
+ clip_model = clip_model.eval()
318
+ for param in self.parameters():
319
+ param.requires_grad = False
320
+ self.clip_dict[self.clip_name] = clip_model
321
+
322
+ self.transform = transforms.Compose(
323
+ [
324
+ transforms.Resize(224, transforms.InterpolationMode.BILINEAR, antialias=True),
325
+ transforms.CenterCrop(224), # crop a (224, 224) square
326
+ transforms.Normalize(
327
+ mean=[0.48145466, 0.4578275, 0.40821073],
328
+ std=[0.26862954, 0.26130258, 0.27577711],
329
+ ),
330
+ ]
331
+ )
332
+ self.zero_embedding_radio = zero_embedding_radio
333
+ self.embedding_dim = clip_model.vision_embed_dim
334
+
335
+ self._move_flag = False
336
+
337
+ @property
338
+ def clip(self):
339
+ return self.clip_dict[self.clip_name]
340
+
341
+ def move(self):
342
+ if self._move_flag:
343
+ return
344
+
345
+ self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
346
+ self._move_flag = True
347
+
348
+ def unconditional_embedding(self, batch_size):
349
+ zero = torch.zeros(
350
+ batch_size,
351
+ self.clip.vision_model.embeddings.num_positions,
352
+ self.embedding_dim,
353
+ device=self.device,
354
+ dtype=self.clip.visual_projection.weight.dtype,
355
+ )
356
+ return zero
357
+
358
+ def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
359
+ self.move()
360
+
361
+ if value_range is not None:
362
+ low, high = value_range
363
+ image = (image - low) / (high - low)
364
+
365
+ image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
366
+
367
+ z = self.clip.vision_model(self.transform(image)).last_hidden_state
368
+
369
+ if zero_embedding_radio > 0:
370
+ mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
371
+ z = z * mask.to(z)
372
+
373
+ return z
374
+
375
+ def encode(self, image):
376
+ return self(image, zero_embedding_radio=self.zero_embedding_radio)
377
+
378
+
379
+ class MoECLIPImageEncoder(nn.Module):
380
+ def __init__(
381
+ self,
382
+ versions,
383
+ hidden_state_dim,
384
+ num_projection_vector=8,
385
+ zero_embedding_radio=0.1,
386
+ device="cuda",
387
+ precision="fp16",
388
+ normalize=False,
389
+ clip_max=0,
390
+ transform_type="base",
391
+ argument_p=0.2,
392
+ ):
393
+ super().__init__()
394
+
395
+ self.device = torch.device(device)
396
+ self.hidden_state_dim = hidden_state_dim
397
+ self.zero_embedding_radio = zero_embedding_radio
398
+ self.num_projection_vector = num_projection_vector
399
+ self.dtype = dict(fp16=torch.float16, fp32=torch.float32, bf16=torch.bfloat16)[precision]
400
+ self.normalize = normalize
401
+ self.clip_max = clip_max
402
+
403
+ if transform_type == "base":
404
+ self.transform = transforms.Compose(
405
+ [
406
+ transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
407
+ transforms.CenterCrop(224), # crop a (224, 224) square
408
+ transforms.Normalize(
409
+ mean=[0.48145466, 0.4578275, 0.40821073],
410
+ std=[0.26862954, 0.26130258, 0.27577711],
411
+ ),
412
+ ]
413
+ )
414
+ elif transform_type == "crop_blur_resize":
415
+ self.transform = transforms.Compose(
416
+ [
417
+ transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
418
+ transforms.CenterCrop(224), # crop a (224, 224) square
419
+ transforms.RandomApply(
420
+ transforms=[
421
+ transforms.RandomResizedCrop(
422
+ size=224,
423
+ scale=(0.8, 1.0),
424
+ ratio=(0.99, 1.01),
425
+ interpolation=transforms.InterpolationMode.BICUBIC,
426
+ ),
427
+ ],
428
+ p=argument_p,
429
+ ),
430
+ transforms.RandomApply(
431
+ transforms=[
432
+ transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 5)),
433
+ ],
434
+ p=argument_p,
435
+ ),
436
+ transforms.RandomApply(
437
+ transforms=[
438
+ RandomResize(size=224, resize_radio=(0.2, 1)),
439
+ ],
440
+ p=argument_p,
441
+ ),
442
+ transforms.Normalize(
443
+ mean=[0.48145466, 0.4578275, 0.40821073],
444
+ std=[0.26862954, 0.26130258, 0.27577711],
445
+ ),
446
+ ]
447
+ )
448
+ else:
449
+ raise ValueError(f"invalid {transform_type=}")
450
+
451
+ if isinstance(versions, str):
452
+ versions = (versions,)
453
+
454
+ # 如果直接把clips定位为当前类的子module,1. 会在保存ckp时存无用的多个权重。 2. pl会调用to,导致layer_norm的权重也被转换成fp16
455
+ clips = OrderedDict()
456
+
457
+ for v in versions:
458
+ # 因为clips不是子module,直接指定device="cuda"会错误地导致clip模型权重都被放到cuda:0上。
459
+ clips[v], _ = clip.load(name=v, device="cpu", jit=False, download_root=None)
460
+ delattr(clips[v], "transformer")
461
+ clips[v].eval()
462
+ clips[v].requires_grad_(False)
463
+
464
+ self.clips_hidden_dim = sum(clips[v].ln_final.weight.size(0) for v in clips)
465
+
466
+ if self.num_projection_vector == 0:
467
+ self.projection = nn.Identity()
468
+ else:
469
+ self.projection = nn.Linear(self.clips_hidden_dim, hidden_state_dim * self.num_projection_vector, bias=True)
470
+ self.projection.to(dtype=self.dtype)
471
+ nn.init.normal_(self.projection.weight, std=self.clips_hidden_dim ** -0.5)
472
+
473
+ self.clips = clips
474
+
475
+ self._move_flag = False
476
+
477
+ def move(self):
478
+ if self._move_flag:
479
+ return
480
+
481
+ def convert_weights(model: nn.Module):
482
+ """Convert applicable model parameters to fp16"""
483
+
484
+ def _convert_weights_to_fp16(l):
485
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
486
+ l.weight.data = l.weight.data.type(self.dtype)
487
+ if l.bias is not None:
488
+ l.bias.data = l.bias.data.type(self.dtype)
489
+
490
+ if isinstance(l, nn.MultiheadAttention):
491
+ for attr in [
492
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
493
+ "in_proj_bias",
494
+ "bias_k",
495
+ "bias_v",
496
+ ]:
497
+ tensor = getattr(l, attr)
498
+ if tensor is not None:
499
+ tensor.data = tensor.data.type(self.dtype)
500
+
501
+ for name in ["text_projection", "proj"]:
502
+ if hasattr(l, name):
503
+ attr = getattr(l, name)
504
+ if attr is not None:
505
+ attr.data = attr.data.type(self.dtype)
506
+
507
+ model.apply(_convert_weights_to_fp16)
508
+
509
+ for k in self.clips:
510
+ self.clips[k].to(self.device)
511
+ convert_weights(self.clips[k]) # fp32 -> self.dtype
512
+ self._move_flag = True
513
+
514
+ def unconditional_embedding(self, batch_size=None):
515
+ zero = torch.zeros(
516
+ batch_size,
517
+ self.clips_hidden_dim,
518
+ device=self.device,
519
+ dtype=self.dtype,
520
+ )
521
+ if self.num_projection_vector > 0:
522
+ zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
523
+ return zero
524
+
525
+ def convert_embedding(self, z):
526
+ if self.num_projection_vector > 0:
527
+ z = self.projection(z.type(self.projection.weight.dtype)).view(len(z), self.num_projection_vector, -1)
528
+ return z
529
+
530
+ def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
531
+ if value_range is not None:
532
+ low, high = value_range
533
+ image = (image - low) / (high - low)
534
+
535
+ image = self.transform(image)
536
+
537
+ with torch.no_grad():
538
+ embs = []
539
+ for v in self.clips:
540
+ x = self.clips[v].encode_image(image)
541
+ if self.normalize:
542
+ x = x / x.norm(p=2, dim=-1, keepdim=True) * (x.size(-1) ** 0.5)
543
+ # clip_max only works with normalization
544
+ if self.clip_max > 0:
545
+ x = x.clamp(-self.clip_max, self.clip_max)
546
+ embs.append(x)
547
+
548
+ z = torch.cat(embs, dim=-1)
549
+ if self.normalize:
550
+ z /= z.size(-1) ** 0.5
551
+
552
+ if zero_embedding_radio > 0:
553
+ mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
554
+ z = z + mask.to(z)
555
+
556
+ if self.num_projection_vector > 0:
557
+ z = self.projection(z).view(len(image), self.num_projection_vector, -1)
558
+ return z
559
+
560
+ def encode(self, image):
561
+ self.move()
562
+ return self(image, zero_embedding_radio=self.zero_embedding_radio)
MeshAnything/miche/michelangelo/models/modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .checkpoint import checkpoint
MeshAnything/miche/michelangelo/models/modules/checkpoint.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
4
+ """
5
+
6
+ import torch
7
+ from typing import Callable, Iterable, Sequence, Union
8
+
9
+
10
+ def checkpoint(
11
+ func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
12
+ inputs: Sequence[torch.Tensor],
13
+ params: Iterable[torch.Tensor],
14
+ flag: bool,
15
+ use_deepspeed: bool = False
16
+ ):
17
+ """
18
+ Evaluate a function without caching intermediate activations, allowing for
19
+ reduced memory at the expense of extra compute in the backward pass.
20
+ :param func: the function to evaluate.
21
+ :param inputs: the argument sequence to pass to `func`.
22
+ :param params: a sequence of parameters `func` depends on but does not
23
+ explicitly take as arguments.
24
+ :param flag: if False, disable gradient checkpointing.
25
+ :param use_deepspeed: if True, use deepspeed
26
+ """
27
+ if flag:
28
+ if use_deepspeed:
29
+ import deepspeed
30
+ return deepspeed.checkpointing.checkpoint(func, *inputs)
31
+
32
+ args = tuple(inputs) + tuple(params)
33
+ return CheckpointFunction.apply(func, len(inputs), *args)
34
+ else:
35
+ return func(*inputs)
36
+
37
+
38
+ class CheckpointFunction(torch.autograd.Function):
39
+ @staticmethod
40
+ @torch.cuda.amp.custom_fwd
41
+ def forward(ctx, run_function, length, *args):
42
+ ctx.run_function = run_function
43
+ ctx.input_tensors = list(args[:length])
44
+ ctx.input_params = list(args[length:])
45
+
46
+ with torch.no_grad():
47
+ output_tensors = ctx.run_function(*ctx.input_tensors)
48
+ return output_tensors
49
+
50
+ @staticmethod
51
+ @torch.cuda.amp.custom_bwd
52
+ def backward(ctx, *output_grads):
53
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
54
+ with torch.enable_grad():
55
+ # Fixes a bug where the first op in run_function modifies the
56
+ # Tensor storage in place, which is not allowed for detach()'d
57
+ # Tensors.
58
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
59
+ output_tensors = ctx.run_function(*shallow_copies)
60
+ input_grads = torch.autograd.grad(
61
+ output_tensors,
62
+ ctx.input_tensors + ctx.input_params,
63
+ output_grads,
64
+ allow_unused=True,
65
+ )
66
+ del ctx.input_tensors
67
+ del ctx.input_params
68
+ del output_tensors
69
+ return (None, None) + input_grads
MeshAnything/miche/michelangelo/models/modules/diffusion_transformer.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Optional
7
+
8
+ from MeshAnything.miche.michelangelo.models.modules.checkpoint import checkpoint
9
+ from MeshAnything.miche.michelangelo.models.modules.transformer_blocks import (
10
+ init_linear,
11
+ MLP,
12
+ MultiheadCrossAttention,
13
+ MultiheadAttention,
14
+ ResidualAttentionBlock
15
+ )
16
+
17
+
18
+ class AdaLayerNorm(nn.Module):
19
+ def __init__(self,
20
+ device: torch.device,
21
+ dtype: torch.dtype,
22
+ width: int):
23
+
24
+ super().__init__()
25
+
26
+ self.silu = nn.SiLU(inplace=True)
27
+ self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype)
28
+ self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype)
29
+
30
+ def forward(self, x, timestep):
31
+ emb = self.linear(timestep)
32
+ scale, shift = torch.chunk(emb, 2, dim=2)
33
+ x = self.layernorm(x) * (1 + scale) + shift
34
+ return x
35
+
36
+
37
+ class DitBlock(nn.Module):
38
+ def __init__(
39
+ self,
40
+ *,
41
+ device: torch.device,
42
+ dtype: torch.dtype,
43
+ n_ctx: int,
44
+ width: int,
45
+ heads: int,
46
+ context_dim: int,
47
+ qkv_bias: bool = False,
48
+ init_scale: float = 1.0,
49
+ use_checkpoint: bool = False
50
+ ):
51
+ super().__init__()
52
+
53
+ self.use_checkpoint = use_checkpoint
54
+
55
+ self.attn = MultiheadAttention(
56
+ device=device,
57
+ dtype=dtype,
58
+ n_ctx=n_ctx,
59
+ width=width,
60
+ heads=heads,
61
+ init_scale=init_scale,
62
+ qkv_bias=qkv_bias
63
+ )
64
+ self.ln_1 = AdaLayerNorm(device, dtype, width)
65
+
66
+ if context_dim is not None:
67
+ self.ln_2 = AdaLayerNorm(device, dtype, width)
68
+ self.cross_attn = MultiheadCrossAttention(
69
+ device=device,
70
+ dtype=dtype,
71
+ width=width,
72
+ heads=heads,
73
+ data_width=context_dim,
74
+ init_scale=init_scale,
75
+ qkv_bias=qkv_bias
76
+ )
77
+
78
+ self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
79
+ self.ln_3 = AdaLayerNorm(device, dtype, width)
80
+
81
+ def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
82
+ return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint)
83
+
84
+ def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
85
+ x = x + self.attn(self.ln_1(x, t))
86
+ if context is not None:
87
+ x = x + self.cross_attn(self.ln_2(x, t), context)
88
+ x = x + self.mlp(self.ln_3(x, t))
89
+ return x
90
+
91
+
92
+ class DiT(nn.Module):
93
+ def __init__(
94
+ self,
95
+ *,
96
+ device: Optional[torch.device],
97
+ dtype: Optional[torch.dtype],
98
+ n_ctx: int,
99
+ width: int,
100
+ layers: int,
101
+ heads: int,
102
+ context_dim: int,
103
+ init_scale: float = 0.25,
104
+ qkv_bias: bool = False,
105
+ use_checkpoint: bool = False
106
+ ):
107
+ super().__init__()
108
+ self.n_ctx = n_ctx
109
+ self.width = width
110
+ self.layers = layers
111
+
112
+ self.resblocks = nn.ModuleList(
113
+ [
114
+ DitBlock(
115
+ device=device,
116
+ dtype=dtype,
117
+ n_ctx=n_ctx,
118
+ width=width,
119
+ heads=heads,
120
+ context_dim=context_dim,
121
+ qkv_bias=qkv_bias,
122
+ init_scale=init_scale,
123
+ use_checkpoint=use_checkpoint
124
+ )
125
+ for _ in range(layers)
126
+ ]
127
+ )
128
+
129
+ def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
130
+ for block in self.resblocks:
131
+ x = block(x, t, context)
132
+ return x
133
+
134
+
135
+ class UNetDiffusionTransformer(nn.Module):
136
+ def __init__(
137
+ self,
138
+ *,
139
+ device: Optional[torch.device],
140
+ dtype: Optional[torch.dtype],
141
+ n_ctx: int,
142
+ width: int,
143
+ layers: int,
144
+ heads: int,
145
+ init_scale: float = 0.25,
146
+ qkv_bias: bool = False,
147
+ skip_ln: bool = False,
148
+ use_checkpoint: bool = False
149
+ ):
150
+ super().__init__()
151
+
152
+ self.n_ctx = n_ctx
153
+ self.width = width
154
+ self.layers = layers
155
+
156
+ self.encoder = nn.ModuleList()
157
+ for _ in range(layers):
158
+ resblock = ResidualAttentionBlock(
159
+ device=device,
160
+ dtype=dtype,
161
+ n_ctx=n_ctx,
162
+ width=width,
163
+ heads=heads,
164
+ init_scale=init_scale,
165
+ qkv_bias=qkv_bias,
166
+ use_checkpoint=use_checkpoint
167
+ )
168
+ self.encoder.append(resblock)
169
+
170
+ self.middle_block = ResidualAttentionBlock(
171
+ device=device,
172
+ dtype=dtype,
173
+ n_ctx=n_ctx,
174
+ width=width,
175
+ heads=heads,
176
+ init_scale=init_scale,
177
+ qkv_bias=qkv_bias,
178
+ use_checkpoint=use_checkpoint
179
+ )
180
+
181
+ self.decoder = nn.ModuleList()
182
+ for _ in range(layers):
183
+ resblock = ResidualAttentionBlock(
184
+ device=device,
185
+ dtype=dtype,
186
+ n_ctx=n_ctx,
187
+ width=width,
188
+ heads=heads,
189
+ init_scale=init_scale,
190
+ qkv_bias=qkv_bias,
191
+ use_checkpoint=use_checkpoint
192
+ )
193
+ linear = nn.Linear(width * 2, width, device=device, dtype=dtype)
194
+ init_linear(linear, init_scale)
195
+
196
+ layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None
197
+
198
+ self.decoder.append(nn.ModuleList([resblock, linear, layer_norm]))
199
+
200
+ def forward(self, x: torch.Tensor):
201
+
202
+ enc_outputs = []
203
+ for block in self.encoder:
204
+ x = block(x)
205
+ enc_outputs.append(x)
206
+
207
+ x = self.middle_block(x)
208
+
209
+ for i, (resblock, linear, layer_norm) in enumerate(self.decoder):
210
+ x = torch.cat([enc_outputs.pop(), x], dim=-1)
211
+ x = linear(x)
212
+
213
+ if layer_norm is not None:
214
+ x = layer_norm(x)
215
+
216
+ x = resblock(x)
217
+
218
+ return x
MeshAnything/miche/michelangelo/models/modules/distributions.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Union, List
4
+
5
+
6
+ class AbstractDistribution(object):
7
+ def sample(self):
8
+ raise NotImplementedError()
9
+
10
+ def mode(self):
11
+ raise NotImplementedError()
12
+
13
+
14
+ class DiracDistribution(AbstractDistribution):
15
+ def __init__(self, value):
16
+ self.value = value
17
+
18
+ def sample(self):
19
+ return self.value
20
+
21
+ def mode(self):
22
+ return self.value
23
+
24
+
25
+ class DiagonalGaussianDistribution(object):
26
+ def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
27
+ self.feat_dim = feat_dim
28
+ self.parameters = parameters
29
+
30
+ if isinstance(parameters, list):
31
+ self.mean = parameters[0]
32
+ self.logvar = parameters[1]
33
+ else:
34
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
35
+
36
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
37
+ self.deterministic = deterministic
38
+ self.std = torch.exp(0.5 * self.logvar)
39
+ self.var = torch.exp(self.logvar)
40
+ if self.deterministic:
41
+ self.var = self.std = torch.zeros_like(self.mean)
42
+
43
+ def sample(self):
44
+ x = self.mean + self.std * torch.randn_like(self.mean)
45
+ return x
46
+
47
+ def kl(self, other=None, dims=(1, 2, 3)):
48
+ if self.deterministic:
49
+ return torch.Tensor([0.])
50
+ else:
51
+ if other is None:
52
+ return 0.5 * torch.mean(torch.pow(self.mean, 2)
53
+ + self.var - 1.0 - self.logvar,
54
+ dim=dims)
55
+ else:
56
+ return 0.5 * torch.mean(
57
+ torch.pow(self.mean - other.mean, 2) / other.var
58
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
59
+ dim=dims)
60
+
61
+ def nll(self, sample, dims=(1, 2, 3)):
62
+ if self.deterministic:
63
+ return torch.Tensor([0.])
64
+ logtwopi = np.log(2.0 * np.pi)
65
+ return 0.5 * torch.sum(
66
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
67
+ dim=dims)
68
+
69
+ def mode(self):
70
+ return self.mean
71
+
72
+
73
+ def normal_kl(mean1, logvar1, mean2, logvar2):
74
+ """
75
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
76
+ Compute the KL divergence between two gaussians.
77
+ Shapes are automatically broadcasted, so batches can be compared to
78
+ scalars, among other use cases.
79
+ """
80
+ tensor = None
81
+ for obj in (mean1, logvar1, mean2, logvar2):
82
+ if isinstance(obj, torch.Tensor):
83
+ tensor = obj
84
+ break
85
+ assert tensor is not None, "at least one argument must be a Tensor"
86
+
87
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
88
+ # Tensors, but it does not work for torch.exp().
89
+ logvar1, logvar2 = [
90
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
91
+ for x in (logvar1, logvar2)
92
+ ]
93
+
94
+ return 0.5 * (
95
+ -1.0
96
+ + logvar2
97
+ - logvar1
98
+ + torch.exp(logvar1 - logvar2)
99
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
100
+ )
MeshAnything/miche/michelangelo/models/modules/embedder.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import math
7
+
8
+ VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"]
9
+
10
+
11
+ class FourierEmbedder(nn.Module):
12
+ """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
13
+ each feature dimension of `x[..., i]` into:
14
+ [
15
+ sin(x[..., i]),
16
+ sin(f_1*x[..., i]),
17
+ sin(f_2*x[..., i]),
18
+ ...
19
+ sin(f_N * x[..., i]),
20
+ cos(x[..., i]),
21
+ cos(f_1*x[..., i]),
22
+ cos(f_2*x[..., i]),
23
+ ...
24
+ cos(f_N * x[..., i]),
25
+ x[..., i] # only present if include_input is True.
26
+ ], here f_i is the frequency.
27
+
28
+ Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
29
+ If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
30
+ Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
31
+
32
+ Args:
33
+ num_freqs (int): the number of frequencies, default is 6;
34
+ logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
35
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
36
+ input_dim (int): the input dimension, default is 3;
37
+ include_input (bool): include the input tensor or not, default is True.
38
+
39
+ Attributes:
40
+ frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
41
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
42
+
43
+ out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
44
+ otherwise, it is input_dim * num_freqs * 2.
45
+
46
+ """
47
+
48
+ def __init__(self,
49
+ num_freqs: int = 6,
50
+ logspace: bool = True,
51
+ input_dim: int = 3,
52
+ include_input: bool = True,
53
+ include_pi: bool = True) -> None:
54
+
55
+ """The initialization"""
56
+
57
+ super().__init__()
58
+
59
+ if logspace:
60
+ frequencies = 2.0 ** torch.arange(
61
+ num_freqs,
62
+ dtype=torch.float32
63
+ )
64
+ else:
65
+ frequencies = torch.linspace(
66
+ 1.0,
67
+ 2.0 ** (num_freqs - 1),
68
+ num_freqs,
69
+ dtype=torch.float32
70
+ )
71
+
72
+ if include_pi:
73
+ frequencies *= torch.pi
74
+
75
+ self.register_buffer("frequencies", frequencies, persistent=False)
76
+ self.include_input = include_input
77
+ self.num_freqs = num_freqs
78
+
79
+ self.out_dim = self.get_dims(input_dim)
80
+
81
+ def get_dims(self, input_dim):
82
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
83
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
84
+
85
+ return out_dim
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ """ Forward process.
89
+
90
+ Args:
91
+ x: tensor of shape [..., dim]
92
+
93
+ Returns:
94
+ embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
95
+ where temp is 1 if include_input is True and 0 otherwise.
96
+ """
97
+
98
+ if self.num_freqs > 0:
99
+ embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
100
+ if self.include_input:
101
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
102
+ else:
103
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
104
+ else:
105
+ return x
106
+
107
+
108
+ class LearnedFourierEmbedder(nn.Module):
109
+ """ following @crowsonkb "s lead with learned sinusoidal pos emb """
110
+ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
111
+
112
+ def __init__(self, in_channels, dim):
113
+ super().__init__()
114
+ assert (dim % 2) == 0
115
+ half_dim = dim // 2
116
+ per_channel_dim = half_dim // in_channels
117
+ self.weights = nn.Parameter(torch.randn(per_channel_dim))
118
+
119
+ def forward(self, x):
120
+ """
121
+
122
+ Args:
123
+ x (torch.FloatTensor): [..., c]
124
+
125
+ Returns:
126
+ x (torch.FloatTensor): [..., d]
127
+ """
128
+
129
+ # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d]
130
+ freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1)
131
+ fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1)
132
+ return fouriered
133
+
134
+
135
+ class TriplaneLearnedFourierEmbedder(nn.Module):
136
+ def __init__(self, in_channels, dim):
137
+ super().__init__()
138
+
139
+ self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
140
+ self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
141
+ self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
142
+
143
+ self.out_dim = in_channels + dim
144
+
145
+ def forward(self, x):
146
+
147
+ yz_embed = self.yz_plane_embedder(x)
148
+ xz_embed = self.xz_plane_embedder(x)
149
+ xy_embed = self.xy_plane_embedder(x)
150
+
151
+ embed = yz_embed + xz_embed + xy_embed
152
+
153
+ return embed
154
+
155
+
156
+ def sequential_pos_embed(num_len, embed_dim):
157
+ assert embed_dim % 2 == 0
158
+
159
+ pos = torch.arange(num_len, dtype=torch.float32)
160
+ omega = torch.arange(embed_dim // 2, dtype=torch.float32)
161
+ omega /= embed_dim / 2.
162
+ omega = 1. / 10000 ** omega # (D/2,)
163
+
164
+ pos = pos.reshape(-1) # (M,)
165
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
166
+
167
+ emb_sin = torch.sin(out) # (M, D/2)
168
+ emb_cos = torch.cos(out) # (M, D/2)
169
+
170
+ embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
171
+
172
+ return embeddings
173
+
174
+
175
+ def timestep_embedding(timesteps, dim, max_period=10000):
176
+ """
177
+ Create sinusoidal timestep embeddings.
178
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
179
+ These may be fractional.
180
+ :param dim: the dimension of the output.
181
+ :param max_period: controls the minimum frequency of the embeddings.
182
+ :return: an [N x dim] Tensor of positional embeddings.
183
+ """
184
+ half = dim // 2
185
+ freqs = torch.exp(
186
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
187
+ ).to(device=timesteps.device)
188
+ args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
189
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
190
+ if dim % 2:
191
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
192
+ return embedding
193
+
194
+
195
+ def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4,
196
+ num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16,
197
+ log2_hashmap_size=19, desired_resolution=None):
198
+ if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1):
199
+ return nn.Identity(), input_dim
200
+
201
+ elif embed_type == "fourier":
202
+ embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim,
203
+ logspace=True, include_input=True)
204
+ return embedder_obj, embedder_obj.out_dim
205
+
206
+ elif embed_type == "hashgrid":
207
+ raise NotImplementedError
208
+
209
+ elif embed_type == "sphere_harmonic":
210
+ raise NotImplementedError
211
+
212
+ else:
213
+ raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}")
MeshAnything/miche/michelangelo/models/modules/transformer_blocks.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from typing import Optional
8
+
9
+ from MeshAnything.miche.michelangelo.models.modules.checkpoint import checkpoint
10
+
11
+
12
+ def init_linear(l, stddev):
13
+ nn.init.normal_(l.weight, std=stddev)
14
+ if l.bias is not None:
15
+ nn.init.constant_(l.bias, 0.0)
16
+
17
+
18
+ class MultiheadAttention(nn.Module):
19
+ def __init__(
20
+ self,
21
+ *,
22
+ device: torch.device,
23
+ dtype: torch.dtype,
24
+ n_ctx: int,
25
+ width: int,
26
+ heads: int,
27
+ init_scale: float,
28
+ qkv_bias: bool,
29
+ flash: bool = False
30
+ ):
31
+ super().__init__()
32
+ self.n_ctx = n_ctx
33
+ self.width = width
34
+ self.heads = heads
35
+ self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype)
36
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
37
+ self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash)
38
+ init_linear(self.c_qkv, init_scale)
39
+ init_linear(self.c_proj, init_scale)
40
+
41
+ def forward(self, x):
42
+ x = self.c_qkv(x)
43
+ x = checkpoint(self.attention, (x,), (), True)
44
+ x = self.c_proj(x)
45
+ return x
46
+
47
+
48
+ class QKVMultiheadAttention(nn.Module):
49
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False):
50
+ super().__init__()
51
+ self.device = device
52
+ self.dtype = dtype
53
+ self.heads = heads
54
+ self.n_ctx = n_ctx
55
+ self.flash = flash
56
+
57
+ def forward(self, qkv):
58
+ bs, n_ctx, width = qkv.shape
59
+ attn_ch = width // self.heads // 3
60
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
61
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
62
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
63
+
64
+ if self.flash:
65
+ out = F.scaled_dot_product_attention(q, k, v)
66
+ else:
67
+ weight = torch.einsum(
68
+ "bthc,bshc->bhts", q * scale, k * scale
69
+ ) # More stable with f16 than dividing afterwards
70
+ wdtype = weight.dtype
71
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
72
+ out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
73
+
74
+ return out
75
+
76
+
77
+ class ResidualAttentionBlock(nn.Module):
78
+ def __init__(
79
+ self,
80
+ *,
81
+ device: torch.device,
82
+ dtype: torch.dtype,
83
+ n_ctx: int,
84
+ width: int,
85
+ heads: int,
86
+ init_scale: float = 1.0,
87
+ qkv_bias: bool = True,
88
+ flash: bool = False,
89
+ use_checkpoint: bool = False
90
+ ):
91
+ super().__init__()
92
+
93
+ self.use_checkpoint = use_checkpoint
94
+
95
+ self.attn = MultiheadAttention(
96
+ device=device,
97
+ dtype=dtype,
98
+ n_ctx=n_ctx,
99
+ width=width,
100
+ heads=heads,
101
+ init_scale=init_scale,
102
+ qkv_bias=qkv_bias,
103
+ flash=flash
104
+ )
105
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
106
+ self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
107
+ self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
108
+
109
+ def _forward(self, x: torch.Tensor):
110
+ x = x + self.attn(self.ln_1(x))
111
+ x = x + self.mlp(self.ln_2(x))
112
+ return x
113
+
114
+ def forward(self, x: torch.Tensor):
115
+ return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
116
+
117
+
118
+ class MultiheadCrossAttention(nn.Module):
119
+ def __init__(
120
+ self,
121
+ *,
122
+ device: torch.device,
123
+ dtype: torch.dtype,
124
+ width: int,
125
+ heads: int,
126
+ init_scale: float,
127
+ qkv_bias: bool = True,
128
+ flash: bool = False,
129
+ n_data: Optional[int] = None,
130
+ data_width: Optional[int] = None,
131
+ ):
132
+ super().__init__()
133
+ self.n_data = n_data
134
+ self.width = width
135
+ self.heads = heads
136
+ self.data_width = width if data_width is None else data_width
137
+ self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype)
138
+ self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype)
139
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
140
+ self.attention = QKVMultiheadCrossAttention(
141
+ device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash
142
+ )
143
+ init_linear(self.c_q, init_scale)
144
+ init_linear(self.c_kv, init_scale)
145
+ init_linear(self.c_proj, init_scale)
146
+
147
+ def forward(self, x, data):
148
+ x = self.c_q(x)
149
+ data = self.c_kv(data)
150
+ x = checkpoint(self.attention, (x, data), (), True)
151
+ x = self.c_proj(x)
152
+ return x
153
+
154
+
155
+ class QKVMultiheadCrossAttention(nn.Module):
156
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int,
157
+ flash: bool = False, n_data: Optional[int] = None):
158
+
159
+ super().__init__()
160
+ self.device = device
161
+ self.dtype = dtype
162
+ self.heads = heads
163
+ self.n_data = n_data
164
+ self.flash = flash
165
+
166
+ def forward(self, q, kv):
167
+ _, n_ctx, _ = q.shape
168
+ bs, n_data, width = kv.shape
169
+ attn_ch = width // self.heads // 2
170
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
171
+ q = q.view(bs, n_ctx, self.heads, -1)
172
+ kv = kv.view(bs, n_data, self.heads, -1)
173
+ k, v = torch.split(kv, attn_ch, dim=-1)
174
+
175
+ if self.flash:
176
+ out = F.scaled_dot_product_attention(q, k, v)
177
+ else:
178
+ weight = torch.einsum(
179
+ "bthc,bshc->bhts", q * scale, k * scale
180
+ ) # More stable with f16 than dividing afterwards
181
+ wdtype = weight.dtype
182
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
183
+ out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
184
+
185
+ return out
186
+
187
+
188
+ class ResidualCrossAttentionBlock(nn.Module):
189
+ def __init__(
190
+ self,
191
+ *,
192
+ device: Optional[torch.device],
193
+ dtype: Optional[torch.dtype],
194
+ n_data: Optional[int] = None,
195
+ width: int,
196
+ heads: int,
197
+ data_width: Optional[int] = None,
198
+ init_scale: float = 0.25,
199
+ qkv_bias: bool = True,
200
+ flash: bool = False
201
+ ):
202
+ super().__init__()
203
+
204
+ if data_width is None:
205
+ data_width = width
206
+
207
+ self.attn = MultiheadCrossAttention(
208
+ device=device,
209
+ dtype=dtype,
210
+ n_data=n_data,
211
+ width=width,
212
+ heads=heads,
213
+ data_width=data_width,
214
+ init_scale=init_scale,
215
+ qkv_bias=qkv_bias,
216
+ flash=flash,
217
+ )
218
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
219
+ self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
220
+ self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
221
+ self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
222
+
223
+ def forward(self, x: torch.Tensor, data: torch.Tensor):
224
+ x = x + self.attn(self.ln_1(x), self.ln_2(data))
225
+ x = x + self.mlp(self.ln_3(x))
226
+ return x
227
+
228
+
229
+ class MLP(nn.Module):
230
+ def __init__(self, *,
231
+ device: Optional[torch.device],
232
+ dtype: Optional[torch.dtype],
233
+ width: int,
234
+ init_scale: float):
235
+ super().__init__()
236
+ self.width = width
237
+ self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
238
+ self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
239
+ self.gelu = nn.GELU()
240
+ init_linear(self.c_fc, init_scale)
241
+ init_linear(self.c_proj, init_scale)
242
+
243
+ def forward(self, x):
244
+ return self.c_proj(self.gelu(self.c_fc(x)))
245
+
246
+
247
+ class Transformer(nn.Module):
248
+ def __init__(
249
+ self,
250
+ *,
251
+ device: Optional[torch.device],
252
+ dtype: Optional[torch.dtype],
253
+ n_ctx: int,
254
+ width: int,
255
+ layers: int,
256
+ heads: int,
257
+ init_scale: float = 0.25,
258
+ qkv_bias: bool = True,
259
+ flash: bool = False,
260
+ use_checkpoint: bool = False
261
+ ):
262
+ super().__init__()
263
+ self.n_ctx = n_ctx
264
+ self.width = width
265
+ self.layers = layers
266
+ self.resblocks = nn.ModuleList(
267
+ [
268
+ ResidualAttentionBlock(
269
+ device=device,
270
+ dtype=dtype,
271
+ n_ctx=n_ctx,
272
+ width=width,
273
+ heads=heads,
274
+ init_scale=init_scale,
275
+ qkv_bias=qkv_bias,
276
+ flash=flash,
277
+ use_checkpoint=use_checkpoint
278
+ )
279
+ for _ in range(layers)
280
+ ]
281
+ )
282
+
283
+ def forward(self, x: torch.Tensor):
284
+ for block in self.resblocks:
285
+ x = block(x)
286
+ return x
MeshAnything/miche/michelangelo/models/modules/transformer_vit.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Optional
7
+ import warnings
8
+
9
+ from MeshAnything.miche.michelangelo.models.modules.checkpoint import checkpoint
10
+
11
+
12
+ def _trunc_normal_(tensor, mean, std, a, b):
13
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
14
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
15
+ def norm_cdf(x):
16
+ # Computes standard normal cumulative distribution function
17
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
18
+
19
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
20
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
21
+ "The distribution of values may be incorrect.",
22
+ stacklevel=2)
23
+
24
+ # Values are generated by using a truncated uniform distribution and
25
+ # then using the inverse CDF for the normal distribution.
26
+ # Get upper and lower cdf values
27
+ l = norm_cdf((a - mean) / std)
28
+ u = norm_cdf((b - mean) / std)
29
+
30
+ # Uniformly fill tensor with values from [l, u], then translate to
31
+ # [2l-1, 2u-1].
32
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
33
+
34
+ # Use inverse cdf transform for normal distribution to get truncated
35
+ # standard normal
36
+ tensor.erfinv_()
37
+
38
+ # Transform to proper mean, std
39
+ tensor.mul_(std * math.sqrt(2.))
40
+ tensor.add_(mean)
41
+
42
+ # Clamp to ensure it's in the proper range
43
+ tensor.clamp_(min=a, max=b)
44
+ return tensor
45
+
46
+
47
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
48
+ # type: (Tensor | nn.Parameter, float, float, float, float) -> Tensor
49
+ r"""Fills the input Tensor with values drawn from a truncated
50
+ normal distribution. The values are effectively drawn from the
51
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
52
+ with values outside :math:`[a, b]` redrawn until they are within
53
+ the bounds. The method used for generating the random values works
54
+ best when :math:`a \leq \text{mean} \leq b`.
55
+ NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
56
+ applied while sampling the normal with mean/std applied, therefore a, b args
57
+ should be adjusted to match the range of mean, std args.
58
+ Args:
59
+ tensor: an n-dimensional `torch.Tensor`
60
+ mean: the mean of the normal distribution
61
+ std: the standard deviation of the normal distribution
62
+ a: the minimum cutoff value
63
+ b: the maximum cutoff value
64
+ Examples:
65
+ >>> w = torch.empty(3, 5)
66
+ >>> nn.init.trunc_normal_(w)
67
+ """
68
+ with torch.no_grad():
69
+ return _trunc_normal_(tensor, mean, std, a, b)
70
+
71
+
72
+ def init_weights(m):
73
+ if isinstance(m, nn.Linear):
74
+ trunc_normal_(m.weight, std=.02)
75
+ if isinstance(m, nn.Linear) and m.bias is not None:
76
+ nn.init.constant_(m.bias, 0)
77
+ elif isinstance(m, nn.LayerNorm):
78
+ nn.init.constant_(m.bias, 0)
79
+ nn.init.constant_(m.weight, 1.0)
80
+
81
+
82
+ class MultiheadAttention(nn.Module):
83
+ def __init__(
84
+ self,
85
+ *,
86
+ device: torch.device,
87
+ dtype: torch.dtype,
88
+ n_ctx: int,
89
+ width: int,
90
+ heads: int,
91
+ qkv_bias: bool
92
+ ):
93
+ super().__init__()
94
+ self.n_ctx = n_ctx
95
+ self.width = width
96
+ self.heads = heads
97
+ self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype)
98
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
99
+ self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx)
100
+
101
+ def forward(self, x):
102
+ x = self.c_qkv(x)
103
+ x = checkpoint(self.attention, (x,), (), True)
104
+ x = self.c_proj(x)
105
+ return x
106
+
107
+
108
+ class QKVMultiheadAttention(nn.Module):
109
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int):
110
+ super().__init__()
111
+ self.device = device
112
+ self.dtype = dtype
113
+ self.heads = heads
114
+ self.n_ctx = n_ctx
115
+
116
+ def forward(self, qkv):
117
+ bs, n_ctx, width = qkv.shape
118
+ attn_ch = width // self.heads // 3
119
+ scale = 1 / math.sqrt(attn_ch)
120
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
121
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
122
+ weight = torch.einsum("bthc,bshc->bhts", q, k) * scale
123
+ wdtype = weight.dtype
124
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
125
+ return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
126
+
127
+
128
+ class ResidualAttentionBlock(nn.Module):
129
+ def __init__(
130
+ self,
131
+ *,
132
+ device: torch.device,
133
+ dtype: torch.dtype,
134
+ n_ctx: int,
135
+ width: int,
136
+ heads: int,
137
+ qkv_bias: bool = True,
138
+ use_checkpoint: bool = False
139
+ ):
140
+ super().__init__()
141
+
142
+ self.use_checkpoint = use_checkpoint
143
+
144
+ self.attn = MultiheadAttention(
145
+ device=device,
146
+ dtype=dtype,
147
+ n_ctx=n_ctx,
148
+ width=width,
149
+ heads=heads,
150
+ qkv_bias=qkv_bias
151
+ )
152
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
153
+ self.mlp = MLP(device=device, dtype=dtype, width=width)
154
+ self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
155
+
156
+ def _forward(self, x: torch.Tensor):
157
+ x = x + self.attn(self.ln_1(x))
158
+ x = x + self.mlp(self.ln_2(x))
159
+ return x
160
+
161
+ def forward(self, x: torch.Tensor):
162
+ return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
163
+
164
+
165
+ class MultiheadCrossAttention(nn.Module):
166
+ def __init__(
167
+ self,
168
+ *,
169
+ device: torch.device,
170
+ dtype: torch.dtype,
171
+ width: int,
172
+ heads: int,
173
+ qkv_bias: bool = True,
174
+ n_data: Optional[int] = None,
175
+ data_width: Optional[int] = None,
176
+ ):
177
+ super().__init__()
178
+ self.n_data = n_data
179
+ self.width = width
180
+ self.heads = heads
181
+ self.data_width = width if data_width is None else data_width
182
+ self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype)
183
+ self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype)
184
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
185
+ self.attention = QKVMultiheadCrossAttention(
186
+ device=device, dtype=dtype, heads=heads, n_data=n_data
187
+ )
188
+
189
+ def forward(self, x, data):
190
+ x = self.c_q(x)
191
+ data = self.c_kv(data)
192
+ x = checkpoint(self.attention, (x, data), (), True)
193
+ x = self.c_proj(x)
194
+ return x
195
+
196
+
197
+ class QKVMultiheadCrossAttention(nn.Module):
198
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: Optional[int] = None):
199
+ super().__init__()
200
+ self.device = device
201
+ self.dtype = dtype
202
+ self.heads = heads
203
+ self.n_data = n_data
204
+
205
+ def forward(self, q, kv):
206
+ _, n_ctx, _ = q.shape
207
+ bs, n_data, width = kv.shape
208
+ attn_ch = width // self.heads // 2
209
+ scale = 1 / math.sqrt(attn_ch)
210
+ q = q.view(bs, n_ctx, self.heads, -1)
211
+ kv = kv.view(bs, n_data, self.heads, -1)
212
+ k, v = torch.split(kv, attn_ch, dim=-1)
213
+ weight = torch.einsum("bthc,bshc->bhts", q, k) * scale
214
+ wdtype = weight.dtype
215
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
216
+ return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
217
+
218
+
219
+ class ResidualCrossAttentionBlock(nn.Module):
220
+ def __init__(
221
+ self,
222
+ *,
223
+ device: Optional[torch.device],
224
+ dtype: Optional[torch.dtype],
225
+ n_data: Optional[int] = None,
226
+ width: int,
227
+ heads: int,
228
+ data_width: Optional[int] = None,
229
+ qkv_bias: bool = True
230
+ ):
231
+ super().__init__()
232
+
233
+ if data_width is None:
234
+ data_width = width
235
+
236
+ self.attn = MultiheadCrossAttention(
237
+ device=device,
238
+ dtype=dtype,
239
+ n_data=n_data,
240
+ width=width,
241
+ heads=heads,
242
+ data_width=data_width,
243
+ qkv_bias=qkv_bias
244
+ )
245
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
246
+ self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
247
+ self.mlp = MLP(device=device, dtype=dtype, width=width)
248
+ self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
249
+
250
+ def forward(self, x: torch.Tensor, data: torch.Tensor):
251
+ x = x + self.attn(self.ln_1(x), self.ln_2(data))
252
+ x = x + self.mlp(self.ln_3(x))
253
+ return x
254
+
255
+
256
+ class MLP(nn.Module):
257
+ def __init__(self, *,
258
+ device: Optional[torch.device],
259
+ dtype: Optional[torch.dtype],
260
+ width: int):
261
+ super().__init__()
262
+ self.width = width
263
+ self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
264
+ self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
265
+ self.gelu = nn.GELU()
266
+
267
+ def forward(self, x):
268
+ return self.c_proj(self.gelu(self.c_fc(x)))
269
+
270
+
271
+ class Transformer(nn.Module):
272
+ def __init__(
273
+ self,
274
+ *,
275
+ device: Optional[torch.device],
276
+ dtype: Optional[torch.dtype],
277
+ n_ctx: int,
278
+ width: int,
279
+ layers: int,
280
+ heads: int,
281
+ qkv_bias: bool = True,
282
+ use_checkpoint: bool = False
283
+ ):
284
+ super().__init__()
285
+ self.n_ctx = n_ctx
286
+ self.width = width
287
+ self.layers = layers
288
+ self.resblocks = nn.ModuleList(
289
+ [
290
+ ResidualAttentionBlock(
291
+ device=device,
292
+ dtype=dtype,
293
+ n_ctx=n_ctx,
294
+ width=width,
295
+ heads=heads,
296
+ qkv_bias=qkv_bias,
297
+ use_checkpoint=use_checkpoint
298
+ )
299
+ for _ in range(layers)
300
+ ]
301
+ )
302
+
303
+ self.apply(init_weights)
304
+
305
+ def forward(self, x: torch.Tensor):
306
+ for block in self.resblocks:
307
+ x = block(x)
308
+ return x
MeshAnything/miche/michelangelo/models/tsal/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
MeshAnything/miche/michelangelo/models/tsal/asl_pl_module.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import List, Tuple, Dict, Optional
4
+ from omegaconf import DictConfig
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ from torch.optim import lr_scheduler
10
+ from typing import Union
11
+ from functools import partial
12
+
13
+ from MeshAnything.miche.michelangelo.utils import instantiate_from_config
14
+
15
+ from .tsal_base import (
16
+ AlignedShapeAsLatentModule,
17
+ ShapeAsLatentModule,
18
+ Latent2MeshOutput,
19
+ AlignedMeshOutput
20
+ )
21
+ from MeshAnything.miche.michelangelo.models.tsal.inference_utils import extract_geometry
22
+ import trimesh
23
+
24
+ class AlignedShapeAsLatentPLModule(nn.Module):
25
+ def __init__(self, *,
26
+ shape_module_cfg,
27
+ aligned_module_cfg,
28
+ loss_cfg,
29
+ optimizer_cfg: Optional[DictConfig] = None,
30
+ ckpt_path: Optional[str] = None,
31
+ ignore_keys: Union[Tuple[str], List[str]] = ()):
32
+
33
+ super().__init__()
34
+
35
+ shape_model: ShapeAsLatentModule = instantiate_from_config(
36
+ shape_module_cfg, device=None, dtype=None
37
+ )
38
+ self.model: AlignedShapeAsLatentModule = instantiate_from_config(
39
+ aligned_module_cfg, shape_model=shape_model
40
+ )
41
+
42
+ self.loss = instantiate_from_config(loss_cfg)
43
+
44
+ self.optimizer_cfg = optimizer_cfg
45
+
46
+ if ckpt_path is not None:
47
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
48
+
49
+ def set_shape_model_only(self):
50
+ self.model.set_shape_model_only()
51
+
52
+
53
+
54
+ @property
55
+ def latent_shape(self):
56
+ return self.model.shape_model.latent_shape
57
+
58
+ @property
59
+ def zero_rank(self):
60
+ if self._trainer:
61
+ zero_rank = self.trainer.local_rank == 0
62
+ else:
63
+ zero_rank = True
64
+
65
+ return zero_rank
66
+
67
+ def init_from_ckpt(self, path, ignore_keys=()):
68
+ state_dict = torch.load(path, map_location="cpu")["state_dict"]
69
+
70
+ keys = list(state_dict.keys())
71
+ for k in keys:
72
+ for ik in ignore_keys:
73
+ if k.startswith(ik):
74
+ print("Deleting key {} from state_dict.".format(k))
75
+ del state_dict[k]
76
+
77
+ missing, unexpected = self.load_state_dict(state_dict, strict=False)
78
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
79
+ if len(missing) > 0:
80
+ print(f"Missing Keys: {missing}")
81
+ print(f"Unexpected Keys: {unexpected}")
82
+
83
+ def configure_optimizers(self) -> Tuple[List, List]:
84
+ lr = self.learning_rate
85
+
86
+ trainable_parameters = list(self.model.parameters())
87
+
88
+ if self.optimizer_cfg is None:
89
+ optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
90
+ schedulers = []
91
+ else:
92
+ optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
93
+ scheduler_func = instantiate_from_config(
94
+ self.optimizer_cfg.scheduler,
95
+ max_decay_steps=self.trainer.max_steps,
96
+ lr_max=lr
97
+ )
98
+ scheduler = {
99
+ "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
100
+ "interval": "step",
101
+ "frequency": 1
102
+ }
103
+ optimizers = [optimizer]
104
+ schedulers = [scheduler]
105
+
106
+ return optimizers, schedulers
107
+
108
+ def forward(self,
109
+ surface: torch.FloatTensor,
110
+ image: torch.FloatTensor,
111
+ text: torch.FloatTensor,
112
+ volume_queries: torch.FloatTensor):
113
+
114
+ """
115
+
116
+ Args:
117
+ surface (torch.FloatTensor):
118
+ image (torch.FloatTensor):
119
+ text (torch.FloatTensor):
120
+ volume_queries (torch.FloatTensor):
121
+
122
+ Returns:
123
+
124
+ """
125
+
126
+ embed_outputs, shape_z = self.model(surface, image, text)
127
+
128
+ shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z)
129
+ latents = self.model.shape_model.decode(shape_zq)
130
+ logits = self.model.shape_model.query_geometry(volume_queries, latents)
131
+
132
+ return embed_outputs, logits, posterior
133
+
134
+ def encode(self, surface: torch.FloatTensor, sample_posterior=True):
135
+
136
+ pc = surface[..., 0:3]
137
+ feats = surface[..., 3:6]
138
+
139
+ shape_embed, shape_zq, posterior = self.model.shape_model.encode(
140
+ pc=pc, feats=feats, sample_posterior=sample_posterior
141
+ )
142
+
143
+ return shape_zq
144
+
145
+ def encode_latents(self, surface: torch.FloatTensor):
146
+
147
+ pc = surface[..., 0:3]
148
+ feats = surface[..., 3:6]
149
+
150
+ shape_embed, shape_latents = self.model.shape_model.encode_latents(
151
+ pc=pc, feats=feats
152
+ )
153
+ shape_embed = shape_embed.unsqueeze(1)
154
+ assert shape_embed.shape[1] == 1 and shape_latents.shape[1] == 256
155
+ cat_latents = torch.cat([shape_embed, shape_latents], dim=1)
156
+
157
+ return cat_latents
158
+
159
+ def recon(self, surface):
160
+ cat_latents = self.encode_latents(surface)
161
+ shape_latents = cat_latents[:, 1:]
162
+ shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_latents)
163
+
164
+ # decoding
165
+ latents = self.model.shape_model.decode(shape_zq)
166
+ geometric_func = partial(self.model.shape_model.query_geometry, latents=latents)
167
+
168
+ # reconstruction
169
+ mesh_v_f, has_surface = extract_geometry(
170
+ geometric_func=geometric_func,
171
+ device=surface.device,
172
+ batch_size=surface.shape[0],
173
+ bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
174
+ octree_depth=7,
175
+ num_chunks=10000,
176
+ )
177
+ recon_mesh = trimesh.Trimesh(mesh_v_f[0][0], mesh_v_f[0][1])
178
+
179
+ return recon_mesh
180
+
181
+
182
+ def to_shape_latents(self, latents):
183
+
184
+ shape_zq, posterior = self.model.shape_model.encode_kl_embed(latents, sample_posterior = False)
185
+ return self.model.shape_model.decode(shape_zq)
186
+
187
+ def decode(self,
188
+ z_q,
189
+ bounds: Union[Tuple[float], List[float], float] = 1.1,
190
+ octree_depth: int = 7,
191
+ num_chunks: int = 10000) -> List[Latent2MeshOutput]:
192
+
193
+ latents = self.model.shape_model.decode(z_q) # latents: [bs, num_latents, dim]
194
+ outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks)
195
+
196
+ return outputs
197
+
198
+ def training_step(self, batch: Dict[str, torch.FloatTensor],
199
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
200
+ """
201
+
202
+ Args:
203
+ batch (dict): the batch sample, and it contains:
204
+ - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)]
205
+ - image (torch.FloatTensor): [bs, 3, 224, 224]
206
+ - text (torch.FloatTensor): [bs, num_templates, 77]
207
+ - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)]
208
+
209
+ batch_idx (int):
210
+
211
+ optimizer_idx (int):
212
+
213
+ Returns:
214
+ loss (torch.FloatTensor):
215
+
216
+ """
217
+
218
+ surface = batch["surface"]
219
+ image = batch["image"]
220
+ text = batch["text"]
221
+
222
+ volume_queries = batch["geo_points"][..., 0:3]
223
+ shape_labels = batch["geo_points"][..., -1]
224
+
225
+ embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries)
226
+
227
+ aeloss, log_dict_ae = self.loss(
228
+ **embed_outputs,
229
+ posteriors=posteriors,
230
+ shape_logits=shape_logits,
231
+ shape_labels=shape_labels,
232
+ split="train"
233
+ )
234
+
235
+ self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0],
236
+ sync_dist=False, rank_zero_only=True)
237
+
238
+ return aeloss
239
+
240
+ def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor:
241
+
242
+ surface = batch["surface"]
243
+ image = batch["image"]
244
+ text = batch["text"]
245
+
246
+ volume_queries = batch["geo_points"][..., 0:3]
247
+ shape_labels = batch["geo_points"][..., -1]
248
+
249
+ embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries)
250
+
251
+ aeloss, log_dict_ae = self.loss(
252
+ **embed_outputs,
253
+ posteriors=posteriors,
254
+ shape_logits=shape_logits,
255
+ shape_labels=shape_labels,
256
+ split="val"
257
+ )
258
+ self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0],
259
+ sync_dist=False, rank_zero_only=True)
260
+
261
+ return aeloss
262
+
263
+ def visual_alignment(self,
264
+ surface: torch.FloatTensor,
265
+ image: torch.FloatTensor,
266
+ text: torch.FloatTensor,
267
+ description: Optional[List[str]] = None,
268
+ bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
269
+ octree_depth: int = 7,
270
+ num_chunks: int = 10000) -> List[AlignedMeshOutput]:
271
+
272
+ """
273
+
274
+ Args:
275
+ surface:
276
+ image:
277
+ text:
278
+ description:
279
+ bounds:
280
+ octree_depth:
281
+ num_chunks:
282
+
283
+ Returns:
284
+ mesh_outputs (List[AlignedMeshOutput]): the mesh outputs list.
285
+
286
+ """
287
+
288
+ outputs = []
289
+
290
+ device = surface.device
291
+ bs = surface.shape[0]
292
+
293
+ embed_outputs, shape_z = self.model(surface, image, text)
294
+
295
+ # calculate the similarity
296
+ image_embed = embed_outputs["image_embed"]
297
+ text_embed = embed_outputs["text_embed"]
298
+ shape_embed = embed_outputs["shape_embed"]
299
+
300
+ # normalized features
301
+ shape_embed = F.normalize(shape_embed, dim=-1, p=2)
302
+ text_embed = F.normalize(text_embed, dim=-1, p=2)
303
+ image_embed = F.normalize(image_embed, dim=-1, p=2)
304
+
305
+ # B x B
306
+ shape_text_similarity = (100.0 * shape_embed @ text_embed.T).softmax(dim=-1)
307
+
308
+ # B x B
309
+ shape_image_similarity = (100.0 * shape_embed @ image_embed.T).softmax(dim=-1)
310
+
311
+ # shape reconstruction
312
+ shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z)
313
+ latents = self.model.shape_model.decode(shape_zq)
314
+ geometric_func = partial(self.model.shape_model.query_geometry, latents=latents)
315
+
316
+ # 2. decode geometry
317
+ mesh_v_f, has_surface = extract_geometry(
318
+ geometric_func=geometric_func,
319
+ device=device,
320
+ batch_size=bs,
321
+ bounds=bounds,
322
+ octree_depth=octree_depth,
323
+ num_chunks=num_chunks,
324
+ disable=not self.zero_rank
325
+ )
326
+
327
+ # 3. decode texture
328
+ for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
329
+ if not is_surface:
330
+ outputs.append(None)
331
+ continue
332
+
333
+ out = AlignedMeshOutput()
334
+ out.mesh_v = mesh_v
335
+ out.mesh_f = mesh_f
336
+ out.surface = surface[i].cpu().numpy()
337
+ out.image = image[i].cpu().numpy()
338
+ if description is not None:
339
+ out.text = description[i]
340
+ out.shape_text_similarity = shape_text_similarity[i, i]
341
+ out.shape_image_similarity = shape_image_similarity[i, i]
342
+
343
+ outputs.append(out)
344
+
345
+ return outputs
346
+
347
+ def latent2mesh(self,
348
+ latents: torch.FloatTensor,
349
+ bounds: Union[Tuple[float], List[float], float] = 1.1,
350
+ octree_depth: int = 7,
351
+ num_chunks: int = 10000) -> List[Latent2MeshOutput]:
352
+
353
+ """
354
+
355
+ Args:
356
+ latents: [bs, num_latents, dim]
357
+ bounds:
358
+ octree_depth:
359
+ num_chunks:
360
+
361
+ Returns:
362
+ mesh_outputs (List[MeshOutput]): the mesh outputs list.
363
+
364
+ """
365
+
366
+ outputs = []
367
+
368
+ geometric_func = partial(self.model.shape_model.query_geometry, latents=latents)
369
+
370
+ # 2. decode geometry
371
+ device = latents.device
372
+ mesh_v_f, has_surface = extract_geometry(
373
+ geometric_func=geometric_func,
374
+ device=device,
375
+ batch_size=len(latents),
376
+ bounds=bounds,
377
+ octree_depth=octree_depth,
378
+ num_chunks=num_chunks,
379
+ disable=not self.zero_rank
380
+ )
381
+
382
+ # 3. decode texture
383
+ for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
384
+ if not is_surface:
385
+ outputs.append(None)
386
+ continue
387
+
388
+ out = Latent2MeshOutput()
389
+ out.mesh_v = mesh_v
390
+ out.mesh_f = mesh_f
391
+
392
+ outputs.append(out)
393
+
394
+ return outputs
395
+
MeshAnything/miche/michelangelo/models/tsal/clip_asl_module.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from torch import nn
5
+ from einops import rearrange
6
+ from transformers import CLIPModel
7
+
8
+ from MeshAnything.miche.michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentModule
9
+
10
+
11
+ class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule):
12
+
13
+ def __init__(self, *,
14
+ shape_model,
15
+ clip_model_version: str = "openai/clip-vit-large-patch14"):
16
+
17
+ super().__init__()
18
+
19
+ # self.clip_model: CLIPModel = CLIPModel.from_pretrained(clip_model_version)
20
+ # for params in self.clip_model.parameters():
21
+ # params.requires_grad = False
22
+ self.clip_model = None
23
+ self.shape_model = shape_model
24
+ self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, self.shape_model.width))
25
+ # nn.init.normal_(self.shape_projection, std=self.shape_model.width ** -0.5)
26
+
27
+ def set_shape_model_only(self):
28
+ self.clip_model = None
29
+
30
+ def encode_shape_embed(self, surface, return_latents: bool = False):
31
+ """
32
+
33
+ Args:
34
+ surface (torch.FloatTensor): [bs, n, 3 + c]
35
+ return_latents (bool):
36
+
37
+ Returns:
38
+ x (torch.FloatTensor): [bs, projection_dim]
39
+ shape_latents (torch.FloatTensor): [bs, m, d]
40
+ """
41
+
42
+ pc = surface[..., 0:3]
43
+ feats = surface[..., 3:]
44
+
45
+ shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats)
46
+ x = shape_embed @ self.shape_projection
47
+
48
+ if return_latents:
49
+ return x, shape_latents
50
+ else:
51
+ return x
52
+
53
+ def encode_image_embed(self, image):
54
+ """
55
+
56
+ Args:
57
+ image (torch.FloatTensor): [bs, 3, h, w]
58
+
59
+ Returns:
60
+ x (torch.FloatTensor): [bs, projection_dim]
61
+ """
62
+
63
+ x = self.clip_model.get_image_features(image)
64
+
65
+ return x
66
+
67
+ def encode_text_embed(self, text):
68
+ x = self.clip_model.get_text_features(text)
69
+ return x
70
+
71
+ def forward(self, surface, image, text):
72
+ """
73
+
74
+ Args:
75
+ surface (torch.FloatTensor):
76
+ image (torch.FloatTensor): [bs, 3, 224, 224]
77
+ text (torch.LongTensor): [bs, num_templates, 77]
78
+
79
+ Returns:
80
+ embed_outputs (dict): the embedding outputs, and it contains:
81
+ - image_embed (torch.FloatTensor):
82
+ - text_embed (torch.FloatTensor):
83
+ - shape_embed (torch.FloatTensor):
84
+ - logit_scale (float):
85
+ """
86
+
87
+ # # text embedding
88
+ # text_embed_all = []
89
+ # for i in range(text.shape[0]):
90
+ # text_for_one_sample = text[i]
91
+ # text_embed = self.encode_text_embed(text_for_one_sample)
92
+ # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
93
+ # text_embed = text_embed.mean(dim=0)
94
+ # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
95
+ # text_embed_all.append(text_embed)
96
+ # text_embed_all = torch.stack(text_embed_all)
97
+
98
+ b = text.shape[0]
99
+ text_tokens = rearrange(text, "b t l -> (b t) l")
100
+ text_embed = self.encode_text_embed(text_tokens)
101
+ text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
102
+ text_embed = text_embed.mean(dim=1)
103
+ text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
104
+
105
+ # image embedding
106
+ image_embed = self.encode_image_embed(image)
107
+
108
+ # shape embedding
109
+ shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True)
110
+
111
+ embed_outputs = {
112
+ "image_embed": image_embed,
113
+ "text_embed": text_embed,
114
+ "shape_embed": shape_embed,
115
+ # "logit_scale": self.clip_model.logit_scale.exp()
116
+ }
117
+
118
+ return embed_outputs, shape_latents
MeshAnything/miche/michelangelo/models/tsal/inference_utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from tqdm import tqdm
5
+ from einops import repeat
6
+ import numpy as np
7
+ from typing import Callable, Tuple, List, Union, Optional
8
+ from skimage import measure
9
+
10
+ from MeshAnything.miche.michelangelo.graphics.primitives import generate_dense_grid_points
11
+
12
+
13
+ @torch.no_grad()
14
+ def extract_geometry(geometric_func: Callable,
15
+ device: torch.device,
16
+ batch_size: int = 1,
17
+ bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
18
+ octree_depth: int = 7,
19
+ num_chunks: int = 10000,
20
+ disable: bool = True):
21
+ """
22
+
23
+ Args:
24
+ geometric_func:
25
+ device:
26
+ bounds:
27
+ octree_depth:
28
+ batch_size:
29
+ num_chunks:
30
+ disable:
31
+
32
+ Returns:
33
+
34
+ """
35
+
36
+ if isinstance(bounds, float):
37
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
38
+
39
+ bbox_min = np.array(bounds[0:3])
40
+ bbox_max = np.array(bounds[3:6])
41
+ bbox_size = bbox_max - bbox_min
42
+
43
+ xyz_samples, grid_size, length = generate_dense_grid_points(
44
+ bbox_min=bbox_min,
45
+ bbox_max=bbox_max,
46
+ octree_depth=octree_depth,
47
+ indexing="ij"
48
+ )
49
+ xyz_samples = torch.FloatTensor(xyz_samples)
50
+
51
+ batch_logits = []
52
+ for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
53
+ desc="Implicit Function:", disable=disable, leave=False):
54
+ queries = xyz_samples[start: start + num_chunks, :].to(device)
55
+ batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
56
+
57
+ logits = geometric_func(batch_queries)
58
+ batch_logits.append(logits.cpu())
59
+
60
+ grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy()
61
+
62
+ mesh_v_f = []
63
+ has_surface = np.zeros((batch_size,), dtype=np.bool_)
64
+ for i in range(batch_size):
65
+ try:
66
+ vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner")
67
+ vertices = vertices / grid_size * bbox_size + bbox_min
68
+ # vertices[:, [0, 1]] = vertices[:, [1, 0]]
69
+ mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces)))
70
+ has_surface[i] = True
71
+
72
+ except ValueError:
73
+ mesh_v_f.append((None, None))
74
+ has_surface[i] = False
75
+
76
+ except RuntimeError:
77
+ mesh_v_f.append((None, None))
78
+ has_surface[i] = False
79
+
80
+ return mesh_v_f, has_surface
MeshAnything/miche/michelangelo/models/tsal/loss.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from typing import Optional, Tuple, Dict
7
+
8
+ from MeshAnything.miche.michelangelo.models.modules.distributions import DiagonalGaussianDistribution
9
+ from MeshAnything.miche.michelangelo.utils.eval import compute_psnr
10
+ from MeshAnything.miche.michelangelo.utils import misc
11
+
12
+
13
+ class KLNearFar(nn.Module):
14
+ def __init__(self,
15
+ near_weight: float = 0.1,
16
+ kl_weight: float = 1.0,
17
+ num_near_samples: Optional[int] = None):
18
+
19
+ super().__init__()
20
+
21
+ self.near_weight = near_weight
22
+ self.kl_weight = kl_weight
23
+ self.num_near_samples = num_near_samples
24
+ self.geo_criterion = nn.BCEWithLogitsLoss()
25
+
26
+ def forward(self,
27
+ posteriors: Optional[DiagonalGaussianDistribution],
28
+ logits: torch.FloatTensor,
29
+ labels: torch.FloatTensor,
30
+ split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
31
+
32
+ """
33
+
34
+ Args:
35
+ posteriors (DiagonalGaussianDistribution or torch.distributions.Normal):
36
+ logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points;
37
+ labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points;
38
+ split (str):
39
+ **kwargs:
40
+
41
+ Returns:
42
+ loss (torch.Tensor): (,)
43
+ log (dict):
44
+
45
+ """
46
+
47
+ if self.num_near_samples is None:
48
+ num_vol = logits.shape[1] // 2
49
+ else:
50
+ num_vol = logits.shape[1] - self.num_near_samples
51
+
52
+ vol_logits = logits[:, 0:num_vol]
53
+ vol_labels = labels[:, 0:num_vol]
54
+
55
+ near_logits = logits[:, num_vol:]
56
+ near_labels = labels[:, num_vol:]
57
+
58
+ # occupancy loss
59
+ # vol_bce = self.geo_criterion(vol_logits, vol_labels)
60
+ # near_bce = self.geo_criterion(near_logits, near_labels)
61
+ vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
62
+ near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
63
+
64
+ if posteriors is None:
65
+ kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device)
66
+ else:
67
+ kl_loss = posteriors.kl(dims=(1, 2))
68
+ kl_loss = torch.mean(kl_loss)
69
+
70
+ loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight
71
+
72
+ with torch.no_grad():
73
+ preds = logits >= 0
74
+ accuracy = (preds == labels).float()
75
+ accuracy = accuracy.mean()
76
+ pos_ratio = torch.mean(labels)
77
+
78
+ log = {
79
+ "{}/total_loss".format(split): loss.clone().detach(),
80
+ "{}/near".format(split): near_bce.detach(),
81
+ "{}/far".format(split): vol_bce.detach(),
82
+ "{}/kl".format(split): kl_loss.detach(),
83
+ "{}/accuracy".format(split): accuracy,
84
+ "{}/pos_ratio".format(split): pos_ratio
85
+ }
86
+
87
+ if posteriors is not None:
88
+ log[f"{split}/mean"] = posteriors.mean.mean().detach()
89
+ log[f"{split}/std_mean"] = posteriors.std.mean().detach()
90
+ log[f"{split}/std_max"] = posteriors.std.max().detach()
91
+
92
+ return loss, log
93
+
94
+
95
+ class KLNearFarColor(nn.Module):
96
+ def __init__(self,
97
+ near_weight: float = 0.1,
98
+ kl_weight: float = 1.0,
99
+ color_weight: float = 1.0,
100
+ color_criterion: str = "mse",
101
+ num_near_samples: Optional[int] = None):
102
+
103
+ super().__init__()
104
+
105
+ self.color_weight = color_weight
106
+ self.near_weight = near_weight
107
+ self.kl_weight = kl_weight
108
+ self.num_near_samples = num_near_samples
109
+
110
+ if color_criterion == "mse":
111
+ self.color_criterion = nn.MSELoss()
112
+
113
+ elif color_criterion == "l1":
114
+ self.color_criterion = nn.L1Loss()
115
+
116
+ else:
117
+ raise ValueError(f"{color_criterion} must be [`mse`, `l1`].")
118
+
119
+ self.geo_criterion = nn.BCEWithLogitsLoss()
120
+
121
+ def forward(self,
122
+ posteriors: Optional[DiagonalGaussianDistribution],
123
+ logits: torch.FloatTensor,
124
+ labels: torch.FloatTensor,
125
+ pred_colors: torch.FloatTensor,
126
+ gt_colors: torch.FloatTensor,
127
+ split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
128
+
129
+ """
130
+
131
+ Args:
132
+ posteriors (DiagonalGaussianDistribution or torch.distributions.Normal):
133
+ logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points;
134
+ labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points;
135
+ pred_colors (torch.FloatTensor): [B, M, 3]
136
+ gt_colors (torch.FloatTensor): [B, M, 3]
137
+ split (str):
138
+ **kwargs:
139
+
140
+ Returns:
141
+ loss (torch.Tensor): (,)
142
+ log (dict):
143
+
144
+ """
145
+
146
+ if self.num_near_samples is None:
147
+ num_vol = logits.shape[1] // 2
148
+ else:
149
+ num_vol = logits.shape[1] - self.num_near_samples
150
+
151
+ vol_logits = logits[:, 0:num_vol]
152
+ vol_labels = labels[:, 0:num_vol]
153
+
154
+ near_logits = logits[:, num_vol:]
155
+ near_labels = labels[:, num_vol:]
156
+
157
+ # occupancy loss
158
+ # vol_bce = self.geo_criterion(vol_logits, vol_labels)
159
+ # near_bce = self.geo_criterion(near_logits, near_labels)
160
+ vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
161
+ near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
162
+
163
+ # surface color loss
164
+ color = self.color_criterion(pred_colors, gt_colors)
165
+
166
+ if posteriors is None:
167
+ kl_loss = torch.tensor(0.0, dtype=pred_colors.dtype, device=pred_colors.device)
168
+ else:
169
+ kl_loss = posteriors.kl(dims=(1, 2))
170
+ kl_loss = torch.mean(kl_loss)
171
+
172
+ loss = vol_bce + near_bce * self.near_weight + color * self.color_weight + kl_loss * self.kl_weight
173
+
174
+ with torch.no_grad():
175
+ preds = logits >= 0
176
+ accuracy = (preds == labels).float()
177
+ accuracy = accuracy.mean()
178
+ psnr = compute_psnr(pred_colors, gt_colors)
179
+
180
+ log = {
181
+ "{}/total_loss".format(split): loss.clone().detach(),
182
+ "{}/near".format(split): near_bce.detach(),
183
+ "{}/far".format(split): vol_bce.detach(),
184
+ "{}/color".format(split): color.detach(),
185
+ "{}/kl".format(split): kl_loss.detach(),
186
+ "{}/psnr".format(split): psnr.detach(),
187
+ "{}/accuracy".format(split): accuracy
188
+ }
189
+
190
+ return loss, log
191
+
192
+
193
+ class ContrastKLNearFar(nn.Module):
194
+ def __init__(self,
195
+ contrast_weight: float = 1.0,
196
+ near_weight: float = 0.1,
197
+ kl_weight: float = 1.0,
198
+ num_near_samples: Optional[int] = None):
199
+
200
+ super().__init__()
201
+
202
+ self.labels = None
203
+ self.last_local_batch_size = None
204
+
205
+ self.contrast_weight = contrast_weight
206
+ self.near_weight = near_weight
207
+ self.kl_weight = kl_weight
208
+ self.num_near_samples = num_near_samples
209
+ self.geo_criterion = nn.BCEWithLogitsLoss()
210
+
211
+ def forward(self,
212
+ shape_embed: torch.FloatTensor,
213
+ text_embed: torch.FloatTensor,
214
+ image_embed: torch.FloatTensor,
215
+ logit_scale: torch.FloatTensor,
216
+ posteriors: Optional[DiagonalGaussianDistribution],
217
+ shape_logits: torch.FloatTensor,
218
+ shape_labels: torch.FloatTensor,
219
+ split: Optional[str] = "train", **kwargs):
220
+
221
+ local_batch_size = shape_embed.size(0)
222
+
223
+ if local_batch_size != self.last_local_batch_size:
224
+ self.labels = local_batch_size * misc.get_rank() + torch.arange(
225
+ local_batch_size, device=shape_embed.device
226
+ ).long()
227
+ self.last_local_batch_size = local_batch_size
228
+
229
+ # normalized features
230
+ shape_embed = F.normalize(shape_embed, dim=-1, p=2)
231
+ text_embed = F.normalize(text_embed, dim=-1, p=2)
232
+ image_embed = F.normalize(image_embed, dim=-1, p=2)
233
+
234
+ # gather features from all GPUs
235
+ shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch(
236
+ [shape_embed, text_embed, image_embed]
237
+ )
238
+
239
+ # cosine similarity as logits
240
+ logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t()
241
+ logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t()
242
+ logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t()
243
+ logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t()
244
+ contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) +
245
+ F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \
246
+ (F.cross_entropy(logits_per_shape_image, self.labels) +
247
+ F.cross_entropy(logits_per_image_shape, self.labels)) / 2
248
+
249
+ # shape reconstruction
250
+ if self.num_near_samples is None:
251
+ num_vol = shape_logits.shape[1] // 2
252
+ else:
253
+ num_vol = shape_logits.shape[1] - self.num_near_samples
254
+
255
+ vol_logits = shape_logits[:, 0:num_vol]
256
+ vol_labels = shape_labels[:, 0:num_vol]
257
+
258
+ near_logits = shape_logits[:, num_vol:]
259
+ near_labels = shape_labels[:, num_vol:]
260
+
261
+ # occupancy loss
262
+ vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
263
+ near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
264
+
265
+ if posteriors is None:
266
+ kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device)
267
+ else:
268
+ kl_loss = posteriors.kl(dims=(1, 2))
269
+ kl_loss = torch.mean(kl_loss)
270
+
271
+ loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight
272
+
273
+ # compute accuracy
274
+ with torch.no_grad():
275
+ pred = torch.argmax(logits_per_shape_text, dim=-1)
276
+ correct = pred.eq(self.labels).sum()
277
+ shape_text_acc = 100 * correct / local_batch_size
278
+
279
+ pred = torch.argmax(logits_per_shape_image, dim=-1)
280
+ correct = pred.eq(self.labels).sum()
281
+ shape_image_acc = 100 * correct / local_batch_size
282
+
283
+ preds = shape_logits >= 0
284
+ accuracy = (preds == shape_labels).float()
285
+ accuracy = accuracy.mean()
286
+
287
+ log = {
288
+ "{}/contrast".format(split): contrast_loss.clone().detach(),
289
+ "{}/near".format(split): near_bce.detach(),
290
+ "{}/far".format(split): vol_bce.detach(),
291
+ "{}/kl".format(split): kl_loss.detach(),
292
+ "{}/shape_text_acc".format(split): shape_text_acc,
293
+ "{}/shape_image_acc".format(split): shape_image_acc,
294
+ "{}/total_loss".format(split): loss.clone().detach(),
295
+ "{}/accuracy".format(split): accuracy,
296
+ }
297
+
298
+ if posteriors is not None:
299
+ log[f"{split}/mean"] = posteriors.mean.mean().detach()
300
+ log[f"{split}/std_mean"] = posteriors.std.mean().detach()
301
+ log[f"{split}/std_max"] = posteriors.std.max().detach()
302
+
303
+ return loss, log
MeshAnything/miche/michelangelo/models/tsal/sal_perceiver.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Optional
6
+ from einops import repeat
7
+ import math
8
+
9
+ from MeshAnything.miche.michelangelo.models.modules import checkpoint
10
+ from MeshAnything.miche.michelangelo.models.modules.embedder import FourierEmbedder
11
+ from MeshAnything.miche.michelangelo.models.modules.distributions import DiagonalGaussianDistribution
12
+ from MeshAnything.miche.michelangelo.models.modules.transformer_blocks import (
13
+ ResidualCrossAttentionBlock,
14
+ Transformer
15
+ )
16
+
17
+ from .tsal_base import ShapeAsLatentModule
18
+
19
+
20
+ class CrossAttentionEncoder(nn.Module):
21
+
22
+ def __init__(self, *,
23
+ device: Optional[torch.device],
24
+ dtype: Optional[torch.dtype],
25
+ num_latents: int,
26
+ fourier_embedder: FourierEmbedder,
27
+ point_feats: int,
28
+ width: int,
29
+ heads: int,
30
+ layers: int,
31
+ init_scale: float = 0.25,
32
+ qkv_bias: bool = True,
33
+ flash: bool = False,
34
+ use_ln_post: bool = False,
35
+ use_checkpoint: bool = False):
36
+
37
+ super().__init__()
38
+
39
+ self.use_checkpoint = use_checkpoint
40
+ self.num_latents = num_latents
41
+
42
+ self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
43
+
44
+ self.fourier_embedder = fourier_embedder
45
+ self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype)
46
+ self.cross_attn = ResidualCrossAttentionBlock(
47
+ device=device,
48
+ dtype=dtype,
49
+ width=width,
50
+ heads=heads,
51
+ init_scale=init_scale,
52
+ qkv_bias=qkv_bias,
53
+ flash=flash,
54
+ )
55
+
56
+ self.self_attn = Transformer(
57
+ device=device,
58
+ dtype=dtype,
59
+ n_ctx=num_latents,
60
+ width=width,
61
+ layers=layers,
62
+ heads=heads,
63
+ init_scale=init_scale,
64
+ qkv_bias=qkv_bias,
65
+ flash=flash,
66
+ use_checkpoint=False
67
+ )
68
+
69
+ if use_ln_post:
70
+ self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device)
71
+ else:
72
+ self.ln_post = None
73
+
74
+ def _forward(self, pc, feats):
75
+ """
76
+
77
+ Args:
78
+ pc (torch.FloatTensor): [B, N, 3]
79
+ feats (torch.FloatTensor or None): [B, N, C]
80
+
81
+ Returns:
82
+
83
+ """
84
+
85
+ bs = pc.shape[0]
86
+
87
+ data = self.fourier_embedder(pc)
88
+ if feats is not None:
89
+ data = torch.cat([data, feats], dim=-1)
90
+ data = self.input_proj(data)
91
+
92
+ query = repeat(self.query, "m c -> b m c", b=bs)
93
+ latents = self.cross_attn(query, data)
94
+ latents = self.self_attn(latents)
95
+
96
+ if self.ln_post is not None:
97
+ latents = self.ln_post(latents)
98
+
99
+ return latents, pc
100
+
101
+ def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
102
+ """
103
+
104
+ Args:
105
+ pc (torch.FloatTensor): [B, N, 3]
106
+ feats (torch.FloatTensor or None): [B, N, C]
107
+
108
+ Returns:
109
+ dict
110
+ """
111
+
112
+ return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint)
113
+
114
+
115
+ class CrossAttentionDecoder(nn.Module):
116
+
117
+ def __init__(self, *,
118
+ device: Optional[torch.device],
119
+ dtype: Optional[torch.dtype],
120
+ num_latents: int,
121
+ out_channels: int,
122
+ fourier_embedder: FourierEmbedder,
123
+ width: int,
124
+ heads: int,
125
+ init_scale: float = 0.25,
126
+ qkv_bias: bool = True,
127
+ flash: bool = False,
128
+ use_checkpoint: bool = False):
129
+
130
+ super().__init__()
131
+
132
+ self.use_checkpoint = use_checkpoint
133
+ self.fourier_embedder = fourier_embedder
134
+
135
+ self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
136
+
137
+ self.cross_attn_decoder = ResidualCrossAttentionBlock(
138
+ device=device,
139
+ dtype=dtype,
140
+ n_data=num_latents,
141
+ width=width,
142
+ heads=heads,
143
+ init_scale=init_scale,
144
+ qkv_bias=qkv_bias,
145
+ flash=flash
146
+ )
147
+
148
+ self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
149
+ self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
150
+
151
+ def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
152
+ queries = self.query_proj(self.fourier_embedder(queries))
153
+ x = self.cross_attn_decoder(queries, latents)
154
+ x = self.ln_post(x)
155
+ x = self.output_proj(x)
156
+ return x
157
+
158
+ def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
159
+ return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint)
160
+
161
+
162
+ class ShapeAsLatentPerceiver(ShapeAsLatentModule):
163
+ def __init__(self, *,
164
+ device: Optional[torch.device],
165
+ dtype: Optional[torch.dtype],
166
+ num_latents: int,
167
+ point_feats: int = 0,
168
+ embed_dim: int = 0,
169
+ num_freqs: int = 8,
170
+ include_pi: bool = True,
171
+ width: int,
172
+ heads: int,
173
+ num_encoder_layers: int,
174
+ num_decoder_layers: int,
175
+ init_scale: float = 0.25,
176
+ qkv_bias: bool = True,
177
+ flash: bool = False,
178
+ use_ln_post: bool = False,
179
+ use_checkpoint: bool = False):
180
+
181
+ super().__init__()
182
+
183
+ self.use_checkpoint = use_checkpoint
184
+
185
+ self.num_latents = num_latents
186
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
187
+
188
+ init_scale = init_scale * math.sqrt(1.0 / width)
189
+ self.encoder = CrossAttentionEncoder(
190
+ device=device,
191
+ dtype=dtype,
192
+ fourier_embedder=self.fourier_embedder,
193
+ num_latents=num_latents,
194
+ point_feats=point_feats,
195
+ width=width,
196
+ heads=heads,
197
+ layers=num_encoder_layers,
198
+ init_scale=init_scale,
199
+ qkv_bias=qkv_bias,
200
+ flash=flash,
201
+ use_ln_post=use_ln_post,
202
+ use_checkpoint=use_checkpoint
203
+ )
204
+
205
+ self.embed_dim = embed_dim
206
+ if embed_dim > 0:
207
+ # VAE embed
208
+ self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype)
209
+ self.post_kl = nn.Linear(embed_dim, width, device=device, dtype=dtype)
210
+ self.latent_shape = (num_latents, embed_dim)
211
+ else:
212
+ self.latent_shape = (num_latents, width)
213
+
214
+ self.transformer = Transformer(
215
+ device=device,
216
+ dtype=dtype,
217
+ n_ctx=num_latents,
218
+ width=width,
219
+ layers=num_decoder_layers,
220
+ heads=heads,
221
+ init_scale=init_scale,
222
+ qkv_bias=qkv_bias,
223
+ flash=flash,
224
+ use_checkpoint=use_checkpoint
225
+ )
226
+
227
+ # geometry decoder
228
+ self.geo_decoder = CrossAttentionDecoder(
229
+ device=device,
230
+ dtype=dtype,
231
+ fourier_embedder=self.fourier_embedder,
232
+ out_channels=1,
233
+ num_latents=num_latents,
234
+ width=width,
235
+ heads=heads,
236
+ init_scale=init_scale,
237
+ qkv_bias=qkv_bias,
238
+ flash=flash,
239
+ use_checkpoint=use_checkpoint
240
+ )
241
+
242
+ def encode(self,
243
+ pc: torch.FloatTensor,
244
+ feats: Optional[torch.FloatTensor] = None,
245
+ sample_posterior: bool = True):
246
+ """
247
+
248
+ Args:
249
+ pc (torch.FloatTensor): [B, N, 3]
250
+ feats (torch.FloatTensor or None): [B, N, C]
251
+ sample_posterior (bool):
252
+
253
+ Returns:
254
+ latents (torch.FloatTensor)
255
+ center_pos (torch.FloatTensor or None):
256
+ posterior (DiagonalGaussianDistribution or None):
257
+ """
258
+
259
+ latents, center_pos = self.encoder(pc, feats)
260
+
261
+ posterior = None
262
+ if self.embed_dim > 0:
263
+ moments = self.pre_kl(latents)
264
+ posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
265
+
266
+ if sample_posterior:
267
+ latents = posterior.sample()
268
+ else:
269
+ latents = posterior.mode()
270
+
271
+ return latents, center_pos, posterior
272
+
273
+ def decode(self, latents: torch.FloatTensor):
274
+ latents = self.post_kl(latents)
275
+ return self.transformer(latents)
276
+
277
+ def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
278
+ logits = self.geo_decoder(queries, latents).squeeze(-1)
279
+ return logits
280
+
281
+ def forward(self,
282
+ pc: torch.FloatTensor,
283
+ feats: torch.FloatTensor,
284
+ volume_queries: torch.FloatTensor,
285
+ sample_posterior: bool = True):
286
+ """
287
+
288
+ Args:
289
+ pc (torch.FloatTensor): [B, N, 3]
290
+ feats (torch.FloatTensor or None): [B, N, C]
291
+ volume_queries (torch.FloatTensor): [B, P, 3]
292
+ sample_posterior (bool):
293
+
294
+ Returns:
295
+ logits (torch.FloatTensor): [B, P]
296
+ center_pos (torch.FloatTensor): [B, M, 3]
297
+ posterior (DiagonalGaussianDistribution or None).
298
+
299
+ """
300
+
301
+ latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
302
+
303
+ latents = self.decode(latents)
304
+ logits = self.query_geometry(volume_queries, latents)
305
+
306
+ return logits, center_pos, posterior
307
+
308
+
309
+ class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver):
310
+
311
+ def __init__(self, *,
312
+ device: Optional[torch.device],
313
+ dtype: Optional[torch.dtype],
314
+ num_latents: int,
315
+ point_feats: int = 0,
316
+ embed_dim: int = 0,
317
+ num_freqs: int = 8,
318
+ include_pi: bool = True,
319
+ width: int,
320
+ heads: int,
321
+ num_encoder_layers: int,
322
+ num_decoder_layers: int,
323
+ init_scale: float = 0.25,
324
+ qkv_bias: bool = True,
325
+ flash: bool = False,
326
+ use_ln_post: bool = False,
327
+ use_checkpoint: bool = False):
328
+
329
+ super().__init__(
330
+ device=device,
331
+ dtype=dtype,
332
+ num_latents=1 + num_latents,
333
+ point_feats=point_feats,
334
+ embed_dim=embed_dim,
335
+ num_freqs=num_freqs,
336
+ include_pi=include_pi,
337
+ width=width,
338
+ heads=heads,
339
+ num_encoder_layers=num_encoder_layers,
340
+ num_decoder_layers=num_decoder_layers,
341
+ init_scale=init_scale,
342
+ qkv_bias=qkv_bias,
343
+ flash=flash,
344
+ use_ln_post=use_ln_post,
345
+ use_checkpoint=use_checkpoint
346
+ )
347
+
348
+ self.width = width
349
+
350
+ def encode(self,
351
+ pc: torch.FloatTensor,
352
+ feats: Optional[torch.FloatTensor] = None,
353
+ sample_posterior: bool = True):
354
+ """
355
+
356
+ Args:
357
+ pc (torch.FloatTensor): [B, N, 3]
358
+ feats (torch.FloatTensor or None): [B, N, c]
359
+ sample_posterior (bool):
360
+
361
+ Returns:
362
+ shape_embed (torch.FloatTensor)
363
+ kl_embed (torch.FloatTensor):
364
+ posterior (DiagonalGaussianDistribution or None):
365
+ """
366
+
367
+ shape_embed, latents = self.encode_latents(pc, feats)
368
+ kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior)
369
+
370
+ return shape_embed, kl_embed, posterior
371
+
372
+ def encode_latents(self,
373
+ pc: torch.FloatTensor,
374
+ feats: Optional[torch.FloatTensor] = None):
375
+
376
+ x, _ = self.encoder(pc, feats)
377
+
378
+ shape_embed = x[:, 0]
379
+ latents = x[:, 1:]
380
+
381
+ return shape_embed, latents
382
+
383
+ def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True):
384
+ posterior = None
385
+ if self.embed_dim > 0:
386
+ moments = self.pre_kl(latents)
387
+ posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
388
+
389
+ if sample_posterior:
390
+ kl_embed = posterior.sample()
391
+ else:
392
+ kl_embed = posterior.mode()
393
+ else:
394
+ kl_embed = latents
395
+
396
+ return kl_embed, posterior
397
+
398
+ def forward(self,
399
+ pc: torch.FloatTensor,
400
+ feats: torch.FloatTensor,
401
+ volume_queries: torch.FloatTensor,
402
+ sample_posterior: bool = True):
403
+ """
404
+
405
+ Args:
406
+ pc (torch.FloatTensor): [B, N, 3]
407
+ feats (torch.FloatTensor or None): [B, N, C]
408
+ volume_queries (torch.FloatTensor): [B, P, 3]
409
+ sample_posterior (bool):
410
+
411
+ Returns:
412
+ shape_embed (torch.FloatTensor): [B, projection_dim]
413
+ logits (torch.FloatTensor): [B, M]
414
+ posterior (DiagonalGaussianDistribution or None).
415
+
416
+ """
417
+
418
+ shape_embed, kl_embed, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
419
+
420
+ latents = self.decode(kl_embed)
421
+ logits = self.query_geometry(volume_queries, latents)
422
+
423
+ return shape_embed, logits, posterior
MeshAnything/miche/michelangelo/models/tsal/sal_pl_module.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import List, Tuple, Dict, Optional
4
+ from omegaconf import DictConfig
5
+
6
+ import torch
7
+ from torch.optim import lr_scheduler
8
+ import pytorch_lightning as pl
9
+ from typing import Union
10
+ from functools import partial
11
+
12
+ from MeshAnything.miche.michelangelo.utils import instantiate_from_config
13
+
14
+ from .inference_utils import extract_geometry
15
+ from .tsal_base import (
16
+ ShapeAsLatentModule,
17
+ Latent2MeshOutput,
18
+ Point2MeshOutput
19
+ )
20
+
21
+
22
+ class ShapeAsLatentPLModule(pl.LightningModule):
23
+
24
+ def __init__(self, *,
25
+ module_cfg,
26
+ loss_cfg,
27
+ optimizer_cfg: Optional[DictConfig] = None,
28
+ ckpt_path: Optional[str] = None,
29
+ ignore_keys: Union[Tuple[str], List[str]] = ()):
30
+
31
+ super().__init__()
32
+
33
+ self.sal: ShapeAsLatentModule = instantiate_from_config(module_cfg, device=None, dtype=None)
34
+
35
+ self.loss = instantiate_from_config(loss_cfg)
36
+
37
+ self.optimizer_cfg = optimizer_cfg
38
+
39
+ if ckpt_path is not None:
40
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
41
+
42
+ self.save_hyperparameters()
43
+
44
+ @property
45
+ def latent_shape(self):
46
+ return self.sal.latent_shape
47
+
48
+ @property
49
+ def zero_rank(self):
50
+ if self._trainer:
51
+ zero_rank = self.trainer.local_rank == 0
52
+ else:
53
+ zero_rank = True
54
+
55
+ return zero_rank
56
+
57
+ def init_from_ckpt(self, path, ignore_keys=()):
58
+ state_dict = torch.load(path, map_location="cpu")["state_dict"]
59
+
60
+ keys = list(state_dict.keys())
61
+ for k in keys:
62
+ for ik in ignore_keys:
63
+ if k.startswith(ik):
64
+ print("Deleting key {} from state_dict.".format(k))
65
+ del state_dict[k]
66
+
67
+ missing, unexpected = self.load_state_dict(state_dict, strict=False)
68
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
69
+ if len(missing) > 0:
70
+ print(f"Missing Keys: {missing}")
71
+ print(f"Unexpected Keys: {unexpected}")
72
+
73
+ def configure_optimizers(self) -> Tuple[List, List]:
74
+ lr = self.learning_rate
75
+
76
+ # optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-4)]
77
+ # optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
78
+
79
+ if self.optimizer_cfg is None:
80
+ optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
81
+ schedulers = []
82
+ else:
83
+ optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=self.sal.parameters())
84
+ scheduler_func = instantiate_from_config(
85
+ self.optimizer_cfg.scheduler,
86
+ max_decay_steps=self.trainer.max_steps,
87
+ lr_max=lr
88
+ )
89
+ scheduler = {
90
+ "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
91
+ "interval": "step",
92
+ "frequency": 1
93
+ }
94
+ optimizers = [optimizer]
95
+ schedulers = [scheduler]
96
+
97
+ return optimizers, schedulers
98
+
99
+ def forward(self,
100
+ pc: torch.FloatTensor,
101
+ feats: torch.FloatTensor,
102
+ volume_queries: torch.FloatTensor):
103
+
104
+ logits, center_pos, posterior = self.sal(pc, feats, volume_queries)
105
+
106
+ return posterior, logits
107
+
108
+ def encode(self, surface: torch.FloatTensor, sample_posterior=True):
109
+
110
+ pc = surface[..., 0:3]
111
+ feats = surface[..., 3:6]
112
+
113
+ latents, center_pos, posterior = self.sal.encode(
114
+ pc=pc, feats=feats, sample_posterior=sample_posterior
115
+ )
116
+
117
+ return latents
118
+
119
+ def decode(self,
120
+ z_q,
121
+ bounds: Union[Tuple[float], List[float], float] = 1.1,
122
+ octree_depth: int = 7,
123
+ num_chunks: int = 10000) -> List[Latent2MeshOutput]:
124
+
125
+ latents = self.sal.decode(z_q) # latents: [bs, num_latents, dim]
126
+ outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks)
127
+
128
+ return outputs
129
+
130
+ def training_step(self, batch: Dict[str, torch.FloatTensor],
131
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
132
+ """
133
+
134
+ Args:
135
+ batch (dict): the batch sample, and it contains:
136
+ - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)]
137
+ - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)]
138
+
139
+ batch_idx (int):
140
+
141
+ optimizer_idx (int):
142
+
143
+ Returns:
144
+ loss (torch.FloatTensor):
145
+
146
+ """
147
+
148
+ pc = batch["surface"][..., 0:3]
149
+ feats = batch["surface"][..., 3:]
150
+
151
+ volume_queries = batch["geo_points"][..., 0:3]
152
+ volume_labels = batch["geo_points"][..., -1]
153
+
154
+ posterior, logits = self(
155
+ pc=pc, feats=feats, volume_queries=volume_queries
156
+ )
157
+ aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="train")
158
+
159
+ self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0],
160
+ sync_dist=False, rank_zero_only=True)
161
+
162
+ return aeloss
163
+
164
+ def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor:
165
+
166
+ pc = batch["surface"][..., 0:3]
167
+ feats = batch["surface"][..., 3:]
168
+
169
+ volume_queries = batch["geo_points"][..., 0:3]
170
+ volume_labels = batch["geo_points"][..., -1]
171
+
172
+ posterior, logits = self(
173
+ pc=pc, feats=feats, volume_queries=volume_queries,
174
+ )
175
+ aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="val")
176
+
177
+ self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0],
178
+ sync_dist=False, rank_zero_only=True)
179
+
180
+ return aeloss
181
+
182
+ def point2mesh(self,
183
+ pc: torch.FloatTensor,
184
+ feats: torch.FloatTensor,
185
+ bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
186
+ octree_depth: int = 7,
187
+ num_chunks: int = 10000) -> List[Point2MeshOutput]:
188
+
189
+ """
190
+
191
+ Args:
192
+ pc:
193
+ feats:
194
+ bounds:
195
+ octree_depth:
196
+ num_chunks:
197
+
198
+ Returns:
199
+ mesh_outputs (List[MeshOutput]): the mesh outputs list.
200
+
201
+ """
202
+
203
+ outputs = []
204
+
205
+ device = pc.device
206
+ bs = pc.shape[0]
207
+
208
+ # 1. point encoder + latents transformer
209
+ latents, center_pos, posterior = self.sal.encode(pc, feats)
210
+ latents = self.sal.decode(latents) # latents: [bs, num_latents, dim]
211
+
212
+ geometric_func = partial(self.sal.query_geometry, latents=latents)
213
+
214
+ # 2. decode geometry
215
+ mesh_v_f, has_surface = extract_geometry(
216
+ geometric_func=geometric_func,
217
+ device=device,
218
+ batch_size=bs,
219
+ bounds=bounds,
220
+ octree_depth=octree_depth,
221
+ num_chunks=num_chunks,
222
+ disable=not self.zero_rank
223
+ )
224
+
225
+ # 3. decode texture
226
+ for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
227
+ if not is_surface:
228
+ outputs.append(None)
229
+ continue
230
+
231
+ out = Point2MeshOutput()
232
+ out.mesh_v = mesh_v
233
+ out.mesh_f = mesh_f
234
+ out.pc = torch.cat([pc[i], feats[i]], dim=-1).cpu().numpy()
235
+
236
+ if center_pos is not None:
237
+ out.center = center_pos[i].cpu().numpy()
238
+
239
+ outputs.append(out)
240
+
241
+ return outputs
242
+
243
+ def latent2mesh(self,
244
+ latents: torch.FloatTensor,
245
+ bounds: Union[Tuple[float], List[float], float] = 1.1,
246
+ octree_depth: int = 7,
247
+ num_chunks: int = 10000) -> List[Latent2MeshOutput]:
248
+
249
+ """
250
+
251
+ Args:
252
+ latents: [bs, num_latents, dim]
253
+ bounds:
254
+ octree_depth:
255
+ num_chunks:
256
+
257
+ Returns:
258
+ mesh_outputs (List[MeshOutput]): the mesh outputs list.
259
+
260
+ """
261
+
262
+ outputs = []
263
+
264
+ geometric_func = partial(self.sal.query_geometry, latents=latents)
265
+
266
+ # 2. decode geometry
267
+ device = latents.device
268
+ mesh_v_f, has_surface = extract_geometry(
269
+ geometric_func=geometric_func,
270
+ device=device,
271
+ batch_size=len(latents),
272
+ bounds=bounds,
273
+ octree_depth=octree_depth,
274
+ num_chunks=num_chunks,
275
+ disable=not self.zero_rank
276
+ )
277
+
278
+ # 3. decode texture
279
+ for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
280
+ if not is_surface:
281
+ outputs.append(None)
282
+ continue
283
+
284
+ out = Latent2MeshOutput()
285
+ out.mesh_v = mesh_v
286
+ out.mesh_f = mesh_f
287
+
288
+ outputs.append(out)
289
+
290
+ return outputs
MeshAnything/miche/michelangelo/models/tsal/tsal_base.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch.nn as nn
4
+ from typing import Tuple, List, Optional
5
+
6
+
7
+ class Point2MeshOutput(object):
8
+ def __init__(self):
9
+ self.mesh_v = None
10
+ self.mesh_f = None
11
+ self.center = None
12
+ self.pc = None
13
+
14
+
15
+ class Latent2MeshOutput(object):
16
+
17
+ def __init__(self):
18
+ self.mesh_v = None
19
+ self.mesh_f = None
20
+
21
+
22
+ class AlignedMeshOutput(object):
23
+
24
+ def __init__(self):
25
+ self.mesh_v = None
26
+ self.mesh_f = None
27
+ self.surface = None
28
+ self.image = None
29
+ self.text: Optional[str] = None
30
+ self.shape_text_similarity: Optional[float] = None
31
+ self.shape_image_similarity: Optional[float] = None
32
+
33
+
34
+ class ShapeAsLatentPLModule(nn.Module):
35
+ latent_shape: Tuple[int]
36
+
37
+ def encode(self, surface, *args, **kwargs):
38
+ raise NotImplementedError
39
+
40
+ def decode(self, z_q, *args, **kwargs):
41
+ raise NotImplementedError
42
+
43
+ def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
44
+ raise NotImplementedError
45
+
46
+ def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
47
+ raise NotImplementedError
48
+
49
+
50
+ class ShapeAsLatentModule(nn.Module):
51
+ latent_shape: Tuple[int, int]
52
+
53
+ def __init__(self, *args, **kwargs):
54
+ super().__init__()
55
+
56
+ def encode(self, *args, **kwargs):
57
+ raise NotImplementedError
58
+
59
+ def decode(self, *args, **kwargs):
60
+ raise NotImplementedError
61
+
62
+ def query_geometry(self, *args, **kwargs):
63
+ raise NotImplementedError
64
+
65
+
66
+ class AlignedShapeAsLatentPLModule(nn.Module):
67
+ latent_shape: Tuple[int]
68
+
69
+ def set_shape_model_only(self):
70
+ raise NotImplementedError
71
+
72
+ def encode(self, surface, *args, **kwargs):
73
+ raise NotImplementedError
74
+
75
+ def decode(self, z_q, *args, **kwargs):
76
+ raise NotImplementedError
77
+
78
+ def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
79
+ raise NotImplementedError
80
+
81
+ def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
82
+ raise NotImplementedError
83
+
84
+
85
+ class AlignedShapeAsLatentModule(nn.Module):
86
+ shape_model: ShapeAsLatentModule
87
+ latent_shape: Tuple[int, int]
88
+
89
+ def __init__(self, *args, **kwargs):
90
+ super().__init__()
91
+
92
+ def set_shape_model_only(self):
93
+ raise NotImplementedError
94
+
95
+ def encode_image_embed(self, *args, **kwargs):
96
+ raise NotImplementedError
97
+
98
+ def encode_text_embed(self, *args, **kwargs):
99
+ raise NotImplementedError
100
+
101
+ def encode_shape_embed(self, *args, **kwargs):
102
+ raise NotImplementedError
103
+
104
+
105
+ class TexturedShapeAsLatentModule(nn.Module):
106
+
107
+ def __init__(self, *args, **kwargs):
108
+ super().__init__()
109
+
110
+ def encode(self, *args, **kwargs):
111
+ raise NotImplementedError
112
+
113
+ def decode(self, *args, **kwargs):
114
+ raise NotImplementedError
115
+
116
+ def query_geometry(self, *args, **kwargs):
117
+ raise NotImplementedError
118
+
119
+ def query_color(self, *args, **kwargs):
120
+ raise NotImplementedError
MeshAnything/miche/michelangelo/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .misc import instantiate_from_config
MeshAnything/miche/michelangelo/utils/eval.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+
5
+
6
+ def compute_psnr(x, y, data_range: float = 2, eps: float = 1e-7):
7
+
8
+ mse = torch.mean((x - y) ** 2)
9
+ psnr = 10 * torch.log10(data_range / (mse + eps))
10
+
11
+ return psnr
12
+
MeshAnything/miche/michelangelo/utils/io.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import io
5
+ import tarfile
6
+ import json
7
+ import numpy as np
8
+ import numpy.lib.format
9
+
10
+
11
+ def mkdir(path):
12
+ os.makedirs(path, exist_ok=True)
13
+ return path
14
+
15
+
16
+ def npy_loads(data):
17
+ stream = io.BytesIO(data)
18
+ return np.lib.format.read_array(stream)
19
+
20
+
21
+ def npz_loads(data):
22
+ return np.load(io.BytesIO(data))
23
+
24
+
25
+ def json_loads(data):
26
+ return json.loads(data)
27
+
28
+
29
+ def load_json(filepath):
30
+ with open(filepath, "r") as f:
31
+ data = json.load(f)
32
+ return data
33
+
34
+
35
+ def write_json(filepath, data):
36
+ with open(filepath, "w") as f:
37
+ json.dump(data, f, indent=2)
38
+
39
+
40
+ def extract_tar(tar_path, tar_cache_folder):
41
+
42
+ with tarfile.open(tar_path, "r") as tar:
43
+ tar.extractall(path=tar_cache_folder)
44
+
45
+ tar_uids = sorted(os.listdir(tar_cache_folder))
46
+ print(f"extract tar: {tar_path} to {tar_cache_folder}")
47
+ return tar_uids
MeshAnything/miche/michelangelo/utils/misc.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import importlib
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+
9
+
10
+ def get_obj_from_str(string, reload=False):
11
+ module, cls = string.rsplit(".", 1)
12
+ if reload:
13
+ module_imp = importlib.import_module(module)
14
+ importlib.reload(module_imp)
15
+ return getattr(importlib.import_module(module, package=None), cls)
16
+
17
+
18
+ def get_obj_from_config(config):
19
+ if "target" not in config:
20
+ raise KeyError("Expected key `target` to instantiate.")
21
+
22
+ return get_obj_from_str(config["target"])
23
+
24
+
25
+ def instantiate_from_config(config, **kwargs):
26
+ if "target" not in config:
27
+ raise KeyError("Expected key `target` to instantiate.")
28
+
29
+ cls = get_obj_from_str(config["target"])
30
+
31
+ params = config.get("params", dict())
32
+ # params.update(kwargs)
33
+ # instance = cls(**params)
34
+ kwargs.update(params)
35
+ instance = cls(**kwargs)
36
+
37
+ return instance
38
+
39
+
40
+ def is_dist_avail_and_initialized():
41
+ if not dist.is_available():
42
+ return False
43
+ if not dist.is_initialized():
44
+ return False
45
+ return True
46
+
47
+
48
+ def get_rank():
49
+ if not is_dist_avail_and_initialized():
50
+ return 0
51
+ return dist.get_rank()
52
+
53
+
54
+ def get_world_size():
55
+ if not is_dist_avail_and_initialized():
56
+ return 1
57
+ return dist.get_world_size()
58
+
59
+
60
+ def all_gather_batch(tensors):
61
+ """
62
+ Performs all_gather operation on the provided tensors.
63
+ """
64
+ # Queue the gathered tensors
65
+ world_size = get_world_size()
66
+ # There is no need for reduction in the single-proc case
67
+ if world_size == 1:
68
+ return tensors
69
+ tensor_list = []
70
+ output_tensor = []
71
+ for tensor in tensors:
72
+ tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
73
+ dist.all_gather(
74
+ tensor_all,
75
+ tensor,
76
+ async_op=False # performance opt
77
+ )
78
+
79
+ tensor_list.append(tensor_all)
80
+
81
+ for tensor_all in tensor_list:
82
+ output_tensor.append(torch.cat(tensor_all, dim=0))
83
+ return output_tensor
MeshAnything/miche/michelangelo/utils/visualizers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
MeshAnything/miche/michelangelo/utils/visualizers/color_util.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+
4
+
5
+ # Helper functions
6
+ def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None):
7
+ colormap = plt.cm.get_cmap(colormap)
8
+ if normalize:
9
+ vmin = np.min(inp)
10
+ vmax = np.max(inp)
11
+
12
+ norm = plt.Normalize(vmin, vmax)
13
+ return colormap(norm(inp))[:, :3]
14
+
15
+
16
+ def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256):
17
+ # tex dims need to be power of two.
18
+ array = np.ones((width, height, 3), dtype='float32')
19
+
20
+ # width in texels of each checker
21
+ checker_w = width / n_checkers_x
22
+ checker_h = height / n_checkers_y
23
+
24
+ for y in range(height):
25
+ for x in range(width):
26
+ color_key = int(x / checker_w) + int(y / checker_h)
27
+ if color_key % 2 == 0:
28
+ array[x, y, :] = [1., 0.874, 0.0]
29
+ else:
30
+ array[x, y, :] = [0., 0., 0.]
31
+ return array
32
+
33
+
34
+ def gen_circle(width=256, height=256):
35
+ xx, yy = np.mgrid[:width, :height]
36
+ circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2
37
+ array = np.ones((width, height, 4), dtype='float32')
38
+ array[:, :, 0] = (circle <= width)
39
+ array[:, :, 1] = (circle <= width)
40
+ array[:, :, 2] = (circle <= width)
41
+ array[:, :, 3] = circle <= width
42
+ return array
43
+
MeshAnything/miche/michelangelo/utils/visualizers/html_util.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import io
3
+ import base64
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+
8
+ def to_html_frame(content):
9
+
10
+ html_frame = f"""
11
+ <html>
12
+ <body>
13
+ {content}
14
+ </body>
15
+ </html>
16
+ """
17
+
18
+ return html_frame
19
+
20
+
21
+ def to_single_row_table(caption: str, content: str):
22
+
23
+ table_html = f"""
24
+ <table border = "1">
25
+ <caption>{caption}</caption>
26
+ <tr>
27
+ <td>{content}</td>
28
+ </tr>
29
+ </table>
30
+ """
31
+
32
+ return table_html
33
+
34
+
35
+ def to_image_embed_tag(image: np.ndarray):
36
+
37
+ # Convert np.ndarray to bytes
38
+ img = Image.fromarray(image)
39
+ raw_bytes = io.BytesIO()
40
+ img.save(raw_bytes, "PNG")
41
+
42
+ # Encode bytes to base64
43
+ image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8")
44
+
45
+ image_tag = f"""
46
+ <img src="data:image/png;base64,{image_base64}" alt="Embedded Image">
47
+ """
48
+
49
+ return image_tag
MeshAnything/miche/michelangelo/utils/visualizers/pythreejs_viewer.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from ipywidgets import embed
3
+ import pythreejs as p3s
4
+ import uuid
5
+
6
+ from .color_util import get_colors, gen_circle, gen_checkers
7
+
8
+
9
+ EMBED_URL = "https://cdn.jsdelivr.net/npm/@jupyter-widgets/[email protected]/dist/embed-amd.js"
10
+
11
+
12
+ class PyThreeJSViewer(object):
13
+
14
+ def __init__(self, settings, render_mode="WEBSITE"):
15
+ self.render_mode = render_mode
16
+ self.__update_settings(settings)
17
+ self._light = p3s.DirectionalLight(color='white', position=[0, 0, 1], intensity=0.6)
18
+ self._light2 = p3s.AmbientLight(intensity=0.5)
19
+ self._cam = p3s.PerspectiveCamera(position=[0, 0, 1], lookAt=[0, 0, 0], fov=self.__s["fov"],
20
+ aspect=self.__s["width"] / self.__s["height"], children=[self._light])
21
+ self._orbit = p3s.OrbitControls(controlling=self._cam)
22
+ self._scene = p3s.Scene(children=[self._cam, self._light2], background=self.__s["background"]) # "#4c4c80"
23
+ self._renderer = p3s.Renderer(camera=self._cam, scene=self._scene, controls=[self._orbit],
24
+ width=self.__s["width"], height=self.__s["height"],
25
+ antialias=self.__s["antialias"])
26
+
27
+ self.__objects = {}
28
+ self.__cnt = 0
29
+
30
+ def jupyter_mode(self):
31
+ self.render_mode = "JUPYTER"
32
+
33
+ def offline(self):
34
+ self.render_mode = "OFFLINE"
35
+
36
+ def website(self):
37
+ self.render_mode = "WEBSITE"
38
+
39
+ def __get_shading(self, shading):
40
+ shad = {"flat": True, "wireframe": False, "wire_width": 0.03, "wire_color": "black",
41
+ "side": 'DoubleSide', "colormap": "viridis", "normalize": [None, None],
42
+ "bbox": False, "roughness": 0.5, "metalness": 0.25, "reflectivity": 1.0,
43
+ "line_width": 1.0, "line_color": "black",
44
+ "point_color": "red", "point_size": 0.01, "point_shape": "circle",
45
+ "text_color": "red"
46
+ }
47
+ for k in shading:
48
+ shad[k] = shading[k]
49
+ return shad
50
+
51
+ def __update_settings(self, settings={}):
52
+ sett = {"width": 600, "height": 600, "antialias": True, "scale": 1.5, "background": "#ffffff",
53
+ "fov": 30}
54
+ for k in settings:
55
+ sett[k] = settings[k]
56
+ self.__s = sett
57
+
58
+ def __add_object(self, obj, parent=None):
59
+ if not parent: # Object is added to global scene and objects dict
60
+ self.__objects[self.__cnt] = obj
61
+ self.__cnt += 1
62
+ self._scene.add(obj["mesh"])
63
+ else: # Object is added to parent object and NOT to objects dict
64
+ parent.add(obj["mesh"])
65
+
66
+ self.__update_view()
67
+
68
+ if self.render_mode == "JUPYTER":
69
+ return self.__cnt - 1
70
+ elif self.render_mode == "WEBSITE":
71
+ return self
72
+
73
+ def __add_line_geometry(self, lines, shading, obj=None):
74
+ lines = lines.astype("float32", copy=False)
75
+ mi = np.min(lines, axis=0)
76
+ ma = np.max(lines, axis=0)
77
+
78
+ geometry = p3s.LineSegmentsGeometry(positions=lines.reshape((-1, 2, 3)))
79
+ material = p3s.LineMaterial(linewidth=shading["line_width"], color=shading["line_color"])
80
+ # , vertexColors='VertexColors'),
81
+ lines = p3s.LineSegments2(geometry=geometry, material=material) # type='LinePieces')
82
+ line_obj = {"geometry": geometry, "mesh": lines, "material": material,
83
+ "max": ma, "min": mi, "type": "Lines", "wireframe": None}
84
+
85
+ if obj:
86
+ return self.__add_object(line_obj, obj), line_obj
87
+ else:
88
+ return self.__add_object(line_obj)
89
+
90
+ def __update_view(self):
91
+ if len(self.__objects) == 0:
92
+ return
93
+ ma = np.zeros((len(self.__objects), 3))
94
+ mi = np.zeros((len(self.__objects), 3))
95
+ for r, obj in enumerate(self.__objects):
96
+ ma[r] = self.__objects[obj]["max"]
97
+ mi[r] = self.__objects[obj]["min"]
98
+ ma = np.max(ma, axis=0)
99
+ mi = np.min(mi, axis=0)
100
+ diag = np.linalg.norm(ma - mi)
101
+ mean = ((ma - mi) / 2 + mi).tolist()
102
+ scale = self.__s["scale"] * (diag)
103
+ self._orbit.target = mean
104
+ self._cam.lookAt(mean)
105
+ self._cam.position = [mean[0], mean[1], mean[2] + scale]
106
+ self._light.position = [mean[0], mean[1], mean[2] + scale]
107
+
108
+ self._orbit.exec_three_obj_method('update')
109
+ self._cam.exec_three_obj_method('updateProjectionMatrix')
110
+
111
+ def __get_bbox(self, v):
112
+ m = np.min(v, axis=0)
113
+ M = np.max(v, axis=0)
114
+
115
+ # Corners of the bounding box
116
+ v_box = np.array([[m[0], m[1], m[2]], [M[0], m[1], m[2]], [M[0], M[1], m[2]], [m[0], M[1], m[2]],
117
+ [m[0], m[1], M[2]], [M[0], m[1], M[2]], [M[0], M[1], M[2]], [m[0], M[1], M[2]]])
118
+
119
+ f_box = np.array([[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4],
120
+ [0, 4], [1, 5], [2, 6], [7, 3]], dtype=np.uint32)
121
+ return v_box, f_box
122
+
123
+ def __get_colors(self, v, f, c, sh):
124
+ coloring = "VertexColors"
125
+ if type(c) == np.ndarray and c.size == 3: # Single color
126
+ colors = np.ones_like(v)
127
+ colors[:, 0] = c[0]
128
+ colors[:, 1] = c[1]
129
+ colors[:, 2] = c[2]
130
+ # print("Single colors")
131
+ elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[1] == 3: # Color values for
132
+ if c.shape[0] == f.shape[0]: # faces
133
+ colors = np.hstack([c, c, c]).reshape((-1, 3))
134
+ coloring = "FaceColors"
135
+ # print("Face color values")
136
+ elif c.shape[0] == v.shape[0]: # vertices
137
+ colors = c
138
+ # print("Vertex color values")
139
+ else: # Wrong size, fallback
140
+ print("Invalid color array given! Supported are numpy arrays.", type(c))
141
+ colors = np.ones_like(v)
142
+ colors[:, 0] = 1.0
143
+ colors[:, 1] = 0.874
144
+ colors[:, 2] = 0.0
145
+ elif type(c) == np.ndarray and c.size == f.shape[0]: # Function values for faces
146
+ normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
147
+ cc = get_colors(c, sh["colormap"], normalize=normalize,
148
+ vmin=sh["normalize"][0], vmax=sh["normalize"][1])
149
+ # print(cc.shape)
150
+ colors = np.hstack([cc, cc, cc]).reshape((-1, 3))
151
+ coloring = "FaceColors"
152
+ # print("Face function values")
153
+ elif type(c) == np.ndarray and c.size == v.shape[0]: # Function values for vertices
154
+ normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
155
+ colors = get_colors(c, sh["colormap"], normalize=normalize,
156
+ vmin=sh["normalize"][0], vmax=sh["normalize"][1])
157
+ # print("Vertex function values")
158
+
159
+ else:
160
+ colors = np.ones_like(v)
161
+ colors[:, 0] = 1.0
162
+ colors[:, 1] = 0.874
163
+ colors[:, 2] = 0.0
164
+
165
+ # No color
166
+ if c is not None:
167
+ print("Invalid color array given! Supported are numpy arrays.", type(c))
168
+
169
+ return colors, coloring
170
+
171
+ def __get_point_colors(self, v, c, sh):
172
+ v_color = True
173
+ if c is None: # No color given, use global color
174
+ # conv = mpl.colors.ColorConverter()
175
+ colors = sh["point_color"] # np.array(conv.to_rgb(sh["point_color"]))
176
+ v_color = False
177
+ elif isinstance(c, str): # No color given, use global color
178
+ # conv = mpl.colors.ColorConverter()
179
+ colors = c # np.array(conv.to_rgb(c))
180
+ v_color = False
181
+ elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] == 3:
182
+ # Point color
183
+ colors = c.astype("float32", copy=False)
184
+
185
+ elif isinstance(c, np.ndarray) and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] != 3:
186
+ # Function values for vertices, but the colors are features
187
+ c_norm = np.linalg.norm(c, ord=2, axis=-1)
188
+ normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
189
+ colors = get_colors(c_norm, sh["colormap"], normalize=normalize,
190
+ vmin=sh["normalize"][0], vmax=sh["normalize"][1])
191
+ colors = colors.astype("float32", copy=False)
192
+
193
+ elif type(c) == np.ndarray and c.size == v.shape[0]: # Function color
194
+ normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
195
+ colors = get_colors(c, sh["colormap"], normalize=normalize,
196
+ vmin=sh["normalize"][0], vmax=sh["normalize"][1])
197
+ colors = colors.astype("float32", copy=False)
198
+ # print("Vertex function values")
199
+
200
+ else:
201
+ print("Invalid color array given! Supported are numpy arrays.", type(c))
202
+ colors = sh["point_color"]
203
+ v_color = False
204
+
205
+ return colors, v_color
206
+
207
+ def add_mesh(self, v, f, c=None, uv=None, n=None, shading={}, texture_data=None, **kwargs):
208
+ shading.update(kwargs)
209
+ sh = self.__get_shading(shading)
210
+ mesh_obj = {}
211
+
212
+ # it is a tet
213
+ if v.shape[1] == 3 and f.shape[1] == 4:
214
+ f_tmp = np.ndarray([f.shape[0] * 4, 3], dtype=f.dtype)
215
+ for i in range(f.shape[0]):
216
+ f_tmp[i * 4 + 0] = np.array([f[i][1], f[i][0], f[i][2]])
217
+ f_tmp[i * 4 + 1] = np.array([f[i][0], f[i][1], f[i][3]])
218
+ f_tmp[i * 4 + 2] = np.array([f[i][1], f[i][2], f[i][3]])
219
+ f_tmp[i * 4 + 3] = np.array([f[i][2], f[i][0], f[i][3]])
220
+ f = f_tmp
221
+
222
+ if v.shape[1] == 2:
223
+ v = np.append(v, np.zeros([v.shape[0], 1]), 1)
224
+
225
+ # Type adjustment vertices
226
+ v = v.astype("float32", copy=False)
227
+
228
+ # Color setup
229
+ colors, coloring = self.__get_colors(v, f, c, sh)
230
+
231
+ # Type adjustment faces and colors
232
+ c = colors.astype("float32", copy=False)
233
+
234
+ # Material and geometry setup
235
+ ba_dict = {"color": p3s.BufferAttribute(c)}
236
+ if coloring == "FaceColors":
237
+ verts = np.zeros((f.shape[0] * 3, 3), dtype="float32")
238
+ for ii in range(f.shape[0]):
239
+ # print(ii*3, f[ii])
240
+ verts[ii * 3] = v[f[ii, 0]]
241
+ verts[ii * 3 + 1] = v[f[ii, 1]]
242
+ verts[ii * 3 + 2] = v[f[ii, 2]]
243
+ v = verts
244
+ else:
245
+ f = f.astype("uint32", copy=False).ravel()
246
+ ba_dict["index"] = p3s.BufferAttribute(f, normalized=False)
247
+
248
+ ba_dict["position"] = p3s.BufferAttribute(v, normalized=False)
249
+
250
+ if uv is not None:
251
+ uv = (uv - np.min(uv)) / (np.max(uv) - np.min(uv))
252
+ if texture_data is None:
253
+ texture_data = gen_checkers(20, 20)
254
+ tex = p3s.DataTexture(data=texture_data, format="RGBFormat", type="FloatType")
255
+ material = p3s.MeshStandardMaterial(map=tex, reflectivity=sh["reflectivity"], side=sh["side"],
256
+ roughness=sh["roughness"], metalness=sh["metalness"],
257
+ flatShading=sh["flat"],
258
+ polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5)
259
+ ba_dict["uv"] = p3s.BufferAttribute(uv.astype("float32", copy=False))
260
+ else:
261
+ material = p3s.MeshStandardMaterial(vertexColors=coloring, reflectivity=sh["reflectivity"],
262
+ side=sh["side"], roughness=sh["roughness"], metalness=sh["metalness"],
263
+ flatShading=sh["flat"],
264
+ polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5)
265
+
266
+ if type(n) != type(None) and coloring == "VertexColors": # TODO: properly handle normals for FaceColors as well
267
+ ba_dict["normal"] = p3s.BufferAttribute(n.astype("float32", copy=False), normalized=True)
268
+
269
+ geometry = p3s.BufferGeometry(attributes=ba_dict)
270
+
271
+ if coloring == "VertexColors" and type(n) == type(None):
272
+ geometry.exec_three_obj_method('computeVertexNormals')
273
+ elif coloring == "FaceColors" and type(n) == type(None):
274
+ geometry.exec_three_obj_method('computeFaceNormals')
275
+
276
+ # Mesh setup
277
+ mesh = p3s.Mesh(geometry=geometry, material=material)
278
+
279
+ # Wireframe setup
280
+ mesh_obj["wireframe"] = None
281
+ if sh["wireframe"]:
282
+ wf_geometry = p3s.WireframeGeometry(mesh.geometry) # WireframeGeometry
283
+ wf_material = p3s.LineBasicMaterial(color=sh["wire_color"], linewidth=sh["wire_width"])
284
+ wireframe = p3s.LineSegments(wf_geometry, wf_material)
285
+ mesh.add(wireframe)
286
+ mesh_obj["wireframe"] = wireframe
287
+
288
+ # Bounding box setup
289
+ if sh["bbox"]:
290
+ v_box, f_box = self.__get_bbox(v)
291
+ _, bbox = self.add_edges(v_box, f_box, sh, mesh)
292
+ mesh_obj["bbox"] = [bbox, v_box, f_box]
293
+
294
+ # Object setup
295
+ mesh_obj["max"] = np.max(v, axis=0)
296
+ mesh_obj["min"] = np.min(v, axis=0)
297
+ mesh_obj["geometry"] = geometry
298
+ mesh_obj["mesh"] = mesh
299
+ mesh_obj["material"] = material
300
+ mesh_obj["type"] = "Mesh"
301
+ mesh_obj["shading"] = sh
302
+ mesh_obj["coloring"] = coloring
303
+ mesh_obj["arrays"] = [v, f, c] # TODO replays with proper storage or remove if not needed
304
+
305
+ return self.__add_object(mesh_obj)
306
+
307
+ def add_lines(self, beginning, ending, shading={}, obj=None, **kwargs):
308
+ shading.update(kwargs)
309
+ if len(beginning.shape) == 1:
310
+ if len(beginning) == 2:
311
+ beginning = np.array([[beginning[0], beginning[1], 0]])
312
+ else:
313
+ if beginning.shape[1] == 2:
314
+ beginning = np.append(
315
+ beginning, np.zeros([beginning.shape[0], 1]), 1)
316
+ if len(ending.shape) == 1:
317
+ if len(ending) == 2:
318
+ ending = np.array([[ending[0], ending[1], 0]])
319
+ else:
320
+ if ending.shape[1] == 2:
321
+ ending = np.append(
322
+ ending, np.zeros([ending.shape[0], 1]), 1)
323
+
324
+ sh = self.__get_shading(shading)
325
+ lines = np.hstack([beginning, ending])
326
+ lines = lines.reshape((-1, 3))
327
+ return self.__add_line_geometry(lines, sh, obj)
328
+
329
+ def add_edges(self, vertices, edges, shading={}, obj=None, **kwargs):
330
+ shading.update(kwargs)
331
+ if vertices.shape[1] == 2:
332
+ vertices = np.append(
333
+ vertices, np.zeros([vertices.shape[0], 1]), 1)
334
+ sh = self.__get_shading(shading)
335
+ lines = np.zeros((edges.size, 3))
336
+ cnt = 0
337
+ for e in edges:
338
+ lines[cnt, :] = vertices[e[0]]
339
+ lines[cnt + 1, :] = vertices[e[1]]
340
+ cnt += 2
341
+ return self.__add_line_geometry(lines, sh, obj)
342
+
343
+ def add_points(self, points, c=None, shading={}, obj=None, **kwargs):
344
+ shading.update(kwargs)
345
+ if len(points.shape) == 1:
346
+ if len(points) == 2:
347
+ points = np.array([[points[0], points[1], 0]])
348
+ else:
349
+ if points.shape[1] == 2:
350
+ points = np.append(
351
+ points, np.zeros([points.shape[0], 1]), 1)
352
+ sh = self.__get_shading(shading)
353
+ points = points.astype("float32", copy=False)
354
+ mi = np.min(points, axis=0)
355
+ ma = np.max(points, axis=0)
356
+
357
+ g_attributes = {"position": p3s.BufferAttribute(points, normalized=False)}
358
+ m_attributes = {"size": sh["point_size"]}
359
+
360
+ if sh["point_shape"] == "circle": # Plot circles
361
+ tex = p3s.DataTexture(data=gen_circle(16, 16), format="RGBAFormat", type="FloatType")
362
+ m_attributes["map"] = tex
363
+ m_attributes["alphaTest"] = 0.5
364
+ m_attributes["transparency"] = True
365
+ else: # Plot squares
366
+ pass
367
+
368
+ colors, v_colors = self.__get_point_colors(points, c, sh)
369
+ if v_colors: # Colors per point
370
+ m_attributes["vertexColors"] = 'VertexColors'
371
+ g_attributes["color"] = p3s.BufferAttribute(colors, normalized=False)
372
+
373
+ else: # Colors for all points
374
+ m_attributes["color"] = colors
375
+
376
+ material = p3s.PointsMaterial(**m_attributes)
377
+ geometry = p3s.BufferGeometry(attributes=g_attributes)
378
+ points = p3s.Points(geometry=geometry, material=material)
379
+ point_obj = {"geometry": geometry, "mesh": points, "material": material,
380
+ "max": ma, "min": mi, "type": "Points", "wireframe": None}
381
+
382
+ if obj:
383
+ return self.__add_object(point_obj, obj), point_obj
384
+ else:
385
+ return self.__add_object(point_obj)
386
+
387
+ def remove_object(self, obj_id):
388
+ if obj_id not in self.__objects:
389
+ print("Invalid object id. Valid ids are: ", list(self.__objects.keys()))
390
+ return
391
+ self._scene.remove(self.__objects[obj_id]["mesh"])
392
+ del self.__objects[obj_id]
393
+ self.__update_view()
394
+
395
+ def reset(self):
396
+ for obj_id in list(self.__objects.keys()).copy():
397
+ self._scene.remove(self.__objects[obj_id]["mesh"])
398
+ del self.__objects[obj_id]
399
+ self.__update_view()
400
+
401
+ def update_object(self, oid=0, vertices=None, colors=None, faces=None):
402
+ obj = self.__objects[oid]
403
+ if type(vertices) != type(None):
404
+ if obj["coloring"] == "FaceColors":
405
+ f = obj["arrays"][1]
406
+ verts = np.zeros((f.shape[0] * 3, 3), dtype="float32")
407
+ for ii in range(f.shape[0]):
408
+ # print(ii*3, f[ii])
409
+ verts[ii * 3] = vertices[f[ii, 0]]
410
+ verts[ii * 3 + 1] = vertices[f[ii, 1]]
411
+ verts[ii * 3 + 2] = vertices[f[ii, 2]]
412
+ v = verts
413
+
414
+ else:
415
+ v = vertices.astype("float32", copy=False)
416
+ obj["geometry"].attributes["position"].array = v
417
+ # self.wireframe.attributes["position"].array = v # Wireframe updates?
418
+ obj["geometry"].attributes["position"].needsUpdate = True
419
+ # obj["geometry"].exec_three_obj_method('computeVertexNormals')
420
+ if type(colors) != type(None):
421
+ colors, coloring = self.__get_colors(obj["arrays"][0], obj["arrays"][1], colors, obj["shading"])
422
+ colors = colors.astype("float32", copy=False)
423
+ obj["geometry"].attributes["color"].array = colors
424
+ obj["geometry"].attributes["color"].needsUpdate = True
425
+ if type(faces) != type(None):
426
+ if obj["coloring"] == "FaceColors":
427
+ print("Face updates are currently only possible in vertex color mode.")
428
+ return
429
+ f = faces.astype("uint32", copy=False).ravel()
430
+ print(obj["geometry"].attributes)
431
+ obj["geometry"].attributes["index"].array = f
432
+ # self.wireframe.attributes["position"].array = v # Wireframe updates?
433
+ obj["geometry"].attributes["index"].needsUpdate = True
434
+ # obj["geometry"].exec_three_obj_method('computeVertexNormals')
435
+ # self.mesh.geometry.verticesNeedUpdate = True
436
+ # self.mesh.geometry.elementsNeedUpdate = True
437
+ # self.update()
438
+ if self.render_mode == "WEBSITE":
439
+ return self
440
+
441
+ # def update(self):
442
+ # self.mesh.exec_three_obj_method('update')
443
+ # self.orbit.exec_three_obj_method('update')
444
+ # self.cam.exec_three_obj_method('updateProjectionMatrix')
445
+ # self.scene.exec_three_obj_method('update')
446
+
447
+ def add_text(self, text, shading={}, **kwargs):
448
+ shading.update(kwargs)
449
+ sh = self.__get_shading(shading)
450
+ tt = p3s.TextTexture(string=text, color=sh["text_color"])
451
+ sm = p3s.SpriteMaterial(map=tt)
452
+ text = p3s.Sprite(material=sm, scaleToTexture=True)
453
+ self._scene.add(text)
454
+
455
+ # def add_widget(self, widget, callback):
456
+ # self.widgets.append(widget)
457
+ # widget.observe(callback, names='value')
458
+
459
+ # def add_dropdown(self, options, default, desc, cb):
460
+ # widget = widgets.Dropdown(options=options, value=default, description=desc)
461
+ # self.__widgets.append(widget)
462
+ # widget.observe(cb, names="value")
463
+ # display(widget)
464
+
465
+ # def add_button(self, text, cb):
466
+ # button = widgets.Button(description=text)
467
+ # self.__widgets.append(button)
468
+ # button.on_click(cb)
469
+ # display(button)
470
+
471
+ def to_html(self, imports=True, html_frame=True):
472
+ # Bake positions (fixes centering bug in offline rendering)
473
+ if len(self.__objects) == 0:
474
+ return
475
+ ma = np.zeros((len(self.__objects), 3))
476
+ mi = np.zeros((len(self.__objects), 3))
477
+ for r, obj in enumerate(self.__objects):
478
+ ma[r] = self.__objects[obj]["max"]
479
+ mi[r] = self.__objects[obj]["min"]
480
+ ma = np.max(ma, axis=0)
481
+ mi = np.min(mi, axis=0)
482
+ diag = np.linalg.norm(ma - mi)
483
+ mean = (ma - mi) / 2 + mi
484
+ for r, obj in enumerate(self.__objects):
485
+ v = self.__objects[obj]["geometry"].attributes["position"].array
486
+ v -= mean
487
+ v += np.array([0.0, .9, 0.0]) #! to move the obj to the center of window
488
+
489
+ scale = self.__s["scale"] * (diag)
490
+ self._orbit.target = [0.0, 0.0, 0.0]
491
+ self._cam.lookAt([0.0, 0.0, 0.0])
492
+ # self._cam.position = [0.0, 0.0, scale]
493
+ self._cam.position = [0.0, 0.5, scale * 1.3] #! show four complete meshes in the window
494
+ self._light.position = [0.0, 0.0, scale]
495
+
496
+ state = embed.dependency_state(self._renderer)
497
+
498
+ # Somehow these entries are missing when the state is exported in python.
499
+ # Exporting from the GUI works, so we are inserting the missing entries.
500
+ for k in state:
501
+ if state[k]["model_name"] == "OrbitControlsModel":
502
+ state[k]["state"]["maxAzimuthAngle"] = "inf"
503
+ state[k]["state"]["maxDistance"] = "inf"
504
+ state[k]["state"]["maxZoom"] = "inf"
505
+ state[k]["state"]["minAzimuthAngle"] = "-inf"
506
+
507
+ tpl = embed.load_requirejs_template
508
+ if not imports:
509
+ embed.load_requirejs_template = ""
510
+
511
+ s = embed.embed_snippet(self._renderer, state=state, embed_url=EMBED_URL)
512
+ # s = embed.embed_snippet(self.__w, state=state)
513
+ embed.load_requirejs_template = tpl
514
+
515
+ if html_frame:
516
+ s = "<html>\n<body>\n" + s + "\n</body>\n</html>"
517
+
518
+ # Revert changes
519
+ for r, obj in enumerate(self.__objects):
520
+ v = self.__objects[obj]["geometry"].attributes["position"].array
521
+ v += mean
522
+ self.__update_view()
523
+
524
+ return s
525
+
526
+ def save(self, filename=""):
527
+ if filename == "":
528
+ uid = str(uuid.uuid4()) + ".html"
529
+ else:
530
+ filename = filename.replace(".html", "")
531
+ uid = filename + '.html'
532
+ with open(uid, "w") as f:
533
+ f.write(self.to_html())
534
+ print("Plot saved to file %s." % uid)
MeshAnything/miche/shapevae-256.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: MeshAnything.miche.michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
3
+ params:
4
+ shape_module_cfg:
5
+ target: MeshAnything.miche.michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
6
+ params:
7
+ num_latents: 256
8
+ embed_dim: 64
9
+ point_feats: 3 # normal
10
+ num_freqs: 8
11
+ include_pi: false
12
+ heads: 12
13
+ width: 768
14
+ num_encoder_layers: 8
15
+ num_decoder_layers: 16
16
+ use_ln_post: true
17
+ init_scale: 0.25
18
+ qkv_bias: false
19
+ use_checkpoint: true
20
+ aligned_module_cfg:
21
+ target: MeshAnything.miche.michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
22
+ params:
23
+ clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
24
+
25
+ loss_cfg:
26
+ target: MeshAnything.miche.michelangelo.models.tsal.loss.ContrastKLNearFar
27
+ params:
28
+ contrast_weight: 0.1
29
+ near_weight: 0.1
30
+ kl_weight: 0.001
31
+
32
+ optimizer_cfg:
33
+ optimizer:
34
+ target: torch.optim.AdamW
35
+ params:
36
+ betas: [0.9, 0.99]
37
+ eps: 1.e-6
38
+ weight_decay: 1.e-2
39
+
40
+ scheduler:
41
+ target: MeshAnything.miche.michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
42
+ params:
43
+ warm_up_steps: 5000
44
+ f_start: 1.e-6
45
+ f_min: 1.e-3
46
+ f_max: 1.0
MeshAnything/models/meshanything.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ from transformers import AutoModelForCausalLM, AutoConfig, AutoModel
4
+ from MeshAnything.miche.encode import load_model
5
+ from MeshAnything.models.shape_opt import ShapeOPTConfig
6
+ from einops.layers.torch import Rearrange
7
+ from einops import rearrange, repeat, reduce, pack, unpack
8
+ import torch.nn.functional as F
9
+
10
+ class NoiseResistantDecoder(nn.Module):
11
+
12
+ def __init__(self, args):
13
+ super().__init__()
14
+ self.args = args
15
+ self.pad_id = -1
16
+ self.num_quantizers = 3
17
+
18
+ self.discrete_num = 128
19
+ self.codebook_size = args.codebook_size
20
+ self.codebook_dim = args.codebook_dim
21
+
22
+ config = AutoConfig.from_pretrained("bert-base-uncased")
23
+ config.num_hidden_layers = 6
24
+ self.decoder= AutoModel.from_config(config=config).to_bettertransformer().encoder
25
+ self.n_embd = self.decoder.config.hidden_size
26
+
27
+ self.pos_embedding = nn.Embedding(18000, self.n_embd)
28
+ self.layernorm = nn.LayerNorm(self.n_embd)
29
+ self.point_layernorm = nn.LayerNorm(self.n_embd)
30
+
31
+ self.cond_length = 257
32
+ self.cond_dim = 768
33
+ self.point_pe = nn.Embedding(self.cond_length, self.n_embd)
34
+ self.cond_proj = nn.Linear(self.cond_dim, self.n_embd)
35
+ self.cond_head_proj = nn.Linear(self.cond_dim, self.n_embd)
36
+
37
+ self.project_down_codebook = nn.Linear(self.codebook_dim * 3, self.n_embd)
38
+ self.to_coor_logits = nn.Sequential(
39
+ nn.Linear(self.n_embd, self.discrete_num * 9),
40
+ Rearrange('... (v c) -> ... v c', v = 9)
41
+ )
42
+ def process_point_feature(self, encode_feature):
43
+ point_feature = torch.zeros(encode_feature.shape[0], self.cond_length, self.n_embd, device=self.cond_head_proj.weight.device, dtype=self.cond_head_proj.weight.dtype)
44
+ point_feature[:, 0] = self.cond_head_proj(encode_feature[:, 0])
45
+ point_feature[:, 1:] = self.cond_proj(encode_feature[:, 1:])
46
+
47
+ point_feature = self.point_layernorm(point_feature + self.point_pe.weight[None, :point_feature.shape[1]])
48
+ return point_feature
49
+
50
+ def forward(self, input_ids, input_embeds, point_feature = None):
51
+ input_ids = input_ids.reshape(input_ids.shape[0], -1)
52
+ point_feature = self.process_point_feature(point_feature)
53
+
54
+ face_embeds = rearrange(input_embeds, 'b (nf nv) d -> b nf (nv d)', nv = 3)
55
+ face_embeds = self.project_down_codebook(face_embeds)
56
+
57
+ face_mask = reduce(input_ids != self.pad_id, 'b (nf nv q) -> b nf', 'all', nv = 3, q = self.num_quantizers)
58
+ face_embeds[~face_mask] = 0
59
+
60
+ face_embeds = self.layernorm(face_embeds + self.pos_embedding.weight[None, :face_embeds.shape[1]])
61
+
62
+ outputs = self.decoder(
63
+ hidden_states=torch.concatenate([point_feature, face_embeds], dim=1),
64
+ )
65
+ decoded = outputs.last_hidden_state[:, self.cond_length:] # batch x nfaces x dim
66
+ decoded = decoded.masked_fill(~face_mask.unsqueeze(-1), 0.)
67
+
68
+ # batch x nfaces x 9 -> batch x nfaces x 3 x 3
69
+ pred_face_logits = self.to_coor_logits(decoded) # batch x nfaces x 9 x ndiscrete
70
+ pred_face_coords = rearrange(pred_face_logits.argmax(dim = -1), '... (v c) -> ... v c', v = 3)
71
+
72
+ continuous_coors = undiscretize(
73
+ pred_face_coords,
74
+ num_discrete = self.discrete_num,
75
+ low = -0.5,
76
+ high = 0.5
77
+ )
78
+ continuous_coors = continuous_coors.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1 1'), float('nan'))
79
+
80
+ return continuous_coors
81
+
82
+ class MeshAnything(nn.Module):
83
+ def __init__(self, args):
84
+ super().__init__()
85
+ self.args = args
86
+ self.point_encoder = load_model(ckpt_path=None)
87
+ self.tokenizer = NoiseResistantDecoder(args)
88
+
89
+ self.num_quantizers = 3
90
+ self.face_per_token = self.num_quantizers * 3
91
+ self.cond_length = 257
92
+ self.cond_dim = 768
93
+ self.max_length = args.n_max_triangles * self.face_per_token + 2 + self.cond_length
94
+
95
+ self.config = ShapeOPTConfig.from_pretrained(
96
+ args.llm,
97
+ n_positions=18259,
98
+ max_position_embeddings=18259,
99
+ vocab_size=self.tokenizer.codebook_size + 3,
100
+ _attn_implementation="flash_attention_2"
101
+ )
102
+ self.bos_token_id = 0
103
+ self.eos_token_id = 1
104
+ self.pad_token_id = 2
105
+ self.config.bos_token_id = self.bos_token_id
106
+ self.config.eos_token_id = self.eos_token_id
107
+ self.config.pad_token_id = self.pad_token_id
108
+ self.config.quantize_codebook_dim = self.tokenizer.codebook_dim
109
+ self.config.face_per_token = self.face_per_token
110
+ self.config._attn_implementation="flash_attention_2"
111
+ self.config.cond_length = self.cond_length
112
+ if self.config.word_embed_proj_dim != self.config.hidden_size:
113
+ self.config.word_embed_proj_dim = self.config.hidden_size
114
+ self.transformer = AutoModelForCausalLM.from_config(
115
+ config=self.config, use_flash_attention_2 = True
116
+ )
117
+ self.transformer.to_bettertransformer()
118
+ self.transformer.model.decoder.quantize_codebooks = nn.Parameter(torch.zeros(1, self.tokenizer.codebook_size, self.tokenizer.codebook_dim))
119
+
120
+ self.cond_head_proj = nn.Linear(self.cond_dim, self.config.word_embed_proj_dim)
121
+ self.cond_proj = nn.Linear(self.cond_dim * 2, self.config.word_embed_proj_dim)
122
+
123
+ self.eval()
124
+
125
+ def process_point_feature(self, point_feature):
126
+ encode_feature = torch.zeros(point_feature.shape[0], self.cond_length, self.config.word_embed_proj_dim,
127
+ device=self.cond_head_proj.weight.device, dtype=self.cond_head_proj.weight.dtype)
128
+ encode_feature[:, 0] = self.cond_head_proj(point_feature[:, 0])
129
+ shape_latents = self.point_encoder.to_shape_latents(point_feature[:, 1:])
130
+ encode_feature[:, 1:] = self.cond_proj(torch.cat([point_feature[:, 1:], shape_latents], dim=-1))
131
+
132
+ return encode_feature
133
+
134
+ @torch.no_grad()
135
+ def forward(self, pc_normal, sampling=False) -> dict:
136
+ batch_size = pc_normal.shape[0]
137
+ point_feature = self.point_encoder.encode_latents(pc_normal)
138
+ processed_point_feature = self.process_point_feature(point_feature)
139
+
140
+ generate_length = self.max_length - self.cond_length
141
+ net_device = next(self.parameters()).device
142
+ outputs = torch.ones(batch_size, generate_length).long().to(net_device) * self.eos_token_id
143
+ if not sampling:
144
+ results = self.transformer.generate(
145
+ inputs_embeds=processed_point_feature,
146
+ max_new_tokens=generate_length, # all faces plus two
147
+ num_beams=1,
148
+ bos_token_id=self.bos_token_id,
149
+ eos_token_id=self.eos_token_id,
150
+ pad_token_id=self.pad_token_id,
151
+ )
152
+ else:
153
+ results = self.transformer.generate(
154
+ inputs_embeds = processed_point_feature,
155
+ max_new_tokens=generate_length, # all faces plus two
156
+ do_sample=True,
157
+ top_k=50,
158
+ top_p=0.95,
159
+ bos_token_id = self.bos_token_id,
160
+ eos_token_id = self.eos_token_id,
161
+ pad_token_id = self.pad_token_id,
162
+ )
163
+ assert results.shape[1] <= generate_length # B x ID bos is not included since it's predicted
164
+ outputs[:, :results.shape[1]] = results
165
+ # batch x ntokens ====> batch x ntokens x D
166
+ outputs = outputs[:, 1: -1]
167
+
168
+ outputs[outputs == self.bos_token_id] = self.tokenizer.pad_id
169
+ outputs[outputs == self.eos_token_id] = self.tokenizer.pad_id
170
+ outputs[outputs == self.pad_token_id] = self.tokenizer.pad_id
171
+
172
+ outputs[outputs != self.tokenizer.pad_id] -= 3
173
+ code_embed = self.get_codes(outputs)
174
+ decoder_output = self.tokenizer(outputs, code_embed, point_feature=point_feature)
175
+
176
+ return decoder_output
177
+
178
+ def get_codes(self, indices):
179
+ indices = indices.reshape(indices.shape[0], -1)
180
+
181
+ indices = rearrange(indices, 'b (n q) -> b n q', q=self.num_quantizers)
182
+
183
+ batch, quantize_dim = indices.shape[0], indices.shape[-1]
184
+ # may also receive indices in the shape of 'b h w q' (accept_image_fmap)
185
+
186
+ indices, ps = pack([indices], 'b * q')
187
+
188
+ # because of quantize dropout, one can pass in indices that are coarse
189
+ # and the network should be able to reconstruct
190
+
191
+ if quantize_dim < self.num_quantizers:
192
+ indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)
193
+
194
+ # take care of quantizer dropout
195
+
196
+ mask = indices == -1.
197
+ indices = indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later
198
+
199
+ # dummy implementation for shared codebook
200
+ all_codes = self.transformer.model.decoder.quantize_codebooks[0][indices]
201
+ all_codes = all_codes.permute(2, 0, 1, 3)
202
+
203
+ # mask out any codes that were dropout-ed
204
+
205
+ all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.)
206
+
207
+ # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
208
+
209
+ codes, = unpack(all_codes, ps, 'q b * d')
210
+
211
+ codes_summed = reduce(codes, 'q ... -> ...', 'sum')
212
+ return codes_summed
213
+
214
+ def undiscretize(
215
+ t,
216
+ low,
217
+ high,
218
+ num_discrete
219
+ ) -> Tensor:
220
+ t = t.float()
221
+
222
+ t /= num_discrete
223
+ return t * (high - low) + low
MeshAnything/models/shape_opt.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoConfig, OPTConfig
2
+ from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTModel, OPTDecoder, OPTLearnedPositionalEmbedding, OPTDecoderLayer
3
+ from typing import List, Optional, Tuple, Union
4
+ from einops import repeat
5
+ from transformers.modeling_outputs import (
6
+ CausalLMOutputWithPast,
7
+ )
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import CrossEntropyLoss
11
+ from transformers.utils import replace_return_docstrings, logging
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast
13
+ # from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
14
+
15
+ class ShapeOPTConfig(OPTConfig):
16
+ model_type = "shape_opt"
17
+
18
+ class ShapeOPT(OPTForCausalLM):
19
+ config_class = ShapeOPTConfig
20
+ def __init__(self, config: ShapeOPTConfig):
21
+ super(OPTForCausalLM, self).__init__(config)
22
+ self.model = ShapeOPTModel(config)
23
+
24
+ self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
25
+
26
+ # Initialize weights and apply final processing
27
+ self.post_init()
28
+
29
+ def tie_weights(self):
30
+ """
31
+ Tie the weights between the input embeddings and the output embeddings.
32
+
33
+ If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
34
+ weights instead.
35
+ """
36
+ if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
37
+ if hasattr(self, self.base_model_prefix):
38
+ self = getattr(self, self.base_model_prefix)
39
+ self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
40
+
41
+ for module in self.modules():
42
+ if hasattr(module, "_tie_weights"):
43
+ module._tie_weights()
44
+
45
+
46
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="OPTConfig")
47
+ def forward(
48
+ self,
49
+ input_ids: torch.LongTensor = None,
50
+ face_ids: torch.LongTensor = None,
51
+ attention_mask: Optional[torch.Tensor] = None,
52
+ head_mask: Optional[torch.Tensor] = None,
53
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
54
+ inputs_embeds: Optional[torch.FloatTensor] = None,
55
+ labels: Optional[torch.LongTensor] = None,
56
+ use_cache: Optional[bool] = None,
57
+ output_attentions: Optional[bool] = None,
58
+ output_hidden_states: Optional[bool] = None,
59
+ return_dict: Optional[bool] = None,
60
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
61
+ r"""
62
+ Args:
63
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
64
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
65
+ provide it.
66
+
67
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
68
+ [`PreTrainedTokenizer.__call__`] for details.
69
+
70
+ [What are input IDs?](../glossary#input-ids)
71
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
72
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
73
+
74
+ - 1 for tokens that are **not masked**,
75
+ - 0 for tokens that are **masked**.
76
+
77
+ [What are attention masks?](../glossary#attention-mask)
78
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
79
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
80
+
81
+ - 1 indicates the head is **not masked**,
82
+ - 0 indicates the head is **masked**.
83
+
84
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
85
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
86
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
87
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
88
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
89
+
90
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
91
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
92
+
93
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
94
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
95
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
96
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
97
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
98
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
99
+ than the model's internal embedding lookup matrix.
100
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
101
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
102
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
103
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
104
+ use_cache (`bool`, *optional*):
105
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
106
+ (see `past_key_values`).
107
+ output_attentions (`bool`, *optional*):
108
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
109
+ returned tensors for more detail.
110
+ output_hidden_states (`bool`, *optional*):
111
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
112
+ for more detail.
113
+ return_dict (`bool`, *optional*):
114
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
115
+
116
+ Returns:
117
+
118
+ Example:
119
+
120
+ ```python
121
+ >>> from transformers import AutoTokenizer, OPTForCausalLM
122
+
123
+ >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
124
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
125
+
126
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
127
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
128
+
129
+ >>> # Generate
130
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
131
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
132
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
133
+ ```"""
134
+
135
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
136
+ output_hidden_states = (
137
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
138
+ )
139
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
140
+
141
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
142
+ outputs = self.model.decoder(
143
+ input_ids=input_ids,
144
+ face_ids = face_ids,
145
+ attention_mask=attention_mask,
146
+ head_mask=head_mask,
147
+ past_key_values=past_key_values,
148
+ inputs_embeds=inputs_embeds,
149
+ use_cache=use_cache,
150
+ output_attentions=output_attentions,
151
+ output_hidden_states=output_hidden_states,
152
+ return_dict=return_dict,
153
+ )
154
+
155
+ logits = self.lm_head(outputs[0]).contiguous()
156
+
157
+ loss = None
158
+ if labels is not None:
159
+ # move labels to correct device to enable model parallelism
160
+ labels = labels.to(logits.device)
161
+ # Shift so that tokens < n predict n
162
+ shift_logits = logits[..., :-1, :].contiguous()
163
+ shift_labels = labels[..., 1:].contiguous()
164
+ # Flatten the tokens
165
+ loss_fct = CrossEntropyLoss()
166
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
167
+
168
+ if not return_dict:
169
+ output = (logits,) + outputs[1:]
170
+ return (loss,) + output if loss is not None else output
171
+
172
+ return CausalLMOutputWithPast(
173
+ loss=loss,
174
+ logits=logits,
175
+ past_key_values=outputs.past_key_values,
176
+ hidden_states=outputs.hidden_states,
177
+ attentions=outputs.attentions,
178
+ )
179
+
180
+ class ShapeOPTModel(OPTModel):
181
+ config_class = ShapeOPTConfig
182
+ def __init__(self, config: ShapeOPTConfig):
183
+ super(OPTModel,self).__init__(config)
184
+ self.decoder = ShapeOPTDecoder(config)
185
+ # Initialize weights and apply final processing
186
+ self.post_init()
187
+
188
+ class ShapeOPTDecoder(OPTDecoder):
189
+ config_class = ShapeOPTConfig
190
+ def __init__(self, config: ShapeOPTConfig):
191
+ super(OPTDecoder,self).__init__(config)
192
+ self.config = config
193
+ self.dropout = config.dropout
194
+ self.layerdrop = config.layerdrop
195
+ self.padding_idx = config.pad_token_id
196
+ self.max_target_positions = config.max_position_embeddings
197
+ self.vocab_size = config.vocab_size
198
+
199
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) # not used
200
+ self.hidden_size = config.hidden_size
201
+ self.word_embed_proj_dim = config.word_embed_proj_dim
202
+ self.extra_embeds = nn.Embedding(3, config.word_embed_proj_dim) #padding_idx=self.padding_idx)
203
+ self.input_layer = nn.Linear(config.quantize_codebook_dim, config.word_embed_proj_dim)
204
+
205
+ self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
206
+ self.token_embed_positions = OPTFacePositionalEmbedding(config.face_per_token + 3, config.word_embed_proj_dim) #padding_idx=self.padding_idx)
207
+ self.face_per_token = config.face_per_token
208
+ self.cond_length = config.cond_length
209
+ self.cond_embed = nn.Embedding(2, config.word_embed_proj_dim)
210
+
211
+ if config.word_embed_proj_dim != config.hidden_size:
212
+ self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
213
+ else:
214
+ self.project_out = None
215
+
216
+ if config.word_embed_proj_dim != config.hidden_size:
217
+ self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
218
+ else:
219
+ self.project_in = None
220
+ # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
221
+ # with checkpoints that have been fine-tuned before transformers v4.20.1
222
+ # see https://github.com/facebookresearch/metaseq/pull/164
223
+ if config.do_layer_norm_before and not config._remove_final_layer_norm:
224
+ self.final_layer_norm = nn.LayerNorm(
225
+ config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
226
+ )
227
+ else:
228
+ self.final_layer_norm = None
229
+
230
+ self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
231
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
232
+
233
+ self.gradient_checkpointing = False
234
+ # Initialize weights and apply final processing
235
+ self.post_init()
236
+
237
+ def embed_with_vae(self, input_ids):
238
+ inputs_embeds = repeat(torch.zeros(input_ids.shape, device=input_ids.device), 'b n -> b n d',
239
+ d=self.word_embed_proj_dim).clone().detach()
240
+ idx_in_extra = torch.isin(input_ids, torch.LongTensor([0, 1, 2]).to(input_ids.device))
241
+ inputs_embeds[idx_in_extra] += self.extra_embeds(input_ids[idx_in_extra])
242
+ self.quantize_codebooks = self.quantize_codebooks.to(input_ids.device)
243
+ inputs_embeds[~idx_in_extra] += self.input_layer(self.quantize_codebooks[0][input_ids[~idx_in_extra] - 3])
244
+
245
+ return inputs_embeds
246
+
247
+
248
+ def forward(
249
+ self,
250
+ input_ids: torch.LongTensor = None,
251
+ face_ids: torch.LongTensor = None,
252
+ attention_mask: Optional[torch.Tensor] = None,
253
+ head_mask: Optional[torch.Tensor] = None,
254
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
255
+ inputs_embeds: Optional[torch.FloatTensor] = None,
256
+ use_cache: Optional[bool] = None,
257
+ output_attentions: Optional[bool] = None,
258
+ output_hidden_states: Optional[bool] = None,
259
+ return_dict: Optional[bool] = None,
260
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
261
+ r"""
262
+ Args:
263
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
264
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
265
+ provide it.
266
+
267
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
268
+ [`PreTrainedTokenizer.__call__`] for details.
269
+
270
+ [What are input IDs?](../glossary#input-ids)
271
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
272
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
273
+
274
+ - 1 for tokens that are **not masked**,
275
+ - 0 for tokens that are **masked**.
276
+
277
+ [What are attention masks?](../glossary#attention-mask)
278
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
279
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
280
+
281
+ - 1 indicates the head is **not masked**,
282
+ - 0 indicates the head is **masked**.
283
+
284
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
285
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
286
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
287
+
288
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
289
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
290
+
291
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
292
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
293
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
294
+
295
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
296
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
297
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
298
+ than the model's internal embedding lookup matrix.
299
+ output_attentions (`bool`, *optional*):
300
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
301
+ returned tensors for more detail.
302
+ output_hidden_states (`bool`, *optional*):
303
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
304
+ for more detail.
305
+ return_dict (`bool`, *optional*):
306
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
307
+ """
308
+ # OPT Decoder
309
+ # print("used my Trans")
310
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
311
+ output_hidden_states = (
312
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
313
+ )
314
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
315
+
316
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
317
+ # Transformer Decoder
318
+ if input_ids is not None:
319
+ input_shape = input_ids.size()
320
+ input_ids = input_ids.view(-1, input_shape[-1])
321
+ inputs_embeds = self.embed_with_vae(input_ids) # nothing to do with position
322
+
323
+ face_embeds = self.token_embed_positions(attention_mask[:, self.cond_length:], face_ids, input_ids,
324
+ self.face_per_token)
325
+ inputs_embeds += face_embeds
326
+ cond_embed_query = torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), device=inputs_embeds.device,
327
+ dtype=inputs_embeds.dtype).long()
328
+ inputs_embeds = inputs_embeds + self.cond_embed(cond_embed_query)
329
+
330
+ elif inputs_embeds is not None:
331
+ # assert self.cond and not self.training
332
+
333
+ total_length = inputs_embeds.shape[1] # B x length x embeding
334
+ cond_embed_query = torch.zeros((inputs_embeds.shape[0], total_length), device=inputs_embeds.device,
335
+ dtype=inputs_embeds.dtype).long()
336
+ inputs_embeds = inputs_embeds + self.cond_embed(cond_embed_query)
337
+ else:
338
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
339
+
340
+ batch_size, seq_length = inputs_embeds.shape[:2] # seq_length not used since mask_seq_length is not used
341
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
342
+ # required mask seq length can be calculated via length of past
343
+ mask_seq_length = past_key_values_length + seq_length # not used since attention mask is input
344
+
345
+ # embed positions
346
+ if self._use_flash_attention_2:
347
+ # 2d mask is passed through the layers
348
+ assert attention_mask is not None
349
+ causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
350
+ attention_mask = (
351
+ torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
352
+ if attention_mask is None
353
+ else attention_mask
354
+ )
355
+ else:
356
+ raise ValueError("Only flash_attention_2 is supported in MeshAnything")
357
+
358
+ pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
359
+
360
+ if self.project_in is not None:
361
+ inputs_embeds = self.project_in(inputs_embeds)
362
+
363
+ hidden_states = inputs_embeds + pos_embeds
364
+
365
+ # decoder layers
366
+ all_hidden_states = () if output_hidden_states else None
367
+ all_self_attns = () if output_attentions else None
368
+ next_decoder_cache = () if use_cache else None
369
+
370
+ # check if head_mask has a correct number of layers specified if desired
371
+ for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
372
+ if attn_mask is not None:
373
+ if attn_mask.size()[0] != (len(self.layers)):
374
+ raise ValueError(
375
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
376
+ f" {head_mask.size()[0]}."
377
+ )
378
+
379
+ for idx, decoder_layer in enumerate(self.layers):
380
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
381
+ if output_hidden_states:
382
+ all_hidden_states += (hidden_states,)
383
+
384
+ if self.training:
385
+ dropout_probability = torch.rand([])
386
+ if dropout_probability < self.layerdrop:
387
+ continue
388
+
389
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
390
+
391
+ if self.gradient_checkpointing and self.training:
392
+ layer_outputs = self._gradient_checkpointing_func(
393
+ decoder_layer.__call__,
394
+ hidden_states,
395
+ causal_attention_mask,
396
+ head_mask[idx] if head_mask is not None else None,
397
+ None,
398
+ output_attentions,
399
+ use_cache,
400
+ )
401
+ else:
402
+ layer_outputs = decoder_layer(
403
+ hidden_states,
404
+ attention_mask=causal_attention_mask,
405
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
406
+ past_key_value=past_key_value,
407
+ output_attentions=output_attentions,
408
+ use_cache=use_cache,
409
+ )
410
+
411
+ hidden_states = layer_outputs[0]
412
+
413
+ if use_cache:
414
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
415
+
416
+ if output_attentions:
417
+ all_self_attns += (layer_outputs[1],)
418
+
419
+ if self.final_layer_norm is not None:
420
+ hidden_states = self.final_layer_norm(hidden_states)
421
+
422
+ if self.project_out is not None:
423
+ hidden_states = self.project_out(hidden_states)
424
+
425
+ # add hidden states from the last decoder layer
426
+ if output_hidden_states:
427
+ all_hidden_states += (hidden_states,)
428
+
429
+ next_cache = next_decoder_cache if use_cache else None
430
+ if not return_dict:
431
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
432
+ return BaseModelOutputWithPast(
433
+ last_hidden_state=hidden_states,
434
+ past_key_values=next_cache,
435
+ hidden_states=all_hidden_states,
436
+ attentions=all_self_attns,
437
+ )
438
+
439
+ class OPTFacePositionalEmbedding(nn.Embedding):
440
+ """
441
+ This module learns positional embeddings up to a fixed maximum size.
442
+ """
443
+
444
+ def __init__(self, num_embeddings: int, embedding_dim: int):
445
+ super().__init__(num_embeddings, embedding_dim)
446
+
447
+ def forward(self, attention_mask=None, face_ids = None, input_ids = None, face_per_token = None):
448
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
449
+ if face_ids is not None:
450
+ return super().forward(face_ids)
451
+
452
+ assert input_ids.shape[1] == 1
453
+ idx_in_extra = torch.isin(input_ids, torch.LongTensor([0, 1, 2]).to(input_ids.device))
454
+ cur_ids = input_ids.clone().detach()
455
+
456
+ cur_index = (attention_mask.sum(dim=1, keepdim=True) - 2) % face_per_token + 3
457
+ cur_ids[~idx_in_extra]=cur_index[~idx_in_extra]
458
+
459
+ return super().forward(cur_ids)
460
+
461
+
462
+ AutoConfig.register("shape_opt", ShapeOPTConfig)
463
+ AutoModelForCausalLM.register(ShapeOPTConfig, ShapeOPT)
464
+