Aditya Patkar commited on
Commit
c046d7f
β€’
1 Parent(s): cb0d40a

Added training files, enforced code formatting

Browse files
Training/MSML612_Project_DDPMTraining.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Training/MSML_612_Project_MinImagen.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4","authorship_tag":"ABX9TyNaEJSE86MIYL6MGO8lUr3+"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU","widgets":{"application/vnd.jupyter.widget-state+json":{"f77bbc2602d846f8bb1c9e06f7b519ef":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_d2c87c66057f4e33bdbfe078ee47c0b2","IPY_MODEL_fe33fa70d2c540e9831d408ffb5d3af9","IPY_MODEL_61d7b89f706348548ebcaf0ced92a44e"],"layout":"IPY_MODEL_a51b1054233342feaef0b16d2627a658"}},"d2c87c66057f4e33bdbfe078ee47c0b2":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_c4aa2aa530034af487957db71e9c509f","placeholder":"​","style":"IPY_MODEL_b4957dccdc5c4b83982212497febf4dc","value":"100%"}},"fe33fa70d2c540e9831d408ffb5d3af9":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_592c8770c24c4366843336989c8cdb5f","max":2,"min":0,"orientation":"horizontal","style":"IPY_MODEL_f6c8ccfd74844b7ebafb37f181b51130","value":2}},"61d7b89f706348548ebcaf0ced92a44e":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_178935e2880043dc9eca862c09c05c16","placeholder":"​","style":"IPY_MODEL_aac7d4c2eebf4eb1914d2d98c52243ce","value":" 2/2 [00:00&lt;00:00, 46.37it/s]"}},"a51b1054233342feaef0b16d2627a658":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"c4aa2aa530034af487957db71e9c509f":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"b4957dccdc5c4b83982212497febf4dc":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"592c8770c24c4366843336989c8cdb5f":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"f6c8ccfd74844b7ebafb37f181b51130":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"178935e2880043dc9eca862c09c05c16":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"aac7d4c2eebf4eb1914d2d98c52243ce":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}}},"cells":[{"cell_type":"markdown","source":["# Introduction\n","\n","This is a training script for a diffusion model called MinImagen. A smaller adaptation of original Imagen architecture introduced by Google."],"metadata":{"id":"yQMHUdQADLcT"}},{"cell_type":"markdown","source":["# Setup"],"metadata":{"id":"Qz5vkO4bEh9b"}},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"id":"LeDWeQuVl_6s","executionInfo":{"status":"ok","timestamp":1690783982884,"user_tz":-330,"elapsed":117837,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"6fb30b50-90aa-4217-82e9-8fdea0826b9c"},"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting minimagen\n"," Downloading minimagen-0.0.9-py3-none-any.whl (43 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.0/43.0 kB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting aiohttp==3.8.1 (from minimagen)\n"," Downloading aiohttp-3.8.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.2 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m21.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting aiosignal==1.2.0 (from minimagen)\n"," Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)\n","Requirement already satisfied: async-timeout==4.0.2 in /usr/local/lib/python3.10/dist-packages (from minimagen) (4.0.2)\n","Collecting attrs==21.4.0 (from minimagen)\n"," Downloading attrs-21.4.0-py2.py3-none-any.whl (60 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.6/60.6 kB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting certifi==2022.6.15 (from minimagen)\n"," Downloading certifi-2022.6.15-py3-none-any.whl (160 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m160.2/160.2 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting charset-normalizer==2.1.0 (from minimagen)\n"," Downloading charset_normalizer-2.1.0-py3-none-any.whl (39 kB)\n","Collecting colorama==0.4.5 (from minimagen)\n"," Downloading colorama-0.4.5-py2.py3-none-any.whl (16 kB)\n","Collecting datasets==2.3.2 (from minimagen)\n"," Downloading datasets-2.3.2-py3-none-any.whl (362 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m362.3/362.3 kB\u001b[0m \u001b[31m36.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting dill==0.3.5.1 (from minimagen)\n"," Downloading dill-0.3.5.1-py2.py3-none-any.whl (95 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting einops==0.4.1 (from minimagen)\n"," Downloading einops-0.4.1-py3-none-any.whl (28 kB)\n","Collecting einops-exts==0.0.3 (from minimagen)\n"," Downloading einops_exts-0.0.3-py3-none-any.whl (3.8 kB)\n","Collecting filelock==3.7.1 (from minimagen)\n"," Downloading filelock-3.7.1-py3-none-any.whl (10 kB)\n","Collecting frozenlist==1.3.0 (from minimagen)\n"," Downloading frozenlist-1.3.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (157 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m157.9/157.9 kB\u001b[0m \u001b[31m19.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting fsspec==2022.5.0 (from minimagen)\n"," Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.6/140.6 kB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting future==0.18.2 (from minimagen)\n"," Downloading future-0.18.2.tar.gz (829 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m829.2/829.2 kB\u001b[0m \u001b[31m71.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n","Collecting huggingface-hub==0.8.1 (from minimagen)\n"," Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m101.5/101.5 kB\u001b[0m \u001b[31m12.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting idna==3.3 (from minimagen)\n"," Downloading idna-3.3-py3-none-any.whl (61 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.2/61.2 kB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting multidict==6.0.2 (from minimagen)\n"," Downloading multidict-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (114 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m114.5/114.5 kB\u001b[0m \u001b[31m12.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting multiprocess==0.70.13 (from minimagen)\n"," Downloading multiprocess-0.70.13-py310-none-any.whl (133 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m133.1/133.1 kB\u001b[0m \u001b[31m16.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting numpy==1.23.1 (from minimagen)\n"," Downloading numpy-1.23.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.0 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m17.0/17.0 MB\u001b[0m \u001b[31m99.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting packaging==21.3 (from minimagen)\n"," Downloading packaging-21.3-py3-none-any.whl (40 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting pandas==1.4.3 (from minimagen)\n"," Downloading pandas-1.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.6 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.6/11.6 MB\u001b[0m \u001b[31m105.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting Pillow==9.2.0 (from minimagen)\n"," Downloading Pillow-9.2.0-cp310-cp310-manylinux_2_28_x86_64.whl (3.2 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.2/3.2 MB\u001b[0m \u001b[31m80.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting pyarrow==8.0.0 (from minimagen)\n"," Downloading pyarrow-8.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.4 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m29.4/29.4 MB\u001b[0m \u001b[31m15.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting pyparsing==3.0.9 (from minimagen)\n"," Downloading pyparsing-3.0.9-py3-none-any.whl (98 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.3/98.3 kB\u001b[0m \u001b[31m11.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: python-dateutil==2.8.2 in /usr/local/lib/python3.10/dist-packages (from minimagen) (2.8.2)\n","Collecting pytz==2022.1 (from minimagen)\n"," Downloading pytz-2022.1-py2.py3-none-any.whl (503 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m503.5/503.5 kB\u001b[0m \u001b[31m44.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting PyYAML==6.0 (from minimagen)\n"," Downloading PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (682 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m682.2/682.2 kB\u001b[0m \u001b[31m56.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting regex==2022.7.9 (from minimagen)\n"," Downloading regex-2022.7.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (764 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m764.0/764.0 kB\u001b[0m \u001b[31m60.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting requests==2.28.1 (from minimagen)\n"," Downloading requests-2.28.1-py3-none-any.whl (62 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.8/62.8 kB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting resize-right==0.0.2 (from minimagen)\n"," Downloading resize_right-0.0.2-py3-none-any.whl (8.9 kB)\n","Collecting responses==0.18.0 (from minimagen)\n"," Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n","Collecting sentencepiece==0.1.96 (from minimagen)\n"," Downloading sentencepiece-0.1.96-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m82.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: six==1.16.0 in /usr/local/lib/python3.10/dist-packages (from minimagen) (1.16.0)\n","Collecting tdqm==0.0.1 (from minimagen)\n"," Downloading tdqm-0.0.1.tar.gz (1.4 kB)\n"," Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n","Collecting tokenizers==0.12.1 (from minimagen)\n"," Downloading tokenizers-0.12.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.6/6.6 MB\u001b[0m \u001b[31m99.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting torch==1.12.0 (from minimagen)\n"," Downloading torch-1.12.0-cp310-cp310-manylinux1_x86_64.whl (776.3 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m776.3/776.3 MB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting torchvision==0.13.0 (from minimagen)\n"," Downloading torchvision-0.13.0-cp310-cp310-manylinux1_x86_64.whl (19.1 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m19.1/19.1 MB\u001b[0m \u001b[31m70.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting tqdm==4.64.0 (from minimagen)\n"," Downloading tqdm-4.64.0-py2.py3-none-any.whl (78 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.4/78.4 kB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting transformers==4.20.1 (from minimagen)\n"," Downloading transformers-4.20.1-py3-none-any.whl (4.4 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.4/4.4 MB\u001b[0m \u001b[31m118.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting typing-extensions==4.3.0 (from minimagen)\n"," Downloading typing_extensions-4.3.0-py3-none-any.whl (25 kB)\n","Collecting urllib3==1.26.10 (from minimagen)\n"," Downloading urllib3-1.26.10-py2.py3-none-any.whl (139 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m139.2/139.2 kB\u001b[0m \u001b[31m17.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting xxhash==3.0.0 (from minimagen)\n"," Downloading xxhash-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (211 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.6/211.6 kB\u001b[0m \u001b[31m25.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting yarl==1.7.2 (from minimagen)\n"," Downloading yarl-1.7.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (305 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m305.3/305.3 kB\u001b[0m \u001b[31m32.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.10/dist-packages (from datasets==2.3.2->minimagen) (2023.6.0)\n","INFO: pip is looking at multiple versions of fsspec[http] to determine which version is compatible with other requirements. This could take a while.\n","Collecting fsspec[http]>=2021.05.0 (from datasets==2.3.2->minimagen)\n"," Downloading fsspec-2023.5.0-py3-none-any.whl (160 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m160.1/160.1 kB\u001b[0m \u001b[31m19.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2023.4.0-py3-none-any.whl (153 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m154.0/154.0 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2023.3.0-py3-none-any.whl (145 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m145.4/145.4 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2023.1.0-py3-none-any.whl (143 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m143.0/143.0 kB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2022.11.0-py3-none-any.whl (139 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m139.5/139.5 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2022.10.0-py3-none-any.whl (138 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.8/138.8 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2022.8.2-py3-none-any.whl (140 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m19.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hINFO: pip is looking at multiple versions of fsspec[http] to determine which version is compatible with other requirements. This could take a while.\n"," Downloading fsspec-2022.7.1-py3-none-any.whl (141 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m141.2/141.2 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Downloading fsspec-2022.7.0-py3-none-any.whl (141 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m141.2/141.2 kB\u001b[0m \u001b[31m20.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hBuilding wheels for collected packages: future, tdqm\n"," Building wheel for future (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for future: filename=future-0.18.2-py3-none-any.whl size=491057 sha256=7027523a9d22db99ee9aceb42c9fee72cfd0c8ed02e5907796d4270817864786\n"," Stored in directory: /root/.cache/pip/wheels/22/73/06/557dc4f4ef68179b9d763930d6eec26b88ed7c389b19588a1c\n"," Building wheel for tdqm (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for tdqm: filename=tdqm-0.0.1-py3-none-any.whl size=1321 sha256=78bbe236ae25f778666a9625686232a20f68f3d889132d56ea15c515fd88a311\n"," Stored in directory: /root/.cache/pip/wheels/37/31/b8/7b711038035720ba0df14376af06e5e76b9bd61759c861ad92\n","Successfully built future tdqm\n","Installing collected packages: tokenizers, sentencepiece, resize-right, pytz, einops, xxhash, urllib3, typing-extensions, tqdm, regex, PyYAML, pyparsing, Pillow, numpy, multidict, idna, future, fsspec, frozenlist, filelock, einops-exts, dill, colorama, charset-normalizer, certifi, attrs, yarl, torch, tdqm, requests, pyarrow, pandas, packaging, multiprocess, aiosignal, torchvision, responses, huggingface-hub, aiohttp, transformers, datasets, minimagen\n"," Attempting uninstall: pytz\n"," Found existing installation: pytz 2022.7.1\n"," Uninstalling pytz-2022.7.1:\n"," Successfully uninstalled pytz-2022.7.1\n"," Attempting uninstall: urllib3\n"," Found existing installation: urllib3 1.26.16\n"," Uninstalling urllib3-1.26.16:\n"," Successfully uninstalled urllib3-1.26.16\n"," Attempting uninstall: typing-extensions\n"," Found existing installation: typing_extensions 4.7.1\n"," Uninstalling typing_extensions-4.7.1:\n"," Successfully uninstalled typing_extensions-4.7.1\n"," Attempting uninstall: tqdm\n"," Found existing installation: tqdm 4.65.0\n"," Uninstalling tqdm-4.65.0:\n"," Successfully uninstalled tqdm-4.65.0\n"," Attempting uninstall: regex\n"," Found existing installation: regex 2022.10.31\n"," Uninstalling regex-2022.10.31:\n"," Successfully uninstalled regex-2022.10.31\n"," Attempting uninstall: PyYAML\n"," Found existing installation: PyYAML 6.0.1\n"," Uninstalling PyYAML-6.0.1:\n"," Successfully uninstalled PyYAML-6.0.1\n"," Attempting uninstall: pyparsing\n"," Found existing installation: pyparsing 3.1.0\n"," Uninstalling pyparsing-3.1.0:\n"," Successfully uninstalled pyparsing-3.1.0\n"," Attempting uninstall: Pillow\n"," Found existing installation: Pillow 9.4.0\n"," Uninstalling Pillow-9.4.0:\n"," Successfully uninstalled Pillow-9.4.0\n"," Attempting uninstall: numpy\n"," Found existing installation: numpy 1.22.4\n"," Uninstalling numpy-1.22.4:\n"," Successfully uninstalled numpy-1.22.4\n"," Attempting uninstall: multidict\n"," Found existing installation: multidict 6.0.4\n"," Uninstalling multidict-6.0.4:\n"," Successfully uninstalled multidict-6.0.4\n"," Attempting uninstall: idna\n"," Found existing installation: idna 3.4\n"," Uninstalling idna-3.4:\n"," Successfully uninstalled idna-3.4\n"," Attempting uninstall: future\n"," Found existing installation: future 0.18.3\n"," Uninstalling future-0.18.3:\n"," Successfully uninstalled future-0.18.3\n"," Attempting uninstall: fsspec\n"," Found existing installation: fsspec 2023.6.0\n"," Uninstalling fsspec-2023.6.0:\n"," Successfully uninstalled fsspec-2023.6.0\n"," Attempting uninstall: frozenlist\n"," Found existing installation: frozenlist 1.4.0\n"," Uninstalling frozenlist-1.4.0:\n"," Successfully uninstalled frozenlist-1.4.0\n"," Attempting uninstall: filelock\n"," Found existing installation: filelock 3.12.2\n"," Uninstalling filelock-3.12.2:\n"," Successfully uninstalled filelock-3.12.2\n"," Attempting uninstall: charset-normalizer\n"," Found existing installation: charset-normalizer 2.0.12\n"," Uninstalling charset-normalizer-2.0.12:\n"," Successfully uninstalled charset-normalizer-2.0.12\n"," Attempting uninstall: certifi\n"," Found existing installation: certifi 2023.7.22\n"," Uninstalling certifi-2023.7.22:\n"," Successfully uninstalled certifi-2023.7.22\n"," Attempting uninstall: attrs\n"," Found existing installation: attrs 23.1.0\n"," Uninstalling attrs-23.1.0:\n"," Successfully uninstalled attrs-23.1.0\n"," Attempting uninstall: yarl\n"," Found existing installation: yarl 1.9.2\n"," Uninstalling yarl-1.9.2:\n"," Successfully uninstalled yarl-1.9.2\n"," Attempting uninstall: torch\n"," Found existing installation: torch 2.0.1+cu118\n"," Uninstalling torch-2.0.1+cu118:\n"," Successfully uninstalled torch-2.0.1+cu118\n"," Attempting uninstall: requests\n"," Found existing installation: requests 2.27.1\n"," Uninstalling requests-2.27.1:\n"," Successfully uninstalled requests-2.27.1\n"," Attempting uninstall: pyarrow\n"," Found existing installation: pyarrow 9.0.0\n"," Uninstalling pyarrow-9.0.0:\n"," Successfully uninstalled pyarrow-9.0.0\n"," Attempting uninstall: pandas\n"," Found existing installation: pandas 1.5.3\n"," Uninstalling pandas-1.5.3:\n"," Successfully uninstalled pandas-1.5.3\n"," Attempting uninstall: packaging\n"," Found existing installation: packaging 23.1\n"," Uninstalling packaging-23.1:\n"," Successfully uninstalled packaging-23.1\n"," Attempting uninstall: aiosignal\n"," Found existing installation: aiosignal 1.3.1\n"," Uninstalling aiosignal-1.3.1:\n"," Successfully uninstalled aiosignal-1.3.1\n"," Attempting uninstall: torchvision\n"," Found existing installation: torchvision 0.15.2+cu118\n"," Uninstalling torchvision-0.15.2+cu118:\n"," Successfully uninstalled torchvision-0.15.2+cu118\n"," Attempting uninstall: aiohttp\n"," Found existing installation: aiohttp 3.8.5\n"," Uninstalling aiohttp-3.8.5:\n"," Successfully uninstalled aiohttp-3.8.5\n","\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n","gcsfs 2023.6.0 requires fsspec==2023.6.0, but you have fsspec 2022.5.0 which is incompatible.\n","google-colab 1.0.0 requires pandas==1.5.3, but you have pandas 1.4.3 which is incompatible.\n","google-colab 1.0.0 requires requests==2.27.1, but you have requests 2.28.1 which is incompatible.\n","torchaudio 2.0.2+cu118 requires torch==2.0.1, but you have torch 1.12.0 which is incompatible.\n","torchdata 0.6.1 requires torch==2.0.1, but you have torch 1.12.0 which is incompatible.\n","torchtext 0.15.2 requires torch==2.0.1, but you have torch 1.12.0 which is incompatible.\n","yfinance 0.2.25 requires pytz>=2022.5, but you have pytz 2022.1 which is incompatible.\u001b[0m\u001b[31m\n","\u001b[0mSuccessfully installed Pillow-9.2.0 PyYAML-6.0 aiohttp-3.8.1 aiosignal-1.2.0 attrs-21.4.0 certifi-2022.6.15 charset-normalizer-2.1.0 colorama-0.4.5 datasets-2.3.2 dill-0.3.5.1 einops-0.4.1 einops-exts-0.0.3 filelock-3.7.1 frozenlist-1.3.0 fsspec-2022.5.0 future-0.18.2 huggingface-hub-0.8.1 idna-3.3 minimagen-0.0.9 multidict-6.0.2 multiprocess-0.70.13 numpy-1.23.1 packaging-21.3 pandas-1.4.3 pyarrow-8.0.0 pyparsing-3.0.9 pytz-2022.1 regex-2022.7.9 requests-2.28.1 resize-right-0.0.2 responses-0.18.0 sentencepiece-0.1.96 tdqm-0.0.1 tokenizers-0.12.1 torch-1.12.0 torchvision-0.13.0 tqdm-4.64.0 transformers-4.20.1 typing-extensions-4.3.0 urllib3-1.26.10 xxhash-3.0.0 yarl-1.7.2\n"]},{"output_type":"display_data","data":{"application/vnd.colab-display-data+json":{"pip_warning":{"packages":["PIL","certifi","numpy","packaging","torch","tqdm"]}}},"metadata":{}}],"source":["#install the minimagen package\n","!pip install minimagen"]},{"cell_type":"code","source":["#utility imports\n","import os\n","from datetime import datetime\n","\n","#pytorch related imports\n","import torch.utils.data as data_utils\n","from torch import optim\n","\n","#minimagen related imports\n","from minimagen.Imagen import Imagen\n","from minimagen.Unet import Unet, Base, Super, BaseTest, SuperTest\n","from minimagen.generate import load_minimagen, load_params\n","from minimagen.t5 import get_encoded_dim\n","from minimagen.training import get_minimagen_parser, ConceptualCaptions, get_minimagen_dl_opts, \\\n"," create_directory, get_model_size, save_training_info, get_default_args, MinimagenTrain, \\\n"," load_testing_parameters"],"metadata":{"id":"b_4eGJywmHR5","executionInfo":{"status":"error","timestamp":1690968676849,"user_tz":-330,"elapsed":5122,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"154b276a-c7dd-4469-d5b8-8d1fc03b5159","colab":{"base_uri":"https://localhost:8080/","height":381}},"execution_count":null,"outputs":[{"output_type":"error","ename":"ModuleNotFoundError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)","\u001b[0;32m<ipython-input-1-192156f481b5>\u001b[0m in \u001b[0;36m<cell line: 7>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mminimagen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mImagen\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mImagen\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mminimagen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnet\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mUnet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mBase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSuper\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mBaseTest\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSuperTest\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mminimagen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerate\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mload_minimagen\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload_params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'minimagen'","","\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"],"errorDetails":{"actions":[{"action":"open_url","actionText":"Open Examples","url":"/notebooks/snippets/importing_libraries.ipynb"}]}}]},{"cell_type":"code","source":["# Get device: Connect to GPU runtime for better performance\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","\n","# Command line argument parser\n","parser = get_minimagen_parser()\n","class args_cls:\n"," a = 0\n","\n","#get an instance of the args_cls\n","args = args_cls()"],"metadata":{"id":"GoNwdipqmH95"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#directory creation for training\n","timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n","dir_path = f\"./training_{timestamp}\"\n","training_dir = create_directory(dir_path)"],"metadata":{"id":"eq8I0I7MmKFz"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#A dictionary of hyperparameters\n","hyperparameters = dict(\n"," PARAMETERS=None,\n"," NUM_WORKERS=0,\n"," BATCH_SIZE=20,\n"," MAX_NUM_WORDS=32,\n"," IMG_SIDE_LEN=128,\n"," EPOCHS=10,\n"," T5_NAME='t5_small',\n"," TRAIN_VALID_FRAC=0.5,\n"," TRAINING_DIRECTORY = '/content/training_20230731_061334',\n"," TIMESTEPS=25,\n"," OPTIM_LR=0.0001,\n"," ACCUM_ITER=1,\n"," CHCKPT_NUM=500,\n"," VALID_NUM=None,\n"," RESTART_DIRECTORY=None,\n"," TESTING=False,\n"," timestamp=None,\n"," )\n","# Replace relevant values in arg dict\n","args.__dict__ = {**args.__dict__, **hyperparameters}"],"metadata":{"id":"8hdqSoAXoxqs"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Data"],"metadata":{"id":"0CvubtDKEmHm"}},{"cell_type":"code","source":["# Load subset of Conceptual Captions dataset.\n","train_dataset, valid_dataset = ConceptualCaptions(args, smalldata=False)\n","indices = torch.arange(1000)\n","\n","#create train and validation datasets with given number of samples\n","train_dataset = data_utils.Subset(train_dataset, indices)\n","valid_dataset = data_utils.Subset(valid_dataset, indices)\n","\n","# Create dataloaders\n","dl_opts = {**get_minimagen_dl_opts(device), 'batch_size': args.BATCH_SIZE, 'num_workers': args.NUM_WORKERS}\n","train_dataloader = torch.utils.data.DataLoader(train_dataset, **dl_opts)\n","valid_dataloader = torch.utils.data.DataLoader(valid_dataset, **dl_opts)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":105,"referenced_widgets":["f77bbc2602d846f8bb1c9e06f7b519ef","d2c87c66057f4e33bdbfe078ee47c0b2","fe33fa70d2c540e9831d408ffb5d3af9","61d7b89f706348548ebcaf0ced92a44e","a51b1054233342feaef0b16d2627a658","c4aa2aa530034af487957db71e9c509f","b4957dccdc5c4b83982212497febf4dc","592c8770c24c4366843336989c8cdb5f","f6c8ccfd74844b7ebafb37f181b51130","178935e2880043dc9eca862c09c05c16","aac7d4c2eebf4eb1914d2d98c52243ce"]},"id":"6LVG7NbZmNQq","executionInfo":{"status":"ok","timestamp":1690786766692,"user_tz":-330,"elapsed":9916,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"7b91d590-fb10-4c7a-e432-a315c656d54a"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["WARNING:datasets.builder:No config specified, defaulting to: conceptual_captions/unlabeled\n","WARNING:datasets.builder:Reusing dataset conceptual_captions (/root/.cache/huggingface/datasets/conceptual_captions/unlabeled/1.0.0/05266784888422e36944016874c44639bccb39069c2227435168ad8b02d600d8)\n"]},{"output_type":"display_data","data":{"text/plain":[" 0%| | 0/2 [00:00<?, ?it/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"f77bbc2602d846f8bb1c9e06f7b519ef"}},"metadata":{}}]},{"cell_type":"markdown","source":["# UNet"],"metadata":{"id":"vmwLUZh2Eqhn"}},{"cell_type":"code","source":["# Instantiate Unet with default parameters and transfer to GPU if available\n","unets_params = [get_default_args(BaseTest), get_default_args(SuperTest)]\n","unets = [Unet(**unet_params).to(device) for unet_params in unets_params]"],"metadata":{"id":"IaKJG4IamPJD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Specify MinImagen parameters\n","imagen_params = dict(\n"," image_sizes=(int(args.IMG_SIDE_LEN / 2), args.IMG_SIDE_LEN),\n"," timesteps=args.TIMESTEPS,\n"," cond_drop_prob=0.15,\n"," text_encoder_name=args.T5_NAME\n",")\n","\n","# Create MinImagen from UNets with specified imagen parameters\n","imagen = Imagen(unets=unets, **imagen_params).to(device)"],"metadata":{"id":"dl-w2Yy6mQ3Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Fill in unspecified arguments with defaults\n","unets_params = [{**get_default_args(Unet), **i} for i in unets_params]\n","imagen_params = {**get_default_args(Imagen), **imagen_params}\n","\n","# Get the size of the Imagen model in megabytes\n","model_size_MB = get_model_size(imagen)\n","\n","# Save all training info (config files, model size, etc.)\n","save_training_info(args, timestamp, unets_params, imagen_params, model_size_MB, training_dir)"],"metadata":{"id":"tzkJfhuRmSqg"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Training"],"metadata":{"id":"0XcZq-Q8EthC"}},{"cell_type":"code","source":["# Create optimizer - Adam\n","optimizer = optim.Adam(imagen.parameters(), lr=args.OPTIM_LR)\n","\n","# Train the MinImagen instance\n","MinimagenTrain(timestamp, args, unets, imagen, train_dataloader, valid_dataloader, training_dir, optimizer, timeout=30)"],"metadata":{"id":"I--5Lt18mUf8","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1690788612546,"user_tz":-330,"elapsed":1835395,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"2e19e526-6526-473f-8e77-007e3df8425d"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 1 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|β–ˆβ–ˆ | 1/5 [00:05<00:22, 5.54s/it]\u001b[A\n"," 40%|β–ˆβ–ˆβ–ˆβ–ˆ | 2/5 [00:31<00:53, 17.84s/it]\u001b[A\n"," 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 3/5 [00:45<00:31, 15.65s/it]\u001b[A\n"," 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 4/5 [00:51<00:11, 11.95s/it]\u001b[A\n","100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [01:06<00:00, 13.20s/it]\n","1it [01:14, 74.38s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.1316, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0483, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:27, 29.42s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 2 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|β–ˆβ–ˆ | 1/5 [00:11<00:46, 11.63s/it]\u001b[A\n"," 40%|β–ˆβ–ˆβ–ˆβ–ˆ | 2/5 [00:26<00:40, 13.34s/it]\u001b[A\n"," 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 3/5 [00:34<00:21, 10.96s/it]\u001b[A\n"," 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 4/5 [00:41<00:09, 9.59s/it]\u001b[A\n","100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [02:56<00:00, 35.30s/it]\n","1it [03:10, 190.22s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.0965, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0373, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [04:12, 50.50s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 3 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|β–ˆβ–ˆ | 1/5 [00:07<00:31, 7.80s/it]\u001b[A\n"," 40%|β–ˆβ–ˆβ–ˆβ–ˆ | 2/5 [00:13<00:19, 6.65s/it]\u001b[A\n"," 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 3/5 [00:27<00:20, 10.07s/it]\u001b[A\n"," 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 4/5 [00:32<00:07, 7.78s/it]\u001b[A\n","100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [02:54<00:00, 34.83s/it]\n","1it [03:54, 234.78s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.0735, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0289, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [04:21, 52.38s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 4 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|β–ˆβ–ˆ | 1/5 [00:15<01:02, 15.63s/it]\u001b[A\n"," 40%|β–ˆβ–ˆβ–ˆβ–ˆ | 2/5 [00:23<00:32, 10.99s/it]\u001b[A\n"," 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 3/5 [00:29<00:17, 8.55s/it]\u001b[A\n"," 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 4/5 [02:48<01:00, 60.17s/it]\u001b[A\n","100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [02:53<00:00, 34.62s/it]\n","1it [02:57, 177.30s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.0502, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0210, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [04:07, 49.47s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 5 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|β–ˆβ–ˆ | 1/5 [00:07<00:28, 7.14s/it]\u001b[A\n"," 40%|β–ˆβ–ˆβ–ˆβ–ˆ | 2/5 [00:31<00:52, 17.42s/it]\u001b[A\n"," 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 3/5 [00:37<00:23, 11.92s/it]\u001b[A\n"," 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 4/5 [00:43<00:09, 9.72s/it]\u001b[A\n","100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [00:51<00:00, 10.29s/it]\n","1it [01:16, 76.25s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.0274, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0135, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:32, 30.58s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 6 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|β–ˆβ–ˆ | 1/5 [00:34<02:17, 34.45s/it]\u001b[A\n"," 40%|β–ˆβ–ˆβ–ˆβ–ˆ | 2/5 [00:45<01:01, 20.50s/it]\u001b[A\n"," 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 3/5 [00:51<00:28, 14.07s/it]\u001b[A\n"," 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 4/5 [00:56<00:10, 10.56s/it]\u001b[A\n","100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [01:11<00:00, 14.37s/it]\n","1it [01:17, 77.59s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(1.0088, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0083, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:25, 29.04s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 7 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|β–ˆβ–ˆ | 1/5 [00:05<00:21, 5.37s/it]\u001b[A\n"," 40%|β–ˆβ–ˆβ–ˆβ–ˆ | 2/5 [00:18<00:29, 9.75s/it]\u001b[A\n"," 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 3/5 [00:34<00:25, 12.60s/it]\u001b[A\n"," 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 4/5 [00:49<00:13, 13.81s/it]\u001b[A\n","100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [00:53<00:00, 10.75s/it]\n","1it [01:00, 60.32s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(0.9863, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0049, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:07, 25.52s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 8 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|β–ˆβ–ˆ | 1/5 [00:10<00:41, 10.40s/it]\u001b[A\n"," 40%|β–ˆβ–ˆβ–ˆβ–ˆ | 2/5 [00:15<00:22, 7.41s/it]\u001b[A\n"," 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 3/5 [00:30<00:21, 10.65s/it]\u001b[A\n"," 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 4/5 [00:41<00:10, 11.00s/it]\u001b[A\n","100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [00:51<00:00, 10.24s/it]\n","1it [01:04, 64.48s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(0.9715, device='cuda:0')\n","Unet 1 avg validation loss: tensor(1.0007, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:05, 25.04s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 9 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|β–ˆβ–ˆ | 1/5 [00:10<00:41, 10.36s/it]\u001b[A\n"," 40%|β–ˆβ–ˆβ–ˆβ–ˆ | 2/5 [00:16<00:23, 7.89s/it]\u001b[A\n"," 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 3/5 [00:24<00:15, 7.73s/it]\u001b[A\n"," 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 4/5 [00:35<00:09, 9.07s/it]\u001b[A\n","100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [02:51<00:00, 34.28s/it]\n","1it [03:30, 210.85s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(0.9587, device='cuda:0')\n","Unet 1 avg validation loss: tensor(0.9981, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [04:11, 50.39s/it]\n"]},{"output_type":"stream","name":"stdout","text":["\n","-------------------- EPOCH 10 --------------------\n","\n","----------Training...----------\n"]},{"output_type":"stream","name":"stderr","text":["\r0it [00:00, ?it/s]"]},{"output_type":"stream","name":"stdout","text":["\n","----------Validation...----------\n"]},{"output_type":"stream","name":"stderr","text":["\n"," 0%| | 0/5 [00:00<?, ?it/s]\u001b[A\n"," 20%|β–ˆβ–ˆ | 1/5 [00:23<01:32, 23.20s/it]\u001b[A\n"," 40%|β–ˆβ–ˆβ–ˆβ–ˆ | 2/5 [00:33<00:46, 15.36s/it]\u001b[A\n"," 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 3/5 [00:37<00:20, 10.14s/it]\u001b[A\n"," 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 4/5 [00:43<00:08, 8.75s/it]\u001b[A\n","100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [00:50<00:00, 10.04s/it]\n","1it [01:37, 97.50s/it]"]},{"output_type":"stream","name":"stdout","text":["Unet 0 avg validation loss: tensor(0.9483, device='cuda:0')\n","Unet 1 avg validation loss: tensor(0.9955, device='cuda:0')\n"]},{"output_type":"stream","name":"stderr","text":["5it [02:03, 24.66s/it]\n"]}]},{"cell_type":"markdown","source":["# Inference"],"metadata":{"id":"BEnz_4zPEwgu"}},{"cell_type":"code","source":["from argparse import ArgumentParser\n","from minimagen.generate import load_minimagen, sample_and_save\n"],"metadata":{"id":"PUWUXmYDmdpm"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Specify the caption(s) to generate images for\n","captions = ['happy']"],"metadata":{"id":"PPzAqX0qmeKa"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["args_cls"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"BGpmb7jamimu","executionInfo":{"status":"ok","timestamp":1690784383671,"user_tz":-330,"elapsed":445,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"7972ae9b-1573-48de-d4be-923370eeb9e4"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["__main__.args_cls"]},"metadata":{},"execution_count":22}]},{"cell_type":"code","source":["# Use `sample_and_save` to generate and save the iamges\n","sample_and_save(captions, training_directory='/content/training_20230731_065902')"],"metadata":{"id":"fMxM5zdNmf8e","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1690788695591,"user_tz":-330,"elapsed":3817,"user":{"displayName":"Aditya Patkar","userId":"16560201646582853800"}},"outputId":"01a1d7c8-bfcc-420b-e679-2f57525f8d30"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["0it [00:00, ?it/s]\n","sampling loop time step: 0%| | 0/25 [00:00<?, ?it/s]\u001b[A\n","sampling loop time step: 12%|β–ˆβ– | 3/25 [00:00<00:01, 21.00it/s]\u001b[A\n","sampling loop time step: 24%|β–ˆβ–ˆβ– | 6/25 [00:00<00:00, 20.08it/s]\u001b[A\n","sampling loop time step: 36%|β–ˆβ–ˆβ–ˆβ–Œ | 9/25 [00:00<00:00, 20.17it/s]\u001b[A\n","sampling loop time step: 48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š | 12/25 [00:00<00:00, 20.03it/s]\u001b[A\n","sampling loop time step: 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 15/25 [00:00<00:00, 20.02it/s]\u001b[A\n","sampling loop time step: 72%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 18/25 [00:00<00:00, 19.98it/s]\u001b[A\n","sampling loop time step: 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 20/25 [00:01<00:00, 19.70it/s]\u001b[A\n","sampling loop time step: 88%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 22/25 [00:01<00:00, 19.56it/s]\u001b[A\n","sampling loop time step: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 25/25 [00:01<00:00, 19.65it/s]\n","1it [00:01, 1.28s/it]\n","sampling loop time step: 0%| | 0/25 [00:00<?, ?it/s]\u001b[A\n","sampling loop time step: 8%|β–Š | 2/25 [00:00<00:01, 11.65it/s]\u001b[A\n","sampling loop time step: 16%|β–ˆβ–Œ | 4/25 [00:00<00:01, 11.37it/s]\u001b[A\n","sampling loop time step: 24%|β–ˆβ–ˆβ– | 6/25 [00:00<00:01, 11.53it/s]\u001b[A\n","sampling loop time step: 32%|β–ˆβ–ˆβ–ˆβ– | 8/25 [00:00<00:01, 11.37it/s]\u001b[A\n","sampling loop time step: 40%|β–ˆβ–ˆβ–ˆβ–ˆ | 10/25 [00:00<00:01, 11.47it/s]\u001b[A\n","sampling loop time step: 48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š | 12/25 [00:01<00:01, 11.08it/s]\u001b[A\n","sampling loop time step: 56%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 14/25 [00:01<00:00, 11.27it/s]\u001b[A\n","sampling loop time step: 64%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 16/25 [00:01<00:00, 11.41it/s]\u001b[A\n","sampling loop time step: 72%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 18/25 [00:01<00:00, 11.18it/s]\u001b[A\n","sampling loop time step: 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 20/25 [00:01<00:00, 11.21it/s]\u001b[A\n","sampling loop time step: 88%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 22/25 [00:01<00:00, 11.29it/s]\u001b[A\n","sampling loop time step: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 25/25 [00:02<00:00, 11.27it/s]\n","2it [00:03, 1.76s/it]\n"]}]}]}
app.py CHANGED
@@ -7,6 +7,7 @@ import streamlit as st
7
  from text_to_image import generate_image
8
  from feature_to_sprite import generate_sprites
9
 
 
10
  def setup():
11
  """
12
  Streamlit related setup. This has to be run for each page.
@@ -81,33 +82,35 @@ def main():
81
  This mode generates 16*16 images of sprites based on a combination of features. It uses a custom model trained on a dataset of sprites.
82
  """
83
  )
84
-
85
  form = st.form(key="my_form")
86
 
87
- #add sliders
88
  hero = form.slider("Hero", min_value=0.0, max_value=1.0, value=1.0, step=0.01)
89
- non_hero = form.slider("Non Hero", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
 
 
90
  food = form.slider("Food", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
91
  spell = form.slider("Spell", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
92
- side_facing = form.slider("Side Facing", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
93
- #add submit button
 
 
94
  submit_button = form.form_submit_button(label="Generate")
95
 
96
- #create feature vector
97
  if submit_button:
98
  feature_vector = [hero, non_hero, food, spell, side_facing]
99
- #show loader
100
  with st.spinner("Generating sprite..."):
101
- #horizontal line and line break
102
  st.markdown("<hr>", unsafe_allow_html=True)
103
  st.markdown("<br>", unsafe_allow_html=True)
104
-
105
  st.subheader("Your Sprite")
106
  st.markdown("<br>", unsafe_allow_html=True)
107
-
108
- generate_sprites(feature_vector)
109
-
110
 
 
111
 
112
 
113
  if __name__ == "__main__":
 
7
  from text_to_image import generate_image
8
  from feature_to_sprite import generate_sprites
9
 
10
+
11
  def setup():
12
  """
13
  Streamlit related setup. This has to be run for each page.
 
82
  This mode generates 16*16 images of sprites based on a combination of features. It uses a custom model trained on a dataset of sprites.
83
  """
84
  )
85
+
86
  form = st.form(key="my_form")
87
 
88
+ # add sliders
89
  hero = form.slider("Hero", min_value=0.0, max_value=1.0, value=1.0, step=0.01)
90
+ non_hero = form.slider(
91
+ "Non Hero", min_value=0.0, max_value=1.0, value=0.0, step=0.01
92
+ )
93
  food = form.slider("Food", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
94
  spell = form.slider("Spell", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
95
+ side_facing = form.slider(
96
+ "Side Facing", min_value=0.0, max_value=1.0, value=0.0, step=0.01
97
+ )
98
+ # add submit button
99
  submit_button = form.form_submit_button(label="Generate")
100
 
101
+ # create feature vector
102
  if submit_button:
103
  feature_vector = [hero, non_hero, food, spell, side_facing]
104
+ # show loader
105
  with st.spinner("Generating sprite..."):
106
+ # horizontal line and line break
107
  st.markdown("<hr>", unsafe_allow_html=True)
108
  st.markdown("<br>", unsafe_allow_html=True)
109
+
110
  st.subheader("Your Sprite")
111
  st.markdown("<br>", unsafe_allow_html=True)
 
 
 
112
 
113
+ _ = generate_sprites(feature_vector)
114
 
115
 
116
  if __name__ == "__main__":
constants.py CHANGED
@@ -1,7 +1,7 @@
1
  """
2
  This file contains all the constants used in the project.
3
  """
4
- import os
5
 
6
  MODEL_ID = "stabilityai/stable-diffusion-2-1"
7
  WANDB_API_KEY = os.environ.get("WANDB_API_KEY")
 
1
  """
2
  This file contains all the constants used in the project.
3
  """
4
+ import os
5
 
6
  MODEL_ID = "stabilityai/stable-diffusion-2-1"
7
  WANDB_API_KEY = os.environ.get("WANDB_API_KEY")
feature_to_sprite.py CHANGED
@@ -8,31 +8,35 @@ import streamlit as st
8
  from utilities import ContextUnet, setup_ddpm
9
  from constants import WANDB_API_KEY
10
 
 
11
  def load_model():
 
 
 
 
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
-
14
- #login to wandb
15
- #wandb.login(key=WANDB_API_KEY)
16
-
17
-
18
- "Load the model from wandb artifacts"
19
  api = wandb.Api(api_key=WANDB_API_KEY)
20
- artifact = api.artifact("teamaditya/model-registry/Feature2Sprite:v0", type="model")
21
  model_path = Path(artifact.download())
22
 
23
  # recover model info from the registry
24
  producer_run = artifact.logged_by()
25
 
26
  # load the weights dictionary
27
- model_weights = torch.load(model_path/"context_model.pth",
28
- map_location="cpu")
29
 
30
  # create the model
31
- model = ContextUnet(in_channels=3,
32
- n_feat=producer_run.config["n_feat"],
33
- n_cfeat=producer_run.config["n_cfeat"],
34
- height=producer_run.config["height"])
35
-
 
 
36
  # load the weights into the model
37
  model.load_state_dict(model_weights)
38
 
@@ -40,46 +44,44 @@ def load_model():
40
  model.eval()
41
  return model.to(DEVICE)
42
 
 
43
  def show_image(img):
 
 
 
44
  img = (img.permute(1, 2, 0).clip(-1, 1).detach().cpu().numpy() + 1) / 2
45
  st.image(img, clamp=True)
 
 
46
 
47
  def generate_sprites(feature_vector):
 
 
 
48
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
49
  config = SimpleNamespace(
50
  # hyperparameters
51
- num_samples = 30,
52
-
53
  # ddpm sampler hyperparameters
54
- timesteps = 500,
55
- beta1 = 1e-4,
56
- beta2 = 0.02,
57
-
58
  # network hyperparameters
59
- height = 16,
60
  )
61
  nn_model = load_model()
62
-
63
- _, sample_ddpm_context = setup_ddpm(config.beta1,
64
- config.beta2,
65
- config.timesteps,
66
- DEVICE)
67
-
68
- noises = torch.randn(config.num_samples, 3,
69
- config.height, config.height).to(DEVICE)
70
-
71
  feature_vector = torch.tensor([feature_vector]).to(DEVICE).float()
72
  ddpm_samples, _ = sample_ddpm_context(nn_model, noises, feature_vector)
73
 
74
- #upscale the 16*16 images to 256*256
75
  ddpm_samples = F.interpolate(ddpm_samples, size=(256, 256), mode="bilinear")
76
  # show the images
77
- show_image(ddpm_samples[0])
78
-
79
-
80
-
81
-
82
-
83
-
84
-
85
-
 
8
  from utilities import ContextUnet, setup_ddpm
9
  from constants import WANDB_API_KEY
10
 
11
+
12
  def load_model():
13
+ """
14
+ This function loads the model from the model registry.
15
+ """
16
+
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ # login to wandb
20
+ # wandb.login(key=WANDB_API_KEY)
21
+
 
 
22
  api = wandb.Api(api_key=WANDB_API_KEY)
23
+ artifact = api.artifact("teamaditya/model-registry/Feature2Sprite:v1", type="model")
24
  model_path = Path(artifact.download())
25
 
26
  # recover model info from the registry
27
  producer_run = artifact.logged_by()
28
 
29
  # load the weights dictionary
30
+ model_weights = torch.load(model_path / "context_model.pth", map_location="cpu")
 
31
 
32
  # create the model
33
+ model = ContextUnet(
34
+ in_channels=3,
35
+ n_feat=producer_run.config["n_feat"],
36
+ n_cfeat=producer_run.config["n_cfeat"],
37
+ height=producer_run.config["height"],
38
+ )
39
+
40
  # load the weights into the model
41
  model.load_state_dict(model_weights)
42
 
 
44
  model.eval()
45
  return model.to(DEVICE)
46
 
47
+
48
  def show_image(img):
49
+ """
50
+ This function shows the image in the streamlit app.
51
+ """
52
  img = (img.permute(1, 2, 0).clip(-1, 1).detach().cpu().numpy() + 1) / 2
53
  st.image(img, clamp=True)
54
+ return img
55
+
56
 
57
  def generate_sprites(feature_vector):
58
+ """
59
+ This function generates sprites from a given feature vector.
60
+ """
61
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
62
  config = SimpleNamespace(
63
  # hyperparameters
64
+ num_samples=30,
 
65
  # ddpm sampler hyperparameters
66
+ timesteps=500,
67
+ beta1=1e-4,
68
+ beta2=0.02,
 
69
  # network hyperparameters
70
+ height=16,
71
  )
72
  nn_model = load_model()
73
+
74
+ _, sample_ddpm_context = setup_ddpm(
75
+ config.beta1, config.beta2, config.timesteps, DEVICE
76
+ )
77
+
78
+ noises = torch.randn(config.num_samples, 3, config.height, config.height).to(DEVICE)
79
+
 
 
80
  feature_vector = torch.tensor([feature_vector]).to(DEVICE).float()
81
  ddpm_samples, _ = sample_ddpm_context(nn_model, noises, feature_vector)
82
 
83
+ # upscale the 16*16 images to 256*256
84
  ddpm_samples = F.interpolate(ddpm_samples, size=(256, 256), mode="bilinear")
85
  # show the images
86
+ img = show_image(ddpm_samples[0])
87
+ return img
 
 
 
 
 
 
 
text_to_image.py CHANGED
@@ -20,7 +20,9 @@ def generate_image(prompt: str) -> torch.Tensor:
20
  """
21
  # load model
22
 
23
- pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID) #torch_dtype=torch.float16
 
 
24
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
25
  if torch.cuda.is_available():
26
  pipe = pipe.to("cuda") # move model to GPU if available
 
20
  """
21
  # load model
22
 
23
+ pipe = StableDiffusionPipeline.from_pretrained(
24
+ MODEL_ID
25
+ ) # torch_dtype=torch.float16
26
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
27
  if torch.cuda.is_available():
28
  pipe = pipe.to("cuda") # move model to GPU if available
utilities.py CHANGED
@@ -3,36 +3,41 @@ import torch
3
  import torch.nn as nn
4
  from tqdm.auto import tqdm
5
 
 
6
  class ContextUnet(nn.Module):
7
- def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28): # cfeat - context features
 
 
8
  super(ContextUnet, self).__init__()
9
 
10
  # number of input channels, number of intermediate feature maps and number of classes
11
  self.in_channels = in_channels
12
  self.n_feat = n_feat
13
  self.n_cfeat = n_cfeat
14
- self.h = height #assume h == w. must be divisible by 4, so 28,24,20,16...
15
 
16
  # Initialize the initial convolutional layer
17
  self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
18
 
19
  # Initialize the down-sampling path of the U-Net with two levels
20
- self.down1 = UnetDown(n_feat, n_feat) # down1 #[10, 256, 8, 8]
21
- self.down2 = UnetDown(n_feat, 2 * n_feat) # down2 #[10, 256, 4, 4]
22
-
23
- # original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
24
  self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())
25
 
26
  # Embed the timestep and context labels with a one-layer fully connected neural network
27
- self.timeembed1 = EmbedFC(1, 2*n_feat)
28
- self.timeembed2 = EmbedFC(1, 1*n_feat)
29
- self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat)
30
- self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat)
31
 
32
  # Initialize the up-sampling path of the U-Net with three levels
33
  self.up0 = nn.Sequential(
34
- nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4), # up-sample
35
- nn.GroupNorm(8, 2 * n_feat), # normalize
 
 
36
  nn.ReLU(),
37
  )
38
  self.up1 = UnetUp(4 * n_feat, n_feat)
@@ -40,10 +45,14 @@ class ContextUnet(nn.Module):
40
 
41
  # Initialize the final convolutional layers to map to the same number of channels as the input image
42
  self.out = nn.Sequential(
43
- nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps #in_channels, out_channels, kernel_size, stride=1, padding=0
44
- nn.GroupNorm(8, n_feat), # normalize
 
 
45
  nn.ReLU(),
46
- nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input
 
 
47
  )
48
 
49
  def forward(self, x, t, c=None):
@@ -57,30 +66,32 @@ class ContextUnet(nn.Module):
57
  # pass the input image through the initial convolutional layer
58
  x = self.init_conv(x)
59
  # pass the result through the down-sampling path
60
- down1 = self.down1(x) #[10, 256, 8, 8]
61
- down2 = self.down2(down1) #[10, 256, 4, 4]
62
-
63
  # convert the feature maps to a vector and apply an activation
64
  hiddenvec = self.to_vec(down2)
65
-
66
  # mask out context if context_mask == 1
67
  if c is None:
68
  c = torch.zeros(x.shape[0], self.n_cfeat).to(x)
69
-
70
  # embed context and timestep
71
- cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1) # (batch, 2*n_feat, 1,1)
 
 
72
  temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
