Spaces:
Sleeping
Sleeping
Aditya Patkar
commited on
Commit
β’
c046d7f
1
Parent(s):
cb0d40a
Added training files, enforced code formatting
Browse files- Training/MSML612_Project_DDPMTraining.ipynb +0 -0
- Training/MSML_612_Project_MinImagen.ipynb +1 -0
- app.py +15 -12
- constants.py +1 -1
- feature_to_sprite.py +42 -40
- text_to_image.py +3 -1
- utilities.py +90 -63
Training/MSML612_Project_DDPMTraining.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Training/MSML_612_Project_MinImagen.ipynb
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4","authorship_tag":"ABX9TyNaEJSE86MIYL6MGO8lUr3+"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU","widgets":{"application/vnd.jupyter.widget-state+json":{"f77bbc2602d846f8bb1c9e06f7b519ef":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_d2c87c66057f4e33bdbfe078ee47c0b2","IPY_MODEL_fe33fa70d2c540e9831d408ffb5d3af9","IPY_MODEL_61d7b89f706348548ebcaf0ced92a44e"],"layout":"IPY_MODEL_a51b1054233342feaef0b16d2627a658"}},"d2c87c66057f4e33bdbfe078ee47c0b2":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_c4aa2aa530034af487957db71e9c509f","placeholder":"β","style":"IPY_MODEL_b4957dccdc5c4b83982212497febf4dc","value":"100%"}},"fe33fa70d2c540e9831d408ffb5d3af9":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_592c8770c24c4366843336989c8cdb5f","max":2,"min":0,"orientation":"horizontal","style":"IPY_MODEL_f6c8ccfd74844b7ebafb37f181b51130","value":2}},"61d7b89f706348548ebcaf0ced92a44e":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_178935e2880043dc9eca862c09c05c16","placeholder":"β","style":"IPY_MODEL_aac7d4c2eebf4eb1914d2d98c52243ce","value":" 2/2 [00:00<00:00, 46.37it/s]"}},"a51b1054233342feaef0b16d2627a658":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"c4aa2aa530034af487957db71e9c509f":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"b4957dccdc5c4b83982212497febf4dc":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"592c8770c24c4366843336989c8cdb5f":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"f6c8ccfd74844b7ebafb37f181b51130":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"178935e2880043dc9eca862c09c05c16":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"aac7d4c2eebf4eb1914d2d98c52243ce":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}}},"cells":[{"cell_type":"markdown","source":["# Introduction\n","\n","This is a training script for a diffusion model called MinImagen. A smaller adaptation of original Imagen architecture introduced by Google."],"metadata":{"id":"yQMHUdQADLcT"}},{"cell_type":"markdown","source":["# Setup"],"metadata":{"id":"Qz5vkO4bEh9b"}},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"id":"LeDWeQuVl_6s","executionInfo":{"status":"ok","timestamp":1690783982884,"user_tz":-330,"elapsed":117837,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"6fb30b50-90aa-4217-82e9-8fdea0826b9c"},"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting minimagen\n"," Downloading minimagen-0.0.9-py3-none-any.whl (43 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m43.0/43.0 kB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting aiohttp==3.8.1 (from minimagen)\n"," Downloading aiohttp-3.8.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.2 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m21.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting aiosignal==1.2.0 (from minimagen)\n"," Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)\n","Requirement already satisfied: async-timeout==4.0.2 in /usr/local/lib/python3.10/dist-packages (from minimagen) (4.0.2)\n","Collecting attrs==21.4.0 (from minimagen)\n"," Downloading attrs-21.4.0-py2.py3-none-any.whl (60 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m60.6/60.6 kB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting certifi==2022.6.15 (from minimagen)\n"," Downloading certifi-2022.6.15-py3-none-any.whl (160 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m160.2/160.2 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting charset-normalizer==2.1.0 (from minimagen)\n"," Downloading charset_normalizer-2.1.0-py3-none-any.whl (39 kB)\n","Collecting colorama==0.4.5 (from minimagen)\n"," Downloading colorama-0.4.5-py2.py3-none-any.whl (16 kB)\n","Collecting datasets==2.3.2 (from minimagen)\n"," Downloading datasets-2.3.2-py3-none-any.whl (362 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m362.3/362.3 kB\u001b[0m \u001b[31m36.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting dill==0.3.5.1 (from minimagen)\n"," Downloading dill-0.3.5.1-py2.py3-none-any.whl (95 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting einops==0.4.1 (from minimagen)\n"," Downloading einops-0.4.1-py3-none-any.whl (28 kB)\n","Collecting einops-exts==0.0.3 (from minimagen)\n"," Downloading einops_exts-0.0.3-py3-none-any.whl (3.8 kB)\n","Collecting filelock==3.7.1 (from minimagen)\n"," Downloading filelock-3.7.1-py3-none-any.whl (10 kB)\n","Collecting frozenlist==1.3.0 (from minimagen)\n"," Downloading frozenlist-1.3.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (157 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m157.9/157.9 kB\u001b[0m \u001b[31m19.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting fsspec==2022.5.0 (from minimagen)\n"," Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m140.6/140.6 kB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting future==0.18.2 (from minimagen)\n"," Downloading future-0.18.2.tar.gz (829 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m829.2/829.2 kB\u001b[0m \u001b[31m71.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n","Collecting huggingface-hub==0.8.1 (from minimagen)\n"," Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m101.5/101.5 kB\u001b[0m \u001b[31m12.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting idna==3.3 (from minimagen)\n"," Downloading idna-3.3-py3-none-any.whl (61 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m61.2/61.2 kB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting multidict==6.0.2 (from minimagen)\n"," Downloading multidict-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (114 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m114.5/114.5 kB\u001b[0m \u001b[31m12.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting multiprocess==0.70.13 (from minimagen)\n"," Downloading multiprocess-0.70.13-py310-none-any.whl (133 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m133.1/133.1 kB\u001b[0m \u001b[31m16.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting numpy==1.23.1 (from minimagen)\n"," Downloading numpy-1.23.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.0 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m17.0/17.0 MB\u001b[0m \u001b[31m99.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting packaging==21.3 (from minimagen)\n"," Downloading packaging-21.3-py3-none-any.whl (40 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting pandas==1.4.3 (from minimagen)\n"," Downloading pandas-1.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.6 MB)\n","\u001b[2K \u001b[90mβββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m11.6/11.6 MB\u001b[0m \u001b[31m105.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting Pillow==9.2.0 (from minimagen)\n"," Downloading Pillow-9.2.0-cp310-cp310-manylinux_2_28_x86_64.whl (3.2 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m3.2/3.2 MB\u001b[0m \u001b[31m80.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting pyarrow==8.0.0 (from minimagen)\n"," Downloading pyarrow-8.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.4 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m29.4/29.4 MB\u001b[0m \u001b[31m15.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting pyparsing==3.0.9 (from minimagen)\n"," Downloading pyparsing-3.0.9-py3-none-any.whl (98 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m98.3/98.3 kB\u001b[0m \u001b[31m11.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: python-dateutil==2.8.2 in /usr/local/lib/python3.10/dist-packages (from minimagen) (2.8.2)\n","Collecting pytz==2022.1 (from minimagen)\n"," Downloading pytz-2022.1-py2.py3-none-any.whl (503 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m503.5/503.5 kB\u001b[0m \u001b[31m44.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting PyYAML==6.0 (from minimagen)\n"," Downloading PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (682 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m682.2/682.2 kB\u001b[0m \u001b[31m56.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting regex==2022.7.9 (from minimagen)\n"," Downloading regex-2022.7.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (764 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m764.0/764.0 kB\u001b[0m \u001b[31m60.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting requests==2.28.1 (from minimagen)\n"," Downloading requests-2.28.1-py3-none-any.whl (62 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m62.8/62.8 kB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting resize-right==0.0.2 (from minimagen)\n"," Downloading resize_right-0.0.2-py3-none-any.whl (8.9 kB)\n","Collecting responses==0.18.0 (from minimagen)\n"," Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n","Collecting sentencepiece==0.1.96 (from minimagen)\n"," Downloading sentencepiece-0.1.96-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m82.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: six==1.16.0 in /usr/local/lib/python3.10/dist-packages (from minimagen) (1.16.0)\n","Collecting tdqm==0.0.1 (from minimagen)\n"," Downloading tdqm-0.0.1.tar.gz (1.4 kB)\n"," Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n","Collecting tokenizers==0.12.1 (from minimagen)\n"," Downloading tokenizers-0.12.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m6.6/6.6 MB\u001b[0m \u001b[31m99.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting torch==1.12.0 (from minimagen)\n"," Downloading torch-1.12.0-cp310-cp310-manylinux1_x86_64.whl (776.3 MB)\n","\u001b[2K \u001b[90mβββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m776.3/776.3 MB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting torchvision==0.13.0 (from minimagen)\n"," Downloading torchvision-0.13.0-cp310-cp310-manylinux1_x86_64.whl (19.1 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m19.1/19.1 MB\u001b[0m \u001b[31m70.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting tqdm==4.64.0 (from minimagen)\n"," Downloading tqdm-4.64.0-py2.py3-none-any.whl (78 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m78.4/78.4 kB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting transformers==4.20.1 (from minimagen)\n"," Downloading transformers-4.20.1-py3-none-any.whl (4.4 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m4.4/4.4 MB\u001b[0m \u001b[31m118.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting typing-extensions==4.3.0 (from minimagen)\n"," Downloading typing_extensions-4.3.0-py3-none-any.whl (25 kB)\n","Collecting urllib3==1.26.10 (from minimagen)\n"," Downloading urllib3-1.26.10-py2.py3-none-any.whl (139 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m139.2/139.2 kB\u001b[0m \u001b[31m17.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting xxhash==3.0.0 (from minimagen)\n"," Downloading xxhash-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (211 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m211.6/211.6 kB\u001b[0m \u001b[31m25.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting yarl==1.7.2 (from minimagen)\n"," Downloading yarl-1.7.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (305 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m305.3/305.3 kB\u001b[0m \u001b[31m32.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.3.2->minimagen) (2023.6.0)\n","INFO: pip is looking at multiple versions of fsspec[http] to determine which version is compatible with other requirements. This could take a while.\n","Collecting fsspec[http]>=2021.05.0 (from datasets==2.3.2->minimagen)\n"," Downloading fsspec-2023.5.0-py3-none-any.whl (160 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m160.1/160.1 kB\u001b[0m \u001b[31m19.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2023.4.0-py3-none-any.whl (153 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m154.0/154.0 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2023.3.0-py3-none-any.whl (145 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m145.4/145.4 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2023.1.0-py3-none-any.whl (143 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m143.0/143.0 kB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2022.11.0-py3-none-any.whl (139 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m139.5/139.5 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2022.10.0-py3-none-any.whl (138 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m138.8/138.8 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2022.8.2-py3-none-any.whl (140 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m19.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hINFO: pip is looking at multiple versions of fsspec[http] to determine which version is compatible with other requirements. This could take a while.\n"," Downloading fsspec-2022.7.1-py3-none-any.whl (141 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m141.2/141.2 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2022.7.0-py3-none-any.whl (141 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m141.2/141.2 kB\u001b[0m \u001b[31m20.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hBuilding wheels for collected packages: future, tdqm\n"," Building wheel for future (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for future: filename=future-0.18.2-py3-none-any.whl size=491057 sha256=7027523a9d22db99ee9aceb42c9fee72cfd0c8ed02e5907796d4270817864786\n"," Stored in directory: /root/.cache/pip/wheels/22/73/06/557dc4f4ef68179b9d763930d6eec26b88ed7c389b19588a1c\n"," Building wheel for tdqm (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for tdqm: filename=tdqm-0.0.1-py3-none-any.whl size=1321 sha256=78bbe236ae25f778666a9625686232a20f68f3d889132d56ea15c515fd88a311\n"," Stored in directory: /root/.cache/pip/wheels/37/31/b8/7b711038035720ba0df14376af06e5e76b9bd61759c861ad92\n","Successfully built future tdqm\n","Installing collected packages: tokenizers, sentencepiece, resize-right, pytz, einops, xxhash, urllib3, typing-extensions, tqdm, regex, PyYAML, pyparsing, Pillow, numpy, multidict, idna, future, fsspec, frozenlist, filelock, einops-exts, dill, colorama, charset-normalizer, certifi, attrs, yarl, torch, tdqm, requests, pyarrow, pandas, packaging, multiprocess, aiosignal, torchvision, responses, huggingface-hub, aiohttp, transformers, datasets, minimagen\n"," Attempting uninstall: pytz\n"," Found existing installation: pytz 2022.7.1\n"," Uninstalling pytz-2022.7.1:\n"," Successfully uninstalled pytz-2022.7.1\n"," Attempting uninstall: urllib3\n"," Found existing installation: urllib3 1.26.16\n"," Uninstalling urllib3-1.26.16:\n"," Successfully uninstalled urllib3-1.26.16\n"," Attempting uninstall: typing-extensions\n"," Found existing installation: typing_extensions 4.7.1\n"," Uninstalling typing_extensions-4.7.1:\n"," Successfully uninstalled typing_extensions-4.7.1\n"," Attempting uninstall: tqdm\n"," Found existing installation: tqdm 4.65.0\n"," Uninstalling tqdm-4.65.0:\n"," Successfully uninstalled tqdm-4.65.0\n"," Attempting uninstall: regex\n"," Found existing installation: regex 2022.10.31\n"," Uninstalling regex-2022.10.31:\n"," Successfully uninstalled regex-2022.10.31\n"," Attempting uninstall: PyYAML\n"," Found existing installation: PyYAML 6.0.1\n"," Uninstalling PyYAML-6.0.1:\n"," Successfully uninstalled PyYAML-6.0.1\n"," Attempting uninstall: pyparsing\n"," Found existing installation: pyparsing 3.1.0\n"," Uninstalling pyparsing-3.1.0:\n"," Successfully uninstalled pyparsing-3.1.0\n"," Attempting uninstall: Pillow\n"," Found existing installation: Pillow 9.4.0\n"," Uninstalling Pillow-9.4.0:\n"," Successfully uninstalled Pillow-9.4.0\n"," Attempting uninstall: numpy\n"," Found existing installation: numpy 1.22.4\n"," Uninstalling numpy-1.22.4:\n"," Successfully uninstalled numpy-1.22.4\n"," Attempting uninstall: multidict\n"," Found existing installation: multidict 6.0.4\n"," Uninstalling multidict-6.0.4:\n"," Successfully uninstalled multidict-6.0.4\n"," Attempting uninstall: idna\n"," Found existing installation: idna 3.4\n"," Uninstalling idna-3.4:\n"," Successfully uninstalled idna-3.4\n"," Attempting uninstall: future\n"," Found existing installation: future 0.18.3\n"," Uninstalling future-0.18.3:\n"," Successfully uninstalled future-0.18.3\n"," Attempting uninstall: fsspec\n"," Found existing installation: fsspec 2023.6.0\n"," Uninstalling fsspec-2023.6.0:\n"," Successfully uninstalled fsspec-2023.6.0\n"," Attempting uninstall: frozenlist\n"," Found existing installation: frozenlist 1.4.0\n"," Uninstalling frozenlist-1.4.0:\n"," Successfully uninstalled frozenlist-1.4.0\n"," Attempting uninstall: filelock\n"," Found existing installation: filelock 3.12.2\n"," Uninstalling filelock-3.12.2:\n"," Successfully uninstalled filelock-3.12.2\n"," Attempting uninstall: charset-normalizer\n"," Found existing installation: charset-normalizer 2.0.12\n"," Uninstalling charset-normalizer-2.0.12:\n"," Successfully uninstalled charset-normalizer-2.0.12\n"," Attempting uninstall: certifi\n"," Found existing installation: certifi 2023.7.22\n"," Uninstalling certifi-2023.7.22:\n"," Successfully uninstalled certifi-2023.7.22\n"," Attempting uninstall: attrs\n"," Found existing installation: attrs 23.1.0\n"," Uninstalling attrs-23.1.0:\n"," Successfully uninstalled attrs-23.1.0\n"," Attempting uninstall: yarl\n"," Found existing installation: yarl 1.9.2\n"," Uninstalling yarl-1.9.2:\n"," Successfully uninstalled yarl-1.9.2\n"," Attempting uninstall: torch\n"," Found existing installation: torch 2.0.1+cu118\n"," Uninstalling torch-2.0.1+cu118:\n"," Successfully uninstalled torch-2.0.1+cu118\n"," Attempting uninstall: requests\n"," Found existing installation: requests 2.27.1\n"," Uninstalling requests-2.27.1:\n"," Successfully uninstalled requests-2.27.1\n"," Attempting uninstall: pyarrow\n"," Found existing installation: pyarrow 9.0.0\n"," Uninstalling pyarrow-9.0.0:\n"," Successfully uninstalled pyarrow-9.0.0\n"," Attempting uninstall: pandas\n"," Found existing installation: pandas 1.5.3\n"," Uninstalling pandas-1.5.3:\n"," Successfully uninstalled pandas-1.5.3\n"," Attempting uninstall: packaging\n"," Found existing installation: packaging 23.1\n"," Uninstalling packaging-23.1:\n"," Successfully uninstalled packaging-23.1\n"," Attempting uninstall: aiosignal\n"," Found existing installation: aiosignal 1.3.1\n"," Uninstalling aiosignal-1.3.1:\n"," Successfully uninstalled aiosignal-1.3.1\n"," Attempting uninstall: torchvision\n"," Found existing installation: torchvision 0.15.2+cu118\n"," Uninstalling torchvision-0.15.2+cu118:\n"," Successfully uninstalled torchvision-0.15.2+cu118\n"," Attempting uninstall: aiohttp\n"," Found existing installation: aiohttp 3.8.5\n"," Uninstalling aiohttp-3.8.5:\n"," Successfully uninstalled aiohttp-3.8.5\n","\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n","gcsfs 2023.6.0 requires fsspec==2023.6.0, but you have fsspec 2022.5.0 which is incompatible.\n","google-colab 1.0.0 requires pandas==1.5.3, but you have pandas 1.4.3 which is incompatible.\n","google-colab 1.0.0 requires requests==2.27.1, but you have requests 2.28.1 which is incompatible.\n","torchaudio 2.0.2+cu118 requires torch==2.0.1, but you have torch 1.12.0 which is incompatible.\n","torchdata 0.6.1 requires torch==2.0.1, but you have torch 1.12.0 which is incompatible.\n","torchtext 0.15.2 requires torch==2.0.1, but you have torch 1.12.0 which is incompatible.\n","yfinance 0.2.25 requires pytz>=2022.5, but you have pytz 2022.1 which is incompatible.\u001b[0m\u001b[31m\n","\u001b[0mSuccessfully installed Pillow-9.2.0 PyYAML-6.0 aiohttp-3.8.1 aiosignal-1.2.0 attrs-21.4.0 certifi-2022.6.15 charset-normalizer-2.1.0 colorama-0.4.5 datasets-2.3.2 dill-0.3.5.1 einops-0.4.1 einops-exts-0.0.3 filelock-3.7.1 frozenlist-1.3.0 fsspec-2022.5.0 future-0.18.2 huggingface-hub-0.8.1 idna-3.3 minimagen-0.0.9 multidict-6.0.2 multiprocess-0.70.13 numpy-1.23.1 packaging-21.3 pandas-1.4.3 pyarrow-8.0.0 pyparsing-3.0.9 pytz-2022.1 regex-2022.7.9 requests-2.28.1 resize-right-0.0.2 responses-0.18.0 sentencepiece-0.1.96 tdqm-0.0.1 tokenizers-0.12.1 torch-1.12.0 torchvision-0.13.0 tqdm-4.64.0 transformers-4.20.1 typing-extensions-4.3.0 urllib3-1.26.10 xxhash-3.0.0 yarl-1.7.2\n"]},{"output_type":"display_data","data":{"application/vnd.colab-display-data+json":{"pip_warning":{"packages":["PIL","certifi","numpy","packaging","torch","tqdm"]}}},"metadata":{}}],"source":["#install the minimagen package\n","!pip install minimagen"]},{"cell_type":"code","source":["#utility imports\n","import os\n","from datetime import datetime\n","\n","#pytorch related imports\n","import torch.utils.data as data_utils\n","from torch import optim\n","\n","#minimagen related imports\n","from minimagen.Imagen import Imagen\n","from minimagen.Unet import Unet, Base, Super, BaseTest, SuperTest\n","from minimagen.generate import load_minimagen, load_params\n","from minimagen.t5 import get_encoded_dim\n","from minimagen.training import get_minimagen_parser, ConceptualCaptions, get_minimagen_dl_opts, \\\n"," create_directory, get_model_size, save_training_info, get_default_args, MinimagenTrain, \\\n"," load_testing_parameters"],"metadata":{"id":"b_4eGJywmHR5","executionInfo":{"status":"error","timestamp":1690968676849,"user_tz":-330,"elapsed":5122,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"154b276a-c7dd-4469-d5b8-8d1fc03b5159","colab":{"base_uri":"https://localhost:8080/","height":381}},"execution_count":null,"outputs":[{"output_type":"error","ename":"ModuleNotFoundError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)","\u001b[0;32m<ipython-input-1-192156f481b5>\u001b[0m in \u001b[0;36m<cell line: 7>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mminimagen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mImagen\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mImagen\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mminimagen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnet\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mUnet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mBase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSuper\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mBaseTest\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSuperTest\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mminimagen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerate\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mload_minimagen\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload_params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'minimagen'","","\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"],"errorDetails":{"actions":[{"action":"open_url","actionText":"Open Examples","url":"/notebooks/snippets/importing_libraries.ipynb"}]}}]},{"cell_type":"code","source":["# Get device: Connect to GPU runtime for better performance\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","\n","# Command line argument parser\n","parser = get_minimagen_parser()\n","class args_cls:\n"," a = 0\n","\n","#get an instance of the args_cls\n","args = args_cls()"],"metadata":{"id":"GoNwdipqmH95"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#directory creation for training\n","timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n","dir_path = f\"./training_{timestamp}\"\n","training_dir = create_directory(dir_path)"],"metadata":{"id":"eq8I0I7MmKFz"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#A dictionary of hyperparameters\n","hyperparameters = dict(\n"," PARAMETERS=None,\n"," NUM_WORKERS=0,\n"," BATCH_SIZE=20,\n"," MAX_NUM_WORDS=32,\n"," IMG_SIDE_LEN=128,\n"," EPOCHS=10,\n"," T5_NAME='t5_small',\n"," TRAIN_VALID_FRAC=0.5,\n"," TRAINING_DIRECTORY = '/content/training_20230731_061334',\n"," TIMESTEPS=25,\n"," OPTIM_LR=0.0001,\n"," ACCUM_ITER=1,\n"," CHCKPT_NUM=500,\n"," VALID_NUM=None,\n"," RESTART_DIRECTORY=None,\n"," TESTING=False,\n"," timestamp=None,\n"," )\n","# Replace relevant values in arg dict\n","args.__dict__ = {**args.__dict__, **hyperparameters}"],"metadata":{"id":"8hdqSoAXoxqs"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Data"],"metadata":{"id":"0CvubtDKEmHm"}},{"cell_type":"code","source":["# Load subset of Conceptual Captions dataset.\n","train_dataset, valid_dataset = ConceptualCaptions(args, smalldata=False)\n","indices = torch.arange(1000)\n","\n","#create train and validation datasets with given number of samples\n","train_dataset = data_utils.Subset(train_dataset, indices)\n","valid_dataset = data_utils.Subset(valid_dataset, indices)\n","\n","# Create dataloaders\n","dl_opts = {**get_minimagen_dl_opts(device), 'batch_size': args.BATCH_SIZE, 'num_workers': args.NUM_WORKERS}\n","train_dataloader = torch.utils.data.DataLoader(train_dataset, **dl_opts)\n","valid_dataloader = torch.utils.data.DataLoader(valid_dataset, **dl_opts)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":105,"referenced_widgets":["f77bbc2602d846f8bb1c9e06f7b519ef","d2c87c66057f4e33bdbfe078ee47c0b2","fe33fa70d2c540e9831d408ffb5d3af9","61d7b89f706348548ebcaf0ced92a44e","a51b1054233342feaef0b16d2627a658","c4aa2aa530034af487957db71e9c509f","b4957dccdc5c4b83982212497febf4dc","592c8770c24c4366843336989c8cdb5f","f6c8ccfd74844b7ebafb37f181b51130","178935e2880043dc9eca862c09c05c16","aac7d4c2eebf4eb1914d2d98c52243ce"]},"id":"6LVG7NbZmNQq","executionInfo":{"status":"ok","timestamp":1690786766692,"user_tz":-330,"elapsed":9916,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"7b91d590-fb10-4c7a-e432-a315c656d54a"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["WARNING:datasets.builder:No config specified, defaulting to: conceptual_captions/unlabeled\n","WARNING:datasets.builder:Reusing dataset conceptual_captions (/root/.cache/huggingface/datasets/conceptual_captions/unlabeled/1.0.0/05266784888422e36944016874c44639bccb39069c2227435168ad8b02d600d8)\n"]},{"output_type":"display_data","data":{"text/plain":[" 0%| | 0/2 [00:00<?, ?it/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"f77bbc2602d846f8bb1c9e06f7b519ef"}},"metadata":{}}]},{"cell_type":"markdown","source":["# UNet"],"metadata":{"id":"vmwLUZh2Eqhn"}},{"cell_type":"code","source":["# Instantiate Unet with default parameters and transfer to GPU if available\n","unets_params = [get_default_args(BaseTest), get_default_args(SuperTest)]\n","unets = [Unet(**unet_params).to(device) for unet_params in unets_params]"],"metadata":{"id":"IaKJG4IamPJD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Specify MinImagen parameters\n","imagen_params = dict(\n"," image_sizes=(int(args.IMG_SIDE_LEN / 2), args.IMG_SIDE_LEN),\n"," timesteps=args.TIMESTEPS,\n"," cond_drop_prob=0.15,\n"," text_encoder_name=args.T5_NAME\n",")\n","\n","# Create MinImagen from UNets with specified imagen parameters\n","imagen = Imagen(unets=unets, **imagen_params).to(device)"],"metadata":{"id":"dl-w2Yy6mQ3Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Fill in unspecified arguments with defaults\n","unets_params = [{**get_default_args(Unet), **i} for i in unets_params]\n","imagen_params = {**get_default_args(Imagen), **imagen_params}\n","\n","# Get the size of the Imagen model in megabytes\n","model_size_MB = get_model_size(imagen)\n","\n","# Save all training info (config files, model size, etc.)\n","save_training_info(args, timestamp, unets_params, imagen_params, model_size_MB, training_dir)"],"metadata":{"id":"tzkJfhuRmSqg"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Training"],"metadata":{"id":"0XcZq-Q8EthC"}},{"cell_type":"code","source":["# Create optimizer - Adam\n","optimizer = optim.Adam(imagen.parameters(), lr=args.OPTIM_LR)\n","\n","# Train the MinImagen instance\n","MinimagenTrain(timestamp, args, unets, imagen, train_dataloader, valid_dataloader, training_dir, optimizer, timeout=30)"],"metadata":{"id":"I--5Lt18mUf8","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1690788612546,"user_tz":-330,"elapsed":1835395,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"2e19e526-6526-473f-8e77-007e3df8425d"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 1 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:05<00:22, 5.54s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:31<00:53, 17.84s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:45<00:31, 15.65s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:51<00:11, 11.95s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [01:06<00:00, 13.20s/it]\n","1it [01:14, 74.38s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.1316, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0483, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:27, 29.42s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 2 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:11<00:46, 11.63s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:26<00:40, 13.34s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:34<00:21, 10.96s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:41<00:09, 9.59s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [02:56<00:00, 35.30s/it]\n","1it [03:10, 190.22s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.0965, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0373, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [04:12, 50.50s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 3 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:07<00:31, 7.80s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:13<00:19, 6.65s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:27<00:20, 10.07s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:32<00:07, 7.78s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [02:54<00:00, 34.83s/it]\n","1it [03:54, 234.78s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.0735, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0289, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [04:21, 52.38s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 4 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:15<01:02, 15.63s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:23<00:32, 10.99s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:29<00:17, 8.55s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [02:48<01:00, 60.17s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [02:53<00:00, 34.62s/it]\n","1it [02:57, 177.30s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.0502, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0210, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [04:07, 49.47s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 5 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:07<00:28, 7.14s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:31<00:52, 17.42s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:37<00:23, 11.92s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:43<00:09, 9.72s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [00:51<00:00, 10.29s/it]\n","1it [01:16, 76.25s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.0274, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0135, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:32, 30.58s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 6 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:34<02:17, 34.45s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:45<01:01, 20.50s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:51<00:28, 14.07s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:56<00:10, 10.56s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [01:11<00:00, 14.37s/it]\n","1it [01:17, 77.59s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.0088, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0083, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:25, 29.04s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 7 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:05<00:21, 5.37s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:18<00:29, 9.75s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:34<00:25, 12.60s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:49<00:13, 13.81s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [00:53<00:00, 10.75s/it]\n","1it [01:00, 60.32s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(0.9863, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0049, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:07, 25.52s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 8 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:10<00:41, 10.40s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:15<00:22, 7.41s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:30<00:21, 10.65s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:41<00:10, 11.00s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [00:51<00:00, 10.24s/it]\n","1it [01:04, 64.48s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(0.9715, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0007, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:05, 25.04s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 9 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:10<00:41, 10.36s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:16<00:23, 7.89s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:24<00:15, 7.73s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:35<00:09, 9.07s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [02:51<00:00, 34.28s/it]\n","1it [03:30, 210.85s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(0.9587, device='cuda:0')\n","Unet 1 avg validation loss: tensor(0.9981, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [04:11, 50.39s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 10 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|ββ | 1/5 [00:23<01:32, 23.20s/it]\u001b[A\n"," 40%|ββββ | 2/5 [00:33<00:46, 15.36s/it]\u001b[A\n"," 60%|ββββββ | 3/5 [00:37<00:20, 10.14s/it]\u001b[A\n"," 80%|ββββββββ | 4/5 [00:43<00:08, 8.75s/it]\u001b[A\n","100%|ββββββββββ| 5/5 [00:50<00:00, 10.04s/it]\n","1it [01:37, 97.50s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(0.9483, device='cuda:0')\n","Unet 1 avg validation loss: tensor(0.9955, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:03, 24.66s/it]\n"]}]},{"cell_type":"markdown","source":["# Inference"],"metadata":{"id":"BEnz_4zPEwgu"}},{"cell_type":"code","source":["from argparse import ArgumentParser\n","from minimagen.generate import load_minimagen, sample_and_save\n"],"metadata":{"id":"PUWUXmYDmdpm"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Specify the caption(s) to generate images for\n","captions = ['happy']"],"metadata":{"id":"PPzAqX0qmeKa"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["args_cls"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"BGpmb7jamimu","executionInfo":{"status":"ok","timestamp":1690784383671,"user_tz":-330,"elapsed":445,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"7972ae9b-1573-48de-d4be-923370eeb9e4"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["__main__.args_cls"]},"metadata":{},"execution_count":22}]},{"cell_type":"code","source":["# Use `sample_and_save` to generate and save the iamges\n","sample_and_save(captions, training_directory='/content/training_20230731_065902')"],"metadata":{"id":"fMxM5zdNmf8e","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1690788695591,"user_tz":-330,"elapsed":3817,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"01a1d7c8-bfcc-420b-e679-2f57525f8d30"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["0it [00:00, ?it/s]\n","sampling loop time step: 0%| | 0/25 [00:00<?, ?it/s]\u001b[A\n","sampling loop time step: 12%|ββ | 3/25 [00:00<00:01, 21.00it/s]\u001b[A\n","sampling loop time step: 24%|βββ | 6/25 [00:00<00:00, 20.08it/s]\u001b[A\n","sampling loop time step: 36%|ββββ | 9/25 [00:00<00:00, 20.17it/s]\u001b[A\n","sampling loop time step: 48%|βββββ | 12/25 [00:00<00:00, 20.03it/s]\u001b[A\n","sampling loop time step: 60%|ββββββ | 15/25 [00:00<00:00, 20.02it/s]\u001b[A\n","sampling loop time step: 72%|ββββββββ | 18/25 [00:00<00:00, 19.98it/s]\u001b[A\n","sampling loop time step: 80%|ββββββββ | 20/25 [00:01<00:00, 19.70it/s]\u001b[A\n","sampling loop time step: 88%|βββββββββ | 22/25 [00:01<00:00, 19.56it/s]\u001b[A\n","sampling loop time step: 100%|ββββββββββ| 25/25 [00:01<00:00, 19.65it/s]\n","1it [00:01, 1.28s/it]\n","sampling loop time step: 0%| | 0/25 [00:00<?, ?it/s]\u001b[A\n","sampling loop time step: 8%|β | 2/25 [00:00<00:01, 11.65it/s]\u001b[A\n","sampling loop time step: 16%|ββ | 4/25 [00:00<00:01, 11.37it/s]\u001b[A\n","sampling loop time step: 24%|βββ | 6/25 [00:00<00:01, 11.53it/s]\u001b[A\n","sampling loop time step: 32%|ββββ | 8/25 [00:00<00:01, 11.37it/s]\u001b[A\n","sampling loop time step: 40%|ββββ | 10/25 [00:00<00:01, 11.47it/s]\u001b[A\n","sampling loop time step: 48%|βββββ | 12/25 [00:01<00:01, 11.08it/s]\u001b[A\n","sampling loop time step: 56%|ββββββ | 14/25 [00:01<00:00, 11.27it/s]\u001b[A\n","sampling loop time step: 64%|βββββββ | 16/25 [00:01<00:00, 11.41it/s]\u001b[A\n","sampling loop time step: 72%|ββββββββ | 18/25 [00:01<00:00, 11.18it/s]\u001b[A\n","sampling loop time step: 80%|ββββββββ | 20/25 [00:01<00:00, 11.21it/s]\u001b[A\n","sampling loop time step: 88%|βββββββββ | 22/25 [00:01<00:00, 11.29it/s]\u001b[A\n","sampling loop time step: 100%|ββββββββββ| 25/25 [00:02<00:00, 11.27it/s]\n","2it [00:03, 1.76s/it]\n"]}]}]}
|
app.py
CHANGED
@@ -7,6 +7,7 @@ import streamlit as st
|
|
7 |
from text_to_image import generate_image
|
8 |
from feature_to_sprite import generate_sprites
|
9 |
|
|
|
10 |
def setup():
|
11 |
"""
|
12 |
Streamlit related setup. This has to be run for each page.
|
@@ -81,33 +82,35 @@ def main():
|
|
81 |
This mode generates 16*16 images of sprites based on a combination of features. It uses a custom model trained on a dataset of sprites.
|
82 |
"""
|
83 |
)
|
84 |
-
|
85 |
form = st.form(key="my_form")
|
86 |
|
87 |
-
#add sliders
|
88 |
hero = form.slider("Hero", min_value=0.0, max_value=1.0, value=1.0, step=0.01)
|
89 |
-
non_hero = form.slider(
|
|
|
|
|
90 |
food = form.slider("Food", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
|
91 |
spell = form.slider("Spell", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
|
92 |
-
side_facing = form.slider(
|
93 |
-
|
|
|
|
|
94 |
submit_button = form.form_submit_button(label="Generate")
|
95 |
|
96 |
-
#create feature vector
|
97 |
if submit_button:
|
98 |
feature_vector = [hero, non_hero, food, spell, side_facing]
|
99 |
-
#show loader
|
100 |
with st.spinner("Generating sprite..."):
|
101 |
-
#horizontal line and line break
|
102 |
st.markdown("<hr>", unsafe_allow_html=True)
|
103 |
st.markdown("<br>", unsafe_allow_html=True)
|
104 |
-
|
105 |
st.subheader("Your Sprite")
|
106 |
st.markdown("<br>", unsafe_allow_html=True)
|
107 |
-
|
108 |
-
generate_sprites(feature_vector)
|
109 |
-
|
110 |
|
|
|
111 |
|
112 |
|
113 |
if __name__ == "__main__":
|
|
|
7 |
from text_to_image import generate_image
|
8 |
from feature_to_sprite import generate_sprites
|
9 |
|
10 |
+
|
11 |
def setup():
|
12 |
"""
|
13 |
Streamlit related setup. This has to be run for each page.
|
|
|
82 |
This mode generates 16*16 images of sprites based on a combination of features. It uses a custom model trained on a dataset of sprites.
|
83 |
"""
|
84 |
)
|
85 |
+
|
86 |
form = st.form(key="my_form")
|
87 |
|
88 |
+
# add sliders
|
89 |
hero = form.slider("Hero", min_value=0.0, max_value=1.0, value=1.0, step=0.01)
|
90 |
+
non_hero = form.slider(
|
91 |
+
"Non Hero", min_value=0.0, max_value=1.0, value=0.0, step=0.01
|
92 |
+
)
|
93 |
food = form.slider("Food", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
|
94 |
spell = form.slider("Spell", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
|
95 |
+
side_facing = form.slider(
|
96 |
+
"Side Facing", min_value=0.0, max_value=1.0, value=0.0, step=0.01
|
97 |
+
)
|
98 |
+
# add submit button
|
99 |
submit_button = form.form_submit_button(label="Generate")
|
100 |
|
101 |
+
# create feature vector
|
102 |
if submit_button:
|
103 |
feature_vector = [hero, non_hero, food, spell, side_facing]
|
104 |
+
# show loader
|
105 |
with st.spinner("Generating sprite..."):
|
106 |
+
# horizontal line and line break
|
107 |
st.markdown("<hr>", unsafe_allow_html=True)
|
108 |
st.markdown("<br>", unsafe_allow_html=True)
|
109 |
+
|
110 |
st.subheader("Your Sprite")
|
111 |
st.markdown("<br>", unsafe_allow_html=True)
|
|
|
|
|
|
|
112 |
|
113 |
+
_ = generate_sprites(feature_vector)
|
114 |
|
115 |
|
116 |
if __name__ == "__main__":
|
constants.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
"""
|
2 |
This file contains all the constants used in the project.
|
3 |
"""
|
4 |
-
import os
|
5 |
|
6 |
MODEL_ID = "stabilityai/stable-diffusion-2-1"
|
7 |
WANDB_API_KEY = os.environ.get("WANDB_API_KEY")
|
|
|
1 |
"""
|
2 |
This file contains all the constants used in the project.
|
3 |
"""
|
4 |
+
import os
|
5 |
|
6 |
MODEL_ID = "stabilityai/stable-diffusion-2-1"
|
7 |
WANDB_API_KEY = os.environ.get("WANDB_API_KEY")
|
feature_to_sprite.py
CHANGED
@@ -8,31 +8,35 @@ import streamlit as st
|
|
8 |
from utilities import ContextUnet, setup_ddpm
|
9 |
from constants import WANDB_API_KEY
|
10 |
|
|
|
11 |
def load_model():
|
|
|
|
|
|
|
|
|
12 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
-
|
14 |
-
#login to wandb
|
15 |
-
#wandb.login(key=WANDB_API_KEY)
|
16 |
-
|
17 |
-
|
18 |
-
"Load the model from wandb artifacts"
|
19 |
api = wandb.Api(api_key=WANDB_API_KEY)
|
20 |
-
artifact = api.artifact("teamaditya/model-registry/Feature2Sprite:
|
21 |
model_path = Path(artifact.download())
|
22 |
|
23 |
# recover model info from the registry
|
24 |
producer_run = artifact.logged_by()
|
25 |
|
26 |
# load the weights dictionary
|
27 |
-
model_weights = torch.load(model_path/"context_model.pth",
|
28 |
-
map_location="cpu")
|
29 |
|
30 |
# create the model
|
31 |
-
model = ContextUnet(
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
36 |
# load the weights into the model
|
37 |
model.load_state_dict(model_weights)
|
38 |
|
@@ -40,46 +44,44 @@ def load_model():
|
|
40 |
model.eval()
|
41 |
return model.to(DEVICE)
|
42 |
|
|
|
43 |
def show_image(img):
|
|
|
|
|
|
|
44 |
img = (img.permute(1, 2, 0).clip(-1, 1).detach().cpu().numpy() + 1) / 2
|
45 |
st.image(img, clamp=True)
|
|
|
|
|
46 |
|
47 |
def generate_sprites(feature_vector):
|
|
|
|
|
|
|
48 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
49 |
config = SimpleNamespace(
|
50 |
# hyperparameters
|
51 |
-
num_samples
|
52 |
-
|
53 |
# ddpm sampler hyperparameters
|
54 |
-
timesteps
|
55 |
-
beta1
|
56 |
-
beta2
|
57 |
-
|
58 |
# network hyperparameters
|
59 |
-
height
|
60 |
)
|
61 |
nn_model = load_model()
|
62 |
-
|
63 |
-
_, sample_ddpm_context = setup_ddpm(
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
config.height, config.height).to(DEVICE)
|
70 |
-
|
71 |
feature_vector = torch.tensor([feature_vector]).to(DEVICE).float()
|
72 |
ddpm_samples, _ = sample_ddpm_context(nn_model, noises, feature_vector)
|
73 |
|
74 |
-
#upscale the 16*16 images to 256*256
|
75 |
ddpm_samples = F.interpolate(ddpm_samples, size=(256, 256), mode="bilinear")
|
76 |
# show the images
|
77 |
-
show_image(ddpm_samples[0])
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
8 |
from utilities import ContextUnet, setup_ddpm
|
9 |
from constants import WANDB_API_KEY
|
10 |
|
11 |
+
|
12 |
def load_model():
|
13 |
+
"""
|
14 |
+
This function loads the model from the model registry.
|
15 |
+
"""
|
16 |
+
|
17 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
|
19 |
+
# login to wandb
|
20 |
+
# wandb.login(key=WANDB_API_KEY)
|
21 |
+
|
|
|
|
|
22 |
api = wandb.Api(api_key=WANDB_API_KEY)
|
23 |
+
artifact = api.artifact("teamaditya/model-registry/Feature2Sprite:v1", type="model")
|
24 |
model_path = Path(artifact.download())
|
25 |
|
26 |
# recover model info from the registry
|
27 |
producer_run = artifact.logged_by()
|
28 |
|
29 |
# load the weights dictionary
|
30 |
+
model_weights = torch.load(model_path / "context_model.pth", map_location="cpu")
|
|
|
31 |
|
32 |
# create the model
|
33 |
+
model = ContextUnet(
|
34 |
+
in_channels=3,
|
35 |
+
n_feat=producer_run.config["n_feat"],
|
36 |
+
n_cfeat=producer_run.config["n_cfeat"],
|
37 |
+
height=producer_run.config["height"],
|
38 |
+
)
|
39 |
+
|
40 |
# load the weights into the model
|
41 |
model.load_state_dict(model_weights)
|
42 |
|
|
|
44 |
model.eval()
|
45 |
return model.to(DEVICE)
|
46 |
|
47 |
+
|
48 |
def show_image(img):
|
49 |
+
"""
|
50 |
+
This function shows the image in the streamlit app.
|
51 |
+
"""
|
52 |
img = (img.permute(1, 2, 0).clip(-1, 1).detach().cpu().numpy() + 1) / 2
|
53 |
st.image(img, clamp=True)
|
54 |
+
return img
|
55 |
+
|
56 |
|
57 |
def generate_sprites(feature_vector):
|
58 |
+
"""
|
59 |
+
This function generates sprites from a given feature vector.
|
60 |
+
"""
|
61 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
62 |
config = SimpleNamespace(
|
63 |
# hyperparameters
|
64 |
+
num_samples=30,
|
|
|
65 |
# ddpm sampler hyperparameters
|
66 |
+
timesteps=500,
|
67 |
+
beta1=1e-4,
|
68 |
+
beta2=0.02,
|
|
|
69 |
# network hyperparameters
|
70 |
+
height=16,
|
71 |
)
|
72 |
nn_model = load_model()
|
73 |
+
|
74 |
+
_, sample_ddpm_context = setup_ddpm(
|
75 |
+
config.beta1, config.beta2, config.timesteps, DEVICE
|
76 |
+
)
|
77 |
+
|
78 |
+
noises = torch.randn(config.num_samples, 3, config.height, config.height).to(DEVICE)
|
79 |
+
|
|
|
|
|
80 |
feature_vector = torch.tensor([feature_vector]).to(DEVICE).float()
|
81 |
ddpm_samples, _ = sample_ddpm_context(nn_model, noises, feature_vector)
|
82 |
|
83 |
+
# upscale the 16*16 images to 256*256
|
84 |
ddpm_samples = F.interpolate(ddpm_samples, size=(256, 256), mode="bilinear")
|
85 |
# show the images
|
86 |
+
img = show_image(ddpm_samples[0])
|
87 |
+
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_to_image.py
CHANGED
@@ -20,7 +20,9 @@ def generate_image(prompt: str) -> torch.Tensor:
|
|
20 |
"""
|
21 |
# load model
|
22 |
|
23 |
-
pipe = StableDiffusionPipeline.from_pretrained(
|
|
|
|
|
24 |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
25 |
if torch.cuda.is_available():
|
26 |
pipe = pipe.to("cuda") # move model to GPU if available
|
|
|
20 |
"""
|
21 |
# load model
|
22 |
|
23 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
24 |
+
MODEL_ID
|
25 |
+
) # torch_dtype=torch.float16
|
26 |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
27 |
if torch.cuda.is_available():
|
28 |
pipe = pipe.to("cuda") # move model to GPU if available
|
utilities.py
CHANGED
@@ -3,36 +3,41 @@ import torch
|
|
3 |
import torch.nn as nn
|
4 |
from tqdm.auto import tqdm
|
5 |
|
|
|
6 |
class ContextUnet(nn.Module):
|
7 |
-
def __init__(
|
|
|
|
|
8 |
super(ContextUnet, self).__init__()
|
9 |
|
10 |
# number of input channels, number of intermediate feature maps and number of classes
|
11 |
self.in_channels = in_channels
|
12 |
self.n_feat = n_feat
|
13 |
self.n_cfeat = n_cfeat
|
14 |
-
self.h = height #assume h == w. must be divisible by 4, so 28,24,20,16...
|
15 |
|
16 |
# Initialize the initial convolutional layer
|
17 |
self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
|
18 |
|
19 |
# Initialize the down-sampling path of the U-Net with two levels
|
20 |
-
self.down1 = UnetDown(n_feat, n_feat)
|
21 |
-
self.down2 = UnetDown(n_feat, 2 * n_feat)
|
22 |
-
|
23 |
-
|
24 |
self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())
|
25 |
|
26 |
# Embed the timestep and context labels with a one-layer fully connected neural network
|
27 |
-
self.timeembed1 = EmbedFC(1, 2*n_feat)
|
28 |
-
self.timeembed2 = EmbedFC(1, 1*n_feat)
|
29 |
-
self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat)
|
30 |
-
self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat)
|
31 |
|
32 |
# Initialize the up-sampling path of the U-Net with three levels
|
33 |
self.up0 = nn.Sequential(
|
34 |
-
nn.ConvTranspose2d(
|
35 |
-
|
|
|
|
|
36 |
nn.ReLU(),
|
37 |
)
|
38 |
self.up1 = UnetUp(4 * n_feat, n_feat)
|
@@ -40,10 +45,14 @@ class ContextUnet(nn.Module):
|
|
40 |
|
41 |
# Initialize the final convolutional layers to map to the same number of channels as the input image
|
42 |
self.out = nn.Sequential(
|
43 |
-
nn.Conv2d(
|
44 |
-
|
|
|
|
|
45 |
nn.ReLU(),
|
46 |
-
nn.Conv2d(
|
|
|
|
|
47 |
)
|
48 |
|
49 |
def forward(self, x, t, c=None):
|
@@ -57,30 +66,32 @@ class ContextUnet(nn.Module):
|
|
57 |
# pass the input image through the initial convolutional layer
|
58 |
x = self.init_conv(x)
|
59 |
# pass the result through the down-sampling path
|
60 |
-
down1 = self.down1(x)
|
61 |
-
down2 = self.down2(down1)
|
62 |
-
|
63 |
# convert the feature maps to a vector and apply an activation
|
64 |
hiddenvec = self.to_vec(down2)
|
65 |
-
|
66 |
# mask out context if context_mask == 1
|
67 |
if c is None:
|
68 |
c = torch.zeros(x.shape[0], self.n_cfeat).to(x)
|
69 |
-
|
70 |
# embed context and timestep
|
71 |
-
cemb1 = self.contextembed1(c).view(
|
|
|
|
|
72 |
temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
|
73 |
cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
|
74 |
temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
|
75 |
-
#print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}")
|
76 |
-
|
77 |
|
78 |
up1 = self.up0(hiddenvec)
|
79 |
-
up2 = self.up1(cemb1*up1 + temb1, down2) # add and multiply embeddings
|
80 |
-
up3 = self.up2(cemb2*up2 + temb2, down1)
|
81 |
out = self.out(torch.cat((up3, x), 1))
|
82 |
return out
|
83 |
|
|
|
84 |
class ResidualConvBlock(nn.Module):
|
85 |
def __init__(
|
86 |
self, in_channels: int, out_channels: int, is_res: bool = False
|
@@ -95,20 +106,23 @@ class ResidualConvBlock(nn.Module):
|
|
95 |
|
96 |
# First convolutional layer
|
97 |
self.conv1 = nn.Sequential(
|
98 |
-
nn.Conv2d(
|
99 |
-
|
100 |
-
|
|
|
|
|
101 |
)
|
102 |
|
103 |
# Second convolutional layer
|
104 |
self.conv2 = nn.Sequential(
|
105 |
-
nn.Conv2d(
|
106 |
-
|
107 |
-
|
|
|
|
|
108 |
)
|
109 |
|
110 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
111 |
-
|
112 |
# If using residual connection
|
113 |
if self.is_res:
|
114 |
# Apply first convolutional layer
|
@@ -122,9 +136,11 @@ class ResidualConvBlock(nn.Module):
|
|
122 |
out = x + x2
|
123 |
else:
|
124 |
# If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection
|
125 |
-
shortcut = nn.Conv2d(
|
|
|
|
|
126 |
out = shortcut(x) + x2
|
127 |
-
#print(f"resconv forward: x {x.shape}, x1 {x1.shape}, x2 {x2.shape}, out {out.shape}")
|
128 |
|
129 |
# Normalize output tensor
|
130 |
return out / 1.414
|
@@ -145,12 +161,11 @@ class ResidualConvBlock(nn.Module):
|
|
145 |
self.conv2[0].in_channels = out_channels
|
146 |
self.conv2[0].out_channels = out_channels
|
147 |
|
148 |
-
|
149 |
|
150 |
class UnetUp(nn.Module):
|
151 |
def __init__(self, in_channels, out_channels):
|
152 |
super(UnetUp, self).__init__()
|
153 |
-
|
154 |
# Create a list of layers for the upsampling block
|
155 |
# The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers
|
156 |
layers = [
|
@@ -158,27 +173,31 @@ class UnetUp(nn.Module):
|
|
158 |
ResidualConvBlock(out_channels, out_channels),
|
159 |
ResidualConvBlock(out_channels, out_channels),
|
160 |
]
|
161 |
-
|
162 |
# Use the layers to create a sequential model
|
163 |
self.model = nn.Sequential(*layers)
|
164 |
|
165 |
def forward(self, x, skip):
|
166 |
# Concatenate the input tensor x with the skip connection tensor along the channel dimension
|
167 |
x = torch.cat((x, skip), 1)
|
168 |
-
|
169 |
# Pass the concatenated tensor through the sequential model and return the output
|
170 |
x = self.model(x)
|
171 |
return x
|
172 |
|
173 |
-
|
174 |
class UnetDown(nn.Module):
|
175 |
def __init__(self, in_channels, out_channels):
|
176 |
super(UnetDown, self).__init__()
|
177 |
-
|
178 |
# Create a list of layers for the downsampling block
|
179 |
# Each block consists of two ResidualConvBlock layers, followed by a MaxPool2d layer for downsampling
|
180 |
-
layers = [
|
181 |
-
|
|
|
|
|
|
|
|
|
182 |
# Use the layers to create a sequential model
|
183 |
self.model = nn.Sequential(*layers)
|
184 |
|
@@ -186,22 +205,23 @@ class UnetDown(nn.Module):
|
|
186 |
# Pass the input through the sequential model and return the output
|
187 |
return self.model(x)
|
188 |
|
|
|
189 |
class EmbedFC(nn.Module):
|
190 |
def __init__(self, input_dim, emb_dim):
|
191 |
super(EmbedFC, self).__init__()
|
192 |
-
|
193 |
This class defines a generic one layer feed-forward neural network for embedding input data of
|
194 |
dimensionality input_dim to an embedding space of dimensionality emb_dim.
|
195 |
-
|
196 |
self.input_dim = input_dim
|
197 |
-
|
198 |
# define the layers for the network
|
199 |
layers = [
|
200 |
nn.Linear(input_dim, emb_dim),
|
201 |
nn.GELU(),
|
202 |
nn.Linear(emb_dim, emb_dim),
|
203 |
]
|
204 |
-
|
205 |
# create a PyTorch sequential model consisting of the defined layers
|
206 |
self.model = nn.Sequential(*layers)
|
207 |
|
@@ -210,46 +230,53 @@ class EmbedFC(nn.Module):
|
|
210 |
x = x.view(-1, self.input_dim)
|
211 |
# apply the model layers to the flattened tensor
|
212 |
return self.model(x)
|
213 |
-
|
|
|
214 |
def unorm(x):
|
215 |
# unity norm. results in range of [0,1]
|
216 |
# assume x (h,w,3)
|
217 |
-
xmax = x.max((0,1))
|
218 |
-
xmin = x.min((0,1))
|
219 |
-
return(x - xmin)/(xmax - xmin)
|
|
|
220 |
|
221 |
def norm_all(store, n_t, n_s):
|
222 |
# runs unity norm on all timesteps of all samples
|
223 |
nstore = np.zeros_like(store)
|
224 |
for t in range(n_t):
|
225 |
for s in range(n_s):
|
226 |
-
nstore[t,s] = unorm(store[t,s])
|
227 |
return nstore
|
228 |
|
|
|
229 |
def norm_torch(x_all):
|
230 |
# runs unity norm on all timesteps of all samples
|
231 |
# input is (n_samples, 3,h,w), the torch image format
|
232 |
x = x_all.cpu().numpy()
|
233 |
-
xmax = x.max((2,3))
|
234 |
-
xmin = x.min((2,3))
|
235 |
-
xmax = np.expand_dims(xmax,(2,3))
|
236 |
-
xmin = np.expand_dims(xmin,(2,3))
|
237 |
-
nstore = (x - xmin)/(xmax - xmin)
|
238 |
return torch.from_numpy(nstore)
|
239 |
|
240 |
|
241 |
## diffusion functions
|
242 |
|
|
|
243 |
def setup_ddpm(beta1, beta2, timesteps, device):
|
244 |
# construct DDPM noise schedule and sampling functions
|
245 |
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
|
246 |
a_t = 1 - b_t
|
247 |
-
ab_t = torch.cumsum(a_t.log(), dim=0).exp()
|
248 |
ab_t[0] = 1
|
249 |
|
250 |
# helper function: perturbs an image to a specified noise level
|
251 |
def perturb_input(x, t, noise):
|
252 |
-
return
|
|
|
|
|
|
|
253 |
|
254 |
# helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
|
255 |
def _denoise_add_noise(x, t, pred_noise, z=None):
|
@@ -264,10 +291,10 @@ def setup_ddpm(beta1, beta2, timesteps, device):
|
|
264 |
@torch.no_grad()
|
265 |
def sample_ddpm_context(nn_model, noises, context, save_rate=20):
|
266 |
# array to keep track of generated steps for plotting
|
267 |
-
intermediate = []
|
268 |
pbar = tqdm(range(timesteps, 0, -1), leave=False)
|
269 |
for i in pbar:
|
270 |
-
pbar.set_description(f
|
271 |
|
272 |
# reshape time tensor
|
273 |
t = torch.tensor([i / timesteps])[:, None, None, None].to(noises.device)
|
@@ -275,12 +302,12 @@ def setup_ddpm(beta1, beta2, timesteps, device):
|
|
275 |
# sample some random noise to inject back in. For i = 1, don't add back in noise
|
276 |
z = torch.randn_like(noises) if i > 1 else 0
|
277 |
|
278 |
-
eps = nn_model(noises, t, c=context)
|
279 |
noises = _denoise_add_noise(noises, i, eps, z)
|
280 |
-
if i % save_rate==0 or i==timesteps or i<8:
|
281 |
intermediate.append(noises.detach().cpu().numpy())
|
282 |
|
283 |
intermediate = np.stack(intermediate)
|
284 |
return noises.clip(-1, 1), intermediate
|
285 |
-
|
286 |
return perturb_input, sample_ddpm_context
|
|
|
3 |
import torch.nn as nn
|
4 |
from tqdm.auto import tqdm
|
5 |
|
6 |
+
|
7 |
class ContextUnet(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self, in_channels, n_feat=256, n_cfeat=10, height=28
|
10 |
+
): # cfeat - context features
|
11 |
super(ContextUnet, self).__init__()
|
12 |
|
13 |
# number of input channels, number of intermediate feature maps and number of classes
|
14 |
self.in_channels = in_channels
|
15 |
self.n_feat = n_feat
|
16 |
self.n_cfeat = n_cfeat
|
17 |
+
self.h = height # assume h == w. must be divisible by 4, so 28,24,20,16...
|
18 |
|
19 |
# Initialize the initial convolutional layer
|
20 |
self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
|
21 |
|
22 |
# Initialize the down-sampling path of the U-Net with two levels
|
23 |
+
self.down1 = UnetDown(n_feat, n_feat) # down1 #[10, 256, 8, 8]
|
24 |
+
self.down2 = UnetDown(n_feat, 2 * n_feat) # down2 #[10, 256, 4, 4]
|
25 |
+
|
26 |
+
# original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
|
27 |
self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())
|
28 |
|
29 |
# Embed the timestep and context labels with a one-layer fully connected neural network
|
30 |
+
self.timeembed1 = EmbedFC(1, 2 * n_feat)
|
31 |
+
self.timeembed2 = EmbedFC(1, 1 * n_feat)
|
32 |
+
self.contextembed1 = EmbedFC(n_cfeat, 2 * n_feat)
|
33 |
+
self.contextembed2 = EmbedFC(n_cfeat, 1 * n_feat)
|
34 |
|
35 |
# Initialize the up-sampling path of the U-Net with three levels
|
36 |
self.up0 = nn.Sequential(
|
37 |
+
nn.ConvTranspose2d(
|
38 |
+
2 * n_feat, 2 * n_feat, self.h // 4, self.h // 4
|
39 |
+
), # up-sample
|
40 |
+
nn.GroupNorm(8, 2 * n_feat), # normalize
|
41 |
nn.ReLU(),
|
42 |
)
|
43 |
self.up1 = UnetUp(4 * n_feat, n_feat)
|
|
|
45 |
|
46 |
# Initialize the final convolutional layers to map to the same number of channels as the input image
|
47 |
self.out = nn.Sequential(
|
48 |
+
nn.Conv2d(
|
49 |
+
2 * n_feat, n_feat, 3, 1, 1
|
50 |
+
), # reduce number of feature maps #in_channels, out_channels, kernel_size, stride=1, padding=0
|
51 |
+
nn.GroupNorm(8, n_feat), # normalize
|
52 |
nn.ReLU(),
|
53 |
+
nn.Conv2d(
|
54 |
+
n_feat, self.in_channels, 3, 1, 1
|
55 |
+
), # map to same number of channels as input
|
56 |
)
|
57 |
|
58 |
def forward(self, x, t, c=None):
|
|
|
66 |
# pass the input image through the initial convolutional layer
|
67 |
x = self.init_conv(x)
|
68 |
# pass the result through the down-sampling path
|
69 |
+
down1 = self.down1(x) # [10, 256, 8, 8]
|
70 |
+
down2 = self.down2(down1) # [10, 256, 4, 4]
|
71 |
+
|
72 |
# convert the feature maps to a vector and apply an activation
|
73 |
hiddenvec = self.to_vec(down2)
|
74 |
+
|
75 |
# mask out context if context_mask == 1
|
76 |
if c is None:
|
77 |
c = torch.zeros(x.shape[0], self.n_cfeat).to(x)
|
78 |
+
|
79 |
# embed context and timestep
|
80 |
+
cemb1 = self.contextembed1(c).view(
|
81 |
+
-1, self.n_feat * 2, 1, 1
|
82 |
+
) # (batch, 2*n_feat, 1,1)
|
83 |
temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
|
84 |
cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
|
85 |
temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
|
86 |
+
# print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}")
|
|
|
87 |
|
88 |
up1 = self.up0(hiddenvec)
|
89 |
+
up2 = self.up1(cemb1 * up1 + temb1, down2) # add and multiply embeddings
|
90 |
+
up3 = self.up2(cemb2 * up2 + temb2, down1)
|
91 |
out = self.out(torch.cat((up3, x), 1))
|
92 |
return out
|
93 |
|
94 |
+
|
95 |
class ResidualConvBlock(nn.Module):
|
96 |
def __init__(
|
97 |
self, in_channels: int, out_channels: int, is_res: bool = False
|
|
|
106 |
|
107 |
# First convolutional layer
|
108 |
self.conv1 = nn.Sequential(
|
109 |
+
nn.Conv2d(
|
110 |
+
in_channels, out_channels, 3, 1, 1
|
111 |
+
), # 3x3 kernel with stride 1 and padding 1
|
112 |
+
nn.BatchNorm2d(out_channels), # Batch normalization
|
113 |
+
nn.GELU(), # GELU activation function
|
114 |
)
|
115 |
|
116 |
# Second convolutional layer
|
117 |
self.conv2 = nn.Sequential(
|
118 |
+
nn.Conv2d(
|
119 |
+
out_channels, out_channels, 3, 1, 1
|
120 |
+
), # 3x3 kernel with stride 1 and padding 1
|
121 |
+
nn.BatchNorm2d(out_channels), # Batch normalization
|
122 |
+
nn.GELU(), # GELU activation function
|
123 |
)
|
124 |
|
125 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
126 |
# If using residual connection
|
127 |
if self.is_res:
|
128 |
# Apply first convolutional layer
|
|
|
136 |
out = x + x2
|
137 |
else:
|
138 |
# If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection
|
139 |
+
shortcut = nn.Conv2d(
|
140 |
+
x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0
|
141 |
+
).to(x.device)
|
142 |
out = shortcut(x) + x2
|
143 |
+
# print(f"resconv forward: x {x.shape}, x1 {x1.shape}, x2 {x2.shape}, out {out.shape}")
|
144 |
|
145 |
# Normalize output tensor
|
146 |
return out / 1.414
|
|
|
161 |
self.conv2[0].in_channels = out_channels
|
162 |
self.conv2[0].out_channels = out_channels
|
163 |
|
|
|
164 |
|
165 |
class UnetUp(nn.Module):
|
166 |
def __init__(self, in_channels, out_channels):
|
167 |
super(UnetUp, self).__init__()
|
168 |
+
|
169 |
# Create a list of layers for the upsampling block
|
170 |
# The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers
|
171 |
layers = [
|
|
|
173 |
ResidualConvBlock(out_channels, out_channels),
|
174 |
ResidualConvBlock(out_channels, out_channels),
|
175 |
]
|
176 |
+
|
177 |
# Use the layers to create a sequential model
|
178 |
self.model = nn.Sequential(*layers)
|
179 |
|
180 |
def forward(self, x, skip):
|
181 |
# Concatenate the input tensor x with the skip connection tensor along the channel dimension
|
182 |
x = torch.cat((x, skip), 1)
|
183 |
+
|
184 |
# Pass the concatenated tensor through the sequential model and return the output
|
185 |
x = self.model(x)
|
186 |
return x
|
187 |
|
188 |
+
|
189 |
class UnetDown(nn.Module):
|
190 |
def __init__(self, in_channels, out_channels):
|
191 |
super(UnetDown, self).__init__()
|
192 |
+
|
193 |
# Create a list of layers for the downsampling block
|
194 |
# Each block consists of two ResidualConvBlock layers, followed by a MaxPool2d layer for downsampling
|
195 |
+
layers = [
|
196 |
+
ResidualConvBlock(in_channels, out_channels),
|
197 |
+
ResidualConvBlock(out_channels, out_channels),
|
198 |
+
nn.MaxPool2d(2),
|
199 |
+
]
|
200 |
+
|
201 |
# Use the layers to create a sequential model
|
202 |
self.model = nn.Sequential(*layers)
|
203 |
|
|
|
205 |
# Pass the input through the sequential model and return the output
|
206 |
return self.model(x)
|
207 |
|
208 |
+
|
209 |
class EmbedFC(nn.Module):
|
210 |
def __init__(self, input_dim, emb_dim):
|
211 |
super(EmbedFC, self).__init__()
|
212 |
+
"""
|
213 |
This class defines a generic one layer feed-forward neural network for embedding input data of
|
214 |
dimensionality input_dim to an embedding space of dimensionality emb_dim.
|
215 |
+
"""
|
216 |
self.input_dim = input_dim
|
217 |
+
|
218 |
# define the layers for the network
|
219 |
layers = [
|
220 |
nn.Linear(input_dim, emb_dim),
|
221 |
nn.GELU(),
|
222 |
nn.Linear(emb_dim, emb_dim),
|
223 |
]
|
224 |
+
|
225 |
# create a PyTorch sequential model consisting of the defined layers
|
226 |
self.model = nn.Sequential(*layers)
|
227 |
|
|
|
230 |
x = x.view(-1, self.input_dim)
|
231 |
# apply the model layers to the flattened tensor
|
232 |
return self.model(x)
|
233 |
+
|
234 |
+
|
235 |
def unorm(x):
|
236 |
# unity norm. results in range of [0,1]
|
237 |
# assume x (h,w,3)
|
238 |
+
xmax = x.max((0, 1))
|
239 |
+
xmin = x.min((0, 1))
|
240 |
+
return (x - xmin) / (xmax - xmin)
|
241 |
+
|
242 |
|
243 |
def norm_all(store, n_t, n_s):
|
244 |
# runs unity norm on all timesteps of all samples
|
245 |
nstore = np.zeros_like(store)
|
246 |
for t in range(n_t):
|
247 |
for s in range(n_s):
|
248 |
+
nstore[t, s] = unorm(store[t, s])
|
249 |
return nstore
|
250 |
|
251 |
+
|
252 |
def norm_torch(x_all):
|
253 |
# runs unity norm on all timesteps of all samples
|
254 |
# input is (n_samples, 3,h,w), the torch image format
|
255 |
x = x_all.cpu().numpy()
|
256 |
+
xmax = x.max((2, 3))
|
257 |
+
xmin = x.min((2, 3))
|
258 |
+
xmax = np.expand_dims(xmax, (2, 3))
|
259 |
+
xmin = np.expand_dims(xmin, (2, 3))
|
260 |
+
nstore = (x - xmin) / (xmax - xmin)
|
261 |
return torch.from_numpy(nstore)
|
262 |
|
263 |
|
264 |
## diffusion functions
|
265 |
|
266 |
+
|
267 |
def setup_ddpm(beta1, beta2, timesteps, device):
|
268 |
# construct DDPM noise schedule and sampling functions
|
269 |
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
|
270 |
a_t = 1 - b_t
|
271 |
+
ab_t = torch.cumsum(a_t.log(), dim=0).exp()
|
272 |
ab_t[0] = 1
|
273 |
|
274 |
# helper function: perturbs an image to a specified noise level
|
275 |
def perturb_input(x, t, noise):
|
276 |
+
return (
|
277 |
+
ab_t.sqrt()[t, None, None, None] * x
|
278 |
+
+ (1 - ab_t[t, None, None, None]) * noise
|
279 |
+
)
|
280 |
|
281 |
# helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
|
282 |
def _denoise_add_noise(x, t, pred_noise, z=None):
|
|
|
291 |
@torch.no_grad()
|
292 |
def sample_ddpm_context(nn_model, noises, context, save_rate=20):
|
293 |
# array to keep track of generated steps for plotting
|
294 |
+
intermediate = []
|
295 |
pbar = tqdm(range(timesteps, 0, -1), leave=False)
|
296 |
for i in pbar:
|
297 |
+
pbar.set_description(f"sampling timestep {i:3d}")
|
298 |
|
299 |
# reshape time tensor
|
300 |
t = torch.tensor([i / timesteps])[:, None, None, None].to(noises.device)
|
|
|
302 |
# sample some random noise to inject back in. For i = 1, don't add back in noise
|
303 |
z = torch.randn_like(noises) if i > 1 else 0
|
304 |
|
305 |
+
eps = nn_model(noises, t, c=context) # predict noise e_(x_t,t, ctx)
|
306 |
noises = _denoise_add_noise(noises, i, eps, z)
|
307 |
+
if i % save_rate == 0 or i == timesteps or i < 8:
|
308 |
intermediate.append(noises.detach().cpu().numpy())
|
309 |
|
310 |
intermediate = np.stack(intermediate)
|
311 |
return noises.clip(-1, 1), intermediate
|
312 |
+
|
313 |
return perturb_input, sample_ddpm_context
|