Spaces:
Sleeping
Sleeping
StevenChen16
commited on
Commit
•
1adde95
1
Parent(s):
bc3dba1
Update train.py
Browse files
train.py
CHANGED
@@ -179,61 +179,61 @@ def train_one_step(model, noise_image, optimizer, target_content_features, targe
|
|
179 |
def main(content_img, style_img, epochs, step_per_epoch, learning_rate, content_loss_factor, style_loss_factor, img_size, img_width, img_height):
|
180 |
global CONTENT_LOSS_FACTOR, STYLE_LOSS_FACTOR, CONTENT_IMAGE_PATH, STYLE_IMAGE_PATH, OUTPUT_DIR, EPOCHS, LEARNING_RATE, STEPS_PER_EPOCH, M, N, image_mean, image_std, IMG_WIDTH, IMG_HEIGHT
|
181 |
|
182 |
-
with tf.device('/cuda:0'):
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
|
238 |
return save_image_for_gradio(noise_image)
|
239 |
|
|
|
179 |
def main(content_img, style_img, epochs, step_per_epoch, learning_rate, content_loss_factor, style_loss_factor, img_size, img_width, img_height):
|
180 |
global CONTENT_LOSS_FACTOR, STYLE_LOSS_FACTOR, CONTENT_IMAGE_PATH, STYLE_IMAGE_PATH, OUTPUT_DIR, EPOCHS, LEARNING_RATE, STEPS_PER_EPOCH, M, N, image_mean, image_std, IMG_WIDTH, IMG_HEIGHT
|
181 |
|
182 |
+
# with tf.device('/cuda:0'):
|
183 |
+
CONTENT_LOSS_FACTOR = content_loss_factor
|
184 |
+
STYLE_LOSS_FACTOR = style_loss_factor
|
185 |
+
CONTENT_IMAGE_PATH = content_img
|
186 |
+
STYLE_IMAGE_PATH = style_img
|
187 |
+
EPOCHS = epochs
|
188 |
+
LEARNING_RATE = learning_rate
|
189 |
+
STEPS_PER_EPOCH = step_per_epoch
|
190 |
+
|
191 |
+
# 内容特征层及损失加权系数
|
192 |
+
CONTENT_LAYERS = {"block4_conv2": 0.5, "block5_conv2": 0.5}
|
193 |
+
# 风格特征层及损失加权系数
|
194 |
+
STYLE_LAYERS = {
|
195 |
+
"block1_conv1": 0.2,
|
196 |
+
"block2_conv1": 0.2,
|
197 |
+
"block3_conv1": 0.2,
|
198 |
+
"block4_conv1": 0.2,
|
199 |
+
"block5_conv1": 0.2,
|
200 |
+
}
|
201 |
+
|
202 |
+
if img_size == "default size":
|
203 |
+
IMG_WIDTH = 450
|
204 |
+
IMG_HEIGHT = 300
|
205 |
+
else:
|
206 |
+
IMG_WIDTH = img_width
|
207 |
+
IMG_HEIGHT = img_height
|
208 |
+
|
209 |
+
print("IMG_WIDTH:", IMG_WIDTH)
|
210 |
+
print("IMG_HEIGHT:", IMG_HEIGHT)
|
211 |
+
|
212 |
+
# 我们准备使用经典网络在imagenet数据集上的预训练权重,所以归一化时也要使用imagenet的平均值和标准差
|
213 |
+
image_mean = tf.constant([0.485, 0.456, 0.406])
|
214 |
+
image_std = tf.constant([0.299, 0.224, 0.225])
|
215 |
+
|
216 |
+
model = NeuralStyleTransferModel(CONTENT_LAYERS, STYLE_LAYERS)
|
217 |
+
|
218 |
+
content_image = load_images_from_list(CONTENT_IMAGE_PATH, IMG_WIDTH, IMG_HEIGHT)
|
219 |
+
style_image = load_images_from_list(STYLE_IMAGE_PATH, IMG_WIDTH, IMG_HEIGHT)
|
220 |
+
|
221 |
+
target_content_features = model(content_image)["content"]
|
222 |
+
target_style_features = model(style_image)["style"]
|
223 |
+
|
224 |
+
M = IMG_WIDTH * IMG_HEIGHT
|
225 |
+
N = 3
|
226 |
+
|
227 |
+
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)
|
228 |
+
|
229 |
+
noise_image = tf.Variable((content_image[0] + np.random.uniform(-0.2, 0.2, (1, IMG_HEIGHT, IMG_WIDTH, 3))) / 2)
|
230 |
+
|
231 |
+
for epoch in range(EPOCHS):
|
232 |
+
with tqdm(total=STEPS_PER_EPOCH, desc="Epoch {}/{}".format(epoch + 1, EPOCHS)) as pbar:
|
233 |
+
for step in range(STEPS_PER_EPOCH):
|
234 |
+
_loss = train_one_step(model, noise_image, optimizer, target_content_features, target_style_features)
|
235 |
+
pbar.set_postfix({"loss": "%.4f" % float(_loss)})
|
236 |
+
pbar.update(1)
|
237 |
|
238 |
return save_image_for_gradio(noise_image)
|
239 |
|