thecho7 commited on
Commit
c426e13
1 Parent(s): 86f5435
Files changed (47) hide show
  1. .gitattributes +2 -0
  2. Dockerfile +54 -0
  3. LICENSE +21 -0
  4. README.md +171 -13
  5. __pycache__/kernel_utils.cpython-310.pyc +0 -0
  6. app.py +86 -0
  7. configs/b5.json +28 -0
  8. configs/b7.json +29 -0
  9. download_weights.sh +9 -0
  10. examples/liuujwwgpr.mp4 +3 -0
  11. examples/nlurbvsozt.mp4 +3 -0
  12. examples/rfjuhbnlro.mp4 +3 -0
  13. kernel_utils.py +365 -0
  14. libs/shape_predictor_68_face_landmarks.dat +3 -0
  15. training/__init__.py +0 -0
  16. training/__pycache__/__init__.cpython-310.pyc +0 -0
  17. training/__pycache__/__init__.cpython-39.pyc +0 -0
  18. training/__pycache__/losses.cpython-310.pyc +0 -0
  19. training/__pycache__/losses.cpython-39.pyc +0 -0
  20. training/datasets/__init__.py +0 -0
  21. training/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  22. training/datasets/__pycache__/classifier_dataset.cpython-310.pyc +0 -0
  23. training/datasets/__pycache__/validation_set.cpython-310.pyc +0 -0
  24. training/datasets/classifier_dataset.py +384 -0
  25. training/datasets/validation_set.py +60 -0
  26. training/losses.py +28 -0
  27. training/pipelines/__init__.py +0 -0
  28. training/pipelines/train_classifier.py +364 -0
  29. training/tools/__init__.py +0 -0
  30. training/tools/__pycache__/__init__.cpython-310.pyc +0 -0
  31. training/tools/__pycache__/config.cpython-310.pyc +0 -0
  32. training/tools/__pycache__/schedulers.cpython-310.pyc +0 -0
  33. training/tools/__pycache__/utils.cpython-310.pyc +0 -0
  34. training/tools/config.py +43 -0
  35. training/tools/schedulers.py +46 -0
  36. training/tools/utils.py +121 -0
  37. training/transforms/__init__.py +0 -0
  38. training/transforms/__pycache__/__init__.cpython-310.pyc +0 -0
  39. training/transforms/__pycache__/albu.cpython-310.pyc +0 -0
  40. training/transforms/albu.py +100 -0
  41. training/zoo/__init__.py +0 -0
  42. training/zoo/__pycache__/__init__.cpython-310.pyc +0 -0
  43. training/zoo/__pycache__/classifiers.cpython-310.pyc +0 -0
  44. training/zoo/classifiers.py +172 -0
  45. training/zoo/unet.py +151 -0
  46. weights/.gitkeep +0 -0
  47. weights/b7_ns_best.pth +3 -0
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
36
+ *.dat filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG PYTORCH="1.10.0"
2
+ ARG CUDA="11.3"
3
+ ARG CUDNN="8"
4
+
5
+ FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
6
+
7
+ ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
8
+ ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"
9
+
10
+ # Setting noninteractive build, setting up tzdata and configuring timezones
11
+ ENV DEBIAN_FRONTEND=noninteractive
12
+ ENV TZ=Europe/Berlin
13
+ RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
14
+
15
+ RUN apt-get update && apt-get install -y libglib2.0-0 libsm6 libxrender-dev libxext6 nano mc glances vim git \
16
+ && apt-get clean \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ # Install cython
20
+ RUN conda install cython -y && conda clean --all
21
+
22
+ # Installing APEX
23
+ RUN pip install -U pip
24
+ RUN git clone https://github.com/NVIDIA/apex
25
+ RUN sed -i 's/check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)/pass/g' apex/setup.py
26
+ RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./apex
27
+ RUN apt-get update -y
28
+ RUN apt-get install build-essential cmake -y
29
+ RUN apt-get install libopenblas-dev liblapack-dev -y
30
+ RUN apt-get install libx11-dev libgtk-3-dev -y
31
+ RUN pip install dlib
32
+ RUN pip install facenet-pytorch
33
+ RUN pip install albumentations==1.0.0 timm==0.4.12 pytorch_toolbelt tensorboardx
34
+ RUN pip install cython jupyter jupyterlab ipykernel matplotlib tqdm pandas
35
+
36
+ # download pretraned Imagenet models
37
+ RUN apt install wget
38
+ RUN wget https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth -P /root/.cache/torch/hub/checkpoints/
39
+ RUN wget https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth -P /root/.cache/torch/hub/checkpoints/
40
+
41
+ # Setting the working directory
42
+ WORKDIR /workspace
43
+
44
+ # Copying the required codebase
45
+ COPY . /workspace
46
+
47
+ RUN chmod 777 preprocess_data.sh
48
+ RUN chmod 777 train.sh
49
+ RUN chmod 777 predict_submission.sh
50
+
51
+ ENV PYTHONPATH=.
52
+
53
+ CMD ["/bin/bash"]
54
+
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Selim Seferbekov
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,171 @@
1
- ---
2
- title: Deepfake
3
- emoji: 🔥
4
- colorFrom: indigo
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 3.29.0
8
- app_file: app.py
9
- pinned: false
10
- license: unlicense
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## DeepFake Detection (DFDC) Solution by @selimsef
2
+
3
+ ## Challenge details:
4
+
5
+ [Kaggle Challenge Page](https://www.kaggle.com/c/deepfake-detection-challenge)
6
+
7
+
8
+ ### Fake detection articles
9
+ - [The Deepfake Detection Challenge (DFDC) Preview Dataset](https://arxiv.org/abs/1910.08854)
10
+ - [Deep Fake Image Detection Based on Pairwise Learning](https://www.mdpi.com/2076-3417/10/1/370)
11
+ - [DeeperForensics-1.0: A Large-Scale Dataset for Real-World Face Forgery Detection](https://arxiv.org/abs/2001.03024)
12
+ - [DeepFakes and Beyond: A Survey of Face Manipulation and Fake Detection](https://arxiv.org/abs/2001.00179)
13
+ - [Real or Fake? Spoofing State-Of-The-Art Face Synthesis Detection Systems](https://arxiv.org/abs/1911.05351)
14
+ - [CNN-generated images are surprisingly easy to spot... for now](https://arxiv.org/abs/1912.11035)
15
+ - [FakeSpotter: A Simple yet Robust Baseline for Spotting AI-Synthesized Fake Faces](https://arxiv.org/abs/1909.06122)
16
+ - [FakeLocator: Robust Localization of GAN-Based Face Manipulations via Semantic Segmentation Networks with Bells and Whistles](https://arxiv.org/abs/2001.09598)
17
+ - [Media Forensics and DeepFakes: an overview](https://arxiv.org/abs/2001.06564)
18
+ - [Face X-ray for More General Face Forgery Detection](https://arxiv.org/abs/1912.13458)
19
+
20
+ ## Solution description
21
+ In general solution is based on frame-by-frame classification approach. Other complex things did not work so well on public leaderboard.
22
+
23
+ #### Face-Detector
24
+ MTCNN detector is chosen due to kernel time limits. It would be better to use S3FD detector as more precise and robust, but opensource Pytorch implementations don't have a license.
25
+
26
+ Input size for face detector was calculated for each video depending on video resolution.
27
+
28
+ - 2x scale for videos with less than 300 pixels wider side
29
+ - no rescale for videos with wider side between 300 and 1000
30
+ - 0.5x scale for videos with wider side > 1000 pixels
31
+ - 0.33x scale for videos with wider side > 1900 pixels
32
+
33
+ ### Input size
34
+ As soon as I discovered that EfficientNets significantly outperform other encoders I used only them in my solution.
35
+ As I started with B4 I decided to use "native" size for that network (380x380).
36
+ Due to memory costraints I did not increase input size even for B7 encoder.
37
+
38
+ ### Margin
39
+ When I generated crops for training I added 30% of face crop size from each side and used only this setting during the competition.
40
+ See [extract_crops.py](preprocessing/extract_crops.py) for the details
41
+
42
+ ### Encoders
43
+ The winning encoder is current state-of-the-art model (EfficientNet B7) pretrained with ImageNet and noisy student [Self-training with Noisy Student improves ImageNet classification
44
+ ](https://arxiv.org/abs/1911.04252)
45
+
46
+ ### Averaging predictions
47
+ I used 32 frames for each video.
48
+ For each model output instead of simple averaging I used the following heuristic which worked quite well on public leaderbord (0.25 -> 0.22 solo B5).
49
+ ```python
50
+ import numpy as np
51
+
52
+ def confident_strategy(pred, t=0.8):
53
+ pred = np.array(pred)
54
+ sz = len(pred)
55
+ fakes = np.count_nonzero(pred > t)
56
+ # 11 frames are detected as fakes with high probability
57
+ if fakes > sz // 2.5 and fakes > 11:
58
+ return np.mean(pred[pred > t])
59
+ elif np.count_nonzero(pred < 0.2) > 0.9 * sz:
60
+ return np.mean(pred[pred < 0.2])
61
+ else:
62
+ return np.mean(pred)
63
+ ```
64
+
65
+ ### Augmentations
66
+
67
+ I used heavy augmentations by default.
68
+ [Albumentations](https://github.com/albumentations-team/albumentations) library supports most of the augmentations out of the box. Only needed to add IsotropicResize augmentation.
69
+ ```
70
+
71
+ def create_train_transforms(size=300):
72
+ return Compose([
73
+ ImageCompression(quality_lower=60, quality_upper=100, p=0.5),
74
+ GaussNoise(p=0.1),
75
+ GaussianBlur(blur_limit=3, p=0.05),
76
+ HorizontalFlip(),
77
+ OneOf([
78
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
79
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR),
80
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR),
81
+ ], p=1),
82
+ PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
83
+ OneOf([RandomBrightnessContrast(), FancyPCA(), HueSaturationValue()], p=0.7),
84
+ ToGray(p=0.2),
85
+ ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5),
86
+ ]
87
+ )
88
+ ```
89
+ In addition to these augmentations I wanted to achieve better generalization with
90
+ - Cutout like augmentations (dropping artefacts and parts of face)
91
+ - Dropout part of the image, inspired by [GridMask](https://arxiv.org/abs/2001.04086) and [Severstal Winning Solution](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114254)
92
+
93
+ ![augmentations](images/augmentations.jpg "Dropout augmentations")
94
+
95
+ ## Building docker image
96
+ All libraries and enviroment is already configured with Dockerfile. It requires docker engine https://docs.docker.com/engine/install/ubuntu/ and nvidia docker in your system https://github.com/NVIDIA/nvidia-docker.
97
+
98
+ To build a docker image run `docker build -t df .`
99
+
100
+ ## Running docker
101
+ `docker run --runtime=nvidia --ipc=host --rm --volume <DATA_ROOT>:/dataset -it df`
102
+
103
+ ## Data preparation
104
+
105
+ Once DFDC dataset is downloaded all the scripts expect to have `dfdc_train_xxx` folders under data root directory.
106
+
107
+ Preprocessing is done in a single script **`preprocess_data.sh`** which requires dataset directory as first argument.
108
+ It will execute the steps below:
109
+
110
+ ##### 1. Find face bboxes
111
+ To extract face bboxes I used facenet library, basically only MTCNN.
112
+ `python preprocessing/detect_original_faces.py --root-dir DATA_ROOT`
113
+ This script will detect faces in real videos and store them as jsons in DATA_ROOT/bboxes directory
114
+
115
+ ##### 2. Extract crops from videos
116
+ To extract image crops I used bboxes saved before. It will use bounding boxes from original videos for face videos as well.
117
+ `python preprocessing/extract_crops.py --root-dir DATA_ROOT --crops-dir crops`
118
+ This script will extract face crops from videos and save them in DATA_ROOT/crops directory
119
+
120
+ ##### 3. Generate landmarks
121
+ From the saved crops it is quite fast to process crops with MTCNN and extract landmarks
122
+ `python preprocessing/generate_landmarks.py --root-dir DATA_ROOT`
123
+ This script will extract landmarks and save them in DATA_ROOT/landmarks directory
124
+
125
+ ##### 4. Generate diff SSIM masks
126
+ `python preprocessing/generate_diffs.py --root-dir DATA_ROOT`
127
+ This script will extract SSIM difference masks between real and fake images and save them in DATA_ROOT/diffs directory
128
+
129
+ ##### 5. Generate folds
130
+ `python preprocessing/generate_folds.py --root-dir DATA_ROOT --out folds.csv`
131
+ By default it will use 16 splits to have 0-2 folders as a holdout set. Though only 400 videos can be used for validation as well.
132
+
133
+
134
+ ## Training
135
+
136
+ Training 5 B7 models with different seeds is done in **`train.sh`** script.
137
+
138
+ During training checkpoints are saved for every epoch.
139
+
140
+ ## Hardware requirements
141
+ Mostly trained on devbox configuration with 4xTitan V, thanks to Nvidia and DSB2018 competition where I got these gpus https://www.kaggle.com/c/data-science-bowl-2018/
142
+
143
+ Overall training requires 4 GPUs with 12gb+ memory.
144
+ Batch size needs to be adjusted for standard 1080Ti or 2080Ti graphic cards.
145
+
146
+ As I computed fake loss and real loss separately inside each batch, results might be better with larger batch size, for example on V100 gpus.
147
+ Even though SyncBN is used larger batch on each GPU will lead to less noise as DFDC dataset has some fakes where face detector failed and face crops are not really fakes.
148
+
149
+ ## Plotting losses to select checkpoints
150
+
151
+ `python plot_loss.py --log-file logs/<log file>`
152
+
153
+ ![loss plot](images/loss_plot.png "Weighted loss")
154
+
155
+ ## Inference
156
+
157
+
158
+ Kernel is reproduced with `predict_folder.py` script.
159
+
160
+
161
+ ## Pretrained models
162
+ `download_weights.sh` script will download trained models to `weights/` folder. They should be downloaded before building a docker image.
163
+
164
+ Ensemble inference is already preconfigured with `predict_submission.sh` bash script. It expects a directory with videos as first argument and an output csv file as second argument.
165
+
166
+ For example `./predict_submission.sh /mnt/datasets/deepfake/test_videos submission.csv`
167
+
168
+
169
+
170
+
171
+
__pycache__/kernel_utils.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+ import time
5
+
6
+ import torch
7
+ from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video
8
+ from training.zoo.classifiers import DeepFakeClassifier
9
+
10
+ import gradio as gr
11
+
12
+ def model_fn(model_dir):
13
+ model_path = os.path.join(model_dir, 'b7_ns_best.pth')
14
+ model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns") # default: CPU
15
+ checkpoint = torch.load(model_path, map_location="cpu")
16
+ state_dict = checkpoint.get("state_dict", checkpoint)
17
+ model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
18
+ model.eval()
19
+ del checkpoint
20
+ #models.append(model.half())
21
+
22
+ return model
23
+
24
+ def convert_result(pred, class_names=["Real", "Fake"]):
25
+ preds = [pred, 1 - pred]
26
+ assert len(class_names) == len(preds), "Class / Prediction should have the same length"
27
+ return {n: p for n, p in zip(class_names, preds)}
28
+
29
+ def predict_fn(model, video, meta):
30
+ start = time.time()
31
+ prediction = predict_on_video(face_extractor=meta["face_extractor"],
32
+ video_path=video,
33
+ batch_size=meta["fps"],
34
+ input_size=meta["input_size"],
35
+ models=model,
36
+ strategy=meta["strategy"],
37
+ apply_compression=False,
38
+ device='cpu')
39
+
40
+ elapsed_time = round(time.time() - start, 2)
41
+
42
+ prediction = convert_result(prediction)
43
+
44
+ return prediction, elapsed_time
45
+
46
+ # Create title, description and article strings
47
+ title = "Deepfake Detector (private)"
48
+ description = "A video Deepfake Classifier (code: https://github.com/selimsef/dfdc_deepfake_challenge)"
49
+
50
+ example_list = ["examples/" + str(p) for p in os.listdir("examples/")]
51
+
52
+ # Environments
53
+ model_dir = 'weights'
54
+ frames_per_video = 32
55
+ video_reader = VideoReader()
56
+ video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
57
+ face_extractor = FaceExtractor(video_read_fn)
58
+ input_size = 380
59
+ strategy = confident_strategy
60
+ class_names = ["Real", "Fake"]
61
+
62
+ meta = {"fps": 32,
63
+ "face_extractor": face_extractor,
64
+ "input_size": input_size,
65
+ "strategy": strategy}
66
+
67
+ model = model_fn(model_dir)
68
+
69
+ """
70
+ if __name__ == '__main__':
71
+ video_path = "nlurbvsozt.mp4"
72
+ model = model_fn(model_dir)
73
+ a, b = predict_fn([model], video_path, meta)
74
+ print(a, b)
75
+ """
76
+ # Create the Gradio demo
77
+ demo = gr.Interface(fn=predict_fn, # mapping function from input to output
78
+ inputs=[[model], gr.Video(autosize=True), meta],
79
+ outputs=[gr.Label(num_top_classes=2, label="Predictions"), # what are the outputs?
80
+ gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
81
+ examples=example_list,
82
+ title=title,
83
+ description=description)
84
+
85
+ # Launch the demo!
86
+ demo.launch(debug=False,) # Hugging face space don't need shareable_links
configs/b5.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "network": "DeepFakeClassifier",
3
+ "encoder": "tf_efficientnet_b5_ns",
4
+ "batches_per_epoch": 2500,
5
+ "size": 380,
6
+ "fp16": true,
7
+ "optimizer": {
8
+ "batch_size": 20,
9
+ "type": "SGD",
10
+ "momentum": 0.9,
11
+ "weight_decay": 1e-4,
12
+ "learning_rate": 0.01,
13
+ "nesterov": true,
14
+ "schedule": {
15
+ "type": "poly",
16
+ "mode": "step",
17
+ "epochs": 30,
18
+ "params": {"max_iter": 75100}
19
+ }
20
+ },
21
+ "normalize": {
22
+ "mean": [0.485, 0.456, 0.406],
23
+ "std": [0.229, 0.224, 0.225]
24
+ },
25
+ "losses": {
26
+ "BinaryCrossentropy": 1
27
+ }
28
+ }
configs/b7.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "network": "DeepFakeClassifier",
3
+ "encoder": "tf_efficientnet_b7_ns",
4
+ "batches_per_epoch": 2500,
5
+ "size": 380,
6
+ "fp16": true,
7
+ "optimizer": {
8
+ "batch_size": 4,
9
+ "type": "SGD",
10
+ "momentum": 0.9,
11
+ "weight_decay": 1e-4,
12
+ "learning_rate": 1e-4,
13
+ "nesterov": true,
14
+ "schedule": {
15
+ "type": "poly",
16
+ "mode": "step",
17
+ "epochs": 20,
18
+ "params": {"max_iter": 100500}
19
+ }
20
+ },
21
+ "normalize": {
22
+ "mean": [0.485, 0.456, 0.406],
23
+ "std": [0.229, 0.224, 0.225]
24
+ },
25
+ "losses": {
26
+ "BinaryCrossentropy": 1
27
+ }
28
+ }
29
+
download_weights.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ tag=0.0.1
2
+
3
+ wget -O weights/final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36
4
+ wget -O weights/final_555_DeepFakeClassifier_tf_efficientnet_b7_ns_0_19 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_555_DeepFakeClassifier_tf_efficientnet_b7_ns_0_19
5
+ wget -O weights/final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_29 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_29
6
+ wget -O weights/final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_31 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_31
7
+ wget -O weights/final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_37 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_37
8
+ wget -O weights/final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_40 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_40
9
+ wget -O weights/final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23
examples/liuujwwgpr.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3aaefb51aa5720cdabcc68d93da5c6a22573d8da06bdaf5e009c7a370943e85
3
+ size 12852441
examples/nlurbvsozt.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:300b7dea93132b512f35de76572e7fcde666c812b91aec6b189dafa6f100c9b5
3
+ size 4486723
examples/rfjuhbnlro.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6d0bb841ebe6a8e20cf265b45356a1ea3fed9837025e8d549b2437290d79273
3
+ size 16218775
kernel_utils.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from albumentations.augmentations.functional import image_compression
8
+ from facenet_pytorch.models.mtcnn import MTCNN
9
+ from concurrent.futures import ThreadPoolExecutor
10
+
11
+ from torchvision.transforms import Normalize
12
+
13
+ mean = [0.485, 0.456, 0.406]
14
+ std = [0.229, 0.224, 0.225]
15
+ normalize_transform = Normalize(mean, std)
16
+
17
+
18
+ class VideoReader:
19
+ """Helper class for reading one or more frames from a video file."""
20
+
21
+ def __init__(self, verbose=True, insets=(0, 0)):
22
+ """Creates a new VideoReader.
23
+
24
+ Arguments:
25
+ verbose: whether to print warnings and error messages
26
+ insets: amount to inset the image by, as a percentage of
27
+ (width, height). This lets you "zoom in" to an image
28
+ to remove unimportant content around the borders.
29
+ Useful for face detection, which may not work if the
30
+ faces are too small.
31
+ """
32
+ self.verbose = verbose
33
+ self.insets = insets
34
+
35
+ def read_frames(self, path, num_frames, jitter=0, seed=None):
36
+ """Reads frames that are always evenly spaced throughout the video.
37
+
38
+ Arguments:
39
+ path: the video file
40
+ num_frames: how many frames to read, -1 means the entire video
41
+ (warning: this will take up a lot of memory!)
42
+ jitter: if not 0, adds small random offsets to the frame indices;
43
+ this is useful so we don't always land on even or odd frames
44
+ seed: random seed for jittering; if you set this to a fixed value,
45
+ you probably want to set it only on the first video
46
+ """
47
+ assert num_frames > 0
48
+
49
+ capture = cv2.VideoCapture(path)
50
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
51
+ if frame_count <= 0: return None
52
+
53
+ frame_idxs = np.linspace(0, frame_count - 1, num_frames, endpoint=True, dtype=np.int32)
54
+ if jitter > 0:
55
+ np.random.seed(seed)
56
+ jitter_offsets = np.random.randint(-jitter, jitter, len(frame_idxs))
57
+ frame_idxs = np.clip(frame_idxs + jitter_offsets, 0, frame_count - 1)
58
+
59
+ result = self._read_frames_at_indices(path, capture, frame_idxs)
60
+ capture.release()
61
+ return result
62
+
63
+ def read_random_frames(self, path, num_frames, seed=None):
64
+ """Picks the frame indices at random.
65
+
66
+ Arguments:
67
+ path: the video file
68
+ num_frames: how many frames to read, -1 means the entire video
69
+ (warning: this will take up a lot of memory!)
70
+ """
71
+ assert num_frames > 0
72
+ np.random.seed(seed)
73
+
74
+ capture = cv2.VideoCapture(path)
75
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
76
+ if frame_count <= 0: return None
77
+
78
+ frame_idxs = sorted(np.random.choice(np.arange(0, frame_count), num_frames))
79
+ result = self._read_frames_at_indices(path, capture, frame_idxs)
80
+
81
+ capture.release()
82
+ return result
83
+
84
+ def read_frames_at_indices(self, path, frame_idxs):
85
+ """Reads frames from a video and puts them into a NumPy array.
86
+
87
+ Arguments:
88
+ path: the video file
89
+ frame_idxs: a list of frame indices. Important: should be
90
+ sorted from low-to-high! If an index appears multiple
91
+ times, the frame is still read only once.
92
+
93
+ Returns:
94
+ - a NumPy array of shape (num_frames, height, width, 3)
95
+ - a list of the frame indices that were read
96
+
97
+ Reading stops if loading a frame fails, in which case the first
98
+ dimension returned may actually be less than num_frames.
99
+
100
+ Returns None if an exception is thrown for any reason, or if no
101
+ frames were read.
102
+ """
103
+ assert len(frame_idxs) > 0
104
+ capture = cv2.VideoCapture(path)
105
+ result = self._read_frames_at_indices(path, capture, frame_idxs)
106
+ capture.release()
107
+ return result
108
+
109
+ def _read_frames_at_indices(self, path, capture, frame_idxs):
110
+ try:
111
+ frames = []
112
+ idxs_read = []
113
+ for frame_idx in range(frame_idxs[0], frame_idxs[-1] + 1):
114
+ # Get the next frame, but don't decode if we're not using it.
115
+ ret = capture.grab()
116
+ if not ret:
117
+ if self.verbose:
118
+ print("Error grabbing frame %d from movie %s" % (frame_idx, path))
119
+ break
120
+
121
+ # Need to look at this frame?
122
+ current = len(idxs_read)
123
+ if frame_idx == frame_idxs[current]:
124
+ ret, frame = capture.retrieve()
125
+ if not ret or frame is None:
126
+ if self.verbose:
127
+ print("Error retrieving frame %d from movie %s" % (frame_idx, path))
128
+ break
129
+
130
+ frame = self._postprocess_frame(frame)
131
+ frames.append(frame)
132
+ idxs_read.append(frame_idx)
133
+
134
+ if len(frames) > 0:
135
+ return np.stack(frames), idxs_read
136
+ if self.verbose:
137
+ print("No frames read from movie %s" % path)
138
+ return None
139
+ except:
140
+ if self.verbose:
141
+ print("Exception while reading movie %s" % path)
142
+ return None
143
+
144
+ def read_middle_frame(self, path):
145
+ """Reads the frame from the middle of the video."""
146
+ capture = cv2.VideoCapture(path)
147
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
148
+ result = self._read_frame_at_index(path, capture, frame_count // 2)
149
+ capture.release()
150
+ return result
151
+
152
+ def read_frame_at_index(self, path, frame_idx):
153
+ """Reads a single frame from a video.
154
+
155
+ If you just want to read a single frame from the video, this is more
156
+ efficient than scanning through the video to find the frame. However,
157
+ for reading multiple frames it's not efficient.
158
+
159
+ My guess is that a "streaming" approach is more efficient than a
160
+ "random access" approach because, unless you happen to grab a keyframe,
161
+ the decoder still needs to read all the previous frames in order to
162
+ reconstruct the one you're asking for.
163
+
164
+ Returns a NumPy array of shape (1, H, W, 3) and the index of the frame,
165
+ or None if reading failed.
166
+ """
167
+ capture = cv2.VideoCapture(path)
168
+ result = self._read_frame_at_index(path, capture, frame_idx)
169
+ capture.release()
170
+ return result
171
+
172
+ def _read_frame_at_index(self, path, capture, frame_idx):
173
+ capture.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
174
+ ret, frame = capture.read()
175
+ if not ret or frame is None:
176
+ if self.verbose:
177
+ print("Error retrieving frame %d from movie %s" % (frame_idx, path))
178
+ return None
179
+ else:
180
+ frame = self._postprocess_frame(frame)
181
+ return np.expand_dims(frame, axis=0), [frame_idx]
182
+
183
+ def _postprocess_frame(self, frame):
184
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
185
+
186
+ if self.insets[0] > 0:
187
+ W = frame.shape[1]
188
+ p = int(W * self.insets[0])
189
+ frame = frame[:, p:-p, :]
190
+
191
+ if self.insets[1] > 0:
192
+ H = frame.shape[1]
193
+ q = int(H * self.insets[1])
194
+ frame = frame[q:-q, :, :]
195
+
196
+ return frame
197
+
198
+
199
+ class FaceExtractor:
200
+ def __init__(self, video_read_fn):
201
+ self.video_read_fn = video_read_fn
202
+ self.detector = MTCNN(margin=0, thresholds=[0.7, 0.8, 0.8], device="cuda")
203
+
204
+ def process_videos(self, input_dir, filenames, video_idxs):
205
+ videos_read = []
206
+ frames_read = []
207
+ frames = []
208
+ results = []
209
+ for video_idx in video_idxs:
210
+ # Read the full-size frames from this video.
211
+ filename = filenames[video_idx]
212
+ video_path = os.path.join(input_dir, filename)
213
+ result = self.video_read_fn(video_path)
214
+ # Error? Then skip this video.
215
+ if result is None: continue
216
+
217
+ videos_read.append(video_idx)
218
+
219
+ # Keep track of the original frames (need them later).
220
+ my_frames, my_idxs = result
221
+
222
+ frames.append(my_frames)
223
+ frames_read.append(my_idxs)
224
+ for i, frame in enumerate(my_frames):
225
+ h, w = frame.shape[:2]
226
+ img = Image.fromarray(frame.astype(np.uint8))
227
+ img = img.resize(size=[s // 2 for s in img.size])
228
+
229
+ batch_boxes, probs = self.detector.detect(img, landmarks=False)
230
+
231
+ faces = []
232
+ scores = []
233
+ if batch_boxes is None:
234
+ continue
235
+ for bbox, score in zip(batch_boxes, probs):
236
+ if bbox is not None:
237
+ xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox]
238
+ w = xmax - xmin
239
+ h = ymax - ymin
240
+ p_h = h // 3
241
+ p_w = w // 3
242
+ crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w]
243
+ faces.append(crop)
244
+ scores.append(score)
245
+
246
+ frame_dict = {"video_idx": video_idx,
247
+ "frame_idx": my_idxs[i],
248
+ "frame_w": w,
249
+ "frame_h": h,
250
+ "faces": faces,
251
+ "scores": scores}
252
+ results.append(frame_dict)
253
+
254
+ return results
255
+
256
+ def process_video(self, video_path):
257
+ """Convenience method for doing face extraction on a single video."""
258
+ input_dir = os.path.dirname(video_path)
259
+ filenames = [os.path.basename(video_path)]
260
+ return self.process_videos(input_dir, filenames, [0])
261
+
262
+
263
+
264
+ def confident_strategy(pred, t=0.8):
265
+ pred = np.array(pred)
266
+ sz = len(pred)
267
+ fakes = np.count_nonzero(pred > t)
268
+ # 11 frames are detected as fakes with high probability
269
+ if fakes > sz // 2.5 and fakes > 11:
270
+ return np.mean(pred[pred > t])
271
+ elif np.count_nonzero(pred < 0.2) > 0.9 * sz:
272
+ return np.mean(pred[pred < 0.2])
273
+ else:
274
+ return np.mean(pred)
275
+
276
+ strategy = confident_strategy
277
+
278
+
279
+ def put_to_center(img, input_size):
280
+ img = img[:input_size, :input_size]
281
+ image = np.zeros((input_size, input_size, 3), dtype=np.uint8)
282
+ start_w = (input_size - img.shape[1]) // 2
283
+ start_h = (input_size - img.shape[0]) // 2
284
+ image[start_h:start_h + img.shape[0], start_w: start_w + img.shape[1], :] = img
285
+ return image
286
+
287
+
288
+ def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
289
+ h, w = img.shape[:2]
290
+ if max(w, h) == size:
291
+ return img
292
+ if w > h:
293
+ scale = size / w
294
+ h = h * scale
295
+ w = size
296
+ else:
297
+ scale = size / h
298
+ w = w * scale
299
+ h = size
300
+ interpolation = interpolation_up if scale > 1 else interpolation_down
301
+ resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
302
+ return resized
303
+
304
+
305
+ def predict_on_video(face_extractor, video_path, batch_size, input_size, models, strategy=np.mean,
306
+ apply_compression=False, device='cpu'):
307
+ batch_size *= 4
308
+ try:
309
+ faces = face_extractor.process_video(video_path)
310
+ if len(faces) > 0:
311
+ x = np.zeros((batch_size, input_size, input_size, 3), dtype=np.uint8)
312
+ n = 0
313
+ for frame_data in faces:
314
+ for face in frame_data["faces"]:
315
+ resized_face = isotropically_resize_image(face, input_size)
316
+ resized_face = put_to_center(resized_face, input_size)
317
+ if apply_compression:
318
+ resized_face = image_compression(resized_face, quality=90, image_type=".jpg")
319
+ if n + 1 < batch_size:
320
+ x[n] = resized_face
321
+ n += 1
322
+ else:
323
+ pass
324
+ if n > 0:
325
+ if device == 'cpu':
326
+ x = torch.tensor(x, device='cpu').float()
327
+ else:
328
+ x = torch.tensor(x, device="cuda").float()
329
+ # Preprocess the images.
330
+ x = x.permute((0, 3, 1, 2))
331
+ for i in range(len(x)):
332
+ x[i] = normalize_transform(x[i] / 255.)
333
+ # Make a prediction, then take the average.
334
+ with torch.no_grad():
335
+ preds = []
336
+ for model in models:
337
+ if device == 'cpu':
338
+ y_pred = model(x[:n])
339
+ else:
340
+ y_pred = model(x[:n].half())
341
+ y_pred = torch.sigmoid(y_pred.squeeze())
342
+ bpred = y_pred[:n].cpu().numpy()
343
+ preds.append(strategy(bpred))
344
+ return np.mean(preds)
345
+ except Exception as e:
346
+ print("Prediction error on video %s: %s" % (video_path, str(e)))
347
+
348
+ return 0.5
349
+
350
+
351
+ def predict_on_video_set(face_extractor, videos, input_size, num_workers, test_dir, frames_per_video, models,
352
+ strategy=np.mean,
353
+ apply_compression=False):
354
+ def process_file(i):
355
+ filename = videos[i]
356
+ y_pred = predict_on_video(face_extractor=face_extractor, video_path=os.path.join(test_dir, filename),
357
+ input_size=input_size,
358
+ batch_size=frames_per_video,
359
+ models=models, strategy=strategy, apply_compression=apply_compression)
360
+ return y_pred
361
+
362
+ with ThreadPoolExecutor(max_workers=num_workers) as ex:
363
+ predictions = ex.map(process_file, range(len(videos)))
364
+ return list(predictions)
365
+
libs/shape_predictor_68_face_landmarks.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f
3
+ size 99693937
training/__init__.py ADDED
File without changes
training/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (148 Bytes). View file
 
training/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (146 Bytes). View file
 
training/__pycache__/losses.cpython-310.pyc ADDED
Binary file (1.54 kB). View file
 
training/__pycache__/losses.cpython-39.pyc ADDED
Binary file (1.53 kB). View file
 
training/datasets/__init__.py ADDED
File without changes
training/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (157 Bytes). View file
 
training/datasets/__pycache__/classifier_dataset.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
training/datasets/__pycache__/validation_set.cpython-310.pyc ADDED
Binary file (4.99 kB). View file
 
training/datasets/classifier_dataset.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import sys
5
+ import traceback
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import pandas as pd
10
+ import skimage.draw
11
+ from albumentations import ImageCompression, OneOf, GaussianBlur, Blur
12
+ from albumentations.augmentations.functional import image_compression
13
+ from albumentations.augmentations.geometric.functional import rot90
14
+ from albumentations.pytorch.functional import img_to_tensor
15
+ from scipy.ndimage import binary_erosion, binary_dilation
16
+ from skimage import measure
17
+ from torch.utils.data import Dataset
18
+ import dlib
19
+
20
+ from training.datasets.validation_set import PUBLIC_SET
21
+
22
+
23
+ def prepare_bit_masks(mask):
24
+ h, w = mask.shape
25
+ mid_w = w // 2
26
+ mid_h = w // 2
27
+ masks = []
28
+ ones = np.ones_like(mask)
29
+ ones[:mid_h] = 0
30
+ masks.append(ones)
31
+ ones = np.ones_like(mask)
32
+ ones[mid_h:] = 0
33
+ masks.append(ones)
34
+ ones = np.ones_like(mask)
35
+ ones[:, :mid_w] = 0
36
+ masks.append(ones)
37
+ ones = np.ones_like(mask)
38
+ ones[:, mid_w:] = 0
39
+ masks.append(ones)
40
+ ones = np.ones_like(mask)
41
+ ones[:mid_h, :mid_w] = 0
42
+ ones[mid_h:, mid_w:] = 0
43
+ masks.append(ones)
44
+ ones = np.ones_like(mask)
45
+ ones[:mid_h, mid_w:] = 0
46
+ ones[mid_h:, :mid_w] = 0
47
+ masks.append(ones)
48
+ return masks
49
+
50
+
51
+ detector = dlib.get_frontal_face_detector()
52
+ predictor = dlib.shape_predictor('libs/shape_predictor_68_face_landmarks.dat')
53
+
54
+
55
+ def blackout_convex_hull(img):
56
+ try:
57
+ rect = detector(img)[0]
58
+ sp = predictor(img, rect)
59
+ landmarks = np.array([[p.x, p.y] for p in sp.parts()])
60
+ outline = landmarks[[*range(17), *range(26, 16, -1)]]
61
+ Y, X = skimage.draw.polygon(outline[:, 1], outline[:, 0])
62
+ cropped_img = np.zeros(img.shape[:2], dtype=np.uint8)
63
+ cropped_img[Y, X] = 1
64
+ # if random.random() > 0.5:
65
+ # img[cropped_img == 0] = 0
66
+ # #leave only face
67
+ # return img
68
+
69
+ y, x = measure.centroid(cropped_img)
70
+ y = int(y)
71
+ x = int(x)
72
+ first = random.random() > 0.5
73
+ if random.random() > 0.5:
74
+ if first:
75
+ cropped_img[:y, :] = 0
76
+ else:
77
+ cropped_img[y:, :] = 0
78
+ else:
79
+ if first:
80
+ cropped_img[:, :x] = 0
81
+ else:
82
+ cropped_img[:, x:] = 0
83
+
84
+ img[cropped_img > 0] = 0
85
+ except Exception as e:
86
+ pass
87
+
88
+
89
+ def dist(p1, p2):
90
+ return math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
91
+
92
+
93
+ def remove_eyes(image, landmarks):
94
+ image = image.copy()
95
+ (x1, y1), (x2, y2) = landmarks[:2]
96
+ mask = np.zeros_like(image[..., 0])
97
+ line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2)
98
+ w = dist((x1, y1), (x2, y2))
99
+ dilation = int(w // 4)
100
+ line = binary_dilation(line, iterations=dilation)
101
+ image[line, :] = 0
102
+ return image
103
+
104
+
105
+ def remove_nose(image, landmarks):
106
+ image = image.copy()
107
+ (x1, y1), (x2, y2) = landmarks[:2]
108
+ x3, y3 = landmarks[2]
109
+ mask = np.zeros_like(image[..., 0])
110
+ x4 = int((x1 + x2) / 2)
111
+ y4 = int((y1 + y2) / 2)
112
+ line = cv2.line(mask, (x3, y3), (x4, y4), color=(1), thickness=2)
113
+ w = dist((x1, y1), (x2, y2))
114
+ dilation = int(w // 4)
115
+ line = binary_dilation(line, iterations=dilation)
116
+ image[line, :] = 0
117
+ return image
118
+
119
+
120
+ def remove_mouth(image, landmarks):
121
+ image = image.copy()
122
+ (x1, y1), (x2, y2) = landmarks[-2:]
123
+ mask = np.zeros_like(image[..., 0])
124
+ line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2)
125
+ w = dist((x1, y1), (x2, y2))
126
+ dilation = int(w // 3)
127
+ line = binary_dilation(line, iterations=dilation)
128
+ image[line, :] = 0
129
+ return image
130
+
131
+
132
+ def remove_landmark(image, landmarks):
133
+ if random.random() > 0.5:
134
+ image = remove_eyes(image, landmarks)
135
+ elif random.random() > 0.5:
136
+ image = remove_mouth(image, landmarks)
137
+ elif random.random() > 0.5:
138
+ image = remove_nose(image, landmarks)
139
+ return image
140
+
141
+
142
+ def change_padding(image, part=5):
143
+ h, w = image.shape[:2]
144
+ # original padding was done with 1/3 from each side, too much
145
+ pad_h = int(((3 / 5) * h) / part)
146
+ pad_w = int(((3 / 5) * w) / part)
147
+ image = image[h // 5 - pad_h:-h // 5 + pad_h, w // 5 - pad_w:-w // 5 + pad_w]
148
+ return image
149
+
150
+
151
+ def blackout_random(image, mask, label):
152
+ binary_mask = mask > 0.4 * 255
153
+ h, w = binary_mask.shape[:2]
154
+
155
+ tries = 50
156
+ current_try = 1
157
+ while current_try < tries:
158
+ first = random.random() < 0.5
159
+ if random.random() < 0.5:
160
+ pivot = random.randint(h // 2 - h // 5, h // 2 + h // 5)
161
+ bitmap_msk = np.ones_like(binary_mask)
162
+ if first:
163
+ bitmap_msk[:pivot, :] = 0
164
+ else:
165
+ bitmap_msk[pivot:, :] = 0
166
+ else:
167
+ pivot = random.randint(w // 2 - w // 5, w // 2 + w // 5)
168
+ bitmap_msk = np.ones_like(binary_mask)
169
+ if first:
170
+ bitmap_msk[:, :pivot] = 0
171
+ else:
172
+ bitmap_msk[:, pivot:] = 0
173
+
174
+ if label < 0.5 and np.count_nonzero(image * np.expand_dims(bitmap_msk, axis=-1)) / 3 > (h * w) / 5 \
175
+ or np.count_nonzero(binary_mask * bitmap_msk) > 40:
176
+ mask *= bitmap_msk
177
+ image *= np.expand_dims(bitmap_msk, axis=-1)
178
+ break
179
+ current_try += 1
180
+ return image
181
+
182
+
183
+ def blend_original(img):
184
+ img = img.copy()
185
+ h, w = img.shape[:2]
186
+ rect = detector(img)
187
+ if len(rect) == 0:
188
+ return img
189
+ else:
190
+ rect = rect[0]
191
+ sp = predictor(img, rect)
192
+ landmarks = np.array([[p.x, p.y] for p in sp.parts()])
193
+ outline = landmarks[[*range(17), *range(26, 16, -1)]]
194
+ Y, X = skimage.draw.polygon(outline[:, 1], outline[:, 0])
195
+ raw_mask = np.zeros(img.shape[:2], dtype=np.uint8)
196
+ raw_mask[Y, X] = 1
197
+ face = img * np.expand_dims(raw_mask, -1)
198
+
199
+ # add warping
200
+ h1 = random.randint(h - h // 2, h + h // 2)
201
+ w1 = random.randint(w - w // 2, w + w // 2)
202
+ while abs(h1 - h) < h // 3 and abs(w1 - w) < w // 3:
203
+ h1 = random.randint(h - h // 2, h + h // 2)
204
+ w1 = random.randint(w - w // 2, w + w // 2)
205
+ face = cv2.resize(face, (w1, h1), interpolation=random.choice([cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC]))
206
+ face = cv2.resize(face, (w, h), interpolation=random.choice([cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC]))
207
+
208
+ raw_mask = binary_erosion(raw_mask, iterations=random.randint(4, 10))
209
+ img[raw_mask, :] = face[raw_mask, :]
210
+ if random.random() < 0.2:
211
+ img = OneOf([GaussianBlur(), Blur()], p=0.5)(image=img)["image"]
212
+ # image compression
213
+ if random.random() < 0.5:
214
+ img = ImageCompression(quality_lower=40, quality_upper=95)(image=img)["image"]
215
+ return img
216
+
217
+
218
+ class DeepFakeClassifierDataset(Dataset):
219
+
220
+ def __init__(self,
221
+ data_path="/mnt/sota/datasets/deepfake",
222
+ fold=0,
223
+ label_smoothing=0.01,
224
+ padding_part=3,
225
+ hardcore=True,
226
+ crops_dir="crops",
227
+ folds_csv="folds.csv",
228
+ normalize={"mean": [0.485, 0.456, 0.406],
229
+ "std": [0.229, 0.224, 0.225]},
230
+ rotation=False,
231
+ mode="train",
232
+ reduce_val=True,
233
+ oversample_real=True,
234
+ transforms=None
235
+ ):
236
+ super().__init__()
237
+ self.data_root = data_path
238
+ self.fold = fold
239
+ self.folds_csv = folds_csv
240
+ self.mode = mode
241
+ self.rotation = rotation
242
+ self.padding_part = padding_part
243
+ self.hardcore = hardcore
244
+ self.crops_dir = crops_dir
245
+ self.label_smoothing = label_smoothing
246
+ self.normalize = normalize
247
+ self.transforms = transforms
248
+ self.df = pd.read_csv(self.folds_csv)
249
+ self.oversample_real = oversample_real
250
+ self.reduce_val = reduce_val
251
+
252
+ def __getitem__(self, index: int):
253
+
254
+ while True:
255
+ video, img_file, label, ori_video, frame, fold = self.data[index]
256
+ try:
257
+ if self.mode == "train":
258
+ label = np.clip(label, self.label_smoothing, 1 - self.label_smoothing)
259
+ img_path = os.path.join(self.data_root, self.crops_dir, video, img_file)
260
+ image = cv2.imread(img_path, cv2.IMREAD_COLOR)
261
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
262
+ mask = np.zeros(image.shape[:2], dtype=np.uint8)
263
+ diff_path = os.path.join(self.data_root, "diffs", video, img_file[:-4] + "_diff.png")
264
+ try:
265
+ msk = cv2.imread(diff_path, cv2.IMREAD_GRAYSCALE)
266
+ if msk is not None:
267
+ mask = msk
268
+ except:
269
+ print("not found mask", diff_path)
270
+ pass
271
+ if self.mode == "train" and self.hardcore and not self.rotation:
272
+ landmark_path = os.path.join(self.data_root, "landmarks", ori_video, img_file[:-4] + ".npy")
273
+ if os.path.exists(landmark_path) and random.random() < 0.7:
274
+ landmarks = np.load(landmark_path)
275
+ image = remove_landmark(image, landmarks)
276
+ elif random.random() < 0.2:
277
+ blackout_convex_hull(image)
278
+ elif random.random() < 0.1:
279
+ binary_mask = mask > 0.4 * 255
280
+ masks = prepare_bit_masks((binary_mask * 1).astype(np.uint8))
281
+ tries = 6
282
+ current_try = 1
283
+ while current_try < tries:
284
+ bitmap_msk = random.choice(masks)
285
+ if label < 0.5 or np.count_nonzero(mask * bitmap_msk) > 20:
286
+ mask *= bitmap_msk
287
+ image *= np.expand_dims(bitmap_msk, axis=-1)
288
+ break
289
+ current_try += 1
290
+ if self.mode == "train" and self.padding_part > 3:
291
+ image = change_padding(image, self.padding_part)
292
+ valid_label = np.count_nonzero(mask[mask > 20]) > 32 or label < 0.5
293
+ valid_label = 1 if valid_label else 0
294
+ rotation = 0
295
+ if self.transforms:
296
+ data = self.transforms(image=image, mask=mask)
297
+ image = data["image"]
298
+ mask = data["mask"]
299
+ if self.mode == "train" and self.hardcore and self.rotation:
300
+ # landmark_path = os.path.join(self.data_root, "landmarks", ori_video, img_file[:-4] + ".npy")
301
+ dropout = 0.8 if label > 0.5 else 0.6
302
+ if self.rotation:
303
+ dropout *= 0.7
304
+ elif random.random() < dropout:
305
+ blackout_random(image, mask, label)
306
+
307
+ #
308
+ # os.makedirs("../images", exist_ok=True)
309
+ # cv2.imwrite(os.path.join("../images", video+ "_" + str(1 if label > 0.5 else 0) + "_"+img_file), image[...,::-1])
310
+
311
+ if self.mode == "train" and self.rotation:
312
+ rotation = random.randint(0, 3)
313
+ image = rot90(image, rotation)
314
+
315
+ image = img_to_tensor(image, self.normalize)
316
+ return {"image": image, "labels": np.array((label,)), "img_name": os.path.join(video, img_file),
317
+ "valid": valid_label, "rotations": rotation}
318
+ except Exception as e:
319
+ traceback.print_exc(file=sys.stdout)
320
+ print("Broken image", os.path.join(self.data_root, self.crops_dir, video, img_file))
321
+ index = random.randint(0, len(self.data) - 1)
322
+
323
+ def random_blackout_landmark(self, image, mask, landmarks):
324
+ x, y = random.choice(landmarks)
325
+ first = random.random() > 0.5
326
+ # crop half face either vertically or horizontally
327
+ if random.random() > 0.5:
328
+ # width
329
+ if first:
330
+ image[:, :x] = 0
331
+ mask[:, :x] = 0
332
+ else:
333
+ image[:, x:] = 0
334
+ mask[:, x:] = 0
335
+ else:
336
+ # height
337
+ if first:
338
+ image[:y, :] = 0
339
+ mask[:y, :] = 0
340
+ else:
341
+ image[y:, :] = 0
342
+ mask[y:, :] = 0
343
+
344
+ def reset(self, epoch, seed):
345
+ self.data = self._prepare_data(epoch, seed)
346
+
347
+ def __len__(self) -> int:
348
+ return len(self.data)
349
+
350
+ def get_distribution(self):
351
+ return self.n_real, self.n_fake
352
+
353
+ def _prepare_data(self, epoch, seed):
354
+ df = self.df
355
+ if self.mode == "train":
356
+ rows = df[df["fold"] != self.fold]
357
+ else:
358
+ rows = df[df["fold"] == self.fold]
359
+ seed = (epoch + 1) * seed
360
+ if self.oversample_real:
361
+ rows = self._oversample(rows, seed)
362
+ if self.mode == "val" and self.reduce_val:
363
+ # every 2nd frame, to speed up validation
364
+ rows = rows[rows["frame"] % 20 == 0]
365
+ # another option is to use public validation set
366
+ #rows = rows[rows["video"].isin(PUBLIC_SET)]
367
+
368
+ print(
369
+ "real {} fakes {} mode {}".format(len(rows[rows["label"] == 0]), len(rows[rows["label"] == 1]), self.mode))
370
+ data = rows.values
371
+
372
+ self.n_real = len(rows[rows["label"] == 0])
373
+ self.n_fake = len(rows[rows["label"] == 1])
374
+ np.random.seed(seed)
375
+ np.random.shuffle(data)
376
+ return data
377
+
378
+ def _oversample(self, rows: pd.DataFrame, seed):
379
+ real = rows[rows["label"] == 0]
380
+ fakes = rows[rows["label"] == 1]
381
+ num_real = real["video"].count()
382
+ if self.mode == "train":
383
+ fakes = fakes.sample(n=num_real, replace=False, random_state=seed)
384
+ return pd.concat([real, fakes])
training/datasets/validation_set.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ PUBLIC_SET = {'tjuihawuqm', 'prwsfljdjo', 'scrbqgpvzz', 'ziipxxchai', 'uubgqnvfdl', 'wclvkepakb', 'xjvxtuakyd',
4
+ 'qlvsqdroqo', 'bcbqxhziqz', 'yzuestxcbq', 'hxwtsaydal', 'kqlvggiqee', 'vtunvalyji', 'mohiqoogpb',
5
+ 'siebfpwuhu', 'cekwtyxdoo', 'hszwwswewp', 'orekjthsef', 'huvlwkxoxm', 'fmhiujydwo', 'lhvjzhjxdp',
6
+ 'ibxfxggtqh', 'bofrwgeyjo', 'rmufsuogzn', 'zbgssotnjm', 'dpevefkefv', 'sufvvwmbha', 'ncoeewrdlo',
7
+ 'qhsehzgxqj', 'yxadevzohx', 'aomqqjipcp', 'pcyswtgick', 'wfzjxzhdkj', 'rcjfxxhcal', 'lnjkpdviqb',
8
+ 'xmkwsnuzyq', 'ouaowjmigq', 'bkuzquigyt', 'vwxednhlwz', 'mszblrdprw', 'blnmxntbey', 'gccnvdoknm',
9
+ 'mkzaekkvej', 'hclsparpth', 'eryjktdexi', 'hfsvqabzfq', 'acazlolrpz', 'yoyhmxtrys', 'rerpivllud',
10
+ 'elackxuccp', 'zgbhzkditd', 'vjljdfopjg', 'famlupsgqm', 'nymodlmxni', 'qcbkztamqc', 'qclpbcbgeq',
11
+ 'lpkgabskbw', 'mnowxangqx', 'czfqlbcfpa', 'qyyhuvqmyf', 'toinozytsp', 'ztyvglkcsf', 'nplviymzlg',
12
+ 'opvqdabdap', 'uxuvkrjhws', 'mxahsihabr', 'cqxxumarvp', 'ptbfnkajyi', 'njzshtfmcw', 'dcqodpzomd',
13
+ 'ajiyrjfyzp', 'ywauoonmlr', 'gochxzemmq', 'lpgxwdgnio', 'hnfwagcxdf', 'gfcycflhbo', 'gunamloolc',
14
+ 'yhjlnisfel', 'srfefmyjvt', 'evysmtpnrf', 'aktnlyqpah', 'gpsxfxrjrr', 'zfobicuigx', 'mnzabbkpmt',
15
+ 'rfjuhbnlro', 'zuwwbbusgl', 'csnkohqxdv', 'bzvzpwrabw', 'yietrwuncf', 'wynotylpnm', 'ekboxwrwuv',
16
+ 'rcecrgeotc', 'rklawjhbpv', 'ilqwcbprqa', 'jsysgmycsx', 'sqixhnilfm', 'wnlubukrki', 'nikynwcvuh',
17
+ 'sjkfxrlxxs', 'btdxnajogv', 'wjhpisoeaj', 'dyjklprkoc', 'qlqhjcshpk', 'jyfvaequfg', 'dozjwhnedd',
18
+ 'owaogcehvc', 'oyqgwjdwaj', 'vvfszaosiv', 'kmcdjxmnoa', 'jiswxuqzyz', 'ddtbarpcgo', 'wqysrieiqu',
19
+ 'xcruhaccxc', 'honxqdilvv', 'nxgzmgzkfv', 'cxsvvnxpyz', 'demuhxssgl', 'hzoiotcykp', 'fwykevubzy',
20
+ 'tejfudfgpq', 'kvmpmhdxly', 'oojxonbgow', 'vurjckblge', 'oysopgovhu', 'khpipxnsvx', 'pqthmvwonf',
21
+ 'fddmkqjwsh', 'pcoxcmtroa', 'cnxccbjlct', 'ggzjfrirjh', 'jquevmhdvc', 'ecumyiowzs', 'esmqxszybs',
22
+ 'mllzkpgatp', 'ryxaqpfubf', 'hbufmvbium', 'vdtsbqidjb', 'sjwywglgym', 'qxyrtwozyw', 'upmgtackuf',
23
+ 'ucthmsajay', 'zgjosltkie', 'snlyjbnpgw', 'nswtvttxre', 'iznnzjvaxc', 'jhczqfefgw', 'htzbnroagi',
24
+ 'pdswwyyntw', 'uvrzaczrbx', 'vbcgoyxsvn', 'hzssdinxec', 'novarhxpbj', 'vizerpsvbz', 'jawgcggquk',
25
+ 'iorbtaarte', 'yarpxfqejd', 'vhbbwdflyh', 'rrrfjhugvb', 'fneqiqpqvs', 'jytrvwlewz', 'bfjsthfhbd',
26
+ 'rxdoimqble', 'ekelfsnqof', 'uqvxjfpwdo', 'cjkctqqakb', 'tynfsthodx', 'yllztsrwjw', 'bktkwbcawi',
27
+ 'wcqvzujamg', 'bcvheslzrq', 'aqrsylrzgi', 'sktpeppbkc', 'mkmgcxaztt', 'etdliwticv', 'hqzwudvhih',
28
+ 'swsaoktwgi', 'temjefwaas', 'papagllumt', 'xrtvqhdibb', 'oelqpetgwj', 'ggdpclfcgk', 'imdmhwkkni',
29
+ 'lebzjtusnr', 'xhtppuyqdr', 'nxzgekegsp', 'waucvvmtkq', 'rnfcjxynfa', 'adohdulfwb', 'tjywwgftmv',
30
+ 'fjrueenjyp', 'oaguiggjyv', 'ytopzxrswu', 'yxvmusxvcz', 'rukyxomwcx', 'qdqdsaiitt', 'mxlipjhmqk',
31
+ 'voawxrmqyl', 'kezwvsxxzj', 'oocincvedt', 'qooxnxqqjb', 'mwwploizlj', 'yaxgpxhavq', 'uhakqelqri',
32
+ 'bvpeerislp', 'bkcyglmfci', 'jyoxdvxpza', 'gkutjglghz', 'knxltsvzyu', 'ybbrkacebd', 'apvzjkvnwn',
33
+ 'ahjnxtiamx', 'hsbljbsgxr', 'fnxgqcvlsd', 'xphdfgmfmz', 'scbdenmaed', 'ywxpquomgt', 'yljecirelf',
34
+ 'wcvsqnplsk', 'vmxfwxgdei', 'icbsahlivv', 'yhylappzid', 'irqzdokcws', 'petmyhjclt', 'rmlzgerevr',
35
+ 'qarqtkvgby', 'nkhzxomani', 'viteugozpv', 'qhkzlnzruj', 'eisofhptvk', 'gqnaxievjx', 'heiyoojifp',
36
+ 'zcxcmneefk', 'wvgviwnwob', 'gcdtglsoqj', 'yqhouqakbx', 'fopjiyxiqd', 'hierggamuo', 'ypbtpunjvm',
37
+ 'sjinmmbipg', 'kmqkiihrmj', 'wmoqzxddkb', 'lnhkjhyhvw', 'wixbuuzygv', 'fsdrwikhge', 'sfsayjgzrh',
38
+ 'pqdeutauqc', 'frqfsucgao', 'pdufsewrec', 'bfdopzvxbi', 'shnsajrsow', 'rvvpazsffd', 'pxcfrszlgi',
39
+ 'itfsvvmslp', 'ayipraspbn', 'prhmixykhr', 'doniqevxeg', 'dvtpwatuja', 'jiavqbrkyk', 'ipkpxvwroe',
40
+ 'syxobtuucp', 'syuxttuyhm', 'nwvsbmyndn', 'eqslzbqfea', 'ytddugrwph', 'vokrpfjpeb', 'bdshuoldwx',
41
+ 'fmvvmcbdrw', 'bnuwxhfahw', 'gbnzicjyhz', 'txnmkabufs', 'gfdjzwnpyp', 'hweshqpfwe', 'dxgnpnowgk',
42
+ 'xugmhbetrw', 'rktrpsdlci', 'nthpnwylxo', 'ihglzxzroo', 'ocgdbrgmtq', 'ruhtnngrqv', 'xljemofssi',
43
+ 'zxacihctqp', 'ghnpsltzyn', 'lbigytrrtr', 'ndikguxzek', 'mdfndlljvt', 'lyoslorecs', 'oefukgnvel',
44
+ 'zmxeiipnqb', 'cosghhimnd', 'alrtntfxtd', 'eywdmustbb', 'ooafcxxfrs', 'fqgypsunzr', 'hevcclcklc',
45
+ 'uhrqlmlclw', 'ipvwtgdlre', 'wcssbghcpc', 'didzujjhtg', 'fjxovgmwnm', 'dmmvuaikkv', 'hitfycdavv',
46
+ 'zyufpqvpyu', 'coujjnypba', 'temeqbmzxu', 'apedduehoy', 'iksxzpqxzi', 'kwfdyqofzw', 'aassnaulhq',
47
+ 'eyguqfmgzh', 'yiykshcbaz', 'sngjsueuhs', 'okgelildpc', 'ztyuiqrhdk', 'tvhjcfnqtg', 'gfgcwxkbjd',
48
+ 'lbfqksftuo', 'kowiwvrjht', 'dkuqbduxev', 'mwnibuujwz', 'sodvtfqbpf', 'hsbwhlolsn', 'qsjiypnjwi',
49
+ 'blszgmxkvu', 'ystdtnetgj', 'rfwxcinshk', 'vnlzxqwthl', 'ljouzjaqqe', 'gahgyuwzbu', 'xxzefxwyku',
50
+ 'xitgdpzbxv', 'sylnrepacf', 'igpvrfjdzc', 'nxnmkytwze', 'psesikjaxx', 'dvwpvqdflx', 'bjyaxvggle',
51
+ 'dpmgoiwhuf', 'wadvzjhwtw', 'kcjvhgvhpt', 'eppyqpgewp', 'tyjpjpglgx', 'cekarydqba', 'dvkdfhrpph',
52
+ 'cnpanmywno', 'ljauauuyka', 'hicjuubiau', 'cqhwesrciw', 'dnmowthjcj', 'lujvyveojc', 'wndursivcx',
53
+ 'espkiocpxq', 'jsbpkpxwew', 'dsnxgrfdmd', 'hyjqolupxn', 'xdezcezszc', 'axfhbpkdlc', 'qqnlrngaft',
54
+ 'coqwgzpbhx', 'ncmpqwmnzb', 'sznkemeqro', 'omphqltjdd', 'uoccaiathd', 'jzmzdispyo', 'pxjkzvqomp',
55
+ 'udxqbhgvvx', 'dzkyxbbqkr', 'dtozwcapoa', 'qswlzfgcgj', 'tgawasvbbr', 'lmdyicksrv', 'fzvpbrzssi',
56
+ 'dxfdovivlw', 'zzmgnglanj', 'vssmlqoiti', 'vajkicalux', 'ekvwecwltj', 'ylxwcwhjjd', 'keioymnobc',
57
+ 'usqqvxcjmg', 'phjvutxpoi', 'nycmyuzpml', 'bwdmzwhdnw', 'fxuxxtryjn', 'orixbcfvdz', 'hefisnapds',
58
+ 'fpevfidstw', 'halvwiltfs', 'dzojiwfvba', 'ojsxxkalat', 'esjdyghhog', 'ptbnewtvon', 'hcanfkwivl',
59
+ 'yronlutbgm', 'llplvmcvbl', 'yxirnfyijn', 'nwvloufjty', 'rtpbawlmxr', 'aayfryxljh', 'zfrrixsimm',
60
+ 'txmnoyiyte'}
training/losses.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from pytorch_toolbelt.losses import BinaryFocalLoss
4
+ from torch import nn
5
+ from torch.nn.modules.loss import BCEWithLogitsLoss
6
+
7
+
8
+ class WeightedLosses(nn.Module):
9
+ def __init__(self, losses, weights):
10
+ super().__init__()
11
+ self.losses = losses
12
+ self.weights = weights
13
+
14
+ def forward(self, *input: Any, **kwargs: Any):
15
+ cum_loss = 0
16
+ for loss, w in zip(self.losses, self.weights):
17
+ cum_loss += w * loss.forward(*input, **kwargs)
18
+ return cum_loss
19
+
20
+
21
+ class BinaryCrossentropy(BCEWithLogitsLoss):
22
+ pass
23
+
24
+
25
+ class FocalLoss(BinaryFocalLoss):
26
+ def __init__(self, alpha=None, gamma=3, ignore_index=None, reduction="mean", normalized=False,
27
+ reduced_threshold=None):
28
+ super().__init__(alpha, gamma, ignore_index, reduction, normalized, reduced_threshold)
training/pipelines/__init__.py ADDED
File without changes
training/pipelines/train_classifier.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from collections import defaultdict
5
+
6
+ from sklearn.metrics import log_loss
7
+ from torch import topk
8
+
9
+ import sys
10
+ print('@@@@@@@@@@@@@@@@@@')
11
+ sys.path.append('..')
12
+
13
+ from training import losses
14
+ from training.datasets.classifier_dataset import DeepFakeClassifierDataset
15
+ from training.losses import WeightedLosses
16
+ from training.tools.config import load_config
17
+ from training.tools.utils import create_optimizer, AverageMeter
18
+ from training.transforms.albu import IsotropicResize
19
+ from training.zoo import classifiers
20
+
21
+ os.environ["MKL_NUM_THREADS"] = "1"
22
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
23
+ os.environ["OMP_NUM_THREADS"] = "1"
24
+
25
+ import cv2
26
+
27
+ cv2.ocl.setUseOpenCL(False)
28
+ cv2.setNumThreads(0)
29
+ import numpy as np
30
+ from albumentations import Compose, RandomBrightnessContrast, \
31
+ HorizontalFlip, FancyPCA, HueSaturationValue, OneOf, ToGray, \
32
+ ShiftScaleRotate, ImageCompression, PadIfNeeded, GaussNoise, GaussianBlur
33
+
34
+ from apex.parallel import DistributedDataParallel, convert_syncbn_model
35
+ from tensorboardX import SummaryWriter
36
+
37
+ from apex import amp
38
+
39
+ import torch
40
+ from torch.backends import cudnn
41
+ from torch.nn import DataParallel
42
+ from torch.utils.data import DataLoader
43
+ from tqdm import tqdm
44
+ import torch.distributed as dist
45
+
46
+ torch.backends.cudnn.benchmark = True
47
+
48
+ def create_train_transforms(size=300):
49
+ return Compose([
50
+ ImageCompression(quality_lower=60, quality_upper=100, p=0.5),
51
+ GaussNoise(p=0.1),
52
+ GaussianBlur(blur_limit=3, p=0.05),
53
+ HorizontalFlip(),
54
+ OneOf([
55
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
56
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR),
57
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR),
58
+ ], p=1),
59
+ PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
60
+ OneOf([RandomBrightnessContrast(), FancyPCA(), HueSaturationValue()], p=0.7),
61
+ ToGray(p=0.2),
62
+ ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5),
63
+ ]
64
+ )
65
+
66
+
67
+ def create_val_transforms(size=300):
68
+ return Compose([
69
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
70
+ PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
71
+ ])
72
+
73
+
74
+ def main():
75
+ parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
76
+ arg = parser.add_argument
77
+ arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
78
+ arg('--workers', type=int, default=6, help='number of cpu threads to use')
79
+ arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3')
80
+ arg('--output-dir', type=str, default='weights/')
81
+ arg('--resume', type=str, default='')
82
+ arg('--fold', type=int, default=0)
83
+ arg('--prefix', type=str, default='classifier_')
84
+ arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake")
85
+ arg('--folds-csv', type=str, default='folds.csv')
86
+ arg('--crops-dir', type=str, default='crops')
87
+ arg('--label-smoothing', type=float, default=0.01)
88
+ arg('--logdir', type=str, default='logs')
89
+ arg('--zero-score', action='store_true', default=False)
90
+ arg('--from-zero', action='store_true', default=False)
91
+ arg('--distributed', action='store_true', default=False)
92
+ arg('--freeze-epochs', type=int, default=0)
93
+ arg("--local_rank", default=0, type=int)
94
+ arg("--seed", default=777, type=int)
95
+ arg("--padding-part", default=3, type=int)
96
+ arg("--opt-level", default='O1', type=str)
97
+ arg("--test_every", type=int, default=1)
98
+ arg("--no-oversample", action="store_true")
99
+ arg("--no-hardcore", action="store_true")
100
+ arg("--only-changed-frames", action="store_true")
101
+
102
+ args = parser.parse_args()
103
+ os.makedirs(args.output_dir, exist_ok=True)
104
+ if args.distributed:
105
+ torch.cuda.set_device(args.local_rank)
106
+ torch.distributed.init_process_group(backend='nccl', init_method='env://')
107
+ else:
108
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
109
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
110
+
111
+ cudnn.benchmark = True
112
+
113
+ conf = load_config(args.config)
114
+ model = classifiers.__dict__[conf['network']](encoder=conf['encoder'])
115
+
116
+ model = model.cuda()
117
+ if args.distributed:
118
+ model = convert_syncbn_model(model)
119
+ ohem = conf.get("ohem_samples", None)
120
+ reduction = "mean"
121
+ if ohem:
122
+ reduction = "none"
123
+ loss_fn = []
124
+ weights = []
125
+ for loss_name, weight in conf["losses"].items():
126
+ loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda())
127
+ weights.append(weight)
128
+ loss = WeightedLosses(loss_fn, weights)
129
+ loss_functions = {"classifier_loss": loss}
130
+ optimizer, scheduler = create_optimizer(conf['optimizer'], model)
131
+ bce_best = 100
132
+ start_epoch = 0
133
+ batch_size = conf['optimizer']['batch_size']
134
+
135
+ data_train = DeepFakeClassifierDataset(mode="train",
136
+ oversample_real=not args.no_oversample,
137
+ fold=args.fold,
138
+ padding_part=args.padding_part,
139
+ hardcore=not args.no_hardcore,
140
+ crops_dir=args.crops_dir,
141
+ data_path=args.data_dir,
142
+ label_smoothing=args.label_smoothing,
143
+ folds_csv=args.folds_csv,
144
+ transforms=create_train_transforms(conf["size"]),
145
+ normalize=conf.get("normalize", None))
146
+ data_val = DeepFakeClassifierDataset(mode="val",
147
+ fold=args.fold,
148
+ padding_part=args.padding_part,
149
+ crops_dir=args.crops_dir,
150
+ data_path=args.data_dir,
151
+ folds_csv=args.folds_csv,
152
+ transforms=create_val_transforms(conf["size"]),
153
+ normalize=conf.get("normalize", None))
154
+ val_data_loader = DataLoader(data_val, batch_size=batch_size * 2, num_workers=args.workers, shuffle=False,
155
+ pin_memory=False)
156
+ os.makedirs(args.logdir, exist_ok=True)
157
+ summary_writer = SummaryWriter(args.logdir + '/' + conf.get("prefix", args.prefix) + conf['encoder'] + "_" + str(args.fold))
158
+ if args.resume:
159
+ if os.path.isfile(args.resume):
160
+ print("=> loading checkpoint '{}'".format(args.resume))
161
+ checkpoint = torch.load(args.resume, map_location='cpu')
162
+ state_dict = checkpoint['state_dict']
163
+ state_dict = {k[7:]: w for k, w in state_dict.items()}
164
+ model.load_state_dict(state_dict, strict=False)
165
+ if not args.from_zero:
166
+ start_epoch = checkpoint['epoch']
167
+ if not args.zero_score:
168
+ bce_best = checkpoint.get('bce_best', 0)
169
+ print("=> loaded checkpoint '{}' (epoch {}, bce_best {})"
170
+ .format(args.resume, checkpoint['epoch'], checkpoint['bce_best']))
171
+ else:
172
+ print("=> no checkpoint found at '{}'".format(args.resume))
173
+ if args.from_zero:
174
+ start_epoch = 0
175
+ current_epoch = start_epoch
176
+
177
+ if conf['fp16']:
178
+ model, optimizer = amp.initialize(model, optimizer,
179
+ opt_level=args.opt_level,
180
+ loss_scale='dynamic')
181
+
182
+ snapshot_name = "{}{}_{}_{}".format(conf.get("prefix", args.prefix), conf['network'], conf['encoder'], args.fold)
183
+
184
+ if args.distributed:
185
+ model = DistributedDataParallel(model, delay_allreduce=True)
186
+ else:
187
+ model = DataParallel(model).cuda()
188
+ data_val.reset(1, args.seed)
189
+ max_epochs = conf['optimizer']['schedule']['epochs']
190
+ for epoch in range(start_epoch, max_epochs):
191
+ data_train.reset(epoch, args.seed)
192
+ train_sampler = None
193
+ if args.distributed:
194
+ train_sampler = torch.utils.data.distributed.DistributedSampler(data_train)
195
+ train_sampler.set_epoch(epoch)
196
+ if epoch < args.freeze_epochs:
197
+ print("Freezing encoder!!!")
198
+ model.module.encoder.eval()
199
+ for p in model.module.encoder.parameters():
200
+ p.requires_grad = False
201
+ else:
202
+ model.module.encoder.train()
203
+ for p in model.module.encoder.parameters():
204
+ p.requires_grad = True
205
+
206
+ train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers,
207
+ shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False,
208
+ drop_last=True)
209
+
210
+ train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
211
+ args.local_rank, args.only_changed_frames)
212
+ model = model.eval()
213
+
214
+ if args.local_rank == 0:
215
+ torch.save({
216
+ 'epoch': current_epoch + 1,
217
+ 'state_dict': model.state_dict(),
218
+ 'bce_best': bce_best,
219
+ }, args.output_dir + '/' + snapshot_name + "_last")
220
+ torch.save({
221
+ 'epoch': current_epoch + 1,
222
+ 'state_dict': model.state_dict(),
223
+ 'bce_best': bce_best,
224
+ }, args.output_dir + snapshot_name + "_{}".format(current_epoch))
225
+ if (epoch + 1) % args.test_every == 0:
226
+ bce_best = evaluate_val(args, val_data_loader, bce_best, model,
227
+ snapshot_name=snapshot_name,
228
+ current_epoch=current_epoch,
229
+ summary_writer=summary_writer)
230
+ current_epoch += 1
231
+
232
+
233
+ def evaluate_val(args, data_val, bce_best, model, snapshot_name, current_epoch, summary_writer):
234
+ print("Test phase")
235
+ model = model.eval()
236
+
237
+ bce, probs, targets = validate(model, data_loader=data_val)
238
+ if args.local_rank == 0:
239
+ summary_writer.add_scalar('val/bce', float(bce), global_step=current_epoch)
240
+ if bce < bce_best:
241
+ print("Epoch {} improved from {} to {}".format(current_epoch, bce_best, bce))
242
+ if args.output_dir is not None:
243
+ torch.save({
244
+ 'epoch': current_epoch + 1,
245
+ 'state_dict': model.state_dict(),
246
+ 'bce_best': bce,
247
+ }, args.output_dir + snapshot_name + "_best_dice")
248
+ bce_best = bce
249
+ with open("predictions_{}.json".format(args.fold), "w") as f:
250
+ json.dump({"probs": probs, "targets": targets}, f)
251
+ torch.save({
252
+ 'epoch': current_epoch + 1,
253
+ 'state_dict': model.state_dict(),
254
+ 'bce_best': bce_best,
255
+ }, args.output_dir + snapshot_name + "_last")
256
+ print("Epoch: {} bce: {}, bce_best: {}".format(current_epoch, bce, bce_best))
257
+ return bce_best
258
+
259
+
260
+ def validate(net, data_loader, prefix=""):
261
+ probs = defaultdict(list)
262
+ targets = defaultdict(list)
263
+
264
+ with torch.no_grad():
265
+ for sample in tqdm(data_loader):
266
+ imgs = sample["image"].cuda()
267
+ img_names = sample["img_name"]
268
+ labels = sample["labels"].cuda().float()
269
+ out = net(imgs)
270
+ labels = labels.cpu().numpy()
271
+ preds = torch.sigmoid(out).cpu().numpy()
272
+ for i in range(out.shape[0]):
273
+ video, img_id = img_names[i].split("/")
274
+ probs[video].append(preds[i].tolist())
275
+ targets[video].append(labels[i].tolist())
276
+ data_x = []
277
+ data_y = []
278
+ for vid, score in probs.items():
279
+ score = np.array(score)
280
+ lbl = targets[vid]
281
+
282
+ score = np.mean(score)
283
+ lbl = np.mean(lbl)
284
+ data_x.append(score)
285
+ data_y.append(lbl)
286
+ y = np.array(data_y)
287
+ x = np.array(data_x)
288
+ fake_idx = y > 0.1
289
+ real_idx = y < 0.1
290
+ fake_loss = log_loss(y[fake_idx], x[fake_idx], labels=[0, 1])
291
+ real_loss = log_loss(y[real_idx], x[real_idx], labels=[0, 1])
292
+ print("{}fake_loss".format(prefix), fake_loss)
293
+ print("{}real_loss".format(prefix), real_loss)
294
+
295
+ return (fake_loss + real_loss) / 2, probs, targets
296
+
297
+
298
+ def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
299
+ local_rank, only_valid):
300
+ losses = AverageMeter()
301
+ fake_losses = AverageMeter()
302
+ real_losses = AverageMeter()
303
+ max_iters = conf["batches_per_epoch"]
304
+ print("training epoch {}".format(current_epoch))
305
+ model.train()
306
+ pbar = tqdm(enumerate(train_data_loader), total=max_iters, desc="Epoch {}".format(current_epoch), ncols=0)
307
+ if conf["optimizer"]["schedule"]["mode"] == "epoch":
308
+ scheduler.step(current_epoch)
309
+ for i, sample in pbar:
310
+ imgs = sample["image"].cuda()
311
+ labels = sample["labels"].cuda().float()
312
+ out_labels = model(imgs)
313
+ if only_valid:
314
+ valid_idx = sample["valid"].cuda().float() > 0
315
+ out_labels = out_labels[valid_idx]
316
+ labels = labels[valid_idx]
317
+ if labels.size(0) == 0:
318
+ continue
319
+
320
+ fake_loss = 0
321
+ real_loss = 0
322
+ fake_idx = labels > 0.5
323
+ real_idx = labels <= 0.5
324
+
325
+ ohem = conf.get("ohem_samples", None)
326
+ if torch.sum(fake_idx * 1) > 0:
327
+ fake_loss = loss_functions["classifier_loss"](out_labels[fake_idx], labels[fake_idx])
328
+ if torch.sum(real_idx * 1) > 0:
329
+ real_loss = loss_functions["classifier_loss"](out_labels[real_idx], labels[real_idx])
330
+ if ohem:
331
+ fake_loss = topk(fake_loss, k=min(ohem, fake_loss.size(0)), sorted=False)[0].mean()
332
+ real_loss = topk(real_loss, k=min(ohem, real_loss.size(0)), sorted=False)[0].mean()
333
+
334
+ loss = (fake_loss + real_loss) / 2
335
+ losses.update(loss.item(), imgs.size(0))
336
+ fake_losses.update(0 if fake_loss == 0 else fake_loss.item(), imgs.size(0))
337
+ real_losses.update(0 if real_loss == 0 else real_loss.item(), imgs.size(0))
338
+
339
+ optimizer.zero_grad()
340
+ pbar.set_postfix({"lr": float(scheduler.get_lr()[-1]), "epoch": current_epoch, "loss": losses.avg,
341
+ "fake_loss": fake_losses.avg, "real_loss": real_losses.avg})
342
+
343
+ if conf['fp16']:
344
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
345
+ scaled_loss.backward()
346
+ else:
347
+ loss.backward()
348
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
349
+ optimizer.step()
350
+ torch.cuda.synchronize()
351
+ if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"):
352
+ scheduler.step(i + current_epoch * max_iters)
353
+ if i == max_iters - 1:
354
+ break
355
+ pbar.close()
356
+ if local_rank == 0:
357
+ for idx, param_group in enumerate(optimizer.param_groups):
358
+ lr = param_group['lr']
359
+ summary_writer.add_scalar('group{}/lr'.format(idx), float(lr), global_step=current_epoch)
360
+ summary_writer.add_scalar('train/loss', float(losses.avg), global_step=current_epoch)
361
+
362
+
363
+ if __name__ == '__main__':
364
+ main()
training/tools/__init__.py ADDED
File without changes
training/tools/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (154 Bytes). View file
 
training/tools/__pycache__/config.cpython-310.pyc ADDED
Binary file (1.06 kB). View file
 
training/tools/__pycache__/schedulers.cpython-310.pyc ADDED
Binary file (3.01 kB). View file
 
training/tools/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.65 kB). View file
 
training/tools/config.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ DEFAULTS = {
4
+ "network": "dpn",
5
+ "encoder": "dpn92",
6
+ "model_params": {},
7
+ "optimizer": {
8
+ "batch_size": 32,
9
+ "type": "SGD", # supported: SGD, Adam
10
+ "momentum": 0.9,
11
+ "weight_decay": 0,
12
+ "clip": 1.,
13
+ "learning_rate": 0.1,
14
+ "classifier_lr": -1,
15
+ "nesterov": True,
16
+ "schedule": {
17
+ "type": "constant", # supported: constant, step, multistep, exponential, linear, poly
18
+ "mode": "epoch", # supported: epoch, step
19
+ "epochs": 10,
20
+ "params": {}
21
+ }
22
+ },
23
+ "normalize": {
24
+ "mean": [0.485, 0.456, 0.406],
25
+ "std": [0.229, 0.224, 0.225]
26
+ }
27
+ }
28
+
29
+
30
+ def _merge(src, dst):
31
+ for k, v in src.items():
32
+ if k in dst:
33
+ if isinstance(v, dict):
34
+ _merge(src[k], dst[k])
35
+ else:
36
+ dst[k] = v
37
+
38
+
39
+ def load_config(config_file, defaults=DEFAULTS):
40
+ with open(config_file, "r") as fd:
41
+ config = json.load(fd)
42
+ _merge(defaults, config)
43
+ return config
training/tools/schedulers.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bisect import bisect_right
2
+
3
+ from torch.optim.lr_scheduler import _LRScheduler
4
+
5
+
6
+ class LRStepScheduler(_LRScheduler):
7
+ def __init__(self, optimizer, steps, last_epoch=-1):
8
+ self.lr_steps = steps
9
+ super().__init__(optimizer, last_epoch)
10
+
11
+ def get_lr(self):
12
+ pos = max(bisect_right([x for x, y in self.lr_steps], self.last_epoch) - 1, 0)
13
+ return [self.lr_steps[pos][1] if self.lr_steps[pos][0] <= self.last_epoch else base_lr for base_lr in self.base_lrs]
14
+
15
+
16
+ class PolyLR(_LRScheduler):
17
+ """Sets the learning rate of each parameter group according to poly learning rate policy
18
+ """
19
+ def __init__(self, optimizer, max_iter=90000, power=0.9, last_epoch=-1):
20
+ self.max_iter = max_iter
21
+ self.power = power
22
+ super(PolyLR, self).__init__(optimizer, last_epoch)
23
+
24
+ def get_lr(self):
25
+ self.last_epoch = (self.last_epoch + 1) % self.max_iter
26
+ return [base_lr * ((1 - float(self.last_epoch) / self.max_iter) ** (self.power)) for base_lr in self.base_lrs]
27
+
28
+ class ExponentialLRScheduler(_LRScheduler):
29
+ """Decays the learning rate of each parameter group by gamma every epoch.
30
+ When last_epoch=-1, sets initial lr as lr.
31
+
32
+ Args:
33
+ optimizer (Optimizer): Wrapped optimizer.
34
+ gamma (float): Multiplicative factor of learning rate decay.
35
+ last_epoch (int): The index of last epoch. Default: -1.
36
+ """
37
+
38
+ def __init__(self, optimizer, gamma, last_epoch=-1):
39
+ self.gamma = gamma
40
+ super(ExponentialLRScheduler, self).__init__(optimizer, last_epoch)
41
+
42
+ def get_lr(self):
43
+ if self.last_epoch <= 0:
44
+ return self.base_lrs
45
+ return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]
46
+
training/tools/utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from apex.optimizers import FusedAdam, FusedSGD
3
+ from timm.optim import AdamW
4
+ from torch import optim
5
+ from torch.optim import lr_scheduler
6
+ from torch.optim.rmsprop import RMSprop
7
+ from torch.optim.adamw import AdamW
8
+ from torch.optim.lr_scheduler import MultiStepLR, CyclicLR
9
+
10
+ from training.tools.schedulers import ExponentialLRScheduler, PolyLR, LRStepScheduler
11
+
12
+ cv2.ocl.setUseOpenCL(False)
13
+ cv2.setNumThreads(0)
14
+
15
+
16
+ class AverageMeter(object):
17
+ """Computes and stores the average and current value"""
18
+
19
+ def __init__(self):
20
+ self.reset()
21
+
22
+ def reset(self):
23
+ self.val = 0
24
+ self.avg = 0
25
+ self.sum = 0
26
+ self.count = 0
27
+
28
+ def update(self, val, n=1):
29
+ self.val = val
30
+ self.sum += val * n
31
+ self.count += n
32
+ self.avg = self.sum / self.count
33
+
34
+ def create_optimizer(optimizer_config, model, master_params=None):
35
+ """Creates optimizer and schedule from configuration
36
+
37
+ Parameters
38
+ ----------
39
+ optimizer_config : dict
40
+ Dictionary containing the configuration options for the optimizer.
41
+ model : Model
42
+ The network model.
43
+
44
+ Returns
45
+ -------
46
+ optimizer : Optimizer
47
+ The optimizer.
48
+ scheduler : LRScheduler
49
+ The learning rate scheduler.
50
+ """
51
+ if optimizer_config.get("classifier_lr", -1) != -1:
52
+ # Separate classifier parameters from all others
53
+ net_params = []
54
+ classifier_params = []
55
+ for k, v in model.named_parameters():
56
+ if not v.requires_grad:
57
+ continue
58
+ if k.find("encoder") != -1:
59
+ net_params.append(v)
60
+ else:
61
+ classifier_params.append(v)
62
+ params = [
63
+ {"params": net_params},
64
+ {"params": classifier_params, "lr": optimizer_config["classifier_lr"]},
65
+ ]
66
+ else:
67
+ if master_params:
68
+ params = master_params
69
+ else:
70
+ params = model.parameters()
71
+
72
+ if optimizer_config["type"] == "SGD":
73
+ optimizer = optim.SGD(params,
74
+ lr=optimizer_config["learning_rate"],
75
+ momentum=optimizer_config["momentum"],
76
+ weight_decay=optimizer_config["weight_decay"],
77
+ nesterov=optimizer_config["nesterov"])
78
+ elif optimizer_config["type"] == "FusedSGD":
79
+ optimizer = FusedSGD(params,
80
+ lr=optimizer_config["learning_rate"],
81
+ momentum=optimizer_config["momentum"],
82
+ weight_decay=optimizer_config["weight_decay"],
83
+ nesterov=optimizer_config["nesterov"])
84
+ elif optimizer_config["type"] == "Adam":
85
+ optimizer = optim.Adam(params,
86
+ lr=optimizer_config["learning_rate"],
87
+ weight_decay=optimizer_config["weight_decay"])
88
+ elif optimizer_config["type"] == "FusedAdam":
89
+ optimizer = FusedAdam(params,
90
+ lr=optimizer_config["learning_rate"],
91
+ weight_decay=optimizer_config["weight_decay"])
92
+ elif optimizer_config["type"] == "AdamW":
93
+ optimizer = AdamW(params,
94
+ lr=optimizer_config["learning_rate"],
95
+ weight_decay=optimizer_config["weight_decay"])
96
+ elif optimizer_config["type"] == "RmsProp":
97
+ optimizer = RMSprop(params,
98
+ lr=optimizer_config["learning_rate"],
99
+ weight_decay=optimizer_config["weight_decay"])
100
+ else:
101
+ raise KeyError("unrecognized optimizer {}".format(optimizer_config["type"]))
102
+
103
+ if optimizer_config["schedule"]["type"] == "step":
104
+ scheduler = LRStepScheduler(optimizer, **optimizer_config["schedule"]["params"])
105
+ elif optimizer_config["schedule"]["type"] == "clr":
106
+ scheduler = CyclicLR(optimizer, **optimizer_config["schedule"]["params"])
107
+ elif optimizer_config["schedule"]["type"] == "multistep":
108
+ scheduler = MultiStepLR(optimizer, **optimizer_config["schedule"]["params"])
109
+ elif optimizer_config["schedule"]["type"] == "exponential":
110
+ scheduler = ExponentialLRScheduler(optimizer, **optimizer_config["schedule"]["params"])
111
+ elif optimizer_config["schedule"]["type"] == "poly":
112
+ scheduler = PolyLR(optimizer, **optimizer_config["schedule"]["params"])
113
+ elif optimizer_config["schedule"]["type"] == "constant":
114
+ scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 1.0)
115
+ elif optimizer_config["schedule"]["type"] == "linear":
116
+ def linear_lr(it):
117
+ return it * optimizer_config["schedule"]["params"]["alpha"] + optimizer_config["schedule"]["params"]["beta"]
118
+
119
+ scheduler = lr_scheduler.LambdaLR(optimizer, linear_lr)
120
+
121
+ return optimizer, scheduler
training/transforms/__init__.py ADDED
File without changes
training/transforms/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (159 Bytes). View file
 
training/transforms/__pycache__/albu.cpython-310.pyc ADDED
Binary file (4.36 kB). View file
 
training/transforms/albu.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from albumentations import DualTransform, ImageOnlyTransform
6
+ from albumentations.augmentations.crops.functional import crop
7
+ #from albumentations.augmentations.functional import crop
8
+
9
+
10
+ def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
11
+ h, w = img.shape[:2]
12
+ if max(w, h) == size:
13
+ return img
14
+ if w > h:
15
+ scale = size / w
16
+ h = h * scale
17
+ w = size
18
+ else:
19
+ scale = size / h
20
+ w = w * scale
21
+ h = size
22
+ interpolation = interpolation_up if scale > 1 else interpolation_down
23
+ resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
24
+ return resized
25
+
26
+
27
+ class IsotropicResize(DualTransform):
28
+ def __init__(self, max_side, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC,
29
+ always_apply=False, p=1):
30
+ super(IsotropicResize, self).__init__(always_apply, p)
31
+ self.max_side = max_side
32
+ self.interpolation_down = interpolation_down
33
+ self.interpolation_up = interpolation_up
34
+
35
+ def apply(self, img, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, **params):
36
+ return isotropically_resize_image(img, size=self.max_side, interpolation_down=interpolation_down,
37
+ interpolation_up=interpolation_up)
38
+
39
+ def apply_to_mask(self, img, **params):
40
+ return self.apply(img, interpolation_down=cv2.INTER_NEAREST, interpolation_up=cv2.INTER_NEAREST, **params)
41
+
42
+ def get_transform_init_args_names(self):
43
+ return ("max_side", "interpolation_down", "interpolation_up")
44
+
45
+
46
+ class Resize4xAndBack(ImageOnlyTransform):
47
+ def __init__(self, always_apply=False, p=0.5):
48
+ super(Resize4xAndBack, self).__init__(always_apply, p)
49
+
50
+ def apply(self, img, **params):
51
+ h, w = img.shape[:2]
52
+ scale = random.choice([2, 4])
53
+ img = cv2.resize(img, (w // scale, h // scale), interpolation=cv2.INTER_AREA)
54
+ img = cv2.resize(img, (w, h),
55
+ interpolation=random.choice([cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_NEAREST]))
56
+ return img
57
+
58
+
59
+ class RandomSizedCropNonEmptyMaskIfExists(DualTransform):
60
+
61
+ def __init__(self, min_max_height, w2h_ratio=[0.7, 1.3], always_apply=False, p=0.5):
62
+ super(RandomSizedCropNonEmptyMaskIfExists, self).__init__(always_apply, p)
63
+
64
+ self.min_max_height = min_max_height
65
+ self.w2h_ratio = w2h_ratio
66
+
67
+ def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params):
68
+ cropped = crop(img, x_min, y_min, x_max, y_max)
69
+ return cropped
70
+
71
+ @property
72
+ def targets_as_params(self):
73
+ return ["mask"]
74
+
75
+ def get_params_dependent_on_targets(self, params):
76
+ mask = params["mask"]
77
+ mask_height, mask_width = mask.shape[:2]
78
+ crop_height = int(mask_height * random.uniform(self.min_max_height[0], self.min_max_height[1]))
79
+ w2h_ratio = random.uniform(*self.w2h_ratio)
80
+ crop_width = min(int(crop_height * w2h_ratio), mask_width - 1)
81
+ if mask.sum() == 0:
82
+ x_min = random.randint(0, mask_width - crop_width + 1)
83
+ y_min = random.randint(0, mask_height - crop_height + 1)
84
+ else:
85
+ mask = mask.sum(axis=-1) if mask.ndim == 3 else mask
86
+ non_zero_yx = np.argwhere(mask)
87
+ y, x = random.choice(non_zero_yx)
88
+ x_min = x - random.randint(0, crop_width - 1)
89
+ y_min = y - random.randint(0, crop_height - 1)
90
+ x_min = np.clip(x_min, 0, mask_width - crop_width)
91
+ y_min = np.clip(y_min, 0, mask_height - crop_height)
92
+
93
+ x_max = x_min + crop_height
94
+ y_max = y_min + crop_width
95
+ y_max = min(mask_height, y_max)
96
+ x_max = min(mask_width, x_max)
97
+ return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}
98
+
99
+ def get_transform_init_args_names(self):
100
+ return "min_max_height", "height", "width", "w2h_ratio"
training/zoo/__init__.py ADDED
File without changes
training/zoo/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (152 Bytes). View file
 
training/zoo/__pycache__/classifiers.cpython-310.pyc ADDED
Binary file (5.55 kB). View file
 
training/zoo/classifiers.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ import torch
5
+ from timm.models.efficientnet import tf_efficientnet_b4_ns, tf_efficientnet_b3_ns, \
6
+ tf_efficientnet_b5_ns, tf_efficientnet_b2_ns, tf_efficientnet_b6_ns, tf_efficientnet_b7_ns
7
+ from torch import nn
8
+ from torch.nn.modules.dropout import Dropout
9
+ from torch.nn.modules.linear import Linear
10
+ from torch.nn.modules.pooling import AdaptiveAvgPool2d
11
+
12
+ encoder_params = {
13
+ "tf_efficientnet_b3_ns": {
14
+ "features": 1536,
15
+ "init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
16
+ },
17
+ "tf_efficientnet_b2_ns": {
18
+ "features": 1408,
19
+ "init_op": partial(tf_efficientnet_b2_ns, pretrained=False, drop_path_rate=0.2)
20
+ },
21
+ "tf_efficientnet_b4_ns": {
22
+ "features": 1792,
23
+ "init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.5)
24
+ },
25
+ "tf_efficientnet_b5_ns": {
26
+ "features": 2048,
27
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
28
+ },
29
+ "tf_efficientnet_b4_ns_03d": {
30
+ "features": 1792,
31
+ "init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.3)
32
+ },
33
+ "tf_efficientnet_b5_ns_03d": {
34
+ "features": 2048,
35
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.3)
36
+ },
37
+ "tf_efficientnet_b5_ns_04d": {
38
+ "features": 2048,
39
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.4)
40
+ },
41
+ "tf_efficientnet_b6_ns": {
42
+ "features": 2304,
43
+ "init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.2)
44
+ },
45
+ "tf_efficientnet_b7_ns": {
46
+ "features": 2560,
47
+ "init_op": partial(tf_efficientnet_b7_ns, pretrained=True, drop_path_rate=0.2)
48
+ },
49
+ "tf_efficientnet_b6_ns_04d": {
50
+ "features": 2304,
51
+ "init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.4)
52
+ },
53
+ }
54
+
55
+
56
+ def setup_srm_weights(input_channels: int = 3) -> torch.Tensor:
57
+ """Creates the SRM kernels for noise analysis."""
58
+ # note: values taken from Zhou et al., "Learning Rich Features for Image Manipulation Detection", CVPR2018
59
+ srm_kernel = torch.from_numpy(np.array([
60
+ [ # srm 1/2 horiz
61
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
62
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
63
+ [0., 1., -2., 1., 0.], # noqa: E241,E201
64
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
65
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
66
+ ], [ # srm 1/4
67
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
68
+ [0., -1., 2., -1., 0.], # noqa: E241,E201
69
+ [0., 2., -4., 2., 0.], # noqa: E241,E201
70
+ [0., -1., 2., -1., 0.], # noqa: E241,E201
71
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
72
+ ], [ # srm 1/12
73
+ [-1., 2., -2., 2., -1.], # noqa: E241,E201
74
+ [2., -6., 8., -6., 2.], # noqa: E241,E201
75
+ [-2., 8., -12., 8., -2.], # noqa: E241,E201
76
+ [2., -6., 8., -6., 2.], # noqa: E241,E201
77
+ [-1., 2., -2., 2., -1.], # noqa: E241,E201
78
+ ]
79
+ ])).float()
80
+ srm_kernel[0] /= 2
81
+ srm_kernel[1] /= 4
82
+ srm_kernel[2] /= 12
83
+ return srm_kernel.view(3, 1, 5, 5).repeat(1, input_channels, 1, 1)
84
+
85
+
86
+ def setup_srm_layer(input_channels: int = 3) -> torch.nn.Module:
87
+ """Creates a SRM convolution layer for noise analysis."""
88
+ weights = setup_srm_weights(input_channels)
89
+ conv = torch.nn.Conv2d(input_channels, out_channels=3, kernel_size=5, stride=1, padding=2, bias=False)
90
+ with torch.no_grad():
91
+ conv.weight = torch.nn.Parameter(weights, requires_grad=False)
92
+ return conv
93
+
94
+
95
+ class DeepFakeClassifierSRM(nn.Module):
96
+ def __init__(self, encoder, dropout_rate=0.5) -> None:
97
+ super().__init__()
98
+ self.encoder = encoder_params[encoder]["init_op"]()
99
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
100
+ self.srm_conv = setup_srm_layer(3)
101
+ self.dropout = Dropout(dropout_rate)
102
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
103
+
104
+ def forward(self, x):
105
+ noise = self.srm_conv(x)
106
+ x = self.encoder.forward_features(noise)
107
+ x = self.avg_pool(x).flatten(1)
108
+ x = self.dropout(x)
109
+ x = self.fc(x)
110
+ return x
111
+
112
+
113
+ class GlobalWeightedAvgPool2d(nn.Module):
114
+ """
115
+ Global Weighted Average Pooling from paper "Global Weighted Average
116
+ Pooling Bridges Pixel-level Localization and Image-level Classification"
117
+ """
118
+
119
+ def __init__(self, features: int, flatten=False):
120
+ super().__init__()
121
+ self.conv = nn.Conv2d(features, 1, kernel_size=1, bias=True)
122
+ self.flatten = flatten
123
+
124
+ def fscore(self, x):
125
+ m = self.conv(x)
126
+ m = m.sigmoid().exp()
127
+ return m
128
+
129
+ def norm(self, x: torch.Tensor):
130
+ return x / x.sum(dim=[2, 3], keepdim=True)
131
+
132
+ def forward(self, x):
133
+ input_x = x
134
+ x = self.fscore(x)
135
+ x = self.norm(x)
136
+ x = x * input_x
137
+ x = x.sum(dim=[2, 3], keepdim=not self.flatten)
138
+ return x
139
+
140
+
141
+ class DeepFakeClassifier(nn.Module):
142
+ def __init__(self, encoder, dropout_rate=0.0) -> None:
143
+ super().__init__()
144
+ self.encoder = encoder_params[encoder]["init_op"]()
145
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
146
+ self.dropout = Dropout(dropout_rate)
147
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
148
+
149
+ def forward(self, x):
150
+ x = self.encoder.forward_features(x)
151
+ x = self.avg_pool(x).flatten(1)
152
+ x = self.dropout(x)
153
+ x = self.fc(x)
154
+ return x
155
+
156
+
157
+
158
+
159
+ class DeepFakeClassifierGWAP(nn.Module):
160
+ def __init__(self, encoder, dropout_rate=0.5) -> None:
161
+ super().__init__()
162
+ self.encoder = encoder_params[encoder]["init_op"]()
163
+ self.avg_pool = GlobalWeightedAvgPool2d(encoder_params[encoder]["features"])
164
+ self.dropout = Dropout(dropout_rate)
165
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
166
+
167
+ def forward(self, x):
168
+ x = self.encoder.forward_features(x)
169
+ x = self.avg_pool(x).flatten(1)
170
+ x = self.dropout(x)
171
+ x = self.fc(x)
172
+ return x
training/zoo/unet.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from timm.models.efficientnet import tf_efficientnet_b3_ns, tf_efficientnet_b5_ns
5
+ from torch import nn
6
+ from torch.nn import Dropout2d, Conv2d
7
+ from torch.nn.modules.dropout import Dropout
8
+ from torch.nn.modules.linear import Linear
9
+ from torch.nn.modules.pooling import AdaptiveAvgPool2d
10
+ from torch.nn.modules.upsampling import UpsamplingBilinear2d
11
+
12
+ encoder_params = {
13
+ "tf_efficientnet_b3_ns": {
14
+ "features": 1536,
15
+ "filters": [40, 32, 48, 136, 1536],
16
+ "decoder_filters": [64, 128, 256, 256],
17
+ "init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
18
+ },
19
+ "tf_efficientnet_b5_ns": {
20
+ "features": 2048,
21
+ "filters": [48, 40, 64, 176, 2048],
22
+ "decoder_filters": [64, 128, 256, 256],
23
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
24
+ },
25
+ }
26
+
27
+
28
+ class DecoderBlock(nn.Module):
29
+ def __init__(self, in_channels, out_channels):
30
+ super().__init__()
31
+ self.layer = nn.Sequential(
32
+ nn.Upsample(scale_factor=2),
33
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
34
+ nn.ReLU(inplace=True)
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.layer(x)
39
+
40
+
41
+ class ConcatBottleneck(nn.Module):
42
+ def __init__(self, in_channels, out_channels):
43
+ super().__init__()
44
+ self.seq = nn.Sequential(
45
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
46
+ nn.ReLU(inplace=True)
47
+ )
48
+
49
+ def forward(self, dec, enc):
50
+ x = torch.cat([dec, enc], dim=1)
51
+ return self.seq(x)
52
+
53
+
54
+ class Decoder(nn.Module):
55
+ def __init__(self, decoder_filters, filters, upsample_filters=None,
56
+ decoder_block=DecoderBlock, bottleneck=ConcatBottleneck, dropout=0):
57
+ super().__init__()
58
+ self.decoder_filters = decoder_filters
59
+ self.filters = filters
60
+ self.decoder_block = decoder_block
61
+ self.decoder_stages = nn.ModuleList([self._get_decoder(idx) for idx in range(0, len(decoder_filters))])
62
+ self.bottlenecks = nn.ModuleList([bottleneck(self.filters[-i - 2] + f, f)
63
+ for i, f in enumerate(reversed(decoder_filters))])
64
+ self.dropout = Dropout2d(dropout) if dropout > 0 else None
65
+ self.last_block = None
66
+ if upsample_filters:
67
+ self.last_block = decoder_block(decoder_filters[0], out_channels=upsample_filters)
68
+ else:
69
+ self.last_block = UpsamplingBilinear2d(scale_factor=2)
70
+
71
+ def forward(self, encoder_results: list):
72
+ x = encoder_results[0]
73
+ bottlenecks = self.bottlenecks
74
+ for idx, bottleneck in enumerate(bottlenecks):
75
+ rev_idx = - (idx + 1)
76
+ x = self.decoder_stages[rev_idx](x)
77
+ x = bottleneck(x, encoder_results[-rev_idx])
78
+ if self.last_block:
79
+ x = self.last_block(x)
80
+ if self.dropout:
81
+ x = self.dropout(x)
82
+ return x
83
+
84
+ def _get_decoder(self, layer):
85
+ idx = layer + 1
86
+ if idx == len(self.decoder_filters):
87
+ in_channels = self.filters[idx]
88
+ else:
89
+ in_channels = self.decoder_filters[idx]
90
+ return self.decoder_block(in_channels, self.decoder_filters[max(layer, 0)])
91
+
92
+
93
+ def _initialize_weights(module):
94
+ for m in module.modules():
95
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
96
+ m.weight.data = nn.init.kaiming_normal_(m.weight.data)
97
+ if m.bias is not None:
98
+ m.bias.data.zero_()
99
+ elif isinstance(m, nn.BatchNorm2d):
100
+ m.weight.data.fill_(1)
101
+ m.bias.data.zero_()
102
+
103
+
104
+ class EfficientUnetClassifier(nn.Module):
105
+ def __init__(self, encoder, dropout_rate=0.5) -> None:
106
+ super().__init__()
107
+ self.decoder = Decoder(decoder_filters=encoder_params[encoder]["decoder_filters"],
108
+ filters=encoder_params[encoder]["filters"])
109
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
110
+ self.dropout = Dropout(dropout_rate)
111
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
112
+ self.final = Conv2d(encoder_params[encoder]["decoder_filters"][0], out_channels=1, kernel_size=1, bias=False)
113
+ _initialize_weights(self)
114
+ self.encoder = encoder_params[encoder]["init_op"]()
115
+
116
+ def get_encoder_features(self, x):
117
+ encoder_results = []
118
+ x = self.encoder.conv_stem(x)
119
+ x = self.encoder.bn1(x)
120
+ x = self.encoder.act1(x)
121
+ encoder_results.append(x)
122
+ x = self.encoder.blocks[:2](x)
123
+ encoder_results.append(x)
124
+ x = self.encoder.blocks[2:3](x)
125
+ encoder_results.append(x)
126
+ x = self.encoder.blocks[3:5](x)
127
+ encoder_results.append(x)
128
+ x = self.encoder.blocks[5:](x)
129
+ x = self.encoder.conv_head(x)
130
+ x = self.encoder.bn2(x)
131
+ x = self.encoder.act2(x)
132
+ encoder_results.append(x)
133
+ encoder_results = list(reversed(encoder_results))
134
+ return encoder_results
135
+
136
+ def forward(self, x):
137
+ encoder_results = self.get_encoder_features(x)
138
+ seg = self.final(self.decoder(encoder_results))
139
+ x = encoder_results[0]
140
+ x = self.avg_pool(x).flatten(1)
141
+ x = self.dropout(x)
142
+ x = self.fc(x)
143
+ return x, seg
144
+
145
+
146
+ if __name__ == '__main__':
147
+ model = EfficientUnetClassifier("tf_efficientnet_b5_ns")
148
+ model.eval()
149
+ with torch.no_grad():
150
+ input = torch.rand(4, 3, 224, 224)
151
+ print(model(input))
weights/.gitkeep ADDED
File without changes
weights/b7_ns_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9db77ab9318863e2f8ab287c8eb83c2232584b82dc2fb41f1d614ddd7900cccb
3
+ size 266910617