LFS dat
Browse files- .gitattributes +2 -0
- Dockerfile +54 -0
- LICENSE +21 -0
- README.md +171 -13
- __pycache__/kernel_utils.cpython-310.pyc +0 -0
- app.py +86 -0
- configs/b5.json +28 -0
- configs/b7.json +29 -0
- download_weights.sh +9 -0
- examples/liuujwwgpr.mp4 +3 -0
- examples/nlurbvsozt.mp4 +3 -0
- examples/rfjuhbnlro.mp4 +3 -0
- kernel_utils.py +365 -0
- libs/shape_predictor_68_face_landmarks.dat +3 -0
- training/__init__.py +0 -0
- training/__pycache__/__init__.cpython-310.pyc +0 -0
- training/__pycache__/__init__.cpython-39.pyc +0 -0
- training/__pycache__/losses.cpython-310.pyc +0 -0
- training/__pycache__/losses.cpython-39.pyc +0 -0
- training/datasets/__init__.py +0 -0
- training/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- training/datasets/__pycache__/classifier_dataset.cpython-310.pyc +0 -0
- training/datasets/__pycache__/validation_set.cpython-310.pyc +0 -0
- training/datasets/classifier_dataset.py +384 -0
- training/datasets/validation_set.py +60 -0
- training/losses.py +28 -0
- training/pipelines/__init__.py +0 -0
- training/pipelines/train_classifier.py +364 -0
- training/tools/__init__.py +0 -0
- training/tools/__pycache__/__init__.cpython-310.pyc +0 -0
- training/tools/__pycache__/config.cpython-310.pyc +0 -0
- training/tools/__pycache__/schedulers.cpython-310.pyc +0 -0
- training/tools/__pycache__/utils.cpython-310.pyc +0 -0
- training/tools/config.py +43 -0
- training/tools/schedulers.py +46 -0
- training/tools/utils.py +121 -0
- training/transforms/__init__.py +0 -0
- training/transforms/__pycache__/__init__.cpython-310.pyc +0 -0
- training/transforms/__pycache__/albu.cpython-310.pyc +0 -0
- training/transforms/albu.py +100 -0
- training/zoo/__init__.py +0 -0
- training/zoo/__pycache__/__init__.cpython-310.pyc +0 -0
- training/zoo/__pycache__/classifiers.cpython-310.pyc +0 -0
- training/zoo/classifiers.py +172 -0
- training/zoo/unet.py +151 -0
- weights/.gitkeep +0 -0
- weights/b7_ns_best.pth +3 -0
.gitattributes
CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.dat filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ARG PYTORCH="1.10.0"
|
2 |
+
ARG CUDA="11.3"
|
3 |
+
ARG CUDNN="8"
|
4 |
+
|
5 |
+
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
|
6 |
+
|
7 |
+
ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
|
8 |
+
ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"
|
9 |
+
|
10 |
+
# Setting noninteractive build, setting up tzdata and configuring timezones
|
11 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
12 |
+
ENV TZ=Europe/Berlin
|
13 |
+
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
14 |
+
|
15 |
+
RUN apt-get update && apt-get install -y libglib2.0-0 libsm6 libxrender-dev libxext6 nano mc glances vim git \
|
16 |
+
&& apt-get clean \
|
17 |
+
&& rm -rf /var/lib/apt/lists/*
|
18 |
+
|
19 |
+
# Install cython
|
20 |
+
RUN conda install cython -y && conda clean --all
|
21 |
+
|
22 |
+
# Installing APEX
|
23 |
+
RUN pip install -U pip
|
24 |
+
RUN git clone https://github.com/NVIDIA/apex
|
25 |
+
RUN sed -i 's/check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)/pass/g' apex/setup.py
|
26 |
+
RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./apex
|
27 |
+
RUN apt-get update -y
|
28 |
+
RUN apt-get install build-essential cmake -y
|
29 |
+
RUN apt-get install libopenblas-dev liblapack-dev -y
|
30 |
+
RUN apt-get install libx11-dev libgtk-3-dev -y
|
31 |
+
RUN pip install dlib
|
32 |
+
RUN pip install facenet-pytorch
|
33 |
+
RUN pip install albumentations==1.0.0 timm==0.4.12 pytorch_toolbelt tensorboardx
|
34 |
+
RUN pip install cython jupyter jupyterlab ipykernel matplotlib tqdm pandas
|
35 |
+
|
36 |
+
# download pretraned Imagenet models
|
37 |
+
RUN apt install wget
|
38 |
+
RUN wget https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth -P /root/.cache/torch/hub/checkpoints/
|
39 |
+
RUN wget https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth -P /root/.cache/torch/hub/checkpoints/
|
40 |
+
|
41 |
+
# Setting the working directory
|
42 |
+
WORKDIR /workspace
|
43 |
+
|
44 |
+
# Copying the required codebase
|
45 |
+
COPY . /workspace
|
46 |
+
|
47 |
+
RUN chmod 777 preprocess_data.sh
|
48 |
+
RUN chmod 777 train.sh
|
49 |
+
RUN chmod 777 predict_submission.sh
|
50 |
+
|
51 |
+
ENV PYTHONPATH=.
|
52 |
+
|
53 |
+
CMD ["/bin/bash"]
|
54 |
+
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 Selim Seferbekov
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,171 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## DeepFake Detection (DFDC) Solution by @selimsef
|
2 |
+
|
3 |
+
## Challenge details:
|
4 |
+
|
5 |
+
[Kaggle Challenge Page](https://www.kaggle.com/c/deepfake-detection-challenge)
|
6 |
+
|
7 |
+
|
8 |
+
### Fake detection articles
|
9 |
+
- [The Deepfake Detection Challenge (DFDC) Preview Dataset](https://arxiv.org/abs/1910.08854)
|
10 |
+
- [Deep Fake Image Detection Based on Pairwise Learning](https://www.mdpi.com/2076-3417/10/1/370)
|
11 |
+
- [DeeperForensics-1.0: A Large-Scale Dataset for Real-World Face Forgery Detection](https://arxiv.org/abs/2001.03024)
|
12 |
+
- [DeepFakes and Beyond: A Survey of Face Manipulation and Fake Detection](https://arxiv.org/abs/2001.00179)
|
13 |
+
- [Real or Fake? Spoofing State-Of-The-Art Face Synthesis Detection Systems](https://arxiv.org/abs/1911.05351)
|
14 |
+
- [CNN-generated images are surprisingly easy to spot... for now](https://arxiv.org/abs/1912.11035)
|
15 |
+
- [FakeSpotter: A Simple yet Robust Baseline for Spotting AI-Synthesized Fake Faces](https://arxiv.org/abs/1909.06122)
|
16 |
+
- [FakeLocator: Robust Localization of GAN-Based Face Manipulations via Semantic Segmentation Networks with Bells and Whistles](https://arxiv.org/abs/2001.09598)
|
17 |
+
- [Media Forensics and DeepFakes: an overview](https://arxiv.org/abs/2001.06564)
|
18 |
+
- [Face X-ray for More General Face Forgery Detection](https://arxiv.org/abs/1912.13458)
|
19 |
+
|
20 |
+
## Solution description
|
21 |
+
In general solution is based on frame-by-frame classification approach. Other complex things did not work so well on public leaderboard.
|
22 |
+
|
23 |
+
#### Face-Detector
|
24 |
+
MTCNN detector is chosen due to kernel time limits. It would be better to use S3FD detector as more precise and robust, but opensource Pytorch implementations don't have a license.
|
25 |
+
|
26 |
+
Input size for face detector was calculated for each video depending on video resolution.
|
27 |
+
|
28 |
+
- 2x scale for videos with less than 300 pixels wider side
|
29 |
+
- no rescale for videos with wider side between 300 and 1000
|
30 |
+
- 0.5x scale for videos with wider side > 1000 pixels
|
31 |
+
- 0.33x scale for videos with wider side > 1900 pixels
|
32 |
+
|
33 |
+
### Input size
|
34 |
+
As soon as I discovered that EfficientNets significantly outperform other encoders I used only them in my solution.
|
35 |
+
As I started with B4 I decided to use "native" size for that network (380x380).
|
36 |
+
Due to memory costraints I did not increase input size even for B7 encoder.
|
37 |
+
|
38 |
+
### Margin
|
39 |
+
When I generated crops for training I added 30% of face crop size from each side and used only this setting during the competition.
|
40 |
+
See [extract_crops.py](preprocessing/extract_crops.py) for the details
|
41 |
+
|
42 |
+
### Encoders
|
43 |
+
The winning encoder is current state-of-the-art model (EfficientNet B7) pretrained with ImageNet and noisy student [Self-training with Noisy Student improves ImageNet classification
|
44 |
+
](https://arxiv.org/abs/1911.04252)
|
45 |
+
|
46 |
+
### Averaging predictions
|
47 |
+
I used 32 frames for each video.
|
48 |
+
For each model output instead of simple averaging I used the following heuristic which worked quite well on public leaderbord (0.25 -> 0.22 solo B5).
|
49 |
+
```python
|
50 |
+
import numpy as np
|
51 |
+
|
52 |
+
def confident_strategy(pred, t=0.8):
|
53 |
+
pred = np.array(pred)
|
54 |
+
sz = len(pred)
|
55 |
+
fakes = np.count_nonzero(pred > t)
|
56 |
+
# 11 frames are detected as fakes with high probability
|
57 |
+
if fakes > sz // 2.5 and fakes > 11:
|
58 |
+
return np.mean(pred[pred > t])
|
59 |
+
elif np.count_nonzero(pred < 0.2) > 0.9 * sz:
|
60 |
+
return np.mean(pred[pred < 0.2])
|
61 |
+
else:
|
62 |
+
return np.mean(pred)
|
63 |
+
```
|
64 |
+
|
65 |
+
### Augmentations
|
66 |
+
|
67 |
+
I used heavy augmentations by default.
|
68 |
+
[Albumentations](https://github.com/albumentations-team/albumentations) library supports most of the augmentations out of the box. Only needed to add IsotropicResize augmentation.
|
69 |
+
```
|
70 |
+
|
71 |
+
def create_train_transforms(size=300):
|
72 |
+
return Compose([
|
73 |
+
ImageCompression(quality_lower=60, quality_upper=100, p=0.5),
|
74 |
+
GaussNoise(p=0.1),
|
75 |
+
GaussianBlur(blur_limit=3, p=0.05),
|
76 |
+
HorizontalFlip(),
|
77 |
+
OneOf([
|
78 |
+
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
|
79 |
+
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR),
|
80 |
+
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR),
|
81 |
+
], p=1),
|
82 |
+
PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
|
83 |
+
OneOf([RandomBrightnessContrast(), FancyPCA(), HueSaturationValue()], p=0.7),
|
84 |
+
ToGray(p=0.2),
|
85 |
+
ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5),
|
86 |
+
]
|
87 |
+
)
|
88 |
+
```
|
89 |
+
In addition to these augmentations I wanted to achieve better generalization with
|
90 |
+
- Cutout like augmentations (dropping artefacts and parts of face)
|
91 |
+
- Dropout part of the image, inspired by [GridMask](https://arxiv.org/abs/2001.04086) and [Severstal Winning Solution](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114254)
|
92 |
+
|
93 |
+
![augmentations](images/augmentations.jpg "Dropout augmentations")
|
94 |
+
|
95 |
+
## Building docker image
|
96 |
+
All libraries and enviroment is already configured with Dockerfile. It requires docker engine https://docs.docker.com/engine/install/ubuntu/ and nvidia docker in your system https://github.com/NVIDIA/nvidia-docker.
|
97 |
+
|
98 |
+
To build a docker image run `docker build -t df .`
|
99 |
+
|
100 |
+
## Running docker
|
101 |
+
`docker run --runtime=nvidia --ipc=host --rm --volume <DATA_ROOT>:/dataset -it df`
|
102 |
+
|
103 |
+
## Data preparation
|
104 |
+
|
105 |
+
Once DFDC dataset is downloaded all the scripts expect to have `dfdc_train_xxx` folders under data root directory.
|
106 |
+
|
107 |
+
Preprocessing is done in a single script **`preprocess_data.sh`** which requires dataset directory as first argument.
|
108 |
+
It will execute the steps below:
|
109 |
+
|
110 |
+
##### 1. Find face bboxes
|
111 |
+
To extract face bboxes I used facenet library, basically only MTCNN.
|
112 |
+
`python preprocessing/detect_original_faces.py --root-dir DATA_ROOT`
|
113 |
+
This script will detect faces in real videos and store them as jsons in DATA_ROOT/bboxes directory
|
114 |
+
|
115 |
+
##### 2. Extract crops from videos
|
116 |
+
To extract image crops I used bboxes saved before. It will use bounding boxes from original videos for face videos as well.
|
117 |
+
`python preprocessing/extract_crops.py --root-dir DATA_ROOT --crops-dir crops`
|
118 |
+
This script will extract face crops from videos and save them in DATA_ROOT/crops directory
|
119 |
+
|
120 |
+
##### 3. Generate landmarks
|
121 |
+
From the saved crops it is quite fast to process crops with MTCNN and extract landmarks
|
122 |
+
`python preprocessing/generate_landmarks.py --root-dir DATA_ROOT`
|
123 |
+
This script will extract landmarks and save them in DATA_ROOT/landmarks directory
|
124 |
+
|
125 |
+
##### 4. Generate diff SSIM masks
|
126 |
+
`python preprocessing/generate_diffs.py --root-dir DATA_ROOT`
|
127 |
+
This script will extract SSIM difference masks between real and fake images and save them in DATA_ROOT/diffs directory
|
128 |
+
|
129 |
+
##### 5. Generate folds
|
130 |
+
`python preprocessing/generate_folds.py --root-dir DATA_ROOT --out folds.csv`
|
131 |
+
By default it will use 16 splits to have 0-2 folders as a holdout set. Though only 400 videos can be used for validation as well.
|
132 |
+
|
133 |
+
|
134 |
+
## Training
|
135 |
+
|
136 |
+
Training 5 B7 models with different seeds is done in **`train.sh`** script.
|
137 |
+
|
138 |
+
During training checkpoints are saved for every epoch.
|
139 |
+
|
140 |
+
## Hardware requirements
|
141 |
+
Mostly trained on devbox configuration with 4xTitan V, thanks to Nvidia and DSB2018 competition where I got these gpus https://www.kaggle.com/c/data-science-bowl-2018/
|
142 |
+
|
143 |
+
Overall training requires 4 GPUs with 12gb+ memory.
|
144 |
+
Batch size needs to be adjusted for standard 1080Ti or 2080Ti graphic cards.
|
145 |
+
|
146 |
+
As I computed fake loss and real loss separately inside each batch, results might be better with larger batch size, for example on V100 gpus.
|
147 |
+
Even though SyncBN is used larger batch on each GPU will lead to less noise as DFDC dataset has some fakes where face detector failed and face crops are not really fakes.
|
148 |
+
|
149 |
+
## Plotting losses to select checkpoints
|
150 |
+
|
151 |
+
`python plot_loss.py --log-file logs/<log file>`
|
152 |
+
|
153 |
+
![loss plot](images/loss_plot.png "Weighted loss")
|
154 |
+
|
155 |
+
## Inference
|
156 |
+
|
157 |
+
|
158 |
+
Kernel is reproduced with `predict_folder.py` script.
|
159 |
+
|
160 |
+
|
161 |
+
## Pretrained models
|
162 |
+
`download_weights.sh` script will download trained models to `weights/` folder. They should be downloaded before building a docker image.
|
163 |
+
|
164 |
+
Ensemble inference is already preconfigured with `predict_submission.sh` bash script. It expects a directory with videos as first argument and an output csv file as second argument.
|
165 |
+
|
166 |
+
For example `./predict_submission.sh /mnt/datasets/deepfake/test_videos submission.csv`
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
|
171 |
+
|
__pycache__/kernel_utils.cpython-310.pyc
ADDED
Binary file (11.8 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import time
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video
|
8 |
+
from training.zoo.classifiers import DeepFakeClassifier
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
def model_fn(model_dir):
|
13 |
+
model_path = os.path.join(model_dir, 'b7_ns_best.pth')
|
14 |
+
model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns") # default: CPU
|
15 |
+
checkpoint = torch.load(model_path, map_location="cpu")
|
16 |
+
state_dict = checkpoint.get("state_dict", checkpoint)
|
17 |
+
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
|
18 |
+
model.eval()
|
19 |
+
del checkpoint
|
20 |
+
#models.append(model.half())
|
21 |
+
|
22 |
+
return model
|
23 |
+
|
24 |
+
def convert_result(pred, class_names=["Real", "Fake"]):
|
25 |
+
preds = [pred, 1 - pred]
|
26 |
+
assert len(class_names) == len(preds), "Class / Prediction should have the same length"
|
27 |
+
return {n: p for n, p in zip(class_names, preds)}
|
28 |
+
|
29 |
+
def predict_fn(model, video, meta):
|
30 |
+
start = time.time()
|
31 |
+
prediction = predict_on_video(face_extractor=meta["face_extractor"],
|
32 |
+
video_path=video,
|
33 |
+
batch_size=meta["fps"],
|
34 |
+
input_size=meta["input_size"],
|
35 |
+
models=model,
|
36 |
+
strategy=meta["strategy"],
|
37 |
+
apply_compression=False,
|
38 |
+
device='cpu')
|
39 |
+
|
40 |
+
elapsed_time = round(time.time() - start, 2)
|
41 |
+
|
42 |
+
prediction = convert_result(prediction)
|
43 |
+
|
44 |
+
return prediction, elapsed_time
|
45 |
+
|
46 |
+
# Create title, description and article strings
|
47 |
+
title = "Deepfake Detector (private)"
|
48 |
+
description = "A video Deepfake Classifier (code: https://github.com/selimsef/dfdc_deepfake_challenge)"
|
49 |
+
|
50 |
+
example_list = ["examples/" + str(p) for p in os.listdir("examples/")]
|
51 |
+
|
52 |
+
# Environments
|
53 |
+
model_dir = 'weights'
|
54 |
+
frames_per_video = 32
|
55 |
+
video_reader = VideoReader()
|
56 |
+
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
|
57 |
+
face_extractor = FaceExtractor(video_read_fn)
|
58 |
+
input_size = 380
|
59 |
+
strategy = confident_strategy
|
60 |
+
class_names = ["Real", "Fake"]
|
61 |
+
|
62 |
+
meta = {"fps": 32,
|
63 |
+
"face_extractor": face_extractor,
|
64 |
+
"input_size": input_size,
|
65 |
+
"strategy": strategy}
|
66 |
+
|
67 |
+
model = model_fn(model_dir)
|
68 |
+
|
69 |
+
"""
|
70 |
+
if __name__ == '__main__':
|
71 |
+
video_path = "nlurbvsozt.mp4"
|
72 |
+
model = model_fn(model_dir)
|
73 |
+
a, b = predict_fn([model], video_path, meta)
|
74 |
+
print(a, b)
|
75 |
+
"""
|
76 |
+
# Create the Gradio demo
|
77 |
+
demo = gr.Interface(fn=predict_fn, # mapping function from input to output
|
78 |
+
inputs=[[model], gr.Video(autosize=True), meta],
|
79 |
+
outputs=[gr.Label(num_top_classes=2, label="Predictions"), # what are the outputs?
|
80 |
+
gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
|
81 |
+
examples=example_list,
|
82 |
+
title=title,
|
83 |
+
description=description)
|
84 |
+
|
85 |
+
# Launch the demo!
|
86 |
+
demo.launch(debug=False,) # Hugging face space don't need shareable_links
|
configs/b5.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"network": "DeepFakeClassifier",
|
3 |
+
"encoder": "tf_efficientnet_b5_ns",
|
4 |
+
"batches_per_epoch": 2500,
|
5 |
+
"size": 380,
|
6 |
+
"fp16": true,
|
7 |
+
"optimizer": {
|
8 |
+
"batch_size": 20,
|
9 |
+
"type": "SGD",
|
10 |
+
"momentum": 0.9,
|
11 |
+
"weight_decay": 1e-4,
|
12 |
+
"learning_rate": 0.01,
|
13 |
+
"nesterov": true,
|
14 |
+
"schedule": {
|
15 |
+
"type": "poly",
|
16 |
+
"mode": "step",
|
17 |
+
"epochs": 30,
|
18 |
+
"params": {"max_iter": 75100}
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"normalize": {
|
22 |
+
"mean": [0.485, 0.456, 0.406],
|
23 |
+
"std": [0.229, 0.224, 0.225]
|
24 |
+
},
|
25 |
+
"losses": {
|
26 |
+
"BinaryCrossentropy": 1
|
27 |
+
}
|
28 |
+
}
|
configs/b7.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"network": "DeepFakeClassifier",
|
3 |
+
"encoder": "tf_efficientnet_b7_ns",
|
4 |
+
"batches_per_epoch": 2500,
|
5 |
+
"size": 380,
|
6 |
+
"fp16": true,
|
7 |
+
"optimizer": {
|
8 |
+
"batch_size": 4,
|
9 |
+
"type": "SGD",
|
10 |
+
"momentum": 0.9,
|
11 |
+
"weight_decay": 1e-4,
|
12 |
+
"learning_rate": 1e-4,
|
13 |
+
"nesterov": true,
|
14 |
+
"schedule": {
|
15 |
+
"type": "poly",
|
16 |
+
"mode": "step",
|
17 |
+
"epochs": 20,
|
18 |
+
"params": {"max_iter": 100500}
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"normalize": {
|
22 |
+
"mean": [0.485, 0.456, 0.406],
|
23 |
+
"std": [0.229, 0.224, 0.225]
|
24 |
+
},
|
25 |
+
"losses": {
|
26 |
+
"BinaryCrossentropy": 1
|
27 |
+
}
|
28 |
+
}
|
29 |
+
|
download_weights.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tag=0.0.1
|
2 |
+
|
3 |
+
wget -O weights/final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36
|
4 |
+
wget -O weights/final_555_DeepFakeClassifier_tf_efficientnet_b7_ns_0_19 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_555_DeepFakeClassifier_tf_efficientnet_b7_ns_0_19
|
5 |
+
wget -O weights/final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_29 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_29
|
6 |
+
wget -O weights/final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_31 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_31
|
7 |
+
wget -O weights/final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_37 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_37
|
8 |
+
wget -O weights/final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_40 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_40
|
9 |
+
wget -O weights/final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23
|
examples/liuujwwgpr.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b3aaefb51aa5720cdabcc68d93da5c6a22573d8da06bdaf5e009c7a370943e85
|
3 |
+
size 12852441
|
examples/nlurbvsozt.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:300b7dea93132b512f35de76572e7fcde666c812b91aec6b189dafa6f100c9b5
|
3 |
+
size 4486723
|
examples/rfjuhbnlro.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b6d0bb841ebe6a8e20cf265b45356a1ea3fed9837025e8d549b2437290d79273
|
3 |
+
size 16218775
|
kernel_utils.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from albumentations.augmentations.functional import image_compression
|
8 |
+
from facenet_pytorch.models.mtcnn import MTCNN
|
9 |
+
from concurrent.futures import ThreadPoolExecutor
|
10 |
+
|
11 |
+
from torchvision.transforms import Normalize
|
12 |
+
|
13 |
+
mean = [0.485, 0.456, 0.406]
|
14 |
+
std = [0.229, 0.224, 0.225]
|
15 |
+
normalize_transform = Normalize(mean, std)
|
16 |
+
|
17 |
+
|
18 |
+
class VideoReader:
|
19 |
+
"""Helper class for reading one or more frames from a video file."""
|
20 |
+
|
21 |
+
def __init__(self, verbose=True, insets=(0, 0)):
|
22 |
+
"""Creates a new VideoReader.
|
23 |
+
|
24 |
+
Arguments:
|
25 |
+
verbose: whether to print warnings and error messages
|
26 |
+
insets: amount to inset the image by, as a percentage of
|
27 |
+
(width, height). This lets you "zoom in" to an image
|
28 |
+
to remove unimportant content around the borders.
|
29 |
+
Useful for face detection, which may not work if the
|
30 |
+
faces are too small.
|
31 |
+
"""
|
32 |
+
self.verbose = verbose
|
33 |
+
self.insets = insets
|
34 |
+
|
35 |
+
def read_frames(self, path, num_frames, jitter=0, seed=None):
|
36 |
+
"""Reads frames that are always evenly spaced throughout the video.
|
37 |
+
|
38 |
+
Arguments:
|
39 |
+
path: the video file
|
40 |
+
num_frames: how many frames to read, -1 means the entire video
|
41 |
+
(warning: this will take up a lot of memory!)
|
42 |
+
jitter: if not 0, adds small random offsets to the frame indices;
|
43 |
+
this is useful so we don't always land on even or odd frames
|
44 |
+
seed: random seed for jittering; if you set this to a fixed value,
|
45 |
+
you probably want to set it only on the first video
|
46 |
+
"""
|
47 |
+
assert num_frames > 0
|
48 |
+
|
49 |
+
capture = cv2.VideoCapture(path)
|
50 |
+
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
51 |
+
if frame_count <= 0: return None
|
52 |
+
|
53 |
+
frame_idxs = np.linspace(0, frame_count - 1, num_frames, endpoint=True, dtype=np.int32)
|
54 |
+
if jitter > 0:
|
55 |
+
np.random.seed(seed)
|
56 |
+
jitter_offsets = np.random.randint(-jitter, jitter, len(frame_idxs))
|
57 |
+
frame_idxs = np.clip(frame_idxs + jitter_offsets, 0, frame_count - 1)
|
58 |
+
|
59 |
+
result = self._read_frames_at_indices(path, capture, frame_idxs)
|
60 |
+
capture.release()
|
61 |
+
return result
|
62 |
+
|
63 |
+
def read_random_frames(self, path, num_frames, seed=None):
|
64 |
+
"""Picks the frame indices at random.
|
65 |
+
|
66 |
+
Arguments:
|
67 |
+
path: the video file
|
68 |
+
num_frames: how many frames to read, -1 means the entire video
|
69 |
+
(warning: this will take up a lot of memory!)
|
70 |
+
"""
|
71 |
+
assert num_frames > 0
|
72 |
+
np.random.seed(seed)
|
73 |
+
|
74 |
+
capture = cv2.VideoCapture(path)
|
75 |
+
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
76 |
+
if frame_count <= 0: return None
|
77 |
+
|
78 |
+
frame_idxs = sorted(np.random.choice(np.arange(0, frame_count), num_frames))
|
79 |
+
result = self._read_frames_at_indices(path, capture, frame_idxs)
|
80 |
+
|
81 |
+
capture.release()
|
82 |
+
return result
|
83 |
+
|
84 |
+
def read_frames_at_indices(self, path, frame_idxs):
|
85 |
+
"""Reads frames from a video and puts them into a NumPy array.
|
86 |
+
|
87 |
+
Arguments:
|
88 |
+
path: the video file
|
89 |
+
frame_idxs: a list of frame indices. Important: should be
|
90 |
+
sorted from low-to-high! If an index appears multiple
|
91 |
+
times, the frame is still read only once.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
- a NumPy array of shape (num_frames, height, width, 3)
|
95 |
+
- a list of the frame indices that were read
|
96 |
+
|
97 |
+
Reading stops if loading a frame fails, in which case the first
|
98 |
+
dimension returned may actually be less than num_frames.
|
99 |
+
|
100 |
+
Returns None if an exception is thrown for any reason, or if no
|
101 |
+
frames were read.
|
102 |
+
"""
|
103 |
+
assert len(frame_idxs) > 0
|
104 |
+
capture = cv2.VideoCapture(path)
|
105 |
+
result = self._read_frames_at_indices(path, capture, frame_idxs)
|
106 |
+
capture.release()
|
107 |
+
return result
|
108 |
+
|
109 |
+
def _read_frames_at_indices(self, path, capture, frame_idxs):
|
110 |
+
try:
|
111 |
+
frames = []
|
112 |
+
idxs_read = []
|
113 |
+
for frame_idx in range(frame_idxs[0], frame_idxs[-1] + 1):
|
114 |
+
# Get the next frame, but don't decode if we're not using it.
|
115 |
+
ret = capture.grab()
|
116 |
+
if not ret:
|
117 |
+
if self.verbose:
|
118 |
+
print("Error grabbing frame %d from movie %s" % (frame_idx, path))
|
119 |
+
break
|
120 |
+
|
121 |
+
# Need to look at this frame?
|
122 |
+
current = len(idxs_read)
|
123 |
+
if frame_idx == frame_idxs[current]:
|
124 |
+
ret, frame = capture.retrieve()
|
125 |
+
if not ret or frame is None:
|
126 |
+
if self.verbose:
|
127 |
+
print("Error retrieving frame %d from movie %s" % (frame_idx, path))
|
128 |
+
break
|
129 |
+
|
130 |
+
frame = self._postprocess_frame(frame)
|
131 |
+
frames.append(frame)
|
132 |
+
idxs_read.append(frame_idx)
|
133 |
+
|
134 |
+
if len(frames) > 0:
|
135 |
+
return np.stack(frames), idxs_read
|
136 |
+
if self.verbose:
|
137 |
+
print("No frames read from movie %s" % path)
|
138 |
+
return None
|
139 |
+
except:
|
140 |
+
if self.verbose:
|
141 |
+
print("Exception while reading movie %s" % path)
|
142 |
+
return None
|
143 |
+
|
144 |
+
def read_middle_frame(self, path):
|
145 |
+
"""Reads the frame from the middle of the video."""
|
146 |
+
capture = cv2.VideoCapture(path)
|
147 |
+
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
148 |
+
result = self._read_frame_at_index(path, capture, frame_count // 2)
|
149 |
+
capture.release()
|
150 |
+
return result
|
151 |
+
|
152 |
+
def read_frame_at_index(self, path, frame_idx):
|
153 |
+
"""Reads a single frame from a video.
|
154 |
+
|
155 |
+
If you just want to read a single frame from the video, this is more
|
156 |
+
efficient than scanning through the video to find the frame. However,
|
157 |
+
for reading multiple frames it's not efficient.
|
158 |
+
|
159 |
+
My guess is that a "streaming" approach is more efficient than a
|
160 |
+
"random access" approach because, unless you happen to grab a keyframe,
|
161 |
+
the decoder still needs to read all the previous frames in order to
|
162 |
+
reconstruct the one you're asking for.
|
163 |
+
|
164 |
+
Returns a NumPy array of shape (1, H, W, 3) and the index of the frame,
|
165 |
+
or None if reading failed.
|
166 |
+
"""
|
167 |
+
capture = cv2.VideoCapture(path)
|
168 |
+
result = self._read_frame_at_index(path, capture, frame_idx)
|
169 |
+
capture.release()
|
170 |
+
return result
|
171 |
+
|
172 |
+
def _read_frame_at_index(self, path, capture, frame_idx):
|
173 |
+
capture.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
174 |
+
ret, frame = capture.read()
|
175 |
+
if not ret or frame is None:
|
176 |
+
if self.verbose:
|
177 |
+
print("Error retrieving frame %d from movie %s" % (frame_idx, path))
|
178 |
+
return None
|
179 |
+
else:
|
180 |
+
frame = self._postprocess_frame(frame)
|
181 |
+
return np.expand_dims(frame, axis=0), [frame_idx]
|
182 |
+
|
183 |
+
def _postprocess_frame(self, frame):
|
184 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
185 |
+
|
186 |
+
if self.insets[0] > 0:
|
187 |
+
W = frame.shape[1]
|
188 |
+
p = int(W * self.insets[0])
|
189 |
+
frame = frame[:, p:-p, :]
|
190 |
+
|
191 |
+
if self.insets[1] > 0:
|
192 |
+
H = frame.shape[1]
|
193 |
+
q = int(H * self.insets[1])
|
194 |
+
frame = frame[q:-q, :, :]
|
195 |
+
|
196 |
+
return frame
|
197 |
+
|
198 |
+
|
199 |
+
class FaceExtractor:
|
200 |
+
def __init__(self, video_read_fn):
|
201 |
+
self.video_read_fn = video_read_fn
|
202 |
+
self.detector = MTCNN(margin=0, thresholds=[0.7, 0.8, 0.8], device="cuda")
|
203 |
+
|
204 |
+
def process_videos(self, input_dir, filenames, video_idxs):
|
205 |
+
videos_read = []
|
206 |
+
frames_read = []
|
207 |
+
frames = []
|
208 |
+
results = []
|
209 |
+
for video_idx in video_idxs:
|
210 |
+
# Read the full-size frames from this video.
|
211 |
+
filename = filenames[video_idx]
|
212 |
+
video_path = os.path.join(input_dir, filename)
|
213 |
+
result = self.video_read_fn(video_path)
|
214 |
+
# Error? Then skip this video.
|
215 |
+
if result is None: continue
|
216 |
+
|
217 |
+
videos_read.append(video_idx)
|
218 |
+
|
219 |
+
# Keep track of the original frames (need them later).
|
220 |
+
my_frames, my_idxs = result
|
221 |
+
|
222 |
+
frames.append(my_frames)
|
223 |
+
frames_read.append(my_idxs)
|
224 |
+
for i, frame in enumerate(my_frames):
|
225 |
+
h, w = frame.shape[:2]
|
226 |
+
img = Image.fromarray(frame.astype(np.uint8))
|
227 |
+
img = img.resize(size=[s // 2 for s in img.size])
|
228 |
+
|
229 |
+
batch_boxes, probs = self.detector.detect(img, landmarks=False)
|
230 |
+
|
231 |
+
faces = []
|
232 |
+
scores = []
|
233 |
+
if batch_boxes is None:
|
234 |
+
continue
|
235 |
+
for bbox, score in zip(batch_boxes, probs):
|
236 |
+
if bbox is not None:
|
237 |
+
xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox]
|
238 |
+
w = xmax - xmin
|
239 |
+
h = ymax - ymin
|
240 |
+
p_h = h // 3
|
241 |
+
p_w = w // 3
|
242 |
+
crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w]
|
243 |
+
faces.append(crop)
|
244 |
+
scores.append(score)
|
245 |
+
|
246 |
+
frame_dict = {"video_idx": video_idx,
|
247 |
+
"frame_idx": my_idxs[i],
|
248 |
+
"frame_w": w,
|
249 |
+
"frame_h": h,
|
250 |
+
"faces": faces,
|
251 |
+
"scores": scores}
|
252 |
+
results.append(frame_dict)
|
253 |
+
|
254 |
+
return results
|
255 |
+
|
256 |
+
def process_video(self, video_path):
|
257 |
+
"""Convenience method for doing face extraction on a single video."""
|
258 |
+
input_dir = os.path.dirname(video_path)
|
259 |
+
filenames = [os.path.basename(video_path)]
|
260 |
+
return self.process_videos(input_dir, filenames, [0])
|
261 |
+
|
262 |
+
|
263 |
+
|
264 |
+
def confident_strategy(pred, t=0.8):
|
265 |
+
pred = np.array(pred)
|
266 |
+
sz = len(pred)
|
267 |
+
fakes = np.count_nonzero(pred > t)
|
268 |
+
# 11 frames are detected as fakes with high probability
|
269 |
+
if fakes > sz // 2.5 and fakes > 11:
|
270 |
+
return np.mean(pred[pred > t])
|
271 |
+
elif np.count_nonzero(pred < 0.2) > 0.9 * sz:
|
272 |
+
return np.mean(pred[pred < 0.2])
|
273 |
+
else:
|
274 |
+
return np.mean(pred)
|
275 |
+
|
276 |
+
strategy = confident_strategy
|
277 |
+
|
278 |
+
|
279 |
+
def put_to_center(img, input_size):
|
280 |
+
img = img[:input_size, :input_size]
|
281 |
+
image = np.zeros((input_size, input_size, 3), dtype=np.uint8)
|
282 |
+
start_w = (input_size - img.shape[1]) // 2
|
283 |
+
start_h = (input_size - img.shape[0]) // 2
|
284 |
+
image[start_h:start_h + img.shape[0], start_w: start_w + img.shape[1], :] = img
|
285 |
+
return image
|
286 |
+
|
287 |
+
|
288 |
+
def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
|
289 |
+
h, w = img.shape[:2]
|
290 |
+
if max(w, h) == size:
|
291 |
+
return img
|
292 |
+
if w > h:
|
293 |
+
scale = size / w
|
294 |
+
h = h * scale
|
295 |
+
w = size
|
296 |
+
else:
|
297 |
+
scale = size / h
|
298 |
+
w = w * scale
|
299 |
+
h = size
|
300 |
+
interpolation = interpolation_up if scale > 1 else interpolation_down
|
301 |
+
resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
|
302 |
+
return resized
|
303 |
+
|
304 |
+
|
305 |
+
def predict_on_video(face_extractor, video_path, batch_size, input_size, models, strategy=np.mean,
|
306 |
+
apply_compression=False, device='cpu'):
|
307 |
+
batch_size *= 4
|
308 |
+
try:
|
309 |
+
faces = face_extractor.process_video(video_path)
|
310 |
+
if len(faces) > 0:
|
311 |
+
x = np.zeros((batch_size, input_size, input_size, 3), dtype=np.uint8)
|
312 |
+
n = 0
|
313 |
+
for frame_data in faces:
|
314 |
+
for face in frame_data["faces"]:
|
315 |
+
resized_face = isotropically_resize_image(face, input_size)
|
316 |
+
resized_face = put_to_center(resized_face, input_size)
|
317 |
+
if apply_compression:
|
318 |
+
resized_face = image_compression(resized_face, quality=90, image_type=".jpg")
|
319 |
+
if n + 1 < batch_size:
|
320 |
+
x[n] = resized_face
|
321 |
+
n += 1
|
322 |
+
else:
|
323 |
+
pass
|
324 |
+
if n > 0:
|
325 |
+
if device == 'cpu':
|
326 |
+
x = torch.tensor(x, device='cpu').float()
|
327 |
+
else:
|
328 |
+
x = torch.tensor(x, device="cuda").float()
|
329 |
+
# Preprocess the images.
|
330 |
+
x = x.permute((0, 3, 1, 2))
|
331 |
+
for i in range(len(x)):
|
332 |
+
x[i] = normalize_transform(x[i] / 255.)
|
333 |
+
# Make a prediction, then take the average.
|
334 |
+
with torch.no_grad():
|
335 |
+
preds = []
|
336 |
+
for model in models:
|
337 |
+
if device == 'cpu':
|
338 |
+
y_pred = model(x[:n])
|
339 |
+
else:
|
340 |
+
y_pred = model(x[:n].half())
|
341 |
+
y_pred = torch.sigmoid(y_pred.squeeze())
|
342 |
+
bpred = y_pred[:n].cpu().numpy()
|
343 |
+
preds.append(strategy(bpred))
|
344 |
+
return np.mean(preds)
|
345 |
+
except Exception as e:
|
346 |
+
print("Prediction error on video %s: %s" % (video_path, str(e)))
|
347 |
+
|
348 |
+
return 0.5
|
349 |
+
|
350 |
+
|
351 |
+
def predict_on_video_set(face_extractor, videos, input_size, num_workers, test_dir, frames_per_video, models,
|
352 |
+
strategy=np.mean,
|
353 |
+
apply_compression=False):
|
354 |
+
def process_file(i):
|
355 |
+
filename = videos[i]
|
356 |
+
y_pred = predict_on_video(face_extractor=face_extractor, video_path=os.path.join(test_dir, filename),
|
357 |
+
input_size=input_size,
|
358 |
+
batch_size=frames_per_video,
|
359 |
+
models=models, strategy=strategy, apply_compression=apply_compression)
|
360 |
+
return y_pred
|
361 |
+
|
362 |
+
with ThreadPoolExecutor(max_workers=num_workers) as ex:
|
363 |
+
predictions = ex.map(process_file, range(len(videos)))
|
364 |
+
return list(predictions)
|
365 |
+
|
libs/shape_predictor_68_face_landmarks.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f
|
3 |
+
size 99693937
|
training/__init__.py
ADDED
File without changes
|
training/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (148 Bytes). View file
|
|
training/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (146 Bytes). View file
|
|
training/__pycache__/losses.cpython-310.pyc
ADDED
Binary file (1.54 kB). View file
|
|
training/__pycache__/losses.cpython-39.pyc
ADDED
Binary file (1.53 kB). View file
|
|
training/datasets/__init__.py
ADDED
File without changes
|
training/datasets/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (157 Bytes). View file
|
|
training/datasets/__pycache__/classifier_dataset.cpython-310.pyc
ADDED
Binary file (10.8 kB). View file
|
|
training/datasets/__pycache__/validation_set.cpython-310.pyc
ADDED
Binary file (4.99 kB). View file
|
|
training/datasets/classifier_dataset.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import sys
|
5 |
+
import traceback
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import skimage.draw
|
11 |
+
from albumentations import ImageCompression, OneOf, GaussianBlur, Blur
|
12 |
+
from albumentations.augmentations.functional import image_compression
|
13 |
+
from albumentations.augmentations.geometric.functional import rot90
|
14 |
+
from albumentations.pytorch.functional import img_to_tensor
|
15 |
+
from scipy.ndimage import binary_erosion, binary_dilation
|
16 |
+
from skimage import measure
|
17 |
+
from torch.utils.data import Dataset
|
18 |
+
import dlib
|
19 |
+
|
20 |
+
from training.datasets.validation_set import PUBLIC_SET
|
21 |
+
|
22 |
+
|
23 |
+
def prepare_bit_masks(mask):
|
24 |
+
h, w = mask.shape
|
25 |
+
mid_w = w // 2
|
26 |
+
mid_h = w // 2
|
27 |
+
masks = []
|
28 |
+
ones = np.ones_like(mask)
|
29 |
+
ones[:mid_h] = 0
|
30 |
+
masks.append(ones)
|
31 |
+
ones = np.ones_like(mask)
|
32 |
+
ones[mid_h:] = 0
|
33 |
+
masks.append(ones)
|
34 |
+
ones = np.ones_like(mask)
|
35 |
+
ones[:, :mid_w] = 0
|
36 |
+
masks.append(ones)
|
37 |
+
ones = np.ones_like(mask)
|
38 |
+
ones[:, mid_w:] = 0
|
39 |
+
masks.append(ones)
|
40 |
+
ones = np.ones_like(mask)
|
41 |
+
ones[:mid_h, :mid_w] = 0
|
42 |
+
ones[mid_h:, mid_w:] = 0
|
43 |
+
masks.append(ones)
|
44 |
+
ones = np.ones_like(mask)
|
45 |
+
ones[:mid_h, mid_w:] = 0
|
46 |
+
ones[mid_h:, :mid_w] = 0
|
47 |
+
masks.append(ones)
|
48 |
+
return masks
|
49 |
+
|
50 |
+
|
51 |
+
detector = dlib.get_frontal_face_detector()
|
52 |
+
predictor = dlib.shape_predictor('libs/shape_predictor_68_face_landmarks.dat')
|
53 |
+
|
54 |
+
|
55 |
+
def blackout_convex_hull(img):
|
56 |
+
try:
|
57 |
+
rect = detector(img)[0]
|
58 |
+
sp = predictor(img, rect)
|
59 |
+
landmarks = np.array([[p.x, p.y] for p in sp.parts()])
|
60 |
+
outline = landmarks[[*range(17), *range(26, 16, -1)]]
|
61 |
+
Y, X = skimage.draw.polygon(outline[:, 1], outline[:, 0])
|
62 |
+
cropped_img = np.zeros(img.shape[:2], dtype=np.uint8)
|
63 |
+
cropped_img[Y, X] = 1
|
64 |
+
# if random.random() > 0.5:
|
65 |
+
# img[cropped_img == 0] = 0
|
66 |
+
# #leave only face
|
67 |
+
# return img
|
68 |
+
|
69 |
+
y, x = measure.centroid(cropped_img)
|
70 |
+
y = int(y)
|
71 |
+
x = int(x)
|
72 |
+
first = random.random() > 0.5
|
73 |
+
if random.random() > 0.5:
|
74 |
+
if first:
|
75 |
+
cropped_img[:y, :] = 0
|
76 |
+
else:
|
77 |
+
cropped_img[y:, :] = 0
|
78 |
+
else:
|
79 |
+
if first:
|
80 |
+
cropped_img[:, :x] = 0
|
81 |
+
else:
|
82 |
+
cropped_img[:, x:] = 0
|
83 |
+
|
84 |
+
img[cropped_img > 0] = 0
|
85 |
+
except Exception as e:
|
86 |
+
pass
|
87 |
+
|
88 |
+
|
89 |
+
def dist(p1, p2):
|
90 |
+
return math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
|
91 |
+
|
92 |
+
|
93 |
+
def remove_eyes(image, landmarks):
|
94 |
+
image = image.copy()
|
95 |
+
(x1, y1), (x2, y2) = landmarks[:2]
|
96 |
+
mask = np.zeros_like(image[..., 0])
|
97 |
+
line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2)
|
98 |
+
w = dist((x1, y1), (x2, y2))
|
99 |
+
dilation = int(w // 4)
|
100 |
+
line = binary_dilation(line, iterations=dilation)
|
101 |
+
image[line, :] = 0
|
102 |
+
return image
|
103 |
+
|
104 |
+
|
105 |
+
def remove_nose(image, landmarks):
|
106 |
+
image = image.copy()
|
107 |
+
(x1, y1), (x2, y2) = landmarks[:2]
|
108 |
+
x3, y3 = landmarks[2]
|
109 |
+
mask = np.zeros_like(image[..., 0])
|
110 |
+
x4 = int((x1 + x2) / 2)
|
111 |
+
y4 = int((y1 + y2) / 2)
|
112 |
+
line = cv2.line(mask, (x3, y3), (x4, y4), color=(1), thickness=2)
|
113 |
+
w = dist((x1, y1), (x2, y2))
|
114 |
+
dilation = int(w // 4)
|
115 |
+
line = binary_dilation(line, iterations=dilation)
|
116 |
+
image[line, :] = 0
|
117 |
+
return image
|
118 |
+
|
119 |
+
|
120 |
+
def remove_mouth(image, landmarks):
|
121 |
+
image = image.copy()
|
122 |
+
(x1, y1), (x2, y2) = landmarks[-2:]
|
123 |
+
mask = np.zeros_like(image[..., 0])
|
124 |
+
line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2)
|
125 |
+
w = dist((x1, y1), (x2, y2))
|
126 |
+
dilation = int(w // 3)
|
127 |
+
line = binary_dilation(line, iterations=dilation)
|
128 |
+
image[line, :] = 0
|
129 |
+
return image
|
130 |
+
|
131 |
+
|
132 |
+
def remove_landmark(image, landmarks):
|
133 |
+
if random.random() > 0.5:
|
134 |
+
image = remove_eyes(image, landmarks)
|
135 |
+
elif random.random() > 0.5:
|
136 |
+
image = remove_mouth(image, landmarks)
|
137 |
+
elif random.random() > 0.5:
|
138 |
+
image = remove_nose(image, landmarks)
|
139 |
+
return image
|
140 |
+
|
141 |
+
|
142 |
+
def change_padding(image, part=5):
|
143 |
+
h, w = image.shape[:2]
|
144 |
+
# original padding was done with 1/3 from each side, too much
|
145 |
+
pad_h = int(((3 / 5) * h) / part)
|
146 |
+
pad_w = int(((3 / 5) * w) / part)
|
147 |
+
image = image[h // 5 - pad_h:-h // 5 + pad_h, w // 5 - pad_w:-w // 5 + pad_w]
|
148 |
+
return image
|
149 |
+
|
150 |
+
|
151 |
+
def blackout_random(image, mask, label):
|
152 |
+
binary_mask = mask > 0.4 * 255
|
153 |
+
h, w = binary_mask.shape[:2]
|
154 |
+
|
155 |
+
tries = 50
|
156 |
+
current_try = 1
|
157 |
+
while current_try < tries:
|
158 |
+
first = random.random() < 0.5
|
159 |
+
if random.random() < 0.5:
|
160 |
+
pivot = random.randint(h // 2 - h // 5, h // 2 + h // 5)
|
161 |
+
bitmap_msk = np.ones_like(binary_mask)
|
162 |
+
if first:
|
163 |
+
bitmap_msk[:pivot, :] = 0
|
164 |
+
else:
|
165 |
+
bitmap_msk[pivot:, :] = 0
|
166 |
+
else:
|
167 |
+
pivot = random.randint(w // 2 - w // 5, w // 2 + w // 5)
|
168 |
+
bitmap_msk = np.ones_like(binary_mask)
|
169 |
+
if first:
|
170 |
+
bitmap_msk[:, :pivot] = 0
|
171 |
+
else:
|
172 |
+
bitmap_msk[:, pivot:] = 0
|
173 |
+
|
174 |
+
if label < 0.5 and np.count_nonzero(image * np.expand_dims(bitmap_msk, axis=-1)) / 3 > (h * w) / 5 \
|
175 |
+
or np.count_nonzero(binary_mask * bitmap_msk) > 40:
|
176 |
+
mask *= bitmap_msk
|
177 |
+
image *= np.expand_dims(bitmap_msk, axis=-1)
|
178 |
+
break
|
179 |
+
current_try += 1
|
180 |
+
return image
|
181 |
+
|
182 |
+
|
183 |
+
def blend_original(img):
|
184 |
+
img = img.copy()
|
185 |
+
h, w = img.shape[:2]
|
186 |
+
rect = detector(img)
|
187 |
+
if len(rect) == 0:
|
188 |
+
return img
|
189 |
+
else:
|
190 |
+
rect = rect[0]
|
191 |
+
sp = predictor(img, rect)
|
192 |
+
landmarks = np.array([[p.x, p.y] for p in sp.parts()])
|
193 |
+
outline = landmarks[[*range(17), *range(26, 16, -1)]]
|
194 |
+
Y, X = skimage.draw.polygon(outline[:, 1], outline[:, 0])
|
195 |
+
raw_mask = np.zeros(img.shape[:2], dtype=np.uint8)
|
196 |
+
raw_mask[Y, X] = 1
|
197 |
+
face = img * np.expand_dims(raw_mask, -1)
|
198 |
+
|
199 |
+
# add warping
|
200 |
+
h1 = random.randint(h - h // 2, h + h // 2)
|
201 |
+
w1 = random.randint(w - w // 2, w + w // 2)
|
202 |
+
while abs(h1 - h) < h // 3 and abs(w1 - w) < w // 3:
|
203 |
+
h1 = random.randint(h - h // 2, h + h // 2)
|
204 |
+
w1 = random.randint(w - w // 2, w + w // 2)
|
205 |
+
face = cv2.resize(face, (w1, h1), interpolation=random.choice([cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC]))
|
206 |
+
face = cv2.resize(face, (w, h), interpolation=random.choice([cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC]))
|
207 |
+
|
208 |
+
raw_mask = binary_erosion(raw_mask, iterations=random.randint(4, 10))
|
209 |
+
img[raw_mask, :] = face[raw_mask, :]
|
210 |
+
if random.random() < 0.2:
|
211 |
+
img = OneOf([GaussianBlur(), Blur()], p=0.5)(image=img)["image"]
|
212 |
+
# image compression
|
213 |
+
if random.random() < 0.5:
|
214 |
+
img = ImageCompression(quality_lower=40, quality_upper=95)(image=img)["image"]
|
215 |
+
return img
|
216 |
+
|
217 |
+
|
218 |
+
class DeepFakeClassifierDataset(Dataset):
|
219 |
+
|
220 |
+
def __init__(self,
|
221 |
+
data_path="/mnt/sota/datasets/deepfake",
|
222 |
+
fold=0,
|
223 |
+
label_smoothing=0.01,
|
224 |
+
padding_part=3,
|
225 |
+
hardcore=True,
|
226 |
+
crops_dir="crops",
|
227 |
+
folds_csv="folds.csv",
|
228 |
+
normalize={"mean": [0.485, 0.456, 0.406],
|
229 |
+
"std": [0.229, 0.224, 0.225]},
|
230 |
+
rotation=False,
|
231 |
+
mode="train",
|
232 |
+
reduce_val=True,
|
233 |
+
oversample_real=True,
|
234 |
+
transforms=None
|
235 |
+
):
|
236 |
+
super().__init__()
|
237 |
+
self.data_root = data_path
|
238 |
+
self.fold = fold
|
239 |
+
self.folds_csv = folds_csv
|
240 |
+
self.mode = mode
|
241 |
+
self.rotation = rotation
|
242 |
+
self.padding_part = padding_part
|
243 |
+
self.hardcore = hardcore
|
244 |
+
self.crops_dir = crops_dir
|
245 |
+
self.label_smoothing = label_smoothing
|
246 |
+
self.normalize = normalize
|
247 |
+
self.transforms = transforms
|
248 |
+
self.df = pd.read_csv(self.folds_csv)
|
249 |
+
self.oversample_real = oversample_real
|
250 |
+
self.reduce_val = reduce_val
|
251 |
+
|
252 |
+
def __getitem__(self, index: int):
|
253 |
+
|
254 |
+
while True:
|
255 |
+
video, img_file, label, ori_video, frame, fold = self.data[index]
|
256 |
+
try:
|
257 |
+
if self.mode == "train":
|
258 |
+
label = np.clip(label, self.label_smoothing, 1 - self.label_smoothing)
|
259 |
+
img_path = os.path.join(self.data_root, self.crops_dir, video, img_file)
|
260 |
+
image = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
261 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
262 |
+
mask = np.zeros(image.shape[:2], dtype=np.uint8)
|
263 |
+
diff_path = os.path.join(self.data_root, "diffs", video, img_file[:-4] + "_diff.png")
|
264 |
+
try:
|
265 |
+
msk = cv2.imread(diff_path, cv2.IMREAD_GRAYSCALE)
|
266 |
+
if msk is not None:
|
267 |
+
mask = msk
|
268 |
+
except:
|
269 |
+
print("not found mask", diff_path)
|
270 |
+
pass
|
271 |
+
if self.mode == "train" and self.hardcore and not self.rotation:
|
272 |
+
landmark_path = os.path.join(self.data_root, "landmarks", ori_video, img_file[:-4] + ".npy")
|
273 |
+
if os.path.exists(landmark_path) and random.random() < 0.7:
|
274 |
+
landmarks = np.load(landmark_path)
|
275 |
+
image = remove_landmark(image, landmarks)
|
276 |
+
elif random.random() < 0.2:
|
277 |
+
blackout_convex_hull(image)
|
278 |
+
elif random.random() < 0.1:
|
279 |
+
binary_mask = mask > 0.4 * 255
|
280 |
+
masks = prepare_bit_masks((binary_mask * 1).astype(np.uint8))
|
281 |
+
tries = 6
|
282 |
+
current_try = 1
|
283 |
+
while current_try < tries:
|
284 |
+
bitmap_msk = random.choice(masks)
|
285 |
+
if label < 0.5 or np.count_nonzero(mask * bitmap_msk) > 20:
|
286 |
+
mask *= bitmap_msk
|
287 |
+
image *= np.expand_dims(bitmap_msk, axis=-1)
|
288 |
+
break
|
289 |
+
current_try += 1
|
290 |
+
if self.mode == "train" and self.padding_part > 3:
|
291 |
+
image = change_padding(image, self.padding_part)
|
292 |
+
valid_label = np.count_nonzero(mask[mask > 20]) > 32 or label < 0.5
|
293 |
+
valid_label = 1 if valid_label else 0
|
294 |
+
rotation = 0
|
295 |
+
if self.transforms:
|
296 |
+
data = self.transforms(image=image, mask=mask)
|
297 |
+
image = data["image"]
|
298 |
+
mask = data["mask"]
|
299 |
+
if self.mode == "train" and self.hardcore and self.rotation:
|
300 |
+
# landmark_path = os.path.join(self.data_root, "landmarks", ori_video, img_file[:-4] + ".npy")
|
301 |
+
dropout = 0.8 if label > 0.5 else 0.6
|
302 |
+
if self.rotation:
|
303 |
+
dropout *= 0.7
|
304 |
+
elif random.random() < dropout:
|
305 |
+
blackout_random(image, mask, label)
|
306 |
+
|
307 |
+
#
|
308 |
+
# os.makedirs("../images", exist_ok=True)
|
309 |
+
# cv2.imwrite(os.path.join("../images", video+ "_" + str(1 if label > 0.5 else 0) + "_"+img_file), image[...,::-1])
|
310 |
+
|
311 |
+
if self.mode == "train" and self.rotation:
|
312 |
+
rotation = random.randint(0, 3)
|
313 |
+
image = rot90(image, rotation)
|
314 |
+
|
315 |
+
image = img_to_tensor(image, self.normalize)
|
316 |
+
return {"image": image, "labels": np.array((label,)), "img_name": os.path.join(video, img_file),
|
317 |
+
"valid": valid_label, "rotations": rotation}
|
318 |
+
except Exception as e:
|
319 |
+
traceback.print_exc(file=sys.stdout)
|
320 |
+
print("Broken image", os.path.join(self.data_root, self.crops_dir, video, img_file))
|
321 |
+
index = random.randint(0, len(self.data) - 1)
|
322 |
+
|
323 |
+
def random_blackout_landmark(self, image, mask, landmarks):
|
324 |
+
x, y = random.choice(landmarks)
|
325 |
+
first = random.random() > 0.5
|
326 |
+
# crop half face either vertically or horizontally
|
327 |
+
if random.random() > 0.5:
|
328 |
+
# width
|
329 |
+
if first:
|
330 |
+
image[:, :x] = 0
|
331 |
+
mask[:, :x] = 0
|
332 |
+
else:
|
333 |
+
image[:, x:] = 0
|
334 |
+
mask[:, x:] = 0
|
335 |
+
else:
|
336 |
+
# height
|
337 |
+
if first:
|
338 |
+
image[:y, :] = 0
|
339 |
+
mask[:y, :] = 0
|
340 |
+
else:
|
341 |
+
image[y:, :] = 0
|
342 |
+
mask[y:, :] = 0
|
343 |
+
|
344 |
+
def reset(self, epoch, seed):
|
345 |
+
self.data = self._prepare_data(epoch, seed)
|
346 |
+
|
347 |
+
def __len__(self) -> int:
|
348 |
+
return len(self.data)
|
349 |
+
|
350 |
+
def get_distribution(self):
|
351 |
+
return self.n_real, self.n_fake
|
352 |
+
|
353 |
+
def _prepare_data(self, epoch, seed):
|
354 |
+
df = self.df
|
355 |
+
if self.mode == "train":
|
356 |
+
rows = df[df["fold"] != self.fold]
|
357 |
+
else:
|
358 |
+
rows = df[df["fold"] == self.fold]
|
359 |
+
seed = (epoch + 1) * seed
|
360 |
+
if self.oversample_real:
|
361 |
+
rows = self._oversample(rows, seed)
|
362 |
+
if self.mode == "val" and self.reduce_val:
|
363 |
+
# every 2nd frame, to speed up validation
|
364 |
+
rows = rows[rows["frame"] % 20 == 0]
|
365 |
+
# another option is to use public validation set
|
366 |
+
#rows = rows[rows["video"].isin(PUBLIC_SET)]
|
367 |
+
|
368 |
+
print(
|
369 |
+
"real {} fakes {} mode {}".format(len(rows[rows["label"] == 0]), len(rows[rows["label"] == 1]), self.mode))
|
370 |
+
data = rows.values
|
371 |
+
|
372 |
+
self.n_real = len(rows[rows["label"] == 0])
|
373 |
+
self.n_fake = len(rows[rows["label"] == 1])
|
374 |
+
np.random.seed(seed)
|
375 |
+
np.random.shuffle(data)
|
376 |
+
return data
|
377 |
+
|
378 |
+
def _oversample(self, rows: pd.DataFrame, seed):
|
379 |
+
real = rows[rows["label"] == 0]
|
380 |
+
fakes = rows[rows["label"] == 1]
|
381 |
+
num_real = real["video"].count()
|
382 |
+
if self.mode == "train":
|
383 |
+
fakes = fakes.sample(n=num_real, replace=False, random_state=seed)
|
384 |
+
return pd.concat([real, fakes])
|
training/datasets/validation_set.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
PUBLIC_SET = {'tjuihawuqm', 'prwsfljdjo', 'scrbqgpvzz', 'ziipxxchai', 'uubgqnvfdl', 'wclvkepakb', 'xjvxtuakyd',
|
4 |
+
'qlvsqdroqo', 'bcbqxhziqz', 'yzuestxcbq', 'hxwtsaydal', 'kqlvggiqee', 'vtunvalyji', 'mohiqoogpb',
|
5 |
+
'siebfpwuhu', 'cekwtyxdoo', 'hszwwswewp', 'orekjthsef', 'huvlwkxoxm', 'fmhiujydwo', 'lhvjzhjxdp',
|
6 |
+
'ibxfxggtqh', 'bofrwgeyjo', 'rmufsuogzn', 'zbgssotnjm', 'dpevefkefv', 'sufvvwmbha', 'ncoeewrdlo',
|
7 |
+
'qhsehzgxqj', 'yxadevzohx', 'aomqqjipcp', 'pcyswtgick', 'wfzjxzhdkj', 'rcjfxxhcal', 'lnjkpdviqb',
|
8 |
+
'xmkwsnuzyq', 'ouaowjmigq', 'bkuzquigyt', 'vwxednhlwz', 'mszblrdprw', 'blnmxntbey', 'gccnvdoknm',
|
9 |
+
'mkzaekkvej', 'hclsparpth', 'eryjktdexi', 'hfsvqabzfq', 'acazlolrpz', 'yoyhmxtrys', 'rerpivllud',
|
10 |
+
'elackxuccp', 'zgbhzkditd', 'vjljdfopjg', 'famlupsgqm', 'nymodlmxni', 'qcbkztamqc', 'qclpbcbgeq',
|
11 |
+
'lpkgabskbw', 'mnowxangqx', 'czfqlbcfpa', 'qyyhuvqmyf', 'toinozytsp', 'ztyvglkcsf', 'nplviymzlg',
|
12 |
+
'opvqdabdap', 'uxuvkrjhws', 'mxahsihabr', 'cqxxumarvp', 'ptbfnkajyi', 'njzshtfmcw', 'dcqodpzomd',
|
13 |
+
'ajiyrjfyzp', 'ywauoonmlr', 'gochxzemmq', 'lpgxwdgnio', 'hnfwagcxdf', 'gfcycflhbo', 'gunamloolc',
|
14 |
+
'yhjlnisfel', 'srfefmyjvt', 'evysmtpnrf', 'aktnlyqpah', 'gpsxfxrjrr', 'zfobicuigx', 'mnzabbkpmt',
|
15 |
+
'rfjuhbnlro', 'zuwwbbusgl', 'csnkohqxdv', 'bzvzpwrabw', 'yietrwuncf', 'wynotylpnm', 'ekboxwrwuv',
|
16 |
+
'rcecrgeotc', 'rklawjhbpv', 'ilqwcbprqa', 'jsysgmycsx', 'sqixhnilfm', 'wnlubukrki', 'nikynwcvuh',
|
17 |
+
'sjkfxrlxxs', 'btdxnajogv', 'wjhpisoeaj', 'dyjklprkoc', 'qlqhjcshpk', 'jyfvaequfg', 'dozjwhnedd',
|
18 |
+
'owaogcehvc', 'oyqgwjdwaj', 'vvfszaosiv', 'kmcdjxmnoa', 'jiswxuqzyz', 'ddtbarpcgo', 'wqysrieiqu',
|
19 |
+
'xcruhaccxc', 'honxqdilvv', 'nxgzmgzkfv', 'cxsvvnxpyz', 'demuhxssgl', 'hzoiotcykp', 'fwykevubzy',
|
20 |
+
'tejfudfgpq', 'kvmpmhdxly', 'oojxonbgow', 'vurjckblge', 'oysopgovhu', 'khpipxnsvx', 'pqthmvwonf',
|
21 |
+
'fddmkqjwsh', 'pcoxcmtroa', 'cnxccbjlct', 'ggzjfrirjh', 'jquevmhdvc', 'ecumyiowzs', 'esmqxszybs',
|
22 |
+
'mllzkpgatp', 'ryxaqpfubf', 'hbufmvbium', 'vdtsbqidjb', 'sjwywglgym', 'qxyrtwozyw', 'upmgtackuf',
|
23 |
+
'ucthmsajay', 'zgjosltkie', 'snlyjbnpgw', 'nswtvttxre', 'iznnzjvaxc', 'jhczqfefgw', 'htzbnroagi',
|
24 |
+
'pdswwyyntw', 'uvrzaczrbx', 'vbcgoyxsvn', 'hzssdinxec', 'novarhxpbj', 'vizerpsvbz', 'jawgcggquk',
|
25 |
+
'iorbtaarte', 'yarpxfqejd', 'vhbbwdflyh', 'rrrfjhugvb', 'fneqiqpqvs', 'jytrvwlewz', 'bfjsthfhbd',
|
26 |
+
'rxdoimqble', 'ekelfsnqof', 'uqvxjfpwdo', 'cjkctqqakb', 'tynfsthodx', 'yllztsrwjw', 'bktkwbcawi',
|
27 |
+
'wcqvzujamg', 'bcvheslzrq', 'aqrsylrzgi', 'sktpeppbkc', 'mkmgcxaztt', 'etdliwticv', 'hqzwudvhih',
|
28 |
+
'swsaoktwgi', 'temjefwaas', 'papagllumt', 'xrtvqhdibb', 'oelqpetgwj', 'ggdpclfcgk', 'imdmhwkkni',
|
29 |
+
'lebzjtusnr', 'xhtppuyqdr', 'nxzgekegsp', 'waucvvmtkq', 'rnfcjxynfa', 'adohdulfwb', 'tjywwgftmv',
|
30 |
+
'fjrueenjyp', 'oaguiggjyv', 'ytopzxrswu', 'yxvmusxvcz', 'rukyxomwcx', 'qdqdsaiitt', 'mxlipjhmqk',
|
31 |
+
'voawxrmqyl', 'kezwvsxxzj', 'oocincvedt', 'qooxnxqqjb', 'mwwploizlj', 'yaxgpxhavq', 'uhakqelqri',
|
32 |
+
'bvpeerislp', 'bkcyglmfci', 'jyoxdvxpza', 'gkutjglghz', 'knxltsvzyu', 'ybbrkacebd', 'apvzjkvnwn',
|
33 |
+
'ahjnxtiamx', 'hsbljbsgxr', 'fnxgqcvlsd', 'xphdfgmfmz', 'scbdenmaed', 'ywxpquomgt', 'yljecirelf',
|
34 |
+
'wcvsqnplsk', 'vmxfwxgdei', 'icbsahlivv', 'yhylappzid', 'irqzdokcws', 'petmyhjclt', 'rmlzgerevr',
|
35 |
+
'qarqtkvgby', 'nkhzxomani', 'viteugozpv', 'qhkzlnzruj', 'eisofhptvk', 'gqnaxievjx', 'heiyoojifp',
|
36 |
+
'zcxcmneefk', 'wvgviwnwob', 'gcdtglsoqj', 'yqhouqakbx', 'fopjiyxiqd', 'hierggamuo', 'ypbtpunjvm',
|
37 |
+
'sjinmmbipg', 'kmqkiihrmj', 'wmoqzxddkb', 'lnhkjhyhvw', 'wixbuuzygv', 'fsdrwikhge', 'sfsayjgzrh',
|
38 |
+
'pqdeutauqc', 'frqfsucgao', 'pdufsewrec', 'bfdopzvxbi', 'shnsajrsow', 'rvvpazsffd', 'pxcfrszlgi',
|
39 |
+
'itfsvvmslp', 'ayipraspbn', 'prhmixykhr', 'doniqevxeg', 'dvtpwatuja', 'jiavqbrkyk', 'ipkpxvwroe',
|
40 |
+
'syxobtuucp', 'syuxttuyhm', 'nwvsbmyndn', 'eqslzbqfea', 'ytddugrwph', 'vokrpfjpeb', 'bdshuoldwx',
|
41 |
+
'fmvvmcbdrw', 'bnuwxhfahw', 'gbnzicjyhz', 'txnmkabufs', 'gfdjzwnpyp', 'hweshqpfwe', 'dxgnpnowgk',
|
42 |
+
'xugmhbetrw', 'rktrpsdlci', 'nthpnwylxo', 'ihglzxzroo', 'ocgdbrgmtq', 'ruhtnngrqv', 'xljemofssi',
|
43 |
+
'zxacihctqp', 'ghnpsltzyn', 'lbigytrrtr', 'ndikguxzek', 'mdfndlljvt', 'lyoslorecs', 'oefukgnvel',
|
44 |
+
'zmxeiipnqb', 'cosghhimnd', 'alrtntfxtd', 'eywdmustbb', 'ooafcxxfrs', 'fqgypsunzr', 'hevcclcklc',
|
45 |
+
'uhrqlmlclw', 'ipvwtgdlre', 'wcssbghcpc', 'didzujjhtg', 'fjxovgmwnm', 'dmmvuaikkv', 'hitfycdavv',
|
46 |
+
'zyufpqvpyu', 'coujjnypba', 'temeqbmzxu', 'apedduehoy', 'iksxzpqxzi', 'kwfdyqofzw', 'aassnaulhq',
|
47 |
+
'eyguqfmgzh', 'yiykshcbaz', 'sngjsueuhs', 'okgelildpc', 'ztyuiqrhdk', 'tvhjcfnqtg', 'gfgcwxkbjd',
|
48 |
+
'lbfqksftuo', 'kowiwvrjht', 'dkuqbduxev', 'mwnibuujwz', 'sodvtfqbpf', 'hsbwhlolsn', 'qsjiypnjwi',
|
49 |
+
'blszgmxkvu', 'ystdtnetgj', 'rfwxcinshk', 'vnlzxqwthl', 'ljouzjaqqe', 'gahgyuwzbu', 'xxzefxwyku',
|
50 |
+
'xitgdpzbxv', 'sylnrepacf', 'igpvrfjdzc', 'nxnmkytwze', 'psesikjaxx', 'dvwpvqdflx', 'bjyaxvggle',
|
51 |
+
'dpmgoiwhuf', 'wadvzjhwtw', 'kcjvhgvhpt', 'eppyqpgewp', 'tyjpjpglgx', 'cekarydqba', 'dvkdfhrpph',
|
52 |
+
'cnpanmywno', 'ljauauuyka', 'hicjuubiau', 'cqhwesrciw', 'dnmowthjcj', 'lujvyveojc', 'wndursivcx',
|
53 |
+
'espkiocpxq', 'jsbpkpxwew', 'dsnxgrfdmd', 'hyjqolupxn', 'xdezcezszc', 'axfhbpkdlc', 'qqnlrngaft',
|
54 |
+
'coqwgzpbhx', 'ncmpqwmnzb', 'sznkemeqro', 'omphqltjdd', 'uoccaiathd', 'jzmzdispyo', 'pxjkzvqomp',
|
55 |
+
'udxqbhgvvx', 'dzkyxbbqkr', 'dtozwcapoa', 'qswlzfgcgj', 'tgawasvbbr', 'lmdyicksrv', 'fzvpbrzssi',
|
56 |
+
'dxfdovivlw', 'zzmgnglanj', 'vssmlqoiti', 'vajkicalux', 'ekvwecwltj', 'ylxwcwhjjd', 'keioymnobc',
|
57 |
+
'usqqvxcjmg', 'phjvutxpoi', 'nycmyuzpml', 'bwdmzwhdnw', 'fxuxxtryjn', 'orixbcfvdz', 'hefisnapds',
|
58 |
+
'fpevfidstw', 'halvwiltfs', 'dzojiwfvba', 'ojsxxkalat', 'esjdyghhog', 'ptbnewtvon', 'hcanfkwivl',
|
59 |
+
'yronlutbgm', 'llplvmcvbl', 'yxirnfyijn', 'nwvloufjty', 'rtpbawlmxr', 'aayfryxljh', 'zfrrixsimm',
|
60 |
+
'txmnoyiyte'}
|
training/losses.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
from pytorch_toolbelt.losses import BinaryFocalLoss
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn.modules.loss import BCEWithLogitsLoss
|
6 |
+
|
7 |
+
|
8 |
+
class WeightedLosses(nn.Module):
|
9 |
+
def __init__(self, losses, weights):
|
10 |
+
super().__init__()
|
11 |
+
self.losses = losses
|
12 |
+
self.weights = weights
|
13 |
+
|
14 |
+
def forward(self, *input: Any, **kwargs: Any):
|
15 |
+
cum_loss = 0
|
16 |
+
for loss, w in zip(self.losses, self.weights):
|
17 |
+
cum_loss += w * loss.forward(*input, **kwargs)
|
18 |
+
return cum_loss
|
19 |
+
|
20 |
+
|
21 |
+
class BinaryCrossentropy(BCEWithLogitsLoss):
|
22 |
+
pass
|
23 |
+
|
24 |
+
|
25 |
+
class FocalLoss(BinaryFocalLoss):
|
26 |
+
def __init__(self, alpha=None, gamma=3, ignore_index=None, reduction="mean", normalized=False,
|
27 |
+
reduced_threshold=None):
|
28 |
+
super().__init__(alpha, gamma, ignore_index, reduction, normalized, reduced_threshold)
|
training/pipelines/__init__.py
ADDED
File without changes
|
training/pipelines/train_classifier.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
from sklearn.metrics import log_loss
|
7 |
+
from torch import topk
|
8 |
+
|
9 |
+
import sys
|
10 |
+
print('@@@@@@@@@@@@@@@@@@')
|
11 |
+
sys.path.append('..')
|
12 |
+
|
13 |
+
from training import losses
|
14 |
+
from training.datasets.classifier_dataset import DeepFakeClassifierDataset
|
15 |
+
from training.losses import WeightedLosses
|
16 |
+
from training.tools.config import load_config
|
17 |
+
from training.tools.utils import create_optimizer, AverageMeter
|
18 |
+
from training.transforms.albu import IsotropicResize
|
19 |
+
from training.zoo import classifiers
|
20 |
+
|
21 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
22 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
23 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
24 |
+
|
25 |
+
import cv2
|
26 |
+
|
27 |
+
cv2.ocl.setUseOpenCL(False)
|
28 |
+
cv2.setNumThreads(0)
|
29 |
+
import numpy as np
|
30 |
+
from albumentations import Compose, RandomBrightnessContrast, \
|
31 |
+
HorizontalFlip, FancyPCA, HueSaturationValue, OneOf, ToGray, \
|
32 |
+
ShiftScaleRotate, ImageCompression, PadIfNeeded, GaussNoise, GaussianBlur
|
33 |
+
|
34 |
+
from apex.parallel import DistributedDataParallel, convert_syncbn_model
|
35 |
+
from tensorboardX import SummaryWriter
|
36 |
+
|
37 |
+
from apex import amp
|
38 |
+
|
39 |
+
import torch
|
40 |
+
from torch.backends import cudnn
|
41 |
+
from torch.nn import DataParallel
|
42 |
+
from torch.utils.data import DataLoader
|
43 |
+
from tqdm import tqdm
|
44 |
+
import torch.distributed as dist
|
45 |
+
|
46 |
+
torch.backends.cudnn.benchmark = True
|
47 |
+
|
48 |
+
def create_train_transforms(size=300):
|
49 |
+
return Compose([
|
50 |
+
ImageCompression(quality_lower=60, quality_upper=100, p=0.5),
|
51 |
+
GaussNoise(p=0.1),
|
52 |
+
GaussianBlur(blur_limit=3, p=0.05),
|
53 |
+
HorizontalFlip(),
|
54 |
+
OneOf([
|
55 |
+
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
|
56 |
+
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR),
|
57 |
+
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR),
|
58 |
+
], p=1),
|
59 |
+
PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
|
60 |
+
OneOf([RandomBrightnessContrast(), FancyPCA(), HueSaturationValue()], p=0.7),
|
61 |
+
ToGray(p=0.2),
|
62 |
+
ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5),
|
63 |
+
]
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def create_val_transforms(size=300):
|
68 |
+
return Compose([
|
69 |
+
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
|
70 |
+
PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
|
71 |
+
])
|
72 |
+
|
73 |
+
|
74 |
+
def main():
|
75 |
+
parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
|
76 |
+
arg = parser.add_argument
|
77 |
+
arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
|
78 |
+
arg('--workers', type=int, default=6, help='number of cpu threads to use')
|
79 |
+
arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3')
|
80 |
+
arg('--output-dir', type=str, default='weights/')
|
81 |
+
arg('--resume', type=str, default='')
|
82 |
+
arg('--fold', type=int, default=0)
|
83 |
+
arg('--prefix', type=str, default='classifier_')
|
84 |
+
arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake")
|
85 |
+
arg('--folds-csv', type=str, default='folds.csv')
|
86 |
+
arg('--crops-dir', type=str, default='crops')
|
87 |
+
arg('--label-smoothing', type=float, default=0.01)
|
88 |
+
arg('--logdir', type=str, default='logs')
|
89 |
+
arg('--zero-score', action='store_true', default=False)
|
90 |
+
arg('--from-zero', action='store_true', default=False)
|
91 |
+
arg('--distributed', action='store_true', default=False)
|
92 |
+
arg('--freeze-epochs', type=int, default=0)
|
93 |
+
arg("--local_rank", default=0, type=int)
|
94 |
+
arg("--seed", default=777, type=int)
|
95 |
+
arg("--padding-part", default=3, type=int)
|
96 |
+
arg("--opt-level", default='O1', type=str)
|
97 |
+
arg("--test_every", type=int, default=1)
|
98 |
+
arg("--no-oversample", action="store_true")
|
99 |
+
arg("--no-hardcore", action="store_true")
|
100 |
+
arg("--only-changed-frames", action="store_true")
|
101 |
+
|
102 |
+
args = parser.parse_args()
|
103 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
104 |
+
if args.distributed:
|
105 |
+
torch.cuda.set_device(args.local_rank)
|
106 |
+
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
107 |
+
else:
|
108 |
+
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
109 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
110 |
+
|
111 |
+
cudnn.benchmark = True
|
112 |
+
|
113 |
+
conf = load_config(args.config)
|
114 |
+
model = classifiers.__dict__[conf['network']](encoder=conf['encoder'])
|
115 |
+
|
116 |
+
model = model.cuda()
|
117 |
+
if args.distributed:
|
118 |
+
model = convert_syncbn_model(model)
|
119 |
+
ohem = conf.get("ohem_samples", None)
|
120 |
+
reduction = "mean"
|
121 |
+
if ohem:
|
122 |
+
reduction = "none"
|
123 |
+
loss_fn = []
|
124 |
+
weights = []
|
125 |
+
for loss_name, weight in conf["losses"].items():
|
126 |
+
loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda())
|
127 |
+
weights.append(weight)
|
128 |
+
loss = WeightedLosses(loss_fn, weights)
|
129 |
+
loss_functions = {"classifier_loss": loss}
|
130 |
+
optimizer, scheduler = create_optimizer(conf['optimizer'], model)
|
131 |
+
bce_best = 100
|
132 |
+
start_epoch = 0
|
133 |
+
batch_size = conf['optimizer']['batch_size']
|
134 |
+
|
135 |
+
data_train = DeepFakeClassifierDataset(mode="train",
|
136 |
+
oversample_real=not args.no_oversample,
|
137 |
+
fold=args.fold,
|
138 |
+
padding_part=args.padding_part,
|
139 |
+
hardcore=not args.no_hardcore,
|
140 |
+
crops_dir=args.crops_dir,
|
141 |
+
data_path=args.data_dir,
|
142 |
+
label_smoothing=args.label_smoothing,
|
143 |
+
folds_csv=args.folds_csv,
|
144 |
+
transforms=create_train_transforms(conf["size"]),
|
145 |
+
normalize=conf.get("normalize", None))
|
146 |
+
data_val = DeepFakeClassifierDataset(mode="val",
|
147 |
+
fold=args.fold,
|
148 |
+
padding_part=args.padding_part,
|
149 |
+
crops_dir=args.crops_dir,
|
150 |
+
data_path=args.data_dir,
|
151 |
+
folds_csv=args.folds_csv,
|
152 |
+
transforms=create_val_transforms(conf["size"]),
|
153 |
+
normalize=conf.get("normalize", None))
|
154 |
+
val_data_loader = DataLoader(data_val, batch_size=batch_size * 2, num_workers=args.workers, shuffle=False,
|
155 |
+
pin_memory=False)
|
156 |
+
os.makedirs(args.logdir, exist_ok=True)
|
157 |
+
summary_writer = SummaryWriter(args.logdir + '/' + conf.get("prefix", args.prefix) + conf['encoder'] + "_" + str(args.fold))
|
158 |
+
if args.resume:
|
159 |
+
if os.path.isfile(args.resume):
|
160 |
+
print("=> loading checkpoint '{}'".format(args.resume))
|
161 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
162 |
+
state_dict = checkpoint['state_dict']
|
163 |
+
state_dict = {k[7:]: w for k, w in state_dict.items()}
|
164 |
+
model.load_state_dict(state_dict, strict=False)
|
165 |
+
if not args.from_zero:
|
166 |
+
start_epoch = checkpoint['epoch']
|
167 |
+
if not args.zero_score:
|
168 |
+
bce_best = checkpoint.get('bce_best', 0)
|
169 |
+
print("=> loaded checkpoint '{}' (epoch {}, bce_best {})"
|
170 |
+
.format(args.resume, checkpoint['epoch'], checkpoint['bce_best']))
|
171 |
+
else:
|
172 |
+
print("=> no checkpoint found at '{}'".format(args.resume))
|
173 |
+
if args.from_zero:
|
174 |
+
start_epoch = 0
|
175 |
+
current_epoch = start_epoch
|
176 |
+
|
177 |
+
if conf['fp16']:
|
178 |
+
model, optimizer = amp.initialize(model, optimizer,
|
179 |
+
opt_level=args.opt_level,
|
180 |
+
loss_scale='dynamic')
|
181 |
+
|
182 |
+
snapshot_name = "{}{}_{}_{}".format(conf.get("prefix", args.prefix), conf['network'], conf['encoder'], args.fold)
|
183 |
+
|
184 |
+
if args.distributed:
|
185 |
+
model = DistributedDataParallel(model, delay_allreduce=True)
|
186 |
+
else:
|
187 |
+
model = DataParallel(model).cuda()
|
188 |
+
data_val.reset(1, args.seed)
|
189 |
+
max_epochs = conf['optimizer']['schedule']['epochs']
|
190 |
+
for epoch in range(start_epoch, max_epochs):
|
191 |
+
data_train.reset(epoch, args.seed)
|
192 |
+
train_sampler = None
|
193 |
+
if args.distributed:
|
194 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(data_train)
|
195 |
+
train_sampler.set_epoch(epoch)
|
196 |
+
if epoch < args.freeze_epochs:
|
197 |
+
print("Freezing encoder!!!")
|
198 |
+
model.module.encoder.eval()
|
199 |
+
for p in model.module.encoder.parameters():
|
200 |
+
p.requires_grad = False
|
201 |
+
else:
|
202 |
+
model.module.encoder.train()
|
203 |
+
for p in model.module.encoder.parameters():
|
204 |
+
p.requires_grad = True
|
205 |
+
|
206 |
+
train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers,
|
207 |
+
shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False,
|
208 |
+
drop_last=True)
|
209 |
+
|
210 |
+
train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
|
211 |
+
args.local_rank, args.only_changed_frames)
|
212 |
+
model = model.eval()
|
213 |
+
|
214 |
+
if args.local_rank == 0:
|
215 |
+
torch.save({
|
216 |
+
'epoch': current_epoch + 1,
|
217 |
+
'state_dict': model.state_dict(),
|
218 |
+
'bce_best': bce_best,
|
219 |
+
}, args.output_dir + '/' + snapshot_name + "_last")
|
220 |
+
torch.save({
|
221 |
+
'epoch': current_epoch + 1,
|
222 |
+
'state_dict': model.state_dict(),
|
223 |
+
'bce_best': bce_best,
|
224 |
+
}, args.output_dir + snapshot_name + "_{}".format(current_epoch))
|
225 |
+
if (epoch + 1) % args.test_every == 0:
|
226 |
+
bce_best = evaluate_val(args, val_data_loader, bce_best, model,
|
227 |
+
snapshot_name=snapshot_name,
|
228 |
+
current_epoch=current_epoch,
|
229 |
+
summary_writer=summary_writer)
|
230 |
+
current_epoch += 1
|
231 |
+
|
232 |
+
|
233 |
+
def evaluate_val(args, data_val, bce_best, model, snapshot_name, current_epoch, summary_writer):
|
234 |
+
print("Test phase")
|
235 |
+
model = model.eval()
|
236 |
+
|
237 |
+
bce, probs, targets = validate(model, data_loader=data_val)
|
238 |
+
if args.local_rank == 0:
|
239 |
+
summary_writer.add_scalar('val/bce', float(bce), global_step=current_epoch)
|
240 |
+
if bce < bce_best:
|
241 |
+
print("Epoch {} improved from {} to {}".format(current_epoch, bce_best, bce))
|
242 |
+
if args.output_dir is not None:
|
243 |
+
torch.save({
|
244 |
+
'epoch': current_epoch + 1,
|
245 |
+
'state_dict': model.state_dict(),
|
246 |
+
'bce_best': bce,
|
247 |
+
}, args.output_dir + snapshot_name + "_best_dice")
|
248 |
+
bce_best = bce
|
249 |
+
with open("predictions_{}.json".format(args.fold), "w") as f:
|
250 |
+
json.dump({"probs": probs, "targets": targets}, f)
|
251 |
+
torch.save({
|
252 |
+
'epoch': current_epoch + 1,
|
253 |
+
'state_dict': model.state_dict(),
|
254 |
+
'bce_best': bce_best,
|
255 |
+
}, args.output_dir + snapshot_name + "_last")
|
256 |
+
print("Epoch: {} bce: {}, bce_best: {}".format(current_epoch, bce, bce_best))
|
257 |
+
return bce_best
|
258 |
+
|
259 |
+
|
260 |
+
def validate(net, data_loader, prefix=""):
|
261 |
+
probs = defaultdict(list)
|
262 |
+
targets = defaultdict(list)
|
263 |
+
|
264 |
+
with torch.no_grad():
|
265 |
+
for sample in tqdm(data_loader):
|
266 |
+
imgs = sample["image"].cuda()
|
267 |
+
img_names = sample["img_name"]
|
268 |
+
labels = sample["labels"].cuda().float()
|
269 |
+
out = net(imgs)
|
270 |
+
labels = labels.cpu().numpy()
|
271 |
+
preds = torch.sigmoid(out).cpu().numpy()
|
272 |
+
for i in range(out.shape[0]):
|
273 |
+
video, img_id = img_names[i].split("/")
|
274 |
+
probs[video].append(preds[i].tolist())
|
275 |
+
targets[video].append(labels[i].tolist())
|
276 |
+
data_x = []
|
277 |
+
data_y = []
|
278 |
+
for vid, score in probs.items():
|
279 |
+
score = np.array(score)
|
280 |
+
lbl = targets[vid]
|
281 |
+
|
282 |
+
score = np.mean(score)
|
283 |
+
lbl = np.mean(lbl)
|
284 |
+
data_x.append(score)
|
285 |
+
data_y.append(lbl)
|
286 |
+
y = np.array(data_y)
|
287 |
+
x = np.array(data_x)
|
288 |
+
fake_idx = y > 0.1
|
289 |
+
real_idx = y < 0.1
|
290 |
+
fake_loss = log_loss(y[fake_idx], x[fake_idx], labels=[0, 1])
|
291 |
+
real_loss = log_loss(y[real_idx], x[real_idx], labels=[0, 1])
|
292 |
+
print("{}fake_loss".format(prefix), fake_loss)
|
293 |
+
print("{}real_loss".format(prefix), real_loss)
|
294 |
+
|
295 |
+
return (fake_loss + real_loss) / 2, probs, targets
|
296 |
+
|
297 |
+
|
298 |
+
def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
|
299 |
+
local_rank, only_valid):
|
300 |
+
losses = AverageMeter()
|
301 |
+
fake_losses = AverageMeter()
|
302 |
+
real_losses = AverageMeter()
|
303 |
+
max_iters = conf["batches_per_epoch"]
|
304 |
+
print("training epoch {}".format(current_epoch))
|
305 |
+
model.train()
|
306 |
+
pbar = tqdm(enumerate(train_data_loader), total=max_iters, desc="Epoch {}".format(current_epoch), ncols=0)
|
307 |
+
if conf["optimizer"]["schedule"]["mode"] == "epoch":
|
308 |
+
scheduler.step(current_epoch)
|
309 |
+
for i, sample in pbar:
|
310 |
+
imgs = sample["image"].cuda()
|
311 |
+
labels = sample["labels"].cuda().float()
|
312 |
+
out_labels = model(imgs)
|
313 |
+
if only_valid:
|
314 |
+
valid_idx = sample["valid"].cuda().float() > 0
|
315 |
+
out_labels = out_labels[valid_idx]
|
316 |
+
labels = labels[valid_idx]
|
317 |
+
if labels.size(0) == 0:
|
318 |
+
continue
|
319 |
+
|
320 |
+
fake_loss = 0
|
321 |
+
real_loss = 0
|
322 |
+
fake_idx = labels > 0.5
|
323 |
+
real_idx = labels <= 0.5
|
324 |
+
|
325 |
+
ohem = conf.get("ohem_samples", None)
|
326 |
+
if torch.sum(fake_idx * 1) > 0:
|
327 |
+
fake_loss = loss_functions["classifier_loss"](out_labels[fake_idx], labels[fake_idx])
|
328 |
+
if torch.sum(real_idx * 1) > 0:
|
329 |
+
real_loss = loss_functions["classifier_loss"](out_labels[real_idx], labels[real_idx])
|
330 |
+
if ohem:
|
331 |
+
fake_loss = topk(fake_loss, k=min(ohem, fake_loss.size(0)), sorted=False)[0].mean()
|
332 |
+
real_loss = topk(real_loss, k=min(ohem, real_loss.size(0)), sorted=False)[0].mean()
|
333 |
+
|
334 |
+
loss = (fake_loss + real_loss) / 2
|
335 |
+
losses.update(loss.item(), imgs.size(0))
|
336 |
+
fake_losses.update(0 if fake_loss == 0 else fake_loss.item(), imgs.size(0))
|
337 |
+
real_losses.update(0 if real_loss == 0 else real_loss.item(), imgs.size(0))
|
338 |
+
|
339 |
+
optimizer.zero_grad()
|
340 |
+
pbar.set_postfix({"lr": float(scheduler.get_lr()[-1]), "epoch": current_epoch, "loss": losses.avg,
|
341 |
+
"fake_loss": fake_losses.avg, "real_loss": real_losses.avg})
|
342 |
+
|
343 |
+
if conf['fp16']:
|
344 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
345 |
+
scaled_loss.backward()
|
346 |
+
else:
|
347 |
+
loss.backward()
|
348 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
|
349 |
+
optimizer.step()
|
350 |
+
torch.cuda.synchronize()
|
351 |
+
if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"):
|
352 |
+
scheduler.step(i + current_epoch * max_iters)
|
353 |
+
if i == max_iters - 1:
|
354 |
+
break
|
355 |
+
pbar.close()
|
356 |
+
if local_rank == 0:
|
357 |
+
for idx, param_group in enumerate(optimizer.param_groups):
|
358 |
+
lr = param_group['lr']
|
359 |
+
summary_writer.add_scalar('group{}/lr'.format(idx), float(lr), global_step=current_epoch)
|
360 |
+
summary_writer.add_scalar('train/loss', float(losses.avg), global_step=current_epoch)
|
361 |
+
|
362 |
+
|
363 |
+
if __name__ == '__main__':
|
364 |
+
main()
|
training/tools/__init__.py
ADDED
File without changes
|
training/tools/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (154 Bytes). View file
|
|
training/tools/__pycache__/config.cpython-310.pyc
ADDED
Binary file (1.06 kB). View file
|
|
training/tools/__pycache__/schedulers.cpython-310.pyc
ADDED
Binary file (3.01 kB). View file
|
|
training/tools/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.65 kB). View file
|
|
training/tools/config.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
DEFAULTS = {
|
4 |
+
"network": "dpn",
|
5 |
+
"encoder": "dpn92",
|
6 |
+
"model_params": {},
|
7 |
+
"optimizer": {
|
8 |
+
"batch_size": 32,
|
9 |
+
"type": "SGD", # supported: SGD, Adam
|
10 |
+
"momentum": 0.9,
|
11 |
+
"weight_decay": 0,
|
12 |
+
"clip": 1.,
|
13 |
+
"learning_rate": 0.1,
|
14 |
+
"classifier_lr": -1,
|
15 |
+
"nesterov": True,
|
16 |
+
"schedule": {
|
17 |
+
"type": "constant", # supported: constant, step, multistep, exponential, linear, poly
|
18 |
+
"mode": "epoch", # supported: epoch, step
|
19 |
+
"epochs": 10,
|
20 |
+
"params": {}
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"normalize": {
|
24 |
+
"mean": [0.485, 0.456, 0.406],
|
25 |
+
"std": [0.229, 0.224, 0.225]
|
26 |
+
}
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
def _merge(src, dst):
|
31 |
+
for k, v in src.items():
|
32 |
+
if k in dst:
|
33 |
+
if isinstance(v, dict):
|
34 |
+
_merge(src[k], dst[k])
|
35 |
+
else:
|
36 |
+
dst[k] = v
|
37 |
+
|
38 |
+
|
39 |
+
def load_config(config_file, defaults=DEFAULTS):
|
40 |
+
with open(config_file, "r") as fd:
|
41 |
+
config = json.load(fd)
|
42 |
+
_merge(defaults, config)
|
43 |
+
return config
|
training/tools/schedulers.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bisect import bisect_right
|
2 |
+
|
3 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
4 |
+
|
5 |
+
|
6 |
+
class LRStepScheduler(_LRScheduler):
|
7 |
+
def __init__(self, optimizer, steps, last_epoch=-1):
|
8 |
+
self.lr_steps = steps
|
9 |
+
super().__init__(optimizer, last_epoch)
|
10 |
+
|
11 |
+
def get_lr(self):
|
12 |
+
pos = max(bisect_right([x for x, y in self.lr_steps], self.last_epoch) - 1, 0)
|
13 |
+
return [self.lr_steps[pos][1] if self.lr_steps[pos][0] <= self.last_epoch else base_lr for base_lr in self.base_lrs]
|
14 |
+
|
15 |
+
|
16 |
+
class PolyLR(_LRScheduler):
|
17 |
+
"""Sets the learning rate of each parameter group according to poly learning rate policy
|
18 |
+
"""
|
19 |
+
def __init__(self, optimizer, max_iter=90000, power=0.9, last_epoch=-1):
|
20 |
+
self.max_iter = max_iter
|
21 |
+
self.power = power
|
22 |
+
super(PolyLR, self).__init__(optimizer, last_epoch)
|
23 |
+
|
24 |
+
def get_lr(self):
|
25 |
+
self.last_epoch = (self.last_epoch + 1) % self.max_iter
|
26 |
+
return [base_lr * ((1 - float(self.last_epoch) / self.max_iter) ** (self.power)) for base_lr in self.base_lrs]
|
27 |
+
|
28 |
+
class ExponentialLRScheduler(_LRScheduler):
|
29 |
+
"""Decays the learning rate of each parameter group by gamma every epoch.
|
30 |
+
When last_epoch=-1, sets initial lr as lr.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
optimizer (Optimizer): Wrapped optimizer.
|
34 |
+
gamma (float): Multiplicative factor of learning rate decay.
|
35 |
+
last_epoch (int): The index of last epoch. Default: -1.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, optimizer, gamma, last_epoch=-1):
|
39 |
+
self.gamma = gamma
|
40 |
+
super(ExponentialLRScheduler, self).__init__(optimizer, last_epoch)
|
41 |
+
|
42 |
+
def get_lr(self):
|
43 |
+
if self.last_epoch <= 0:
|
44 |
+
return self.base_lrs
|
45 |
+
return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]
|
46 |
+
|
training/tools/utils.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
from apex.optimizers import FusedAdam, FusedSGD
|
3 |
+
from timm.optim import AdamW
|
4 |
+
from torch import optim
|
5 |
+
from torch.optim import lr_scheduler
|
6 |
+
from torch.optim.rmsprop import RMSprop
|
7 |
+
from torch.optim.adamw import AdamW
|
8 |
+
from torch.optim.lr_scheduler import MultiStepLR, CyclicLR
|
9 |
+
|
10 |
+
from training.tools.schedulers import ExponentialLRScheduler, PolyLR, LRStepScheduler
|
11 |
+
|
12 |
+
cv2.ocl.setUseOpenCL(False)
|
13 |
+
cv2.setNumThreads(0)
|
14 |
+
|
15 |
+
|
16 |
+
class AverageMeter(object):
|
17 |
+
"""Computes and stores the average and current value"""
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
self.reset()
|
21 |
+
|
22 |
+
def reset(self):
|
23 |
+
self.val = 0
|
24 |
+
self.avg = 0
|
25 |
+
self.sum = 0
|
26 |
+
self.count = 0
|
27 |
+
|
28 |
+
def update(self, val, n=1):
|
29 |
+
self.val = val
|
30 |
+
self.sum += val * n
|
31 |
+
self.count += n
|
32 |
+
self.avg = self.sum / self.count
|
33 |
+
|
34 |
+
def create_optimizer(optimizer_config, model, master_params=None):
|
35 |
+
"""Creates optimizer and schedule from configuration
|
36 |
+
|
37 |
+
Parameters
|
38 |
+
----------
|
39 |
+
optimizer_config : dict
|
40 |
+
Dictionary containing the configuration options for the optimizer.
|
41 |
+
model : Model
|
42 |
+
The network model.
|
43 |
+
|
44 |
+
Returns
|
45 |
+
-------
|
46 |
+
optimizer : Optimizer
|
47 |
+
The optimizer.
|
48 |
+
scheduler : LRScheduler
|
49 |
+
The learning rate scheduler.
|
50 |
+
"""
|
51 |
+
if optimizer_config.get("classifier_lr", -1) != -1:
|
52 |
+
# Separate classifier parameters from all others
|
53 |
+
net_params = []
|
54 |
+
classifier_params = []
|
55 |
+
for k, v in model.named_parameters():
|
56 |
+
if not v.requires_grad:
|
57 |
+
continue
|
58 |
+
if k.find("encoder") != -1:
|
59 |
+
net_params.append(v)
|
60 |
+
else:
|
61 |
+
classifier_params.append(v)
|
62 |
+
params = [
|
63 |
+
{"params": net_params},
|
64 |
+
{"params": classifier_params, "lr": optimizer_config["classifier_lr"]},
|
65 |
+
]
|
66 |
+
else:
|
67 |
+
if master_params:
|
68 |
+
params = master_params
|
69 |
+
else:
|
70 |
+
params = model.parameters()
|
71 |
+
|
72 |
+
if optimizer_config["type"] == "SGD":
|
73 |
+
optimizer = optim.SGD(params,
|
74 |
+
lr=optimizer_config["learning_rate"],
|
75 |
+
momentum=optimizer_config["momentum"],
|
76 |
+
weight_decay=optimizer_config["weight_decay"],
|
77 |
+
nesterov=optimizer_config["nesterov"])
|
78 |
+
elif optimizer_config["type"] == "FusedSGD":
|
79 |
+
optimizer = FusedSGD(params,
|
80 |
+
lr=optimizer_config["learning_rate"],
|
81 |
+
momentum=optimizer_config["momentum"],
|
82 |
+
weight_decay=optimizer_config["weight_decay"],
|
83 |
+
nesterov=optimizer_config["nesterov"])
|
84 |
+
elif optimizer_config["type"] == "Adam":
|
85 |
+
optimizer = optim.Adam(params,
|
86 |
+
lr=optimizer_config["learning_rate"],
|
87 |
+
weight_decay=optimizer_config["weight_decay"])
|
88 |
+
elif optimizer_config["type"] == "FusedAdam":
|
89 |
+
optimizer = FusedAdam(params,
|
90 |
+
lr=optimizer_config["learning_rate"],
|
91 |
+
weight_decay=optimizer_config["weight_decay"])
|
92 |
+
elif optimizer_config["type"] == "AdamW":
|
93 |
+
optimizer = AdamW(params,
|
94 |
+
lr=optimizer_config["learning_rate"],
|
95 |
+
weight_decay=optimizer_config["weight_decay"])
|
96 |
+
elif optimizer_config["type"] == "RmsProp":
|
97 |
+
optimizer = RMSprop(params,
|
98 |
+
lr=optimizer_config["learning_rate"],
|
99 |
+
weight_decay=optimizer_config["weight_decay"])
|
100 |
+
else:
|
101 |
+
raise KeyError("unrecognized optimizer {}".format(optimizer_config["type"]))
|
102 |
+
|
103 |
+
if optimizer_config["schedule"]["type"] == "step":
|
104 |
+
scheduler = LRStepScheduler(optimizer, **optimizer_config["schedule"]["params"])
|
105 |
+
elif optimizer_config["schedule"]["type"] == "clr":
|
106 |
+
scheduler = CyclicLR(optimizer, **optimizer_config["schedule"]["params"])
|
107 |
+
elif optimizer_config["schedule"]["type"] == "multistep":
|
108 |
+
scheduler = MultiStepLR(optimizer, **optimizer_config["schedule"]["params"])
|
109 |
+
elif optimizer_config["schedule"]["type"] == "exponential":
|
110 |
+
scheduler = ExponentialLRScheduler(optimizer, **optimizer_config["schedule"]["params"])
|
111 |
+
elif optimizer_config["schedule"]["type"] == "poly":
|
112 |
+
scheduler = PolyLR(optimizer, **optimizer_config["schedule"]["params"])
|
113 |
+
elif optimizer_config["schedule"]["type"] == "constant":
|
114 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 1.0)
|
115 |
+
elif optimizer_config["schedule"]["type"] == "linear":
|
116 |
+
def linear_lr(it):
|
117 |
+
return it * optimizer_config["schedule"]["params"]["alpha"] + optimizer_config["schedule"]["params"]["beta"]
|
118 |
+
|
119 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, linear_lr)
|
120 |
+
|
121 |
+
return optimizer, scheduler
|
training/transforms/__init__.py
ADDED
File without changes
|
training/transforms/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (159 Bytes). View file
|
|
training/transforms/__pycache__/albu.cpython-310.pyc
ADDED
Binary file (4.36 kB). View file
|
|
training/transforms/albu.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from albumentations import DualTransform, ImageOnlyTransform
|
6 |
+
from albumentations.augmentations.crops.functional import crop
|
7 |
+
#from albumentations.augmentations.functional import crop
|
8 |
+
|
9 |
+
|
10 |
+
def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
|
11 |
+
h, w = img.shape[:2]
|
12 |
+
if max(w, h) == size:
|
13 |
+
return img
|
14 |
+
if w > h:
|
15 |
+
scale = size / w
|
16 |
+
h = h * scale
|
17 |
+
w = size
|
18 |
+
else:
|
19 |
+
scale = size / h
|
20 |
+
w = w * scale
|
21 |
+
h = size
|
22 |
+
interpolation = interpolation_up if scale > 1 else interpolation_down
|
23 |
+
resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
|
24 |
+
return resized
|
25 |
+
|
26 |
+
|
27 |
+
class IsotropicResize(DualTransform):
|
28 |
+
def __init__(self, max_side, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC,
|
29 |
+
always_apply=False, p=1):
|
30 |
+
super(IsotropicResize, self).__init__(always_apply, p)
|
31 |
+
self.max_side = max_side
|
32 |
+
self.interpolation_down = interpolation_down
|
33 |
+
self.interpolation_up = interpolation_up
|
34 |
+
|
35 |
+
def apply(self, img, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, **params):
|
36 |
+
return isotropically_resize_image(img, size=self.max_side, interpolation_down=interpolation_down,
|
37 |
+
interpolation_up=interpolation_up)
|
38 |
+
|
39 |
+
def apply_to_mask(self, img, **params):
|
40 |
+
return self.apply(img, interpolation_down=cv2.INTER_NEAREST, interpolation_up=cv2.INTER_NEAREST, **params)
|
41 |
+
|
42 |
+
def get_transform_init_args_names(self):
|
43 |
+
return ("max_side", "interpolation_down", "interpolation_up")
|
44 |
+
|
45 |
+
|
46 |
+
class Resize4xAndBack(ImageOnlyTransform):
|
47 |
+
def __init__(self, always_apply=False, p=0.5):
|
48 |
+
super(Resize4xAndBack, self).__init__(always_apply, p)
|
49 |
+
|
50 |
+
def apply(self, img, **params):
|
51 |
+
h, w = img.shape[:2]
|
52 |
+
scale = random.choice([2, 4])
|
53 |
+
img = cv2.resize(img, (w // scale, h // scale), interpolation=cv2.INTER_AREA)
|
54 |
+
img = cv2.resize(img, (w, h),
|
55 |
+
interpolation=random.choice([cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_NEAREST]))
|
56 |
+
return img
|
57 |
+
|
58 |
+
|
59 |
+
class RandomSizedCropNonEmptyMaskIfExists(DualTransform):
|
60 |
+
|
61 |
+
def __init__(self, min_max_height, w2h_ratio=[0.7, 1.3], always_apply=False, p=0.5):
|
62 |
+
super(RandomSizedCropNonEmptyMaskIfExists, self).__init__(always_apply, p)
|
63 |
+
|
64 |
+
self.min_max_height = min_max_height
|
65 |
+
self.w2h_ratio = w2h_ratio
|
66 |
+
|
67 |
+
def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params):
|
68 |
+
cropped = crop(img, x_min, y_min, x_max, y_max)
|
69 |
+
return cropped
|
70 |
+
|
71 |
+
@property
|
72 |
+
def targets_as_params(self):
|
73 |
+
return ["mask"]
|
74 |
+
|
75 |
+
def get_params_dependent_on_targets(self, params):
|
76 |
+
mask = params["mask"]
|
77 |
+
mask_height, mask_width = mask.shape[:2]
|
78 |
+
crop_height = int(mask_height * random.uniform(self.min_max_height[0], self.min_max_height[1]))
|
79 |
+
w2h_ratio = random.uniform(*self.w2h_ratio)
|
80 |
+
crop_width = min(int(crop_height * w2h_ratio), mask_width - 1)
|
81 |
+
if mask.sum() == 0:
|
82 |
+
x_min = random.randint(0, mask_width - crop_width + 1)
|
83 |
+
y_min = random.randint(0, mask_height - crop_height + 1)
|
84 |
+
else:
|
85 |
+
mask = mask.sum(axis=-1) if mask.ndim == 3 else mask
|
86 |
+
non_zero_yx = np.argwhere(mask)
|
87 |
+
y, x = random.choice(non_zero_yx)
|
88 |
+
x_min = x - random.randint(0, crop_width - 1)
|
89 |
+
y_min = y - random.randint(0, crop_height - 1)
|
90 |
+
x_min = np.clip(x_min, 0, mask_width - crop_width)
|
91 |
+
y_min = np.clip(y_min, 0, mask_height - crop_height)
|
92 |
+
|
93 |
+
x_max = x_min + crop_height
|
94 |
+
y_max = y_min + crop_width
|
95 |
+
y_max = min(mask_height, y_max)
|
96 |
+
x_max = min(mask_width, x_max)
|
97 |
+
return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}
|
98 |
+
|
99 |
+
def get_transform_init_args_names(self):
|
100 |
+
return "min_max_height", "height", "width", "w2h_ratio"
|
training/zoo/__init__.py
ADDED
File without changes
|
training/zoo/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (152 Bytes). View file
|
|
training/zoo/__pycache__/classifiers.cpython-310.pyc
ADDED
Binary file (5.55 kB). View file
|
|
training/zoo/classifiers.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from timm.models.efficientnet import tf_efficientnet_b4_ns, tf_efficientnet_b3_ns, \
|
6 |
+
tf_efficientnet_b5_ns, tf_efficientnet_b2_ns, tf_efficientnet_b6_ns, tf_efficientnet_b7_ns
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn.modules.dropout import Dropout
|
9 |
+
from torch.nn.modules.linear import Linear
|
10 |
+
from torch.nn.modules.pooling import AdaptiveAvgPool2d
|
11 |
+
|
12 |
+
encoder_params = {
|
13 |
+
"tf_efficientnet_b3_ns": {
|
14 |
+
"features": 1536,
|
15 |
+
"init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
|
16 |
+
},
|
17 |
+
"tf_efficientnet_b2_ns": {
|
18 |
+
"features": 1408,
|
19 |
+
"init_op": partial(tf_efficientnet_b2_ns, pretrained=False, drop_path_rate=0.2)
|
20 |
+
},
|
21 |
+
"tf_efficientnet_b4_ns": {
|
22 |
+
"features": 1792,
|
23 |
+
"init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.5)
|
24 |
+
},
|
25 |
+
"tf_efficientnet_b5_ns": {
|
26 |
+
"features": 2048,
|
27 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
|
28 |
+
},
|
29 |
+
"tf_efficientnet_b4_ns_03d": {
|
30 |
+
"features": 1792,
|
31 |
+
"init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.3)
|
32 |
+
},
|
33 |
+
"tf_efficientnet_b5_ns_03d": {
|
34 |
+
"features": 2048,
|
35 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.3)
|
36 |
+
},
|
37 |
+
"tf_efficientnet_b5_ns_04d": {
|
38 |
+
"features": 2048,
|
39 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.4)
|
40 |
+
},
|
41 |
+
"tf_efficientnet_b6_ns": {
|
42 |
+
"features": 2304,
|
43 |
+
"init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.2)
|
44 |
+
},
|
45 |
+
"tf_efficientnet_b7_ns": {
|
46 |
+
"features": 2560,
|
47 |
+
"init_op": partial(tf_efficientnet_b7_ns, pretrained=True, drop_path_rate=0.2)
|
48 |
+
},
|
49 |
+
"tf_efficientnet_b6_ns_04d": {
|
50 |
+
"features": 2304,
|
51 |
+
"init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.4)
|
52 |
+
},
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
def setup_srm_weights(input_channels: int = 3) -> torch.Tensor:
|
57 |
+
"""Creates the SRM kernels for noise analysis."""
|
58 |
+
# note: values taken from Zhou et al., "Learning Rich Features for Image Manipulation Detection", CVPR2018
|
59 |
+
srm_kernel = torch.from_numpy(np.array([
|
60 |
+
[ # srm 1/2 horiz
|
61 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
62 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
63 |
+
[0., 1., -2., 1., 0.], # noqa: E241,E201
|
64 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
65 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
66 |
+
], [ # srm 1/4
|
67 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
68 |
+
[0., -1., 2., -1., 0.], # noqa: E241,E201
|
69 |
+
[0., 2., -4., 2., 0.], # noqa: E241,E201
|
70 |
+
[0., -1., 2., -1., 0.], # noqa: E241,E201
|
71 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
72 |
+
], [ # srm 1/12
|
73 |
+
[-1., 2., -2., 2., -1.], # noqa: E241,E201
|
74 |
+
[2., -6., 8., -6., 2.], # noqa: E241,E201
|
75 |
+
[-2., 8., -12., 8., -2.], # noqa: E241,E201
|
76 |
+
[2., -6., 8., -6., 2.], # noqa: E241,E201
|
77 |
+
[-1., 2., -2., 2., -1.], # noqa: E241,E201
|
78 |
+
]
|
79 |
+
])).float()
|
80 |
+
srm_kernel[0] /= 2
|
81 |
+
srm_kernel[1] /= 4
|
82 |
+
srm_kernel[2] /= 12
|
83 |
+
return srm_kernel.view(3, 1, 5, 5).repeat(1, input_channels, 1, 1)
|
84 |
+
|
85 |
+
|
86 |
+
def setup_srm_layer(input_channels: int = 3) -> torch.nn.Module:
|
87 |
+
"""Creates a SRM convolution layer for noise analysis."""
|
88 |
+
weights = setup_srm_weights(input_channels)
|
89 |
+
conv = torch.nn.Conv2d(input_channels, out_channels=3, kernel_size=5, stride=1, padding=2, bias=False)
|
90 |
+
with torch.no_grad():
|
91 |
+
conv.weight = torch.nn.Parameter(weights, requires_grad=False)
|
92 |
+
return conv
|
93 |
+
|
94 |
+
|
95 |
+
class DeepFakeClassifierSRM(nn.Module):
|
96 |
+
def __init__(self, encoder, dropout_rate=0.5) -> None:
|
97 |
+
super().__init__()
|
98 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
99 |
+
self.avg_pool = AdaptiveAvgPool2d((1, 1))
|
100 |
+
self.srm_conv = setup_srm_layer(3)
|
101 |
+
self.dropout = Dropout(dropout_rate)
|
102 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
noise = self.srm_conv(x)
|
106 |
+
x = self.encoder.forward_features(noise)
|
107 |
+
x = self.avg_pool(x).flatten(1)
|
108 |
+
x = self.dropout(x)
|
109 |
+
x = self.fc(x)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class GlobalWeightedAvgPool2d(nn.Module):
|
114 |
+
"""
|
115 |
+
Global Weighted Average Pooling from paper "Global Weighted Average
|
116 |
+
Pooling Bridges Pixel-level Localization and Image-level Classification"
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, features: int, flatten=False):
|
120 |
+
super().__init__()
|
121 |
+
self.conv = nn.Conv2d(features, 1, kernel_size=1, bias=True)
|
122 |
+
self.flatten = flatten
|
123 |
+
|
124 |
+
def fscore(self, x):
|
125 |
+
m = self.conv(x)
|
126 |
+
m = m.sigmoid().exp()
|
127 |
+
return m
|
128 |
+
|
129 |
+
def norm(self, x: torch.Tensor):
|
130 |
+
return x / x.sum(dim=[2, 3], keepdim=True)
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
input_x = x
|
134 |
+
x = self.fscore(x)
|
135 |
+
x = self.norm(x)
|
136 |
+
x = x * input_x
|
137 |
+
x = x.sum(dim=[2, 3], keepdim=not self.flatten)
|
138 |
+
return x
|
139 |
+
|
140 |
+
|
141 |
+
class DeepFakeClassifier(nn.Module):
|
142 |
+
def __init__(self, encoder, dropout_rate=0.0) -> None:
|
143 |
+
super().__init__()
|
144 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
145 |
+
self.avg_pool = AdaptiveAvgPool2d((1, 1))
|
146 |
+
self.dropout = Dropout(dropout_rate)
|
147 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
x = self.encoder.forward_features(x)
|
151 |
+
x = self.avg_pool(x).flatten(1)
|
152 |
+
x = self.dropout(x)
|
153 |
+
x = self.fc(x)
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
class DeepFakeClassifierGWAP(nn.Module):
|
160 |
+
def __init__(self, encoder, dropout_rate=0.5) -> None:
|
161 |
+
super().__init__()
|
162 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
163 |
+
self.avg_pool = GlobalWeightedAvgPool2d(encoder_params[encoder]["features"])
|
164 |
+
self.dropout = Dropout(dropout_rate)
|
165 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
x = self.encoder.forward_features(x)
|
169 |
+
x = self.avg_pool(x).flatten(1)
|
170 |
+
x = self.dropout(x)
|
171 |
+
x = self.fc(x)
|
172 |
+
return x
|
training/zoo/unet.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from timm.models.efficientnet import tf_efficientnet_b3_ns, tf_efficientnet_b5_ns
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import Dropout2d, Conv2d
|
7 |
+
from torch.nn.modules.dropout import Dropout
|
8 |
+
from torch.nn.modules.linear import Linear
|
9 |
+
from torch.nn.modules.pooling import AdaptiveAvgPool2d
|
10 |
+
from torch.nn.modules.upsampling import UpsamplingBilinear2d
|
11 |
+
|
12 |
+
encoder_params = {
|
13 |
+
"tf_efficientnet_b3_ns": {
|
14 |
+
"features": 1536,
|
15 |
+
"filters": [40, 32, 48, 136, 1536],
|
16 |
+
"decoder_filters": [64, 128, 256, 256],
|
17 |
+
"init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
|
18 |
+
},
|
19 |
+
"tf_efficientnet_b5_ns": {
|
20 |
+
"features": 2048,
|
21 |
+
"filters": [48, 40, 64, 176, 2048],
|
22 |
+
"decoder_filters": [64, 128, 256, 256],
|
23 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
|
24 |
+
},
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
class DecoderBlock(nn.Module):
|
29 |
+
def __init__(self, in_channels, out_channels):
|
30 |
+
super().__init__()
|
31 |
+
self.layer = nn.Sequential(
|
32 |
+
nn.Upsample(scale_factor=2),
|
33 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
34 |
+
nn.ReLU(inplace=True)
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.layer(x)
|
39 |
+
|
40 |
+
|
41 |
+
class ConcatBottleneck(nn.Module):
|
42 |
+
def __init__(self, in_channels, out_channels):
|
43 |
+
super().__init__()
|
44 |
+
self.seq = nn.Sequential(
|
45 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
46 |
+
nn.ReLU(inplace=True)
|
47 |
+
)
|
48 |
+
|
49 |
+
def forward(self, dec, enc):
|
50 |
+
x = torch.cat([dec, enc], dim=1)
|
51 |
+
return self.seq(x)
|
52 |
+
|
53 |
+
|
54 |
+
class Decoder(nn.Module):
|
55 |
+
def __init__(self, decoder_filters, filters, upsample_filters=None,
|
56 |
+
decoder_block=DecoderBlock, bottleneck=ConcatBottleneck, dropout=0):
|
57 |
+
super().__init__()
|
58 |
+
self.decoder_filters = decoder_filters
|
59 |
+
self.filters = filters
|
60 |
+
self.decoder_block = decoder_block
|
61 |
+
self.decoder_stages = nn.ModuleList([self._get_decoder(idx) for idx in range(0, len(decoder_filters))])
|
62 |
+
self.bottlenecks = nn.ModuleList([bottleneck(self.filters[-i - 2] + f, f)
|
63 |
+
for i, f in enumerate(reversed(decoder_filters))])
|
64 |
+
self.dropout = Dropout2d(dropout) if dropout > 0 else None
|
65 |
+
self.last_block = None
|
66 |
+
if upsample_filters:
|
67 |
+
self.last_block = decoder_block(decoder_filters[0], out_channels=upsample_filters)
|
68 |
+
else:
|
69 |
+
self.last_block = UpsamplingBilinear2d(scale_factor=2)
|
70 |
+
|
71 |
+
def forward(self, encoder_results: list):
|
72 |
+
x = encoder_results[0]
|
73 |
+
bottlenecks = self.bottlenecks
|
74 |
+
for idx, bottleneck in enumerate(bottlenecks):
|
75 |
+
rev_idx = - (idx + 1)
|
76 |
+
x = self.decoder_stages[rev_idx](x)
|
77 |
+
x = bottleneck(x, encoder_results[-rev_idx])
|
78 |
+
if self.last_block:
|
79 |
+
x = self.last_block(x)
|
80 |
+
if self.dropout:
|
81 |
+
x = self.dropout(x)
|
82 |
+
return x
|
83 |
+
|
84 |
+
def _get_decoder(self, layer):
|
85 |
+
idx = layer + 1
|
86 |
+
if idx == len(self.decoder_filters):
|
87 |
+
in_channels = self.filters[idx]
|
88 |
+
else:
|
89 |
+
in_channels = self.decoder_filters[idx]
|
90 |
+
return self.decoder_block(in_channels, self.decoder_filters[max(layer, 0)])
|
91 |
+
|
92 |
+
|
93 |
+
def _initialize_weights(module):
|
94 |
+
for m in module.modules():
|
95 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
|
96 |
+
m.weight.data = nn.init.kaiming_normal_(m.weight.data)
|
97 |
+
if m.bias is not None:
|
98 |
+
m.bias.data.zero_()
|
99 |
+
elif isinstance(m, nn.BatchNorm2d):
|
100 |
+
m.weight.data.fill_(1)
|
101 |
+
m.bias.data.zero_()
|
102 |
+
|
103 |
+
|
104 |
+
class EfficientUnetClassifier(nn.Module):
|
105 |
+
def __init__(self, encoder, dropout_rate=0.5) -> None:
|
106 |
+
super().__init__()
|
107 |
+
self.decoder = Decoder(decoder_filters=encoder_params[encoder]["decoder_filters"],
|
108 |
+
filters=encoder_params[encoder]["filters"])
|
109 |
+
self.avg_pool = AdaptiveAvgPool2d((1, 1))
|
110 |
+
self.dropout = Dropout(dropout_rate)
|
111 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
112 |
+
self.final = Conv2d(encoder_params[encoder]["decoder_filters"][0], out_channels=1, kernel_size=1, bias=False)
|
113 |
+
_initialize_weights(self)
|
114 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
115 |
+
|
116 |
+
def get_encoder_features(self, x):
|
117 |
+
encoder_results = []
|
118 |
+
x = self.encoder.conv_stem(x)
|
119 |
+
x = self.encoder.bn1(x)
|
120 |
+
x = self.encoder.act1(x)
|
121 |
+
encoder_results.append(x)
|
122 |
+
x = self.encoder.blocks[:2](x)
|
123 |
+
encoder_results.append(x)
|
124 |
+
x = self.encoder.blocks[2:3](x)
|
125 |
+
encoder_results.append(x)
|
126 |
+
x = self.encoder.blocks[3:5](x)
|
127 |
+
encoder_results.append(x)
|
128 |
+
x = self.encoder.blocks[5:](x)
|
129 |
+
x = self.encoder.conv_head(x)
|
130 |
+
x = self.encoder.bn2(x)
|
131 |
+
x = self.encoder.act2(x)
|
132 |
+
encoder_results.append(x)
|
133 |
+
encoder_results = list(reversed(encoder_results))
|
134 |
+
return encoder_results
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
encoder_results = self.get_encoder_features(x)
|
138 |
+
seg = self.final(self.decoder(encoder_results))
|
139 |
+
x = encoder_results[0]
|
140 |
+
x = self.avg_pool(x).flatten(1)
|
141 |
+
x = self.dropout(x)
|
142 |
+
x = self.fc(x)
|
143 |
+
return x, seg
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == '__main__':
|
147 |
+
model = EfficientUnetClassifier("tf_efficientnet_b5_ns")
|
148 |
+
model.eval()
|
149 |
+
with torch.no_grad():
|
150 |
+
input = torch.rand(4, 3, 224, 224)
|
151 |
+
print(model(input))
|
weights/.gitkeep
ADDED
File without changes
|
weights/b7_ns_best.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9db77ab9318863e2f8ab287c8eb83c2232584b82dc2fb41f1d614ddd7900cccb
|
3 |
+
size 266910617
|