73
  cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
74
  temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
75
- #print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}")
76
-
77
 
78
  up1 = self.up0(hiddenvec)
79
- up2 = self.up1(cemb1*up1 + temb1, down2) # add and multiply embeddings
80
- up3 = self.up2(cemb2*up2 + temb2, down1)
81
  out = self.out(torch.cat((up3, x), 1))
82
  return out
83
 
 
84
  class ResidualConvBlock(nn.Module):
85
  def __init__(
86
  self, in_channels: int, out_channels: int, is_res: bool = False
@@ -95,20 +106,23 @@ class ResidualConvBlock(nn.Module):
95
 
96
  # First convolutional layer
97
  self.conv1 = nn.Sequential(
98
- nn.Conv2d(in_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1
99
- nn.BatchNorm2d(out_channels), # Batch normalization
100
- nn.GELU(), # GELU activation function
 
 
101
  )
102
 
103
  # Second convolutional layer
104
  self.conv2 = nn.Sequential(
105
- nn.Conv2d(out_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1
106
- nn.BatchNorm2d(out_channels), # Batch normalization
107
- nn.GELU(), # GELU activation function
 
 
108
  )
109
 
110
  def forward(self, x: torch.Tensor) -> torch.Tensor:
111
-
112
  # If using residual connection
113
  if self.is_res:
114
  # Apply first convolutional layer
@@ -122,9 +136,11 @@ class ResidualConvBlock(nn.Module):
122
  out = x + x2
123
  else:
124
  # If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection
125
- shortcut = nn.Conv2d(x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0).to(x.device)
 
 
126
  out = shortcut(x) + x2
127
- #print(f"resconv forward: x {x.shape}, x1 {x1.shape}, x2 {x2.shape}, out {out.shape}")
128
 
129
  # Normalize output tensor
130
  return out / 1.414
@@ -145,12 +161,11 @@ class ResidualConvBlock(nn.Module):
145
  self.conv2[0].in_channels = out_channels
146
  self.conv2[0].out_channels = out_channels
147
 
148
-
149
 
150
  class UnetUp(nn.Module):
151
  def __init__(self, in_channels, out_channels):
152
  super(UnetUp, self).__init__()
153
-
154
  # Create a list of layers for the upsampling block
155
  # The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers
156
  layers = [
@@ -158,27 +173,31 @@ class UnetUp(nn.Module):
158
  ResidualConvBlock(out_channels, out_channels),
159
  ResidualConvBlock(out_channels, out_channels),
160
  ]
161
-
162
  # Use the layers to create a sequential model
163
  self.model = nn.Sequential(*layers)
164
 
165
  def forward(self, x, skip):
166
  # Concatenate the input tensor x with the skip connection tensor along the channel dimension
167
  x = torch.cat((x, skip), 1)
168
-
169
  # Pass the concatenated tensor through the sequential model and return the output
170
  x = self.model(x)
171
  return x
172
 
173
-
174
  class UnetDown(nn.Module):
175
  def __init__(self, in_channels, out_channels):
176
  super(UnetDown, self).__init__()
177
-
178
  # Create a list of layers for the downsampling block
179
  # Each block consists of two ResidualConvBlock layers, followed by a MaxPool2d layer for downsampling
180
- layers = [ResidualConvBlock(in_channels, out_channels), ResidualConvBlock(out_channels, out_channels), nn.MaxPool2d(2)]
181
-
 
 
 
 
182
  # Use the layers to create a sequential model
183
  self.model = nn.Sequential(*layers)
184
 
@@ -186,22 +205,23 @@ class UnetDown(nn.Module):
186
  # Pass the input through the sequential model and return the output
187
  return self.model(x)
188
 
 
189
  class EmbedFC(nn.Module):
190
  def __init__(self, input_dim, emb_dim):
191
  super(EmbedFC, self).__init__()
192
- '''
193
  This class defines a generic one layer feed-forward neural network for embedding input data of
194
  dimensionality input_dim to an embedding space of dimensionality emb_dim.
195
- '''
196
  self.input_dim = input_dim
197
-
198
  # define the layers for the network
199
  layers = [
200
  nn.Linear(input_dim, emb_dim),
201
  nn.GELU(),
202
  nn.Linear(emb_dim, emb_dim),
203
  ]
204
-
205
  # create a PyTorch sequential model consisting of the defined layers
206
  self.model = nn.Sequential(*layers)
207
 
@@ -210,46 +230,53 @@ class EmbedFC(nn.Module):
210
  x = x.view(-1, self.input_dim)
211
  # apply the model layers to the flattened tensor
212
  return self.model(x)
213
-
 
214
  def unorm(x):
215
  # unity norm. results in range of [0,1]
216
  # assume x (h,w,3)
217
- xmax = x.max((0,1))
218
- xmin = x.min((0,1))
219
- return(x - xmin)/(xmax - xmin)
 
220
 
221
  def norm_all(store, n_t, n_s):
222
  # runs unity norm on all timesteps of all samples
223
  nstore = np.zeros_like(store)
224
  for t in range(n_t):
225
  for s in range(n_s):
226
- nstore[t,s] = unorm(store[t,s])
227
  return nstore
228
 
 
229
  def norm_torch(x_all):
230
  # runs unity norm on all timesteps of all samples
231
  # input is (n_samples, 3,h,w), the torch image format
232
  x = x_all.cpu().numpy()
233
- xmax = x.max((2,3))
234
- xmin = x.min((2,3))
235
- xmax = np.expand_dims(xmax,(2,3))
236
- xmin = np.expand_dims(xmin,(2,3))
237
- nstore = (x - xmin)/(xmax - xmin)
238
  return torch.from_numpy(nstore)
239
 
240
 
241
  ## diffusion functions
242
 
 
243
  def setup_ddpm(beta1, beta2, timesteps, device):
244
  # construct DDPM noise schedule and sampling functions
245
  b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
246
  a_t = 1 - b_t
247
- ab_t = torch.cumsum(a_t.log(), dim=0).exp()
248
  ab_t[0] = 1
249
 
250
  # helper function: perturbs an image to a specified noise level
251
  def perturb_input(x, t, noise):
252
- return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]) * noise
 
 
 
253
 
254
  # helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
255
  def _denoise_add_noise(x, t, pred_noise, z=None):
@@ -264,10 +291,10 @@ def setup_ddpm(beta1, beta2, timesteps, device):
264
  @torch.no_grad()
265
  def sample_ddpm_context(nn_model, noises, context, save_rate=20):
266
  # array to keep track of generated steps for plotting
267
- intermediate = []
268
  pbar = tqdm(range(timesteps, 0, -1), leave=False)
269
  for i in pbar:
270
- pbar.set_description(f'sampling timestep {i:3d}')
271
 
272
  # reshape time tensor
273
  t = torch.tensor([i / timesteps])[:, None, None, None].to(noises.device)
@@ -275,12 +302,12 @@ def setup_ddpm(beta1, beta2, timesteps, device):
275
  # sample some random noise to inject back in. For i = 1, don't add back in noise
276
  z = torch.randn_like(noises) if i > 1 else 0
277
 
278
- eps = nn_model(noises, t, c=context) # predict noise e_(x_t,t, ctx)
279
  noises = _denoise_add_noise(noises, i, eps, z)
280
- if i % save_rate==0 or i==timesteps or i<8:
281
  intermediate.append(noises.detach().cpu().numpy())
282
 
283
  intermediate = np.stack(intermediate)
284
  return noises.clip(-1, 1), intermediate
285
-
286
  return perturb_input, sample_ddpm_context
 
3
  import torch.nn as nn
4
  from tqdm.auto import tqdm
5
 
6
+
7
  class ContextUnet(nn.Module):
8
+ def __init__(
9
+ self, in_channels, n_feat=256, n_cfeat=10, height=28
10
+ ): # cfeat - context features
11
  super(ContextUnet, self).__init__()
12
 
13
  # number of input channels, number of intermediate feature maps and number of classes
14
  self.in_channels = in_channels
15
  self.n_feat = n_feat
16
  self.n_cfeat = n_cfeat
17
+ self.h = height # assume h == w. must be divisible by 4, so 28,24,20,16...
18
 
19
  # Initialize the initial convolutional layer
20
  self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
21
 
22
  # Initialize the down-sampling path of the U-Net with two levels
23
+ self.down1 = UnetDown(n_feat, n_feat) # down1 #[10, 256, 8, 8]
24
+ self.down2 = UnetDown(n_feat, 2 * n_feat) # down2 #[10, 256, 4, 4]
25
+
26
+ # original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
27
  self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())
