File size: 2,519 Bytes
27486b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
model:
  base_learning_rate: 4.5e-6
  target: sgm.models.autoencoder.AutoencodingEngine
  params:
    input_key: jpg
    monitor: val/loss/rec
    disc_start_iter: 0

    encoder_config:
      target: sgm.modules.diffusionmodules.model.Encoder
      params:
        attn_type: vanilla-xformers
        double_z: true
        z_channels: 8
        resolution: 256
        in_channels: 3
        out_ch: 3
        ch: 128
        ch_mult: [1, 2, 4, 4]
        num_res_blocks: 2
        attn_resolutions: []
        dropout: 0.0

    decoder_config:
      target: sgm.modules.diffusionmodules.model.Decoder
      params: ${model.params.encoder_config.params}

    regularizer_config:
      target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer

    loss_config:
      target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
      params:
        perceptual_weight: 0.25
        disc_start: 20001
        disc_weight: 0.5
        learn_logvar: True

        regularization_weights:
          kl_loss: 1.0

data:
  target: sgm.data.dataset.StableDataModuleFromConfig
  params:
    train:
      datapipeline:
        urls:
          - DATA-PATH
        pipeline_config:
          shardshuffle: 10000
          sample_shuffle: 10000

        decoders:
          - pil

        postprocessors:
          - target: sdata.mappers.TorchVisionImageTransforms
            params:
              key: jpg
              transforms:
                - target: torchvision.transforms.Resize
                  params:
                    size: 256
                    interpolation: 3
                - target: torchvision.transforms.ToTensor
          - target: sdata.mappers.Rescaler
          - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
            params:
              h_key: height
              w_key: width

      loader:
        batch_size: 8
        num_workers: 4


lightning:
  strategy:
    target: pytorch_lightning.strategies.DDPStrategy
    params:
      find_unused_parameters: True

  modelcheckpoint:
    params:
      every_n_train_steps: 5000

  callbacks:
    metrics_over_trainsteps_checkpoint:
      params:
        every_n_train_steps: 50000

    image_logger:
      target: main.ImageLogger
      params:
        enable_autocast: False
        batch_frequency: 1000
        max_images: 8
        increase_log_steps: True

  trainer:
    devices: 0,
    limit_val_batches: 50
    benchmark: True
    accumulate_grad_batches: 1
    val_check_interval: 10000