PyTorch Translation and implementation of DeePSim generators
Generating Images with Perceptual Similarity Metrics based on Deep Networks Alexey Dosovitskiy, Thomas Brox (2016)
- Network architecture is translated from the original implementation in caffe using pytorch-caffe repo.
- The network definitions in pure torch were included in
GAN_utils.py
- Weights are translated from the pre-trained caffe weights from Alexey Dosovitskiy's homepage, and saved as torch state dict.
- This repo contains pre-trained state dicts of 9 generative models and a classification model (caffenet). The generative models are trained to invert the representation from various layers in the caffenet (norm1, norm2, conv3, conv4, pool5, fc6, fc7, fc8). All these models are relatively simple, consisting of linear, conv and deconvolution layers.
Example usage
from GAN_utils import Caffenet, upconvGAN
CNN = Caffenet(pretrained=True)
layer = "conv3"
invert_layer_id = 9
G = upconvGAN(name=layer, pretrained=True)
img = Image.open(...)
img = img.resize((227, 227))
img = np.array(img)
RGB_mean = torch.tensor([123.0, 117.0, 104.0])
RGB_mean = torch.reshape(RGB_mean, (1, 3, 1, 1))
img_preproc = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float() #
img_preproc = (img_preproc - RGB_mean)[:, [2, 1, 0], :, :]
with torch.no_grad():
out = CNN.net[:invert_layer_id + 1](img_preproc.cuda())
imgtsr_recon_pp = G.visualize(out).cpu()
plt.subplots(1, 2, figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(imgtsr_recon_pp[0].permute(1, 2, 0).detach().cpu().numpy())
plt.title(f"Reconstructed-{layer}")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(img)
plt.title("Original")
plt.axis("off")
plt.show()
Reconstruction of ImageNet validation set images (1st column) using norm1 to fc8 generators from corresponding layer representation.
To our understanding, the mapping of generative model and the layer number is the following.
invers_layer_map = {
"norm1": 3,
"norm2": 7,
"conv3": 9,
"conv4": 11,
"pool5": 14,
"fc6": 17,
"fc6_eucl": 17,
"fc7": 19,
"fc8": 20,
}
Unable to determine this model's library. Check the
docs
.