28
 
29
  # Embed the timestep and context labels with a one-layer fully connected neural network
30
+ self.timeembed1 = EmbedFC(1, 2 * n_feat)
31
+ self.timeembed2 = EmbedFC(1, 1 * n_feat)
32
+ self.contextembed1 = EmbedFC(n_cfeat, 2 * n_feat)
33
+ self.contextembed2 = EmbedFC(n_cfeat, 1 * n_feat)
34
 
35
  # Initialize the up-sampling path of the U-Net with three levels
36
  self.up0 = nn.Sequential(
37
+ nn.ConvTranspose2d(
38
+ 2 * n_feat, 2 * n_feat, self.h // 4, self.h // 4
39
+ ), # up-sample
40
+ nn.GroupNorm(8, 2 * n_feat), # normalize
41
  nn.ReLU(),
42
  )
43
  self.up1 = UnetUp(4 * n_feat, n_feat)
 
45
 
46
  # Initialize the final convolutional layers to map to the same number of channels as the input image
47
  self.out = nn.Sequential(
48
+ nn.Conv2d(
49
+ 2 * n_feat, n_feat, 3, 1, 1
50
+ ), # reduce number of feature maps #in_channels, out_channels, kernel_size, stride=1, padding=0
51
+ nn.GroupNorm(8, n_feat), # normalize
52
  nn.ReLU(),
53
+ nn.Conv2d(
54
+ n_feat, self.in_channels, 3, 1, 1
55
+ ), # map to same number of channels as input
56
  )
57
 
58
  def forward(self, x, t, c=None):
 
66
  # pass the input image through the initial convolutional layer
67
  x = self.init_conv(x)
68
  # pass the result through the down-sampling path
69
+ down1 = self.down1(x) # [10, 256, 8, 8]
70
+ down2 = self.down2(down1) # [10, 256, 4, 4]
71
+
72
  # convert the feature maps to a vector and apply an activation
73
  hiddenvec = self.to_vec(down2)
74
+
75
  # mask out context if context_mask == 1
76
  if c is None:
77
  c = torch.zeros(x.shape[0], self.n_cfeat).to(x)
78
+
79
  # embed context and timestep
80
+ cemb1 = self.contextembed1(c).view(
81
+ -1, self.n_feat * 2, 1, 1
82
+ ) # (batch, 2*n_feat, 1,1)
83
  temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
84
  cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
85
  temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
86
+ # print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}")
 
87
 
88
  up1 = self.up0(hiddenvec)
89
+ up2 = self.up1(cemb1 * up1 + temb1, down2) # add and multiply embeddings
90
+ up3 = self.up2(cemb2 * up2 + temb2, down1)
91
  out = self.out(torch.cat((up3, x), 1))
92
  return out
93
 
94
+
95
  class ResidualConvBlock(nn.Module):
96
  def __init__(
97
  self, in_channels: int, out_channels: int, is_res: bool = False
 
106
 
107
  # First convolutional layer
108
  self.conv1 = nn.Sequential(
109
+ nn.Conv2d(
110
+ in_channels, out_channels, 3, 1, 1
111
+ ), # 3x3 kernel with stride 1 and padding 1
112
+ nn.BatchNorm2d(out_channels), # Batch normalization
113
+ nn.GELU(), # GELU activation function
114
  )
115
 
116
  # Second convolutional layer
117
  self.conv2 = nn.Sequential(
118
+ nn.Conv2d(
119
+ out_channels, out_channels, 3, 1, 1
120
+ ), # 3x3 kernel with stride 1 and padding 1
121
+ nn.BatchNorm2d(out_channels), # Batch normalization
122
+ nn.GELU(), # GELU activation function
123
  )
124
 
125
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
126
  # If using residual connection
127
  if self.is_res:
128
  # Apply first convolutional layer
 
136
  out = x + x2
137
  else:
138
  # If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection
139
+ shortcut = nn.Conv2d(
140
+ x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0
141
+ ).to(x.device)
142
  out = shortcut(x) + x2
143
+ # print(f"resconv forward: x {x.shape}, x1 {x1.shape}, x2 {x2.shape}, out {out.shape}")
144
 
145
  # Normalize output tensor
146
  return out / 1.414
 
161
  self.conv2[0].in_channels = out_channels
162
  self.conv2[0].out_channels = out_channels
163
 
 
164
 
165
  class UnetUp(nn.Module):
166
  def __init__(self, in_channels, out_channels):
167
  super(UnetUp, self).__init__()
168
+
169
  # Create a list of layers for the upsampling block
170
  # The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers
171
  layers = [
 
173
  ResidualConvBlock(out_channels, out_channels),
174
  ResidualConvBlock(out_channels, out_channels),
175
  ]
176
+
177
  # Use the layers to create a sequential model
178
  self.model = nn.Sequential(*layers)
179
 
180
  def forward(self, x, skip):
181
  # Concatenate the input tensor x with the skip connection tensor along the channel dimension
182
  x = torch.cat((x, skip), 1)
183
+
184
  # Pass the concatenated tensor through the sequential model and return the output
185
  x = self.model(x)
186
  return x
187
 
188
+
189
  class UnetDown(nn.Module):
190
  def __init__(self, in_channels, out_channels):
191
  super(UnetDown, self).__init__()
192
+
193
  # Create a list of layers for the downsampling block
194
  # Each block consists of two ResidualConvBlock layers, followed by a MaxPool2d layer for downsampling
195
+ layers = [
196
+ ResidualConvBlock(in_channels, out_channels),
197
+ ResidualConvBlock(out_channels, out_channels),
198
+ nn.MaxPool2d(2),
199
+ ]
200
+
201
  # Use the layers to create a sequential model
202
  self.model = nn.Sequential(*layers)
203
 
 
205
  # Pass the input through the sequential model and return the output
206
  return self.model(x)
207
 
208
+
209
  class EmbedFC(nn.Module):
210
  def __init__(self, input_dim, emb_dim):
211
  super(EmbedFC, self).__init__()
212
+ """
213
  This class defines a generic one layer feed-forward neural network for embedding input data of
214
  dimensionality input_dim to an embedding space of dimensionality emb_dim.
215
+ """
216
  self.input_dim = input_dim
217
+
218
  # define the layers for the network
219
  layers = [
220
  nn.Linear(input_dim, emb_dim),
221
  nn.GELU(),
222
  nn.Linear(emb_dim, emb_dim),
223
  ]
224
+
225
  # create a PyTorch sequential model consisting of the defined layers
226
  self.model = nn.Sequential(*layers)
227
 
 
230
  x = x.view(-1, self.input_dim)
231
  # apply the model layers to the flattened tensor
232
  return self.model(x)
233
+
234
+
235
  def unorm(x):
236
  # unity norm. results in range of [0,1]
237
  # assume x (h,w,3)
238
+ xmax = x.max((0, 1))
239
+ xmin = x.min((0, 1))
240
+ return (x - xmin) / (xmax - xmin)
241
+
242
 
243
  def norm_all(store, n_t, n_s):
244
  # runs unity norm on all timesteps of all samples
245
  nstore = np.zeros_like(store)
246
  for t in range(n_t):
247
  for s in range(n_s):
248
+ nstore[t, s] = unorm(store[t, s])
249
  return nstore
250
 
251
+
252
  def norm_torch(x_all):
253
  # runs unity norm on all timesteps of all samples
254
  # input is (n_samples, 3,h,w), the torch image format
255
  x = x_all.cpu().numpy()
256
+ xmax = x.max((2, 3))
257
+ xmin = x.min((2, 3))
258
+ xmax = np.expand_dims(xmax, (2, 3))
259
+ xmin = np.expand_dims(xmin, (2, 3))
260
+ nstore = (x - xmin) / (xmax - xmin)
261
  return torch.from_numpy(nstore)
262
 
263
 
264
  ## diffusion functions
265
 
266
+
267
  def setup_ddpm(beta1, beta2, timesteps, device):
268
  # construct DDPM noise schedule and sampling functions
269
  b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
270
  a_t = 1 - b_t
271
+ ab_t = torch.cumsum(a_t.log(), dim=0).exp()
272
  ab_t[0] = 1
273
 
274
  # helper function: perturbs an image to a specified noise level
275
  def perturb_input(x, t, noise):
276
+ return (
277
+ ab_t.sqrt()[t, None, None, None] * x
278
+ + (1 - ab_t[t, None, None, None]) * noise
279
+ )
280
 
281
  # helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
282
  def _denoise_add_noise(x, t, pred_noise, z=None):
 
291
  @torch.no_grad()
292
  def sample_ddpm_context(nn_model, noises, context, save_rate=20):
293
  # array to keep track of generated steps for plotting
294
+ intermediate = []
295
  pbar = tqdm(range(timesteps, 0, -1), leave=False)
296
  for i in pbar:
297
+ pbar.set_description(f"sampling timestep {i:3d}")
298
 
299
  # reshape time tensor
300
  t = torch.tensor([i / timesteps])[:, None, None, None].to(noises.device)
 
302
  # sample some random noise to inject back in. For i = 1, don't add back in noise
303
  z = torch.randn_like(noises) if i > 1 else 0
304
 
305
+ eps = nn_model(noises, t, c=context) # predict noise e_(x_t,t, ctx)
306
  noises = _denoise_add_noise(noises, i, eps, z)
307
+ if i % save_rate == 0 or i == timesteps or i < 8:
308
  intermediate.append(noises.detach().cpu().numpy())
309
 
310
  intermediate = np.stack(intermediate)
311
  return noises.clip(-1, 1), intermediate
312
+
313
  return perturb_input, sample_ddpm_context