yixin1121 commited on
Commit
513e1fb
1 Parent(s): 68ebe2b

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ videos/3249402410.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ videos/4882821564.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ videos/6233408665.mp4 filter=lfs diff=lfs merge=lfs -text
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
Infer.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import os
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ import math
10
+ from tqdm import tqdm
11
+ import argparse
12
+ from collections import OrderedDict
13
+ import json
14
+
15
+ from collections import defaultdict
16
+ from model.deberta_moe import DebertaV2ForMaskedLM
17
+ from transformers import DebertaV2Tokenizer
18
+
19
+ import clip
20
+ import ffmpeg
21
+ from VideoLoader import VideoLoader
22
+
23
+ def get_mask(lengths, max_length):
24
+ """ Computes a batch of padding masks given batched lengths """
25
+ mask = 1 * (
26
+ torch.arange(max_length).unsqueeze(1) < lengths
27
+ ).transpose(0, 1)
28
+ return mask
29
+
30
+ class Infer:
31
+ def __init__(self, device):
32
+ pretrained_ckpt = torch.load("ckpts/model.pth")
33
+ args = pretrained_ckpt['args']
34
+ args.n_ans = 2
35
+ args.max_tokens = 256
36
+ self.args = args
37
+ self.clip_model = clip.load("ViT-L/14", device = device)[0]
38
+ self.tokenizer = DebertaV2Tokenizer.from_pretrained(
39
+ "ckpts/deberta-v2-xlarge", local_files_only=True
40
+ )
41
+
42
+ self.model = DebertaV2ForMaskedLM.from_pretrained(
43
+ features_dim=args.features_dim if args.use_video else 0,
44
+ max_feats=args.max_feats,
45
+ freeze_lm=args.freeze_lm,
46
+ freeze_mlm=args.freeze_mlm,
47
+ ft_ln=args.ft_ln,
48
+ ds_factor_attn=args.ds_factor_attn,
49
+ ds_factor_ff=args.ds_factor_ff,
50
+ dropout=args.dropout,
51
+ n_ans=args.n_ans,
52
+ freeze_last=args.freeze_last,
53
+ pretrained_model_name_or_path="ckpts/deberta-v2-xlarge",
54
+ local_files_only=False,
55
+ add_video_feat=args.add_video_feat,
56
+ freeze_ad=args.freeze_ad,
57
+ )
58
+ new_state_dict = OrderedDict()
59
+ for k, v in pretrained_ckpt['model'].items():
60
+ new_state_dict[k.replace("module.","")] = v
61
+ self.model.load_state_dict(pretrained_ckpt, strict=False)
62
+ self.model.eval()
63
+ self.model.to(device)
64
+ self.device = device
65
+
66
+ self.video_loader = VideoLoader()
67
+ self.set_answer()
68
+
69
+ def _get_clip_feature(self, video):
70
+ feat = self.clip_model.encode_image(video.to(self.device))
71
+ #feat = F.normalize(feat, dim=1)
72
+ return feat
73
+
74
+ def set_answer(self):
75
+ tok_yes = torch.tensor(
76
+ self.tokenizer(
77
+ "Yes",
78
+ add_special_tokens=False,
79
+ max_length=1,
80
+ truncation=True,
81
+ padding="max_length",
82
+ )["input_ids"],
83
+ dtype=torch.long,
84
+ )
85
+ tok_no = torch.tensor(
86
+ self.tokenizer(
87
+ "No",
88
+ add_special_tokens=False,
89
+ max_length=1,
90
+ truncation=True,
91
+ padding="max_length",
92
+ )["input_ids"],
93
+ dtype=torch.long,
94
+ )
95
+
96
+ a2tok = torch.stack([tok_yes, tok_no])
97
+ self.model.set_answer_embeddings(
98
+ a2tok.to(self.model.device), freeze_last=self.args.freeze_last
99
+ )
100
+
101
+ def generate(self, text, video_path, candidates = None):
102
+ video, video_len = self.video_loader(video_path)
103
+ video = self._get_clip_feature(video).unsqueeze(0).float()
104
+ video_mask = get_mask(video_len, 10)
105
+ video_mask = torch.cat([torch.ones((1,1)),video_mask], dim=1)
106
+ logits_list = []
107
+
108
+ question = text.capitalize().strip()
109
+ if question[-1] != "?":
110
+ question = str(question) + "?"
111
+
112
+ for aid in range(len(candidates)):
113
+ prompt = (
114
+ f" Question: {question} Is it '{candidates[aid]}'? {self.tokenizer.mask_token}. Subtitles: "
115
+ )
116
+ prompt = prompt.strip()
117
+ encoded = self.tokenizer(
118
+ prompt,
119
+ add_special_tokens=True,
120
+ max_length=self.args.max_tokens,
121
+ padding="longest",
122
+ truncation=True,
123
+ return_tensors="pt",
124
+ )
125
+ # forward
126
+
127
+ output = self.model(
128
+ video=video.to(self.device),
129
+ video_mask=video_mask.to(self.device),
130
+ input_ids=encoded["input_ids"].to(self.device),
131
+ attention_mask=encoded["attention_mask"].to(self.device),
132
+ )
133
+ # += output['loads'].detach().cpu()
134
+ logits = output["logits"]
135
+ # get logits for the mask token
136
+ delay = 11
137
+ logits = logits[:, delay : encoded["input_ids"].size(1) + delay][
138
+ encoded["input_ids"] == self.tokenizer.mask_token_id
139
+ ]
140
+ logits_list.append(logits.softmax(-1)[:, 0])
141
+
142
+ logits = torch.stack(logits_list, 1)
143
+ if logits.shape[1] == 1:
144
+ preds = logits.round().long().squeeze(1)
145
+ else:
146
+ preds = logits.max(1).indices
147
+
148
+ return candidates[preds]
149
+
README.md CHANGED
@@ -1,13 +1,6 @@
1
  ---
2
- title: T MoENet
3
- emoji: 😻
4
- colorFrom: yellow
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.38.1
8
  app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: T-MoENet
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 3.46.0
6
  ---
 
 
T-MoENet_result.json ADDED
The diff for this file is too large to render. See raw diff
 
VideoLoader.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch as th
3
+ import os
4
+ import numpy as np
5
+ import ffmpeg
6
+
7
+
8
+ class Normalize(object):
9
+ def __init__(self, mean, std):
10
+ self.mean = th.FloatTensor(mean).view(1, 3, 1, 1)
11
+ self.std = th.FloatTensor(std).view(1, 3, 1, 1)
12
+
13
+ def __call__(self, tensor):
14
+ tensor = (tensor - self.mean) / (self.std + 1e-8)
15
+ return tensor
16
+
17
+
18
+ class Preprocessing(object):
19
+ def __init__(self):
20
+ self.norm = Normalize(
21
+ mean=[0.48145466, 0.4578275, 0.40821073],
22
+ std=[0.26862954, 0.26130258, 0.27577711],
23
+ )
24
+
25
+ def __call__(self, tensor):
26
+ tensor = tensor / 255.0
27
+ tensor = self.norm(tensor)
28
+ return tensor
29
+
30
+
31
+ class VideoLoader:
32
+ """Pytorch video loader."""
33
+
34
+ def __init__(
35
+ self,
36
+ framerate=1,
37
+ size=224,
38
+ centercrop=True,
39
+ ):
40
+ self.centercrop = centercrop
41
+ self.size = size
42
+ self.framerate = framerate
43
+ self.preprocess = Preprocessing()
44
+ self.max_feats = 10
45
+ self.features_dim = 768
46
+
47
+ def _get_video_dim(self, video_path):
48
+ probe = ffmpeg.probe(video_path)
49
+ video_stream = next(
50
+ (stream for stream in probe["streams"] if stream["codec_type"] == "video"),
51
+ None,
52
+ )
53
+ width = int(video_stream["width"])
54
+ height = int(video_stream["height"])
55
+ num, denum = video_stream["avg_frame_rate"].split("/")
56
+ frame_rate = int(num) / int(denum)
57
+ return height, width, frame_rate
58
+
59
+ def _get_output_dim(self, h, w):
60
+ if isinstance(self.size, tuple) and len(self.size) == 2:
61
+ return self.size
62
+ elif h >= w:
63
+ return int(h * self.size / w), self.size
64
+ else:
65
+ return self.size, int(w * self.size / h)
66
+
67
+ def _getvideo(self, video_path):
68
+
69
+ if os.path.isfile(video_path):
70
+ print("Decoding video: {}".format(video_path))
71
+ try:
72
+ h, w, fr = self._get_video_dim(video_path)
73
+ except:
74
+ print("ffprobe failed at: {}".format(video_path))
75
+ return {
76
+ "video": th.zeros(1),
77
+ "input": video_path
78
+ }
79
+ if fr < 1:
80
+ print("Corrupted Frame Rate: {}".format(video_path))
81
+ return {
82
+ "video": th.zeros(1),
83
+ "input": video_path
84
+ }
85
+ height, width = self._get_output_dim(h, w)
86
+
87
+ try:
88
+ cmd = (
89
+ ffmpeg.input(video_path)
90
+ .filter("fps", fps=self.framerate)
91
+ .filter("scale", width, height)
92
+ )
93
+ if self.centercrop:
94
+ x = int((width - self.size) / 2.0)
95
+ y = int((height - self.size) / 2.0)
96
+ cmd = cmd.crop(x, y, self.size, self.size)
97
+ out, _ = cmd.output("pipe:", format="rawvideo", pix_fmt="rgb24").run(
98
+ capture_stdout=True, quiet=True
99
+ )
100
+ except:
101
+ print("ffmpeg error at: {}".format(video_path))
102
+ return {
103
+ "video": th.zeros(1),
104
+ "input": video_path,
105
+ }
106
+ if self.centercrop and isinstance(self.size, int):
107
+ height, width = self.size, self.size
108
+ video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3])
109
+ video = th.from_numpy(video.astype("float32"))
110
+ video = video.permute(0, 3, 1, 2) # t,c,h,w
111
+ else:
112
+ video = th.zeros(1)
113
+
114
+ return {"video": video, "input": video_path}
115
+
116
+ def __call__(self, video_path):
117
+
118
+ video = self._getvideo(video_path)['video']
119
+
120
+ if len(video) > self.max_feats:
121
+ sampled = []
122
+ for j in range(self.max_feats):
123
+ sampled.append(video[(j * len(video)) // self.max_feats])
124
+ video = th.stack(sampled)
125
+ video_len = self.max_feats
126
+ elif len(video) < self.max_feats:
127
+ video_len = len(video)
128
+ video = th.cat(
129
+ [video, th.zeros(self.max_feats - video_len, self.features_dim)], 0
130
+ )
131
+ video = self.preprocess(video)
132
+ return video, video_len
133
+
__pycache__/Infer.cpython-38.pyc ADDED
Binary file (3.86 kB). View file
 
__pycache__/VideoLoader.cpython-38.pyc ADDED
Binary file (4.17 kB). View file
 
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import gradio as gr
3
+ import torch
4
+ from fastapi import FastAPI
5
+ import os
6
+ import tempfile
7
+ from Infer import Infer
8
+
9
+ title_markdown = ("""
10
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
11
+ <div>
12
+ <h1 >Temporal-guided Mixture-of-Experts for Zero-Shot Video Question Answering</h1>
13
+ <h5 style="margin: 0;">Under review.</h5>
14
+ </div>
15
+ </div>
16
+
17
+ <div align="center">
18
+ <div style="display:flex; gap: 0.25rem;" align="center">
19
+ <a href='https://github.com/qyx1121/T-MoENet'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
20
+ </div>
21
+ </div>
22
+ """)
23
+
24
+ block_css = """
25
+ #buttons button {
26
+ min-width: min(120px,100%);
27
+ }
28
+ """
29
+
30
+ def save_video_to_local(video_path):
31
+ filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
32
+ shutil.copyfile(video_path, filename)
33
+ return filename
34
+
35
+
36
+ def generate(video, textbox_in, first_run, state, state_):
37
+ flag = 1
38
+ if not textbox_in:
39
+ if len(state_.messages) > 0:
40
+ textbox_in = state_.messages[-1][1]
41
+ state_.messages.pop(-1)
42
+ flag = 0
43
+ else:
44
+ return "Please enter instruction"
45
+ video = video if video else "none"
46
+ # assert not (os.path.exists(image1) and os.path.exists(video))
47
+
48
+ first_run = False if len(state.messages) > 0 else True
49
+
50
+ text_en_in = textbox_in.replace("picture", "image")
51
+
52
+ # images_tensor = [[], []]
53
+ image_processor = handler.image_processor
54
+ if os.path.exists(image1) and not os.path.exists(video):
55
+ tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0]
56
+ # print(tensor.shape)
57
+ tensor = tensor.to(handler.model.device, dtype=dtype)
58
+ images_tensor[0] = images_tensor[0] + [tensor]
59
+ images_tensor[1] = images_tensor[1] + ['image']
60
+ print(torch.cuda.memory_allocated())
61
+ print(torch.cuda.max_memory_allocated())
62
+ video_processor = handler.video_processor
63
+ if not os.path.exists(image1) and os.path.exists(video):
64
+ tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
65
+ # print(tensor.shape)
66
+ tensor = tensor.to(handler.model.device, dtype=dtype)
67
+ images_tensor[0] = images_tensor[0] + [tensor]
68
+ images_tensor[1] = images_tensor[1] + ['video']
69
+ print(torch.cuda.memory_allocated())
70
+ print(torch.cuda.max_memory_allocated())
71
+ if os.path.exists(image1) and os.path.exists(video):
72
+ tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
73
+ # print(tensor.shape)
74
+ tensor = tensor.to(handler.model.device, dtype=dtype)
75
+ images_tensor[0] = images_tensor[0] + [tensor]
76
+ images_tensor[1] = images_tensor[1] + ['video']
77
+
78
+
79
+ tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0]
80
+ # print(tensor.shape)
81
+ tensor = tensor.to(handler.model.device, dtype=dtype)
82
+ images_tensor[0] = images_tensor[0] + [tensor]
83
+ images_tensor[1] = images_tensor[1] + ['image']
84
+ print(torch.cuda.memory_allocated())
85
+ print(torch.cuda.max_memory_allocated())
86
+
87
+
88
+ text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
89
+ state_.messages[-1] = (state_.roles[1], text_en_out)
90
+
91
+ text_en_out = text_en_out.split('#')[0]
92
+ textbox_out = text_en_out
93
+
94
+ show_images = ""
95
+ if flag:
96
+ state.append_message(state.roles[0], textbox_in + "\n" + show_images)
97
+ state.append_message(state.roles[1], textbox_out)
98
+ torch.cuda.empty_cache()
99
+ return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True))
100
+
101
+
102
+ device = "cuda"
103
+ handler = Infer(device)
104
+ # handler.model.to(dtype=dtype)
105
+ if not os.path.exists("temp"):
106
+ os.makedirs("temp")
107
+
108
+ print(torch.cuda.memory_allocated())
109
+ print(torch.cuda.max_memory_allocated())
110
+
111
+ textbox = gr.Textbox(
112
+ show_label=False, placeholder="Enter text and press ENTER", container=False
113
+ )
114
+ with gr.Blocks(title='T-MoENet', theme=gr.themes.Default(), css=block_css) as demo:
115
+ gr.Markdown(title_markdown)
116
+ state = gr.State()
117
+ state_ = gr.State()
118
+ first_run = gr.State()
119
+ images_tensor = gr.State()
120
+
121
+ with gr.Row():
122
+ with gr.Column(scale=3):
123
+ video = gr.Video(label="Input Video")
124
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
125
+ print(cur_dir)
126
+ gr.Examples(
127
+ examples=[
128
+ [
129
+ cur_dir + "/videos/3249402410.mp4",
130
+ "what did the lady in black on the left do after she finished spreading the sauce on her pizza?",
131
+ ],
132
+ [
133
+ cur_dir + "/videos/4882821564.mp4",
134
+ "why did the boy clap his hands when he ran to the christmas tree?",
135
+ ],
136
+ [
137
+ cur_dir + "/videos/6233408665.mp4",
138
+ "what did the people on the sofa do after the lady in pink finished singing?",
139
+ ],
140
+ ],
141
+ inputs=[video, textbox],
142
+ )
143
+
144
+ with gr.Column(scale=7):
145
+ chatbot = gr.Chatbot(label="T-MoENet", bubble_full_width=True)
146
+ with gr.Row():
147
+ with gr.Column(scale=2):
148
+ textbox.render()
149
+ with gr.Column(scale=1, min_width=50):
150
+ submit_btn = gr.Button(
151
+ value="Send", variant="primary", interactive=True
152
+ )
153
+
154
+ submit_btn.click(generate, [video, textbox, first_run, state, state_],
155
+ [state, state_, chatbot, first_run, textbox, video])
156
+
157
+ demo.launch(share=True)
ckpts/deberta-v2-xlarge/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "deberta-v2",
3
+ "attention_probs_dropout_prob": 0.1,
4
+ "hidden_act": "gelu",
5
+ "hidden_dropout_prob": 0.1,
6
+ "hidden_size": 1536,
7
+ "initializer_range": 0.02,
8
+ "intermediate_size": 6144,
9
+ "max_position_embeddings": 512,
10
+ "relative_attention": true,
11
+ "position_buckets": 256,
12
+ "norm_rel_ebd": "layer_norm",
13
+ "share_att_key": true,
14
+ "pos_att_type": "p2c|c2p",
15
+ "layer_norm_eps": 1e-7,
16
+ "conv_kernel_size": 3,
17
+ "conv_act": "gelu",
18
+ "max_relative_positions": -1,
19
+ "position_biased_input": false,
20
+ "num_attention_heads": 24,
21
+ "attention_head_size": 64,
22
+ "num_hidden_layers": 24,
23
+ "type_vocab_size": 0,
24
+ "vocab_size": 128100
25
+ }
ckpts/deberta-v2-xlarge/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7088de0d6925bbd824e5dfd33db6ca5145231b8fd9f702363f18275f14d50ab9
3
+ size 1775809831
ckpts/deberta-v2-xlarge/spm.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5598d5e96f339a8d980c15f9afd405a2e5e1be7db41de3ed13b0f03fac1e8c17
3
+ size 2447305
ckpts/deberta-v2-xlarge/tokenizer_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "do_lower_case": false,
3
+ "vocab_type": "spm"
4
+ }
ckpts/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0220bc2e00d2f07d89746628f10ed0deb069618d4702f43ee6615f4c9f3a406a
3
+ size 499452921
model/__pycache__/adapter.cpython-38.pyc ADDED
Binary file (2.52 kB). View file
 
model/__pycache__/deberta_moe.cpython-38.pyc ADDED
Binary file (41.1 kB). View file
 
model/__pycache__/evl.cpython-38.pyc ADDED
Binary file (10.3 kB). View file
 
model/__pycache__/moe.cpython-38.pyc ADDED
Binary file (15.1 kB). View file
 
model/adapter.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import math
4
+
5
+ class Adapter(nn.Module):
6
+ def __init__(
7
+ self, ds_factor, hidden_dim, ln_after=False, ln_before=False, dropout=0.1
8
+ ):
9
+ super().__init__()
10
+ assert not hidden_dim % ds_factor
11
+ self.down = nn.Linear(hidden_dim, hidden_dim // ds_factor)
12
+ self.act = nn.ReLU()
13
+ self.up = nn.Linear(hidden_dim // ds_factor, hidden_dim)
14
+ self.apply(self.init_weights)
15
+ self.ln_after = ln_after
16
+ self.ln_before = ln_before
17
+ self.dropout = dropout
18
+ if ln_after or ln_before:
19
+ self.ln = nn.LayerNorm(hidden_dim)
20
+ if dropout:
21
+ self.dropout = nn.Dropout(dropout)
22
+
23
+ def init_weights(self, m: nn.Module, std=1e-3):
24
+ if isinstance(m, nn.Linear):
25
+ torch.nn.init.normal_(m.weight, std=std)
26
+ torch.nn.init.normal_(m.bias, std=std)
27
+ m.weight.data = torch.clamp(m.weight.data, min=-2 * std, max=2 * std)
28
+ m.bias.data = torch.clamp(m.bias.data, min=-2 * std, max=2 * std)
29
+ elif isinstance(m, nn.LayerNorm):
30
+ m.bias.data.zero_()
31
+ m.weight.data.fill_(1.0)
32
+
33
+ def forward(self, hidden_states):
34
+ if self.ln_before:
35
+ residual = self.ln(hidden_states)
36
+ residual = self.down(residual)
37
+ else:
38
+ residual = self.down(hidden_states)
39
+ residual = self.act(residual)
40
+ if self.dropout:
41
+ residual = self.dropout(residual)
42
+ residual = self.up(residual)
43
+ if self.ln_after:
44
+ residual = self.ln(hidden_states)
45
+ return hidden_states + residual
46
+
47
+
48
+ class ST_Adapter(nn.Module):
49
+
50
+ def __init__(self, ds_factor, hidden_dim):
51
+ super().__init__()
52
+
53
+ self.down = nn.Linear(hidden_dim, hidden_dim // ds_factor)
54
+ self.conv = nn.Conv1d(
55
+ hidden_dim // ds_factor, hidden_dim // ds_factor,
56
+ kernel_size=3,
57
+ stride=1,
58
+ padding=1,
59
+ groups=hidden_dim // ds_factor
60
+ )
61
+ self.up = nn.Linear(hidden_dim // ds_factor, hidden_dim)
62
+ nn.init.constant_(self.conv.weight, 0.)
63
+ nn.init.constant_(self.conv.bias, 0.)
64
+ nn.init.constant_(self.down.bias, 0.)
65
+ nn.init.constant_(self.up.bias, 0.)
66
+
67
+ def forward(self, x):
68
+ N, T, C = x.size()
69
+ ori_x = x
70
+ x = self.down(x)
71
+ x = x.permute(0, 2, 1).contiguous()
72
+ x = self.conv(x)
73
+ x = x.permute(0, 2, 1).contiguous()
74
+ x = self.up(x)
75
+ x = x + ori_x
76
+ return x
77
+
model/deberta_moe.py ADDED
@@ -0,0 +1,1735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 Microsoft and the Hugging Face Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch DeBERTa-v2 model. """
16
+
17
+ import math
18
+ from collections.abc import Sequence
19
+ from typing import Tuple, Optional
20
+
21
+ import clip
22
+ import numpy as np
23
+ import torch
24
+ from torch import _softmax_backward_data, nn
25
+ from torch.nn import CrossEntropyLoss, LayerNorm
26
+
27
+ from .adapter import Adapter
28
+ from .moe import MoE
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import ModelOutput
31
+
32
+ from transformers.modeling_utils import PreTrainedModel
33
+ from transformers import DebertaV2Config, DebertaV2ForSequenceClassification
34
+ from .evl import EVLTransformer, recursive_gumbel_softmax
35
+
36
+ from transformers import pytorch_utils
37
+
38
+ _CONFIG_FOR_DOC = "DebertaV2Config"
39
+ _TOKENIZER_FOR_DOC = "DebertaV2Tokenizer"
40
+ _CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge"
41
+
42
+ DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [
43
+ "microsoft/deberta-v2-xlarge",
44
+ "microsoft/deberta-v2-xxlarge",
45
+ "microsoft/deberta-v2-xlarge-mnli",
46
+ "microsoft/deberta-v2-xxlarge-mnli",
47
+ ]
48
+
49
+ class MaskedLMOutput(ModelOutput):
50
+ """
51
+ Base class for masked language models outputs.
52
+
53
+ Args:
54
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
55
+ Masked language modeling (MLM) loss.
56
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
57
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
58
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
59
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
60
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
61
+
62
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
63
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
64
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
65
+ sequence_length)`.
66
+
67
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
68
+ heads.
69
+ """
70
+
71
+ loss: Optional[torch.FloatTensor] = None
72
+ logits: torch.FloatTensor = None
73
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
74
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
75
+ loss_moe: Optional[torch.FloatTensor] = None
76
+ loads: Optional[torch.FloatTensor] = None
77
+ embeddings: Optional[torch.FloatTensor] = None
78
+
79
+
80
+ class BaseModelOutput(ModelOutput):
81
+ """
82
+ Base class for model's outputs, with potential hidden states and attentions.
83
+ Args:
84
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
85
+ Sequence of hidden-states at the output of the last layer of the model.
86
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
87
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
88
+ shape `(batch_size, sequence_length, hidden_size)`.
89
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
90
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
91
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
92
+ sequence_length)`.
93
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
94
+ heads.
95
+ """
96
+
97
+ last_hidden_state: torch.FloatTensor = None
98
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
99
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
100
+ position_embeddings: torch.FloatTensor = None
101
+ attention_mask: torch.BoolTensor = None
102
+ loss_moe: torch.FloatTensor = None
103
+ video_g: torch.FloatTensor = None
104
+ loads: torch.LongTensor = None
105
+ embeddings: torch.FloatTensor = None
106
+
107
+
108
+ # Copied from transformers.models.deberta.modeling_deberta.ContextPooler
109
+ class ContextPooler(nn.Module):
110
+ def __init__(self, config):
111
+ super().__init__()
112
+ self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
113
+ self.dropout = StableDropout(config.pooler_dropout)
114
+ self.config = config
115
+
116
+ def forward(self, hidden_states):
117
+ # We "pool" the model by simply taking the hidden state corresponding
118
+ # to the first token.
119
+
120
+ context_token = hidden_states[:, 0]
121
+ context_token = self.dropout(context_token)
122
+ pooled_output = self.dense(context_token)
123
+ pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
124
+ return pooled_output
125
+
126
+ @property
127
+ def output_dim(self):
128
+ return self.config.hidden_size
129
+
130
+
131
+ # Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2
132
+ class XSoftmax(torch.autograd.Function):
133
+ """
134
+ Masked Softmax which is optimized for saving memory
135
+
136
+ Args:
137
+ input (:obj:`torch.tensor`): The input tensor that will apply softmax.
138
+ mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation.
139
+ dim (int): The dimension that will apply softmax
140
+
141
+ Example::
142
+
143
+ import torch
144
+ from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax
145
+
146
+ # Make a tensor
147
+ x = torch.randn([4,20,100])
148
+
149
+ # Create a mask
150
+ mask = (x>0).int()
151
+
152
+ y = XSoftmax.apply(x, mask, dim=-1)
153
+ """
154
+
155
+ @staticmethod
156
+ def forward(self, input, mask, dim):
157
+ self.dim = dim
158
+ rmask = ~(mask.bool())
159
+
160
+ output = input.masked_fill(rmask, float("-inf"))
161
+ output = torch.softmax(output, self.dim)
162
+ output.masked_fill_(rmask, 0)
163
+ self.save_for_backward(output)
164
+ return output
165
+
166
+ @staticmethod
167
+ def backward(self, grad_output):
168
+ (output,) = self.saved_tensors
169
+ inputGrad = _softmax_backward_data(grad_output, output, self.dim, output.dtype)
170
+ return inputGrad, None, None
171
+
172
+
173
+ # Copied from transformers.models.deberta.modeling_deberta.DropoutContext
174
+ class DropoutContext(object):
175
+ def __init__(self):
176
+ self.dropout = 0
177
+ self.mask = None
178
+ self.scale = 1
179
+ self.reuse_mask = True
180
+
181
+
182
+ # Copied from transformers.models.deberta.modeling_deberta.get_mask
183
+ def get_mask(input, local_context):
184
+ if not isinstance(local_context, DropoutContext):
185
+ dropout = local_context
186
+ mask = None
187
+ else:
188
+ dropout = local_context.dropout
189
+ dropout *= local_context.scale
190
+ mask = local_context.mask if local_context.reuse_mask else None
191
+
192
+ if dropout > 0 and mask is None:
193
+ mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
194
+
195
+ if isinstance(local_context, DropoutContext):
196
+ if local_context.mask is None:
197
+ local_context.mask = mask
198
+
199
+ return mask, dropout
200
+
201
+
202
+ # Copied from transformers.models.deberta.modeling_deberta.XDropout
203
+ class XDropout(torch.autograd.Function):
204
+ """Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
205
+
206
+ @staticmethod
207
+ def forward(ctx, input, local_ctx):
208
+ mask, dropout = get_mask(input, local_ctx)
209
+ ctx.scale = 1.0 / (1 - dropout)
210
+ if dropout > 0:
211
+ ctx.save_for_backward(mask)
212
+ return input.masked_fill(mask, 0) * ctx.scale
213
+ else:
214
+ return input
215
+
216
+ @staticmethod
217
+ def backward(ctx, grad_output):
218
+ if ctx.scale > 1:
219
+ (mask,) = ctx.saved_tensors
220
+ return grad_output.masked_fill(mask, 0) * ctx.scale, None
221
+ else:
222
+ return grad_output, None
223
+
224
+
225
+ # Copied from transformers.models.deberta.modeling_deberta.StableDropout
226
+ class StableDropout(nn.Module):
227
+ """
228
+ Optimized dropout module for stabilizing the training
229
+
230
+ Args:
231
+ drop_prob (float): the dropout probabilities
232
+ """
233
+
234
+ def __init__(self, drop_prob):
235
+ super().__init__()
236
+ self.drop_prob = drop_prob
237
+ self.count = 0
238
+ self.context_stack = None
239
+
240
+ def forward(self, x):
241
+ """
242
+ Call the module
243
+
244
+ Args:
245
+ x (:obj:`torch.tensor`): The input tensor to apply dropout
246
+ """
247
+ if self.training and self.drop_prob > 0:
248
+ return XDropout.apply(x, self.get_context())
249
+ return x
250
+
251
+ def clear_context(self):
252
+ self.count = 0
253
+ self.context_stack = None
254
+
255
+ def init_context(self, reuse_mask=True, scale=1):
256
+ if self.context_stack is None:
257
+ self.context_stack = []
258
+ self.count = 0
259
+ for c in self.context_stack:
260
+ c.reuse_mask = reuse_mask
261
+ c.scale = scale
262
+
263
+ def get_context(self):
264
+ if self.context_stack is not None:
265
+ if self.count >= len(self.context_stack):
266
+ self.context_stack.append(DropoutContext())
267
+ ctx = self.context_stack[self.count]
268
+ ctx.dropout = self.drop_prob
269
+ self.count += 1
270
+ return ctx
271
+ else:
272
+ return self.drop_prob
273
+
274
+
275
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm
276
+ class DebertaV2SelfOutput(nn.Module):
277
+ def __init__(self, config, ds_factor, dropout, add_moe, gating):
278
+ super().__init__()
279
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
280
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
281
+ self.dropout = StableDropout(config.hidden_dropout_prob)
282
+ self.add_moe = add_moe
283
+ if not self.add_moe and ds_factor:
284
+ self.adapter = Adapter(ds_factor, config.hidden_size, dropout=dropout)
285
+ else:
286
+ self.moe_layer = MoE(ds_factor = ds_factor, moe_input_size=config.hidden_size, dropout=dropout, num_experts=4, top_k=2, gating=gating)
287
+
288
+ def forward(self, hidden_states, input_tensor, temporal_factor = None, train_mode = True):
289
+ hidden_states = self.dense(hidden_states)
290
+ if not self.add_moe:
291
+ hidden_states = self.adapter(hidden_states)
292
+ else:
293
+ hidden_states, loss_moe, load = self.moe_layer(temporal_factor, hidden_states, train=train_mode)
294
+ hidden_states = self.dropout(hidden_states)
295
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
296
+
297
+ if not self.add_moe:
298
+ return hidden_states, None, None
299
+
300
+ return hidden_states, loss_moe, load
301
+
302
+
303
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2
304
+ class DebertaV2Attention(nn.Module):
305
+ def __init__(self, config, ds_factor, dropout, add_moe = False, gating='linear'):
306
+ super().__init__()
307
+ self.self = DisentangledSelfAttention(config)
308
+ self.output = DebertaV2SelfOutput(config, ds_factor, dropout, add_moe, gating)
309
+ self.config = config
310
+
311
+ def forward(
312
+ self,
313
+ hidden_states,
314
+ attention_mask,
315
+ return_att=False,
316
+ query_states=None,
317
+ relative_pos=None,
318
+ rel_embeddings=None,
319
+ temporal_factor=None,
320
+ train_mode=True
321
+ ):
322
+ self_output = self.self(
323
+ hidden_states,
324
+ attention_mask,
325
+ return_att,
326
+ query_states=query_states,
327
+ relative_pos=relative_pos,
328
+ rel_embeddings=rel_embeddings,
329
+ )
330
+ if return_att:
331
+ self_output, att_matrix = self_output
332
+ if query_states is None:
333
+ query_states = hidden_states
334
+ attention_output, loss_moe, load = self.output(self_output, query_states, temporal_factor, train_mode)
335
+
336
+ if return_att:
337
+ return (attention_output, att_matrix, loss_moe)
338
+ else:
339
+ return attention_output, loss_moe, load
340
+
341
+
342
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2
343
+ class DebertaV2Intermediate(nn.Module):
344
+ def __init__(self, config):
345
+ super().__init__()
346
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
347
+ if isinstance(config.hidden_act, str):
348
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
349
+ else:
350
+ self.intermediate_act_fn = config.hidden_act
351
+
352
+ def forward(self, hidden_states):
353
+ hidden_states = self.dense(hidden_states)
354
+ hidden_states = self.intermediate_act_fn(hidden_states)
355
+ return hidden_states
356
+
357
+
358
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm
359
+ class DebertaV2Output(nn.Module):
360
+ def __init__(self, config, ds_factor, dropout, add_moe = False, gating='linear',layer_id=0):
361
+ super().__init__()
362
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
363
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
364
+ self.dropout = StableDropout(config.hidden_dropout_prob)
365
+ self.config = config
366
+ self.ds_factor = ds_factor
367
+ self.add_moe = add_moe
368
+ if not self.add_moe and self.ds_factor:
369
+ self.adapter = Adapter(ds_factor, config.hidden_size, dropout=dropout)
370
+ elif self.add_moe:
371
+ self.moe_layer = MoE(ds_factor=ds_factor, moe_input_size=config.hidden_size, dropout=dropout, num_experts=4, top_k=1, gating=gating, layer_id=layer_id)
372
+ #self.adapter = Adapter(ds_factor, config.hidden_size, dropout=dropout)
373
+
374
+ def forward(self, hidden_states, input_tensor, temporal_factor, train_mode):
375
+ hidden_states = self.dense(hidden_states)
376
+ if not self.add_moe and self.ds_factor:
377
+ hidden_states = self.adapter(hidden_states)
378
+ elif self.add_moe:
379
+ hidden_states, loss_moe, load = self.moe_layer(temporal_factor, hidden_states, train=train_mode)
380
+ hidden_states = self.dropout(hidden_states)
381
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
382
+
383
+ if not self.add_moe:
384
+ return hidden_states, None, None
385
+
386
+ return hidden_states, loss_moe, load
387
+
388
+
389
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
390
+ class DebertaV2Layer(nn.Module):
391
+ def __init__(
392
+ self,
393
+ config,
394
+ ds_factor_attn,
395
+ ds_factor_ff,
396
+ dropout,
397
+ layer_id,
398
+ ):
399
+ super().__init__()
400
+ self.layer_id = layer_id
401
+ self.add_moe = False
402
+
403
+ #if layer_id >= config.num_hidden_layers - 2:
404
+ # self.add_moe = True
405
+
406
+ if layer_id < 2:
407
+ self.add_moe = True
408
+
409
+ self.attention = DebertaV2Attention(config, ds_factor_attn, dropout, False)
410
+ self.intermediate = DebertaV2Intermediate(config)
411
+ self.output = DebertaV2Output(config, ds_factor_ff, dropout, self.add_moe, gating="linear", layer_id = layer_id)
412
+
413
+ def forward(
414
+ self,
415
+ temporal_factor,
416
+ hidden_states,
417
+ attention_mask,
418
+ return_att=False,
419
+ query_states=None,
420
+ relative_pos=None,
421
+ rel_embeddings=None,
422
+ train_mode=True,
423
+ ):
424
+ attention_output = self.attention(
425
+ hidden_states,
426
+ attention_mask,
427
+ return_att=return_att,
428
+ query_states=query_states,
429
+ relative_pos=relative_pos,
430
+ rel_embeddings=rel_embeddings,
431
+ temporal_factor=temporal_factor,
432
+ train_mode=train_mode
433
+ )
434
+
435
+ if return_att:
436
+ attention_output, att_matrix, loss_moe_attn = attention_output
437
+ else:
438
+ attention_output, loss_moe_attn, load = attention_output
439
+ intermediate_output = self.intermediate(attention_output)
440
+ layer_output, loss_moe_ffn, load = self.output(intermediate_output, attention_output, temporal_factor=temporal_factor, train_mode=train_mode)
441
+
442
+ loss_moe = loss_moe_attn if loss_moe_attn else loss_moe_ffn
443
+ if return_att:
444
+ return (layer_output, att_matrix)
445
+
446
+
447
+ return layer_output, loss_moe, load
448
+
449
+
450
+
451
+ class ConvLayer(nn.Module):
452
+ def __init__(self, config):
453
+ super().__init__()
454
+ kernel_size = getattr(config, "conv_kernel_size", 3)
455
+ groups = getattr(config, "conv_groups", 1)
456
+ self.conv_act = getattr(config, "conv_act", "tanh")
457
+ self.conv = nn.Conv1d(
458
+ config.hidden_size,
459
+ config.hidden_size,
460
+ kernel_size,
461
+ padding=(kernel_size - 1) // 2,
462
+ groups=groups,
463
+ )
464
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
465
+ self.dropout = StableDropout(config.hidden_dropout_prob)
466
+ self.config = config
467
+
468
+ def forward(self, hidden_states, residual_states, input_mask):
469
+ out = (
470
+ self.conv(hidden_states.permute(0, 2, 1).contiguous())
471
+ .permute(0, 2, 1)
472
+ .contiguous()
473
+ )
474
+ rmask = (1 - input_mask).bool()
475
+ out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
476
+ out = ACT2FN[self.conv_act](self.dropout(out))
477
+
478
+ layer_norm_input = residual_states + out
479
+ output = self.LayerNorm(layer_norm_input).to(layer_norm_input)
480
+
481
+ if input_mask is None:
482
+ output_states = output
483
+ else:
484
+ if input_mask.dim() != layer_norm_input.dim():
485
+ if input_mask.dim() == 4:
486
+ input_mask = input_mask.squeeze(1).squeeze(1)
487
+ input_mask = input_mask.unsqueeze(2)
488
+
489
+ input_mask = input_mask.to(output.dtype)
490
+ output_states = output * input_mask
491
+
492
+ return output_states
493
+
494
+
495
+ class DebertaV2Encoder(nn.Module):
496
+ """Modified BertEncoder with relative position bias support"""
497
+
498
+ def __init__(
499
+ self,
500
+ config,
501
+ ds_factor_attn,
502
+ ds_factor_ff,
503
+ dropout,
504
+ ):
505
+ super().__init__()
506
+
507
+ self.layer = nn.ModuleList(
508
+ [
509
+ DebertaV2Layer(
510
+ config,
511
+ ds_factor_attn,
512
+ ds_factor_ff,
513
+ dropout,
514
+ _,
515
+ )
516
+ for _ in range(config.num_hidden_layers)
517
+ ]
518
+ )
519
+
520
+ self.relative_attention = getattr(config, "relative_attention", False)
521
+
522
+ if self.relative_attention:
523
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
524
+ if self.max_relative_positions < 1:
525
+ self.max_relative_positions = config.max_position_embeddings
526
+
527
+ self.position_buckets = getattr(config, "position_buckets", -1)
528
+ pos_ebd_size = self.max_relative_positions * 2
529
+
530
+ if self.position_buckets > 0:
531
+ pos_ebd_size = self.position_buckets * 2
532
+
533
+ self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size)
534
+
535
+ self.norm_rel_ebd = [
536
+ x.strip()
537
+ for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")
538
+ ]
539
+
540
+ if "layer_norm" in self.norm_rel_ebd:
541
+ self.LayerNorm = LayerNorm(
542
+ config.hidden_size, config.layer_norm_eps, elementwise_affine=True
543
+ )
544
+
545
+ self.conv = (
546
+ ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None
547
+ )
548
+
549
+ def get_rel_embedding(self):
550
+ rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
551
+ if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd):
552
+ rel_embeddings = self.LayerNorm(rel_embeddings)
553
+ return rel_embeddings
554
+
555
+ def get_attention_mask(self, attention_mask):
556
+ if attention_mask.dim() <= 2:
557
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
558
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(
559
+ -2
560
+ ).unsqueeze(-1)
561
+ attention_mask = attention_mask.byte()
562
+ elif attention_mask.dim() == 3:
563
+ attention_mask = attention_mask.unsqueeze(1)
564
+
565
+ return attention_mask
566
+
567
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
568
+ if self.relative_attention and relative_pos is None:
569
+ q = (
570
+ query_states.size(-2)
571
+ if query_states is not None
572
+ else hidden_states.size(-2)
573
+ )
574
+ relative_pos = build_relative_position(
575
+ q,
576
+ hidden_states.size(-2),
577
+ bucket_size=self.position_buckets,
578
+ max_position=self.max_relative_positions,
579
+ )
580
+ return relative_pos
581
+
582
+ def forward(
583
+ self,
584
+ temporal_factor,
585
+ hidden_states,
586
+ attention_mask,
587
+ output_hidden_states=True,
588
+ output_attentions=False,
589
+ query_states=None,
590
+ relative_pos=None,
591
+ return_dict=True,
592
+ train_mode=True
593
+ ):
594
+ if attention_mask.dim() <= 2:
595
+ input_mask = attention_mask
596
+ else:
597
+ input_mask = (attention_mask.sum(-2) > 0).byte()
598
+ attention_mask = self.get_attention_mask(attention_mask)
599
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
600
+
601
+ all_hidden_states = () if output_hidden_states else None
602
+ all_attentions = () if output_attentions else None
603
+
604
+ if isinstance(hidden_states, Sequence):
605
+ next_kv = hidden_states[0]
606
+ else:
607
+ next_kv = hidden_states
608
+ rel_embeddings = self.get_rel_embedding()
609
+ output_states = next_kv
610
+
611
+ loss_moe = 0
612
+ loads = []
613
+ embeddings = []
614
+
615
+ for i, layer_module in enumerate(self.layer):
616
+
617
+ if output_hidden_states:
618
+ all_hidden_states = all_hidden_states + (output_states,)
619
+
620
+ output_states, _, load = layer_module(
621
+ temporal_factor,
622
+ next_kv,
623
+ attention_mask,
624
+ output_attentions,
625
+ query_states=query_states,
626
+ relative_pos=relative_pos,
627
+ rel_embeddings=rel_embeddings,
628
+ train_mode=train_mode
629
+ )
630
+ if isinstance(load, torch.Tensor):
631
+ loads.append(load)
632
+
633
+ if _:
634
+ loss_moe = loss_moe + _
635
+
636
+ if output_attentions:
637
+ output_states, att_m = output_states
638
+
639
+ if i == 0 and self.conv is not None:
640
+ output_states = self.conv(hidden_states, output_states, input_mask)
641
+
642
+ if query_states is not None:
643
+ query_states = output_states
644
+ if isinstance(hidden_states, Sequence):
645
+ next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
646
+ else:
647
+ next_kv = output_states
648
+
649
+ if output_attentions:
650
+ all_attentions = all_attentions + (att_m,)
651
+
652
+ if output_hidden_states:
653
+ all_hidden_states = all_hidden_states + (output_states,)
654
+
655
+ if not return_dict:
656
+ return tuple(
657
+ v
658
+ for v in [output_states, all_hidden_states, all_attentions]
659
+ if v is not None
660
+ )
661
+
662
+ if len(loads)>0:
663
+ loads = torch.stack(loads, dim = 0)
664
+
665
+ if len(embeddings) >0:
666
+ embeddings = torch.cat(embeddings, dim=0)
667
+
668
+ return BaseModelOutput(
669
+ last_hidden_state=output_states,
670
+ hidden_states=all_hidden_states,
671
+ attentions=all_attentions,
672
+ loss_moe=loss_moe,
673
+ loads=loads
674
+ )
675
+
676
+
677
+ def make_log_bucket_position(relative_pos, bucket_size, max_position):
678
+ sign = np.sign(relative_pos)
679
+ mid = bucket_size // 2
680
+ abs_pos = np.where(
681
+ (relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)
682
+ )
683
+ log_pos = (
684
+ np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1))
685
+ + mid
686
+ )
687
+ bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int)
688
+ return bucket_pos
689
+
690
+
691
+ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1):
692
+ """
693
+ Build relative position according to the query and key
694
+
695
+ We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key
696
+ :math:`P_k` is range from (0, key_size), The relative positions from query to key is :math:`R_{q \\rightarrow k} =
697
+ P_q - P_k`
698
+
699
+ Args:
700
+ query_size (int): the length of query
701
+ key_size (int): the length of key
702
+ bucket_size (int): the size of position bucket
703
+ max_position (int): the maximum allowed absolute position
704
+
705
+ Return:
706
+ :obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
707
+
708
+ """
709
+ q_ids = np.arange(0, query_size)
710
+ k_ids = np.arange(0, key_size)
711
+ rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1))
712
+ if bucket_size > 0 and max_position > 0:
713
+ rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
714
+ rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
715
+ rel_pos_ids = rel_pos_ids[:query_size, :]
716
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
717
+ return rel_pos_ids
718
+
719
+
720
+ @torch.jit.script
721
+ # Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand
722
+ def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
723
+ return c2p_pos.expand(
724
+ [
725
+ query_layer.size(0),
726
+ query_layer.size(1),
727
+ query_layer.size(2),
728
+ relative_pos.size(-1),
729
+ ]
730
+ )
731
+
732
+
733
+ @torch.jit.script
734
+ # Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand
735
+ def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
736
+ return c2p_pos.expand(
737
+ [
738
+ query_layer.size(0),
739
+ query_layer.size(1),
740
+ key_layer.size(-2),
741
+ key_layer.size(-2),
742
+ ]
743
+ )
744
+
745
+
746
+ @torch.jit.script
747
+ # Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand
748
+ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
749
+ return pos_index.expand(
750
+ p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))
751
+ )
752
+
753
+
754
+ class DisentangledSelfAttention(nn.Module):
755
+ """
756
+ Disentangled self-attention module
757
+
758
+ Parameters:
759
+ config (:obj:`DebertaV2Config`):
760
+ A model config class instance with the configuration to build a new model. The schema is similar to
761
+ `BertConfig`, for more details, please refer :class:`~transformers.DebertaV2Config`
762
+
763
+ """
764
+
765
+ def __init__(self, config):
766
+ super().__init__()
767
+ if config.hidden_size % config.num_attention_heads != 0:
768
+ raise ValueError(
769
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
770
+ f"heads ({config.num_attention_heads})"
771
+ )
772
+ self.num_attention_heads = config.num_attention_heads
773
+ _attention_head_size = config.hidden_size // config.num_attention_heads
774
+ self.attention_head_size = getattr(
775
+ config, "attention_head_size", _attention_head_size
776
+ )
777
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
778
+ self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
779
+ self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
780
+ self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
781
+
782
+ self.share_att_key = getattr(config, "share_att_key", False)
783
+ self.pos_att_type = (
784
+ config.pos_att_type if config.pos_att_type is not None else []
785
+ )
786
+ self.relative_attention = getattr(config, "relative_attention", False)
787
+
788
+ if self.relative_attention:
789
+ self.position_buckets = getattr(config, "position_buckets", -1)
790
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
791
+ if self.max_relative_positions < 1:
792
+ self.max_relative_positions = config.max_position_embeddings
793
+ self.pos_ebd_size = self.max_relative_positions
794
+ if self.position_buckets > 0:
795
+ self.pos_ebd_size = self.position_buckets
796
+
797
+ self.pos_dropout = StableDropout(config.hidden_dropout_prob)
798
+
799
+ if not self.share_att_key:
800
+ if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
801
+ self.pos_key_proj = nn.Linear(
802
+ config.hidden_size, self.all_head_size, bias=True
803
+ )
804
+ if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
805
+ self.pos_query_proj = nn.Linear(
806
+ config.hidden_size, self.all_head_size
807
+ )
808
+
809
+ self.dropout = StableDropout(config.attention_probs_dropout_prob)
810
+
811
+ def transpose_for_scores(self, x, attention_heads):
812
+ new_x_shape = x.size()[:-1] + (attention_heads, -1)
813
+ x = x.view(*new_x_shape)
814
+ return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
815
+
816
+ def forward(
817
+ self,
818
+ hidden_states,
819
+ attention_mask,
820
+ return_att=False,
821
+ query_states=None,
822
+ relative_pos=None,
823
+ rel_embeddings=None,
824
+ ):
825
+ """
826
+ Call the module
827
+
828
+ Args:
829
+ hidden_states (:obj:`torch.FloatTensor`):
830
+ Input states to the module usually the output from previous layer, it will be the Q,K and V in
831
+ `Attention(Q,K,V)`
832
+
833
+ attention_mask (:obj:`torch.ByteTensor`):
834
+ An attention mask matrix of shape [`B`, `N`, `N`] where `B` is the batch size, `N` is the maximum
835
+ sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j`
836
+ th token.
837
+
838
+ return_att (:obj:`bool`, optional):
839
+ Whether return the attention matrix.
840
+
841
+ query_states (:obj:`torch.FloatTensor`, optional):
842
+ The `Q` state in `Attention(Q,K,V)`.
843
+
844
+ relative_pos (:obj:`torch.LongTensor`):
845
+ The relative position encoding between the tokens in the sequence. It's of shape [`B`, `N`, `N`] with
846
+ values ranging in [`-max_relative_positions`, `max_relative_positions`].
847
+
848
+ rel_embeddings (:obj:`torch.FloatTensor`):
849
+ The embedding of relative distances. It's a tensor of shape [:math:`2 \\times
850
+ \\text{max_relative_positions}`, `hidden_size`].
851
+
852
+
853
+ """
854
+ if query_states is None:
855
+ query_states = hidden_states
856
+ query_layer = self.transpose_for_scores(
857
+ self.query_proj(query_states), self.num_attention_heads
858
+ )
859
+ key_layer = self.transpose_for_scores(
860
+ self.key_proj(hidden_states), self.num_attention_heads
861
+ )
862
+ value_layer = self.transpose_for_scores(
863
+ self.value_proj(hidden_states), self.num_attention_heads
864
+ )
865
+
866
+ rel_att = None
867
+ # Take the dot product between "query" and "key" to get the raw attention scores.
868
+ scale_factor = 1
869
+ if "c2p" in self.pos_att_type:
870
+ scale_factor += 1
871
+ if "p2c" in self.pos_att_type:
872
+ scale_factor += 1
873
+ if "p2p" in self.pos_att_type:
874
+ scale_factor += 1
875
+ scale = math.sqrt(query_layer.size(-1) * scale_factor)
876
+ attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
877
+ if self.relative_attention:
878
+ rel_embeddings = self.pos_dropout(rel_embeddings)
879
+ rel_att = self.disentangled_attention_bias(
880
+ query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
881
+ )
882
+
883
+ if rel_att is not None:
884
+ attention_scores = attention_scores + rel_att
885
+ attention_scores = attention_scores
886
+ attention_scores = attention_scores.view(
887
+ -1,
888
+ self.num_attention_heads,
889
+ attention_scores.size(-2),
890
+ attention_scores.size(-1),
891
+ )
892
+
893
+ # bsz x height x length x dimension
894
+ attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
895
+ attention_probs = self.dropout(attention_probs)
896
+ context_layer = torch.bmm(
897
+ attention_probs.view(
898
+ -1, attention_probs.size(-2), attention_probs.size(-1)
899
+ ),
900
+ value_layer,
901
+ )
902
+ context_layer = (
903
+ context_layer.view(
904
+ -1,
905
+ self.num_attention_heads,
906
+ context_layer.size(-2),
907
+ context_layer.size(-1),
908
+ )
909
+ .permute(0, 2, 1, 3)
910
+ .contiguous()
911
+ )
912
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
913
+ context_layer = context_layer.view(*new_context_layer_shape)
914
+ if return_att:
915
+ return (context_layer, attention_probs)
916
+ else:
917
+ return context_layer
918
+
919
+ def disentangled_attention_bias(
920
+ self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
921
+ ):
922
+ if relative_pos is None:
923
+ q = query_layer.size(-2)
924
+ relative_pos = build_relative_position(
925
+ q,
926
+ key_layer.size(-2),
927
+ bucket_size=self.position_buckets,
928
+ max_position=self.max_relative_positions,
929
+ )
930
+ if relative_pos.dim() == 2:
931
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
932
+ elif relative_pos.dim() == 3:
933
+ relative_pos = relative_pos.unsqueeze(1)
934
+ # bsz x height x query x key
935
+ elif relative_pos.dim() != 4:
936
+ raise ValueError(
937
+ f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}"
938
+ )
939
+
940
+ att_span = self.pos_ebd_size
941
+ relative_pos = relative_pos.long().to(query_layer.device)
942
+
943
+ rel_embeddings = rel_embeddings[
944
+ self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :
945
+ ].unsqueeze(0)
946
+ if self.share_att_key:
947
+ pos_query_layer = self.transpose_for_scores(
948
+ self.query_proj(rel_embeddings), self.num_attention_heads
949
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
950
+ pos_key_layer = self.transpose_for_scores(
951
+ self.key_proj(rel_embeddings), self.num_attention_heads
952
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
953
+ else:
954
+ if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
955
+ pos_key_layer = self.transpose_for_scores(
956
+ self.pos_key_proj(rel_embeddings), self.num_attention_heads
957
+ ).repeat(
958
+ query_layer.size(0) // self.num_attention_heads, 1, 1
959
+ ) # .split(self.all_head_size, dim=-1)
960
+ if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
961
+ pos_query_layer = self.transpose_for_scores(
962
+ self.pos_query_proj(rel_embeddings), self.num_attention_heads
963
+ ).repeat(
964
+ query_layer.size(0) // self.num_attention_heads, 1, 1
965
+ ) # .split(self.all_head_size, dim=-1)
966
+
967
+ score = 0
968
+ # content->position
969
+ if "c2p" in self.pos_att_type:
970
+ scale = math.sqrt(pos_key_layer.size(-1) * scale_factor)
971
+ c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
972
+ c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
973
+ c2p_att = torch.gather(
974
+ c2p_att,
975
+ dim=-1,
976
+ index=c2p_pos.squeeze(0).expand(
977
+ [query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]
978
+ ),
979
+ )
980
+ score += c2p_att / scale
981
+
982
+ # position->content
983
+ if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
984
+ scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
985
+ if key_layer.size(-2) != query_layer.size(-2):
986
+ r_pos = build_relative_position(
987
+ key_layer.size(-2),
988
+ key_layer.size(-2),
989
+ bucket_size=self.position_buckets,
990
+ max_position=self.max_relative_positions,
991
+ ).to(query_layer.device)
992
+ r_pos = r_pos.unsqueeze(0)
993
+ else:
994
+ r_pos = relative_pos
995
+
996
+ p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
997
+ if query_layer.size(-2) != key_layer.size(-2):
998
+ pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
999
+
1000
+ if "p2c" in self.pos_att_type:
1001
+ p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
1002
+ p2c_att = torch.gather(
1003
+ p2c_att,
1004
+ dim=-1,
1005
+ index=p2c_pos.squeeze(0).expand(
1006
+ [query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]
1007
+ ),
1008
+ ).transpose(-1, -2)
1009
+ if query_layer.size(-2) != key_layer.size(-2):
1010
+ p2c_att = torch.gather(
1011
+ p2c_att,
1012
+ dim=-2,
1013
+ index=pos_index.expand(
1014
+ p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))
1015
+ ),
1016
+ )
1017
+ score += p2c_att / scale
1018
+
1019
+ # position->position
1020
+ if "p2p" in self.pos_att_type:
1021
+ pos_query = pos_query_layer[:, :, att_span:, :]
1022
+ p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2))
1023
+ p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:])
1024
+ if query_layer.size(-2) != key_layer.size(-2):
1025
+ p2p_att = torch.gather(
1026
+ p2p_att,
1027
+ dim=-2,
1028
+ index=pos_index.expand(
1029
+ query_layer.size()[:2] + (pos_index.size(-2), p2p_att.size(-1))
1030
+ ),
1031
+ )
1032
+ p2p_att = torch.gather(
1033
+ p2p_att,
1034
+ dim=-1,
1035
+ index=c2p_pos.expand(
1036
+ [
1037
+ query_layer.size(0),
1038
+ query_layer.size(1),
1039
+ query_layer.size(2),
1040
+ relative_pos.size(-1),
1041
+ ]
1042
+ ),
1043
+ )
1044
+ score += p2p_att
1045
+
1046
+ return score
1047
+
1048
+
1049
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm
1050
+ class DebertaV2Embeddings(nn.Module):
1051
+ """Construct the embeddings from word, position and token_type embeddings."""
1052
+
1053
+ def __init__(
1054
+ self,
1055
+ config,
1056
+ features_dim,
1057
+ add_video_feat=False,
1058
+ max_feats = 10
1059
+ ):
1060
+ super().__init__()
1061
+ pad_token_id = getattr(config, "pad_token_id", 0)
1062
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
1063
+ self.word_embeddings = nn.Embedding(
1064
+ config.vocab_size, self.embedding_size, padding_idx=pad_token_id
1065
+ )
1066
+
1067
+ self.position_biased_input = getattr(config, "position_biased_input", True)
1068
+ self.position_embeddings = nn.Embedding(
1069
+ config.max_position_embeddings, self.embedding_size
1070
+ ) # it is used for the decoder anyway
1071
+
1072
+ if config.type_vocab_size > 0:
1073
+ self.token_type_embeddings = nn.Embedding(
1074
+ config.type_vocab_size, self.embedding_size
1075
+ )
1076
+
1077
+ if self.embedding_size != config.hidden_size:
1078
+ self.embed_proj = nn.Linear(
1079
+ self.embedding_size, config.hidden_size, bias=False
1080
+ )
1081
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
1082
+ self.dropout = StableDropout(config.hidden_dropout_prob)
1083
+ self.config = config
1084
+
1085
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
1086
+ self.register_buffer(
1087
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
1088
+ )
1089
+
1090
+ self.add_video_feat = add_video_feat
1091
+ self.features_dim = features_dim
1092
+ if self.features_dim:
1093
+ self.linear_video = nn.Linear(features_dim, config.hidden_size)
1094
+ if self.add_video_feat:
1095
+ self.evl = EVLTransformer(max_feats, decoder_num_layers=1,
1096
+ decoder_qkv_dim=768, add_video_feat=self.add_video_feat,
1097
+ add_mask=True)
1098
+ #self.evl = ConvNet()
1099
+
1100
+ def get_video_embedding(self, video, video_mask):
1101
+
1102
+ if self.add_video_feat:
1103
+ video_g = self.evl(video, video_mask)
1104
+ video_feat = self.linear_video(video)
1105
+ video_feat_l = torch.cat([video_g, video_feat], dim = 1)
1106
+
1107
+ else:
1108
+ video_feat_l = self.linear_video(video)
1109
+ video_feat_tmp = video_feat_l * video_mask.unsqueeze(-1)
1110
+ video_g = torch.sum(video_feat_tmp, dim = 1) / video_mask.sum(dim = 1, keepdim=True)
1111
+ return video_g, video_feat_l
1112
+
1113
+ def forward(
1114
+ self,
1115
+ input_ids=None,
1116
+ token_type_ids=None,
1117
+ position_ids=None,
1118
+ mask=None,
1119
+ inputs_embeds=None,
1120
+ video=None,
1121
+ video_mask=None
1122
+ ):
1123
+ if input_ids is not None:
1124
+ input_shape = input_ids.size()
1125
+ else:
1126
+ input_shape = inputs_embeds.size()[:-1]
1127
+
1128
+ if inputs_embeds is None:
1129
+ inputs_embeds = self.word_embeddings(input_ids)
1130
+ if self.features_dim and video is not None:
1131
+ video_global, video = self.get_video_embedding(video, video_mask)
1132
+ inputs_embeds = torch.cat([video, inputs_embeds], 1)
1133
+ input_shape = inputs_embeds[:, :, 0].shape
1134
+
1135
+ seq_length = input_shape[1]
1136
+
1137
+ if position_ids is None:
1138
+ position_ids = self.position_ids[:, :seq_length]
1139
+
1140
+ if token_type_ids is None:
1141
+ token_type_ids = torch.zeros(
1142
+ input_shape, dtype=torch.long, device=self.position_ids.device
1143
+ )
1144
+
1145
+ if self.position_embeddings is not None:
1146
+ position_embeddings = self.position_embeddings(position_ids.long())
1147
+ else:
1148
+ position_embeddings = torch.zeros_like(inputs_embeds)
1149
+
1150
+ embeddings = inputs_embeds
1151
+ if self.position_biased_input:
1152
+ embeddings = embeddings + position_embeddings
1153
+ if self.config.type_vocab_size > 0:
1154
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
1155
+ embeddings = embeddings + token_type_embeddings
1156
+
1157
+ if self.embedding_size != self.config.hidden_size:
1158
+ embeddings = self.embed_proj(embeddings)
1159
+
1160
+ embeddings = self.LayerNorm(embeddings)
1161
+
1162
+ if mask is not None:
1163
+ if mask.dim() != embeddings.dim():
1164
+ if mask.dim() == 4:
1165
+ mask = mask.squeeze(1).squeeze(1)
1166
+ mask = mask.unsqueeze(2)
1167
+ mask = mask.to(embeddings.dtype)
1168
+
1169
+ embeddings = embeddings * mask
1170
+
1171
+ embeddings = self.dropout(embeddings)
1172
+ return {
1173
+ "embeddings": embeddings,
1174
+ "position_embeddings": position_embeddings,
1175
+ "video_global": video_global
1176
+ }
1177
+
1178
+
1179
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2
1180
+
1181
+
1182
+ class DebertaV2PreTrainedModel(PreTrainedModel):
1183
+ """
1184
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1185
+ models.
1186
+ """
1187
+
1188
+ config_class = DebertaV2Config
1189
+ base_model_prefix = "deberta"
1190
+ _keys_to_ignore_on_load_missing = ["position_ids"]
1191
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
1192
+
1193
+ def __init__(self, config):
1194
+ super().__init__(config)
1195
+ self._register_load_state_dict_pre_hook(self._pre_load_hook)
1196
+
1197
+ def _init_weights(self, module):
1198
+ """Initialize the weights."""
1199
+ if isinstance(module, nn.Linear):
1200
+ # Slightly different from the TF version which uses truncated_normal for initialization
1201
+ # cf https://github.com/pytorch/pytorch/pull/5617
1202
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1203
+ if module.bias is not None:
1204
+ module.bias.data.zero_()
1205
+ elif isinstance(module, nn.Embedding):
1206
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1207
+ if module.padding_idx is not None:
1208
+ module.weight.data[module.padding_idx].zero_()
1209
+
1210
+ def _pre_load_hook(
1211
+ self,
1212
+ state_dict,
1213
+ prefix,
1214
+ local_metadata,
1215
+ strict,
1216
+ missing_keys,
1217
+ unexpected_keys,
1218
+ error_msgs,
1219
+ ):
1220
+ """
1221
+ Removes the classifier if it doesn't have the correct number of labels.
1222
+ """
1223
+ self_state = self.state_dict()
1224
+ if (
1225
+ ("classifier.weight" in self_state)
1226
+ and ("classifier.weight" in state_dict)
1227
+ and self_state["classifier.weight"].size()
1228
+ != state_dict["classifier.weight"].size()
1229
+ ):
1230
+ print(
1231
+ f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model "
1232
+ f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint "
1233
+ f"weights. You should train your model on new data."
1234
+ )
1235
+ del state_dict["classifier.weight"]
1236
+ if "classifier.bias" in state_dict:
1237
+ del state_dict["classifier.bias"]
1238
+
1239
+
1240
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2
1241
+ class DebertaV2Model(DebertaV2PreTrainedModel):
1242
+ def __init__(
1243
+ self,
1244
+ config,
1245
+ max_feats=10,
1246
+ features_dim=768,
1247
+ freeze_lm=False,
1248
+ ds_factor_attn=8,
1249
+ ds_factor_ff=8,
1250
+ ft_ln=False,
1251
+ dropout=0.1,
1252
+ add_video_feat = False,
1253
+ freeze_ad=False,
1254
+ ):
1255
+ super().__init__(config)
1256
+
1257
+ self.embeddings = DebertaV2Embeddings(
1258
+ config,
1259
+ features_dim,
1260
+ add_video_feat,
1261
+ max_feats
1262
+ )
1263
+ self.encoder = DebertaV2Encoder(
1264
+ config,
1265
+ ds_factor_attn,
1266
+ ds_factor_ff,
1267
+ dropout,
1268
+ )
1269
+ self.z_steps = 0
1270
+ self.config = config
1271
+
1272
+ self.features_dim = features_dim
1273
+ self.max_feats = max_feats
1274
+ if freeze_lm:
1275
+ for n, p in self.named_parameters():
1276
+ #if (not "linear_video" in n) and (not "adapter" in n):
1277
+ # if ft_ln and "LayerNorm" in n:
1278
+ # continue
1279
+ # else:
1280
+ # p.requires_grad_(False)
1281
+ if not freeze_ad:
1282
+ if (not "evl" in n) and (not "linear_video" in n) and (not "adapter" in n) and (not "moe" in n):
1283
+ if ft_ln and "LayerNorm" in n:
1284
+ continue
1285
+ else:
1286
+ p.requires_grad_(False)
1287
+
1288
+ else:
1289
+ if not "evl" in n:
1290
+ p.requires_grad_(False)
1291
+
1292
+
1293
+
1294
+ self.init_weights()
1295
+
1296
+ def get_input_embeddings(self):
1297
+ return self.embeddings.word_embeddings
1298
+
1299
+ def set_input_embeddings(self, new_embeddings):
1300
+ self.embeddings.word_embeddings = new_embeddings
1301
+
1302
+ def _prune_heads(self, heads_to_prune):
1303
+ """
1304
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1305
+ class PreTrainedModel
1306
+ """
1307
+ raise NotImplementedError(
1308
+ "The prune function is not implemented in DeBERTa model."
1309
+ )
1310
+
1311
+ def forward(
1312
+ self,
1313
+ input_ids=None,
1314
+ attention_mask=None,
1315
+ token_type_ids=None,
1316
+ position_ids=None,
1317
+ inputs_embeds=None,
1318
+ output_attentions=None,
1319
+ output_hidden_states=None,
1320
+ return_dict=None,
1321
+ video=None,
1322
+ video_mask=None,
1323
+ train_mode = True
1324
+ ):
1325
+ output_attentions = (
1326
+ output_attentions
1327
+ if output_attentions is not None
1328
+ else self.config.output_attentions
1329
+ )
1330
+ output_hidden_states = (
1331
+ output_hidden_states
1332
+ if output_hidden_states is not None
1333
+ else self.config.output_hidden_states
1334
+ )
1335
+ return_dict = (
1336
+ return_dict if return_dict is not None else self.config.use_return_dict
1337
+ )
1338
+
1339
+ if input_ids is not None and inputs_embeds is not None:
1340
+ raise ValueError(
1341
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1342
+ )
1343
+ elif input_ids is not None:
1344
+ input_shape = input_ids.size()
1345
+ elif inputs_embeds is not None:
1346
+ input_shape = inputs_embeds.size()[:-1]
1347
+ else:
1348
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1349
+
1350
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1351
+
1352
+ if attention_mask is None:
1353
+ attention_mask = torch.ones(input_shape, device=device)
1354
+
1355
+ if self.features_dim and video is not None:
1356
+ if video_mask is None:
1357
+ video_shape = video[:, :, 0].size()
1358
+ video_mask = torch.ones(video_shape, device=device)
1359
+ attention_mask = torch.cat([video_mask, attention_mask], 1)
1360
+ input_shape = attention_mask.size()
1361
+
1362
+ if token_type_ids is None:
1363
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1364
+
1365
+ embedding_output = self.embeddings(
1366
+ input_ids=input_ids,
1367
+ token_type_ids=token_type_ids,
1368
+ position_ids=position_ids,
1369
+ mask=attention_mask,
1370
+ inputs_embeds=inputs_embeds,
1371
+ video=video,
1372
+ video_mask=video_mask[:, 1:] if video_mask.shape[1] != video.shape[1] else video_mask
1373
+ )
1374
+ embedding_output, position_embeddings, video_g = (
1375
+ embedding_output["embeddings"],
1376
+ embedding_output["position_embeddings"],
1377
+ embedding_output["video_global"]
1378
+ )
1379
+
1380
+ video_g = video_g.squeeze()
1381
+ encoder_outputs = self.encoder(
1382
+ video_g,
1383
+ embedding_output,
1384
+ attention_mask,
1385
+ output_hidden_states=True,
1386
+ output_attentions=output_attentions,
1387
+ return_dict=return_dict,
1388
+ train_mode=train_mode
1389
+ )
1390
+ encoded_layers = encoder_outputs[1]
1391
+ loss_moe =encoder_outputs.loss_moe
1392
+
1393
+ if self.z_steps > 1:
1394
+ hidden_states = encoded_layers[-2]
1395
+ layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
1396
+ query_states = encoded_layers[-1]
1397
+ rel_embeddings = self.encoder.get_rel_embedding()
1398
+ attention_mask = self.encoder.get_attention_mask(attention_mask)
1399
+ rel_pos = self.encoder.get_rel_pos(embedding_output)
1400
+ for layer in layers[1:]:
1401
+ query_states = layer(
1402
+ hidden_states,
1403
+ attention_mask,
1404
+ return_att=False,
1405
+ query_states=query_states,
1406
+ relative_pos=rel_pos,
1407
+ rel_embeddings=rel_embeddings,
1408
+ )
1409
+ encoded_layers.append(query_states)
1410
+
1411
+ sequence_output = encoded_layers[-1]
1412
+
1413
+ if not return_dict:
1414
+ return (sequence_output,) + encoder_outputs[
1415
+ (1 if output_hidden_states else 2) :
1416
+ ]
1417
+
1418
+ return BaseModelOutput(
1419
+ last_hidden_state=sequence_output,
1420
+ hidden_states=encoder_outputs.hidden_states
1421
+ if output_hidden_states
1422
+ else None,
1423
+ attentions=encoder_outputs.attentions,
1424
+ position_embeddings=position_embeddings,
1425
+ attention_mask=attention_mask,
1426
+ video_g=video_g,
1427
+ loss_moe = loss_moe,
1428
+ loads=encoder_outputs.loads
1429
+ )
1430
+
1431
+
1432
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2
1433
+ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
1434
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1435
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1436
+
1437
+ def __init__(
1438
+ self,
1439
+ config,
1440
+ max_feats=10,
1441
+ features_dim=768,
1442
+ freeze_lm=True,
1443
+ freeze_mlm=True,
1444
+ ds_factor_attn=8,
1445
+ ds_factor_ff=8,
1446
+ ft_ln=True,
1447
+ dropout=0.1,
1448
+ n_ans=0,
1449
+ freeze_last=True,
1450
+ add_video_feat = False,
1451
+ freeze_ad=False,
1452
+ add_temporal_trans = False
1453
+ ):
1454
+ """
1455
+ :param config: BiLM configuration
1456
+ :param max_feats: maximum number of frames used by the model
1457
+ :param features_dim: embedding dimension of the visual features, set = 0 for text-only mode
1458
+ :param freeze_lm: whether to freeze or not the language model (Transformer encoder + token embedder)
1459
+ :param freeze_mlm: whether to freeze or not the MLM head
1460
+ :param ds_factor_attn: downsampling factor for the adapter after self-attention, no adapter if set to 0
1461
+ :param ds_factor_ff: downsampling factor for the adapter after feed-forward, no adapter if set to 0
1462
+ :param ft_ln: whether to finetune or not the normalization layers
1463
+ :param dropout: dropout probability in the adapter
1464
+ :param n_ans: number of answers in the downstream vocabulary, set = 0 during cross-modal training
1465
+ :param freeze_last: whether to freeze or not the answer embedding module
1466
+ """
1467
+ super().__init__(config)
1468
+
1469
+ # self.clip, _ = clip.load("ViT-L/14")
1470
+ # for p in self.clip.parameters():
1471
+ # p.requires_grad_(False)
1472
+
1473
+ self.deberta = DebertaV2Model(
1474
+ config,
1475
+ max_feats,
1476
+ features_dim,
1477
+ freeze_lm,
1478
+ ds_factor_attn,
1479
+ ds_factor_ff,
1480
+ ft_ln,
1481
+ dropout,
1482
+ add_video_feat,
1483
+ freeze_ad
1484
+ )
1485
+
1486
+ self.add_video_feat = add_video_feat
1487
+ self.lm_predictions = DebertaV2OnlyMLMHead(config)
1488
+ self.features_dim = features_dim
1489
+ if freeze_mlm:
1490
+ for n, p in self.lm_predictions.named_parameters():
1491
+ if ft_ln and "LayerNorm" in n:
1492
+ continue
1493
+ else:
1494
+ p.requires_grad_(False)
1495
+
1496
+ self.init_weights()
1497
+ self.n_ans = n_ans
1498
+ if n_ans:
1499
+ self.answer_embeddings = nn.Embedding(
1500
+ n_ans, self.deberta.embeddings.embedding_size
1501
+ )
1502
+ self.answer_bias = nn.Parameter(torch.zeros(n_ans))
1503
+ if freeze_last:
1504
+ self.answer_embeddings.requires_grad_(False)
1505
+ self.answer_bias.requires_grad_(False)
1506
+
1507
+ def set_answer_embeddings(self, a2tok, freeze_last=True):
1508
+ a2v = self.deberta.embeddings.word_embeddings(a2tok) # answer embeddings (ans_vocab_num, 1, dim)
1509
+ pad_token_id = getattr(self.config, "pad_token_id", 0)
1510
+ sum_tokens = (a2tok != pad_token_id).sum(1, keepdims=True) # n_ans (1000, 1) n_tokens
1511
+ if len(a2v) != self.n_ans: # reinitialize the answer embeddings
1512
+ assert not self.training
1513
+ self.n_ans = len(a2v)
1514
+ self.answer_embeddings = nn.Embedding(
1515
+ self.n_ans, self.deberta.embeddings.embedding_size
1516
+ ).to(self.device)
1517
+ self.answer_bias.requires_grad = False
1518
+ self.answer_bias.resize_(self.n_ans)
1519
+ self.answer_embeddings.weight.data = torch.div(
1520
+ (a2v * (a2tok != pad_token_id).float()[:, :, None]).sum(1),
1521
+ sum_tokens.clamp(min=1),
1522
+ ) # n_ans
1523
+ a2b = self.lm_predictions.lm_head.bias[a2tok]
1524
+ self.answer_bias.weight = torch.div(
1525
+ (a2b * (a2tok != pad_token_id).float()).sum(1), sum_tokens.clamp(min=1)
1526
+ )
1527
+ if freeze_last:
1528
+ self.answer_embeddings.requires_grad_(False)
1529
+ self.answer_bias.requires_grad_(False)
1530
+
1531
+ def emd_context_layer(self, encoder_layers, z_states, attention_mask, encoder, temporal_factor, train_mode):
1532
+ if attention_mask.dim() <= 2:
1533
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
1534
+ att_mask = extended_attention_mask.byte()
1535
+ attention_mask = att_mask * att_mask.squeeze(-2).unsqueeze(-1)
1536
+ elif attention_mask.dim() == 3:
1537
+ attention_mask = attention_mask.unsqueeze(1)
1538
+ hidden_states = encoder_layers[-2]
1539
+ if not self.config.position_biased_input:
1540
+ layers = [encoder.layer[-1] for _ in range(2)]
1541
+ z_states = z_states + hidden_states
1542
+ query_states = z_states
1543
+ query_mask = attention_mask
1544
+ outputs = []
1545
+ rel_embeddings = encoder.get_rel_embedding()
1546
+
1547
+ for layer in layers:
1548
+ output = layer(
1549
+ temporal_factor,
1550
+ hidden_states,
1551
+ query_mask,
1552
+ return_att=False,
1553
+ query_states=query_states,
1554
+ relative_pos=None,
1555
+ rel_embeddings=rel_embeddings,
1556
+ train_mode=train_mode
1557
+ )
1558
+ query_states = output[0]
1559
+ outputs.append(query_states)
1560
+ else:
1561
+ outputs = [encoder_layers[-1]]
1562
+
1563
+ return outputs
1564
+
1565
+ def forward(
1566
+ self,
1567
+ input_ids=None,
1568
+ attention_mask=None,
1569
+ labels=None,
1570
+ video=None,
1571
+ video_mask=None,
1572
+ train_mode=False,
1573
+ ):
1574
+ token_type_ids=None
1575
+ position_ids=None
1576
+ inputs_embeds=None
1577
+ output_attentions=None
1578
+ return_dict=None
1579
+ mlm=False
1580
+ r"""
1581
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1582
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1583
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1584
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1585
+ """
1586
+
1587
+ return_dict = (
1588
+ return_dict if return_dict is not None else self.config.use_return_dict
1589
+ )
1590
+
1591
+
1592
+ # rand_video = torch.randn(1,30,3,224,224).cuda()
1593
+ # video = self.clip.encode_image(rand_video.squeeze()).unsqueeze(0)
1594
+ # video = video.to(torch.float)
1595
+
1596
+ outputs = self.deberta(
1597
+ input_ids,
1598
+ attention_mask=attention_mask,
1599
+ token_type_ids=token_type_ids,
1600
+ position_ids=position_ids,
1601
+ inputs_embeds=inputs_embeds,
1602
+ output_attentions=output_attentions,
1603
+ output_hidden_states=True,
1604
+ return_dict=return_dict,
1605
+ video=video,
1606
+ video_mask=video_mask,
1607
+ train_mode = train_mode
1608
+ )
1609
+
1610
+ loss_moe = outputs['loss_moe']
1611
+
1612
+ if labels is not None:
1613
+ if (
1614
+ self.features_dim and video is not None
1615
+ ): # ignore the label predictions for visual tokens
1616
+ video_shape = video[:, :, 0].size()
1617
+ # add video_general
1618
+ if self.add_video_feat:
1619
+ video_shape = (video_shape[0], video_shape[1] + 1)
1620
+
1621
+ video_labels = torch.tensor(
1622
+ [[-100] * video_shape[1]] * video_shape[0],
1623
+ dtype=torch.long,
1624
+ device=labels.device,
1625
+ )
1626
+ labels = torch.cat([video_labels, labels], 1)
1627
+
1628
+ # sequence_output = outputs[0]
1629
+ modified = self.emd_context_layer(
1630
+ encoder_layers=outputs["hidden_states"],
1631
+ z_states=outputs["position_embeddings"].repeat(
1632
+ input_ids.shape[0] // len(outputs["position_embeddings"]), 1, 1
1633
+ ),
1634
+ attention_mask=outputs["attention_mask"],
1635
+ encoder=self.deberta.encoder,
1636
+ temporal_factor=outputs["video_g"],
1637
+ train_mode = train_mode
1638
+ )
1639
+ bias = None
1640
+ if self.n_ans and (not mlm): # downstream mode
1641
+ embeddings = self.answer_embeddings.weight
1642
+ bias = self.answer_bias
1643
+ else:
1644
+ embeddings = self.deberta.embeddings.word_embeddings.weight
1645
+ prediction_scores = self.lm_predictions(modified[-1], embeddings, bias)
1646
+
1647
+ masked_lm_loss = None
1648
+ if labels is not None:
1649
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1650
+
1651
+ masked_lm_loss = loss_fct(
1652
+ prediction_scores.view(-1, self.config.vocab_size),
1653
+ labels.view(-1), # labels[labels > 0].view(-1)
1654
+ )
1655
+
1656
+ if not return_dict:
1657
+ output = (prediction_scores,) + outputs[1:]
1658
+ return (
1659
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1660
+ )
1661
+
1662
+ return MaskedLMOutput(
1663
+ loss_moe=loss_moe,
1664
+ loss=masked_lm_loss,
1665
+ logits=prediction_scores,
1666
+ hidden_states=outputs.hidden_states,
1667
+ attentions=outputs.attentions,
1668
+ loads=outputs.loads,
1669
+ embeddings=outputs.video_g
1670
+ )
1671
+
1672
+
1673
+ # copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
1674
+ class DebertaV2PredictionHeadTransform(nn.Module):
1675
+ def __init__(self, config):
1676
+ super().__init__()
1677
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1678
+ if isinstance(config.hidden_act, str):
1679
+ self.transform_act_fn = ACT2FN[config.hidden_act]
1680
+ else:
1681
+ self.transform_act_fn = config.hidden_act
1682
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1683
+
1684
+ def forward(self, hidden_states):
1685
+ hidden_states = self.dense(hidden_states)
1686
+ hidden_states = self.transform_act_fn(hidden_states)
1687
+ hidden_states = self.LayerNorm(hidden_states)
1688
+ return hidden_states
1689
+
1690
+
1691
+ # copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
1692
+ class DebertaV2LMPredictionHead(nn.Module):
1693
+ def __init__(self, config):
1694
+ super().__init__()
1695
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1696
+ if isinstance(config.hidden_act, str):
1697
+ self.transform_act_fn = ACT2FN[config.hidden_act]
1698
+ else:
1699
+ self.transform_act_fn = config.hidden_act
1700
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1701
+
1702
+ # The output weights are the same as the input embeddings, but there is
1703
+ # an output-only bias for each token.
1704
+
1705
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1706
+
1707
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
1708
+
1709
+ def forward(self, hidden_states, embedding_weight, bias=None):
1710
+ hidden_states = self.dense(hidden_states)
1711
+ hidden_states = self.transform_act_fn(hidden_states)
1712
+ hidden_states = self.LayerNorm(hidden_states)
1713
+ if bias is not None:
1714
+ logits = (
1715
+ torch.matmul(hidden_states, embedding_weight.t().to(hidden_states))
1716
+ + bias
1717
+ )
1718
+ else:
1719
+ logits = (
1720
+ torch.matmul(hidden_states, embedding_weight.t().to(hidden_states))
1721
+ + self.bias
1722
+ )
1723
+ return logits
1724
+
1725
+
1726
+ # copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta
1727
+ class DebertaV2OnlyMLMHead(nn.Module):
1728
+ def __init__(self, config):
1729
+ super().__init__()
1730
+ # self.predictions = DebertaV2LMPredictionHead(config)
1731
+ self.lm_head = DebertaV2LMPredictionHead(config)
1732
+
1733
+ def forward(self, sequence_output, embedding_weight, bias=None):
1734
+ prediction_scores = self.lm_head(sequence_output, embedding_weight, bias=bias)
1735
+ return prediction_scores
model/evl.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Dict, Iterable, List, Tuple
3
+ import numpy as np
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from collections import OrderedDict
10
+
11
+
12
+ class QuickGELU(nn.Module):
13
+ def forward(self, x: torch.Tensor):
14
+ return x * torch.sigmoid(1.702 * x)
15
+
16
+ class LayerNorm(nn.LayerNorm):
17
+ """Subclass torch's LayerNorm to handle fp16."""
18
+
19
+ def forward(self, x: torch.Tensor):
20
+ orig_type = x.dtype
21
+ ret = super().forward(x.type(torch.float32))
22
+ return ret.type(orig_type)
23
+
24
+
25
+ class Attention(nn.Module):
26
+ '''
27
+ A generalized attention module with more flexibility.
28
+ '''
29
+
30
+ def __init__(
31
+ self, q_in_dim: int, k_in_dim: int, v_in_dim: int,
32
+ qk_proj_dim: int, v_proj_dim: int, num_heads: int, out_dim: int,
33
+ return_all_features: bool = False, add_mask: bool = False, dropout: float = 0.0
34
+ ):
35
+ super().__init__()
36
+
37
+ self.q_proj = nn.Linear(q_in_dim, qk_proj_dim)
38
+ self.k_proj = nn.Linear(k_in_dim, qk_proj_dim)
39
+ self.v_proj = nn.Linear(v_in_dim, v_proj_dim)
40
+ self.out_proj = nn.Linear(v_proj_dim, out_dim)
41
+
42
+ self.num_heads = num_heads
43
+ self.return_all_features = return_all_features
44
+ assert qk_proj_dim % num_heads == 0 and v_proj_dim % num_heads == 0
45
+
46
+ self.add_mask = add_mask
47
+ self._initialize_weights()
48
+
49
+ def _initialize_weights(self):
50
+ for m in (self.q_proj, self.k_proj, self.v_proj, self.out_proj):
51
+ nn.init.xavier_uniform_(m.weight)
52
+ nn.init.constant_(m.bias, 0.)
53
+
54
+ def forward(self, q, k, v, mask):
55
+ if not self.add_mask:
56
+ mask = torch.ones_like(mask)
57
+
58
+ assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3
59
+ N = q.size(0); assert k.size(0) == N and v.size(0) == N
60
+ Lq, Lkv = q.size(1), k.size(1); assert v.size(1) == Lkv
61
+
62
+ q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
63
+
64
+ H = self.num_heads
65
+ Cqk, Cv = q.size(-1) // H, v.size(-1) // H
66
+
67
+ q = q.view(N, Lq, H, Cqk)
68
+ k = k.view(N, Lkv, H, Cqk)
69
+ v = v.view(N, Lkv, H, Cv)
70
+
71
+ aff = torch.einsum('nqhc,nkhc->nqkh', q / (Cqk ** 0.5), k)
72
+ #aff = aff.softmax(dim=-2)
73
+
74
+ rmask = ~(mask.bool())
75
+ aff = aff.masked_fill(rmask.unsqueeze(1).unsqueeze(-1).to(aff.device), float("-inf"))
76
+ aff = aff.softmax(dim = -2)
77
+
78
+ mix = torch.einsum('nqlh,nlhc->nqhc', aff, v)
79
+
80
+ out = self.out_proj(mix.flatten(-2))
81
+
82
+ if self.return_all_features:
83
+ return dict(q=q, k=k, v=v, aff=aff, out=out)
84
+ else:
85
+ return out
86
+
87
+
88
+ class TransformerDecoderLayer(nn.Module):
89
+
90
+ def __init__(
91
+ self,
92
+ in_feature_dim: int = 768,
93
+ qkv_dim: int = 768,
94
+ num_heads: int = 12,
95
+ mlp_factor: float = 4.0,
96
+ mlp_dropout: float = 0.0,
97
+ act: nn.Module = QuickGELU,
98
+ add_mask: bool = False
99
+ ):
100
+ super().__init__()
101
+
102
+ self.attn = Attention(
103
+ q_in_dim=in_feature_dim, k_in_dim=in_feature_dim, v_in_dim=in_feature_dim,
104
+ qk_proj_dim=qkv_dim, v_proj_dim=qkv_dim, num_heads=num_heads, out_dim=in_feature_dim, add_mask=add_mask
105
+ )
106
+
107
+ mlp_dim = round(mlp_factor * in_feature_dim)
108
+ self.mlp = nn.Sequential(OrderedDict([
109
+ ('fc1', nn.Linear(in_feature_dim, mlp_dim)),
110
+ ('act', act()),
111
+ ('dropout', nn.Dropout(mlp_dropout)),
112
+ ('fc2', nn.Linear(mlp_dim, in_feature_dim)),
113
+ ]))
114
+
115
+ self.norm1 = LayerNorm(in_feature_dim)
116
+ self.norm2 = LayerNorm(in_feature_dim)
117
+ self.norm3 = LayerNorm(in_feature_dim)
118
+
119
+ self._initialize_weights()
120
+
121
+
122
+ def _initialize_weights(self):
123
+ for m in (self.mlp[0], self.mlp[-1]):
124
+ nn.init.xavier_uniform_(m.weight)
125
+ nn.init.normal_(m.bias, std=1e-6)
126
+
127
+
128
+ def forward(self, x, y, mask):
129
+ y_norm = self.norm3(y)
130
+ x = x + self.attn(self.norm1(x), y_norm, y_norm, mask)
131
+ x = x + self.mlp(self.norm2(x))
132
+
133
+ return x
134
+
135
+
136
+ class EVLDecoder(nn.Module):
137
+
138
+ def __init__(
139
+ self,
140
+ num_frames: int = 8,
141
+ spatial_size: Tuple[int, int] = (14, 14),
142
+ num_layers: int = 4,
143
+ in_feature_dim: int = 768,
144
+ qkv_dim: int = 768,
145
+ num_heads: int = 12,
146
+ mlp_factor: float = 4.0,
147
+ enable_temporal_conv: bool = True,
148
+ enable_temporal_pos_embed: bool = True,
149
+ mlp_dropout: float = 0.5,
150
+ add_vid_feat: bool = False,
151
+ add_mask: bool = False,
152
+ ):
153
+ super().__init__()
154
+
155
+ self.num_layers = num_layers
156
+
157
+ self.add_vid_feat = add_vid_feat
158
+
159
+ if add_vid_feat:
160
+ self.decoder_layers = nn.ModuleList(
161
+ [TransformerDecoderLayer(in_feature_dim, qkv_dim, num_heads, mlp_factor, mlp_dropout, add_mask=add_mask) for _ in range(num_layers)]
162
+ )
163
+ self.cls_token = nn.Parameter(torch.zeros([in_feature_dim]))
164
+ self._initialize_weights()
165
+
166
+ if enable_temporal_conv:
167
+ self.temporal_conv = nn.ModuleList(
168
+ [nn.Conv1d(in_feature_dim, in_feature_dim, kernel_size=3, stride=1, padding=1, groups=in_feature_dim) for _ in range(num_layers)]
169
+ )
170
+
171
+ # self.temporal_conv = nn.ModuleList(
172
+ # [nn.Linear(in_feature_dim, in_feature_dim) for _ in range(num_layers)]
173
+ # )
174
+
175
+ if enable_temporal_pos_embed:
176
+ self.temporal_pos_embed = nn.ParameterList(
177
+ [nn.Parameter(torch.zeros([num_frames, in_feature_dim])) for _ in range(num_layers)]
178
+ )
179
+
180
+ def _initialize_weights(self):
181
+ nn.init.normal_(self.cls_token, std=0.02)
182
+
183
+ def forward(self, in_features, video_mask):
184
+ N, T, C = in_features.size()
185
+
186
+ if self.add_vid_feat:
187
+ x = self.cls_token.view(1, 1, -1).repeat(N, 1, 1)
188
+
189
+ frame_features = in_features
190
+ for i in range(self.num_layers):
191
+ frame_features = in_features
192
+ feat = in_features
193
+
194
+ feat = feat.permute(0, 2, 1).contiguous() # N * L, C, T
195
+
196
+
197
+ feat = self.temporal_conv[i](feat)
198
+ feat = feat.view(N, C, T).permute(0, 2, 1,).contiguous() # N, T, C
199
+ frame_features = frame_features + feat
200
+
201
+ frame_features = frame_features + self.temporal_pos_embed[i].view(1, T, C)
202
+
203
+ if self.add_vid_feat:
204
+ x = self.decoder_layers[i](x, frame_features, video_mask)
205
+
206
+ if self.add_vid_feat:
207
+ return x
208
+
209
+ return frame_features
210
+
211
+
212
+ class EVLTransformer(nn.Module):
213
+
214
+ def __init__(
215
+ self,
216
+ num_frames: int = 8,
217
+ decoder_num_layers: int = 2,
218
+ decoder_qkv_dim: int = 768,
219
+ decoder_num_heads: int = 16,
220
+ decoder_mlp_factor: float = 4.0,
221
+ enable_temporal_conv: bool = True,
222
+ enable_temporal_pos_embed: bool = True,
223
+ enable_temporal_cross_attention: bool = False,
224
+ decoder_mlp_dropout: float = 0.5,
225
+ add_video_feat: bool = False,
226
+ output_dim: int = 1536,
227
+ add_mask: bool = False
228
+ ):
229
+ super().__init__()
230
+
231
+ self.decoder_num_layers = decoder_num_layers
232
+
233
+ backbone_feature_dim = 768
234
+ backbone_spatial_size = (16, 16)
235
+
236
+ self.decoder = EVLDecoder(
237
+ num_frames=num_frames,
238
+ spatial_size=backbone_spatial_size,
239
+ num_layers=decoder_num_layers,
240
+ in_feature_dim=backbone_feature_dim,
241
+ qkv_dim=decoder_qkv_dim,
242
+ num_heads=decoder_num_heads,
243
+ mlp_factor=decoder_mlp_factor,
244
+ enable_temporal_conv=enable_temporal_conv,
245
+ enable_temporal_pos_embed=enable_temporal_pos_embed,
246
+ mlp_dropout=decoder_mlp_dropout,
247
+ add_vid_feat = add_video_feat,
248
+ add_mask=add_mask
249
+ )
250
+ self.add_vid_feat = add_video_feat
251
+ if self.add_vid_feat:
252
+ self.norm = nn.LayerNorm(backbone_feature_dim)
253
+ #self.dropout = nn.Dropout(0.5)
254
+ self.proj = nn.Linear(decoder_qkv_dim, output_dim)
255
+
256
+ def forward(self, x, video_mask):
257
+
258
+ features = x
259
+ x = self.decoder(features, video_mask)
260
+ if self.add_vid_feat:
261
+ x = self.norm(x)
262
+ #x = self.dropout(x)
263
+ x = self.proj(x)
264
+
265
+ return x
266
+
267
+ class TemporalAttention(nn.Module):
268
+ def __init__(
269
+ self,
270
+ in_feature_dim: int = 768,
271
+ qkv_dim: int = 768,
272
+ num_heads: int = 8,
273
+ max_frames: int = 40,
274
+ stride: int = 4,
275
+ kernel_size: int = 4,
276
+ add_mask: bool = True,
277
+ ):
278
+ super().__init__()
279
+
280
+ self.num_layers = 2
281
+ self.kernel_size = kernel_size
282
+ self.stride = stride
283
+ max_frames = (max_frames - self.kernel_size) // self.stride + 1
284
+
285
+ self.decoder_layers = nn.ModuleList(
286
+ [TransformerDecoderLayer(in_feature_dim, qkv_dim, num_heads, 2.0, 0.5, add_mask=add_mask) for _ in range(self.num_layers)]
287
+ )
288
+
289
+ '''
290
+ self.attn = Attention(
291
+ q_in_dim=in_feature_dim, k_in_dim=in_feature_dim, v_in_dim=in_feature_dim,
292
+ qk_proj_dim=qkv_dim, v_proj_dim=qkv_dim, num_heads=num_heads, out_dim=in_feature_dim, add_mask=add_mask
293
+ )'''
294
+
295
+ self.temporal_pos_embed = nn.Parameter(torch.zeros([max_frames, in_feature_dim]))
296
+ self.norm = nn.LayerNorm(in_feature_dim)
297
+
298
+ def forward(self, x, video_mask):
299
+
300
+ x, video_mask = avg_1d_pool(x, self.kernel_size, self.stride, video_mask, return_mask=True)
301
+
302
+ x = x + self.temporal_pos_embed.unsqueeze(0)
303
+ for i in range(self.num_layers):
304
+ x = self.decoder_layers[i](x, x, video_mask)
305
+
306
+ #x_norm = self.norm(x)
307
+ #x = x + self.attn(x_norm, x_norm, x_norm, video_mask)
308
+
309
+ return x
310
+
311
+ def recursive_gumbel_softmax(sim, x, video_mask, topk):
312
+ # sim: bs, T
313
+ # x: bs, T, dim
314
+
315
+ feats = []
316
+ bs = x.shape[0]
317
+ idxs = torch.zeros(bs, 10)
318
+ v_masks = []
319
+
320
+ rmask = ~(video_mask.bool())
321
+ sim = sim.masked_fill(rmask.unsqueeze(1).to(sim.device), float("-inf"))
322
+
323
+ for i in range(topk):
324
+ choice = F.gumbel_softmax(sim/0.01, hard=True, dim = -1, tau=0.1).squeeze(1) # bs, T
325
+ idxs[:, i] = torch.argsort(choice, descending=True)[:, 0]
326
+ tmp = torch.sum(choice.unsqueeze(-1) * x, dim = 1, keepdim=True) # bs, dim
327
+ feats.append(tmp)
328
+
329
+ mask_tmp = video_mask[torch.arange(bs), idxs[:, i].to(torch.long)]
330
+ v_masks.append(mask_tmp)
331
+ sim = sim - choice.unsqueeze(1)
332
+
333
+ rank = torch.argsort(idxs, dim = 1)
334
+
335
+ feats = torch.cat(feats, dim= 1) # bs, 10, dim
336
+ res = [feats[torch.arange(bs), rank[:, i]] for i in range(10)]
337
+ res = torch.stack(res, dim=1)
338
+
339
+ video_mask = torch.stack(v_masks, dim=1)
340
+ video_mask = [video_mask[torch.arange(bs), rank[:, i]] for i in range(10)]
341
+ video_mask = torch.stack(video_mask, dim = 1)
342
+
343
+ return res, video_mask
344
+
345
+
model/moe.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ### copy from LIMoE
3
+
4
+
5
+ #from distutils.command.config import config
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.distributions.normal import Normal
10
+ import torch.nn.functional as F
11
+ import numpy as np
12
+
13
+ from transformers.activations import ACT2FN
14
+ from .adapter import Adapter
15
+ from collections import OrderedDict
16
+ from copy import deepcopy
17
+
18
+ #-------------------#
19
+ # MoE
20
+
21
+ class MLP(nn.Module):
22
+ def __init__(self, input_size:int, output_size:int, hidden_size:int):
23
+ super(MLP, self).__init__()
24
+ self.fc1 = nn.Linear(input_size, hidden_size)
25
+ self.fc2 = nn.Linear(hidden_size, output_size)
26
+ self.dropout = nn.Dropout(0.1)
27
+ self.activation = ACT2FN["gelu"]
28
+ self.log_soft = nn.LogSoftmax(1)
29
+ self.apply(self.init_weights)
30
+
31
+ def init_weights(self, m: nn.Module, std=1e-3):
32
+ if isinstance(m, nn.Linear):
33
+ torch.nn.init.normal_(m.weight, std=std)
34
+ torch.nn.init.normal_(m.bias, std=std)
35
+ m.weight.data = torch.clamp(m.weight.data, min=-2 * std, max=2 * std)
36
+ m.bias.data = torch.clamp(m.bias.data, min=-2 * std, max=2 * std)
37
+ elif isinstance(m, nn.LayerNorm):
38
+ m.bias.data.zero_()
39
+ m.weight.data.fill_(1.0)
40
+
41
+ def forward(self, x):
42
+ out = self.fc1(x)
43
+ out = self.activation(out)
44
+ out = self.dropout(out)
45
+ out = self.fc2(out)
46
+ out = self.log_soft(out)
47
+ return out
48
+
49
+
50
+
51
+ class SparseDispatcher(object):
52
+ """Helper for implementing a mixture of experts.
53
+ The purpose of this class is to create input minibatches for the
54
+ experts and to combine the results of the experts to form a unified
55
+ output tensor.
56
+ There are two functions:
57
+ dispatch - take an input Tensor and create input Tensors for each expert.
58
+ combine - take output Tensors from each expert and form a combined output
59
+ Tensor. Outputs from different experts for the same batch element are
60
+ summed together, weighted by the provided "gates".
61
+ The class is initialized with a "gates" Tensor, which specifies which
62
+ batch elements go to which experts, and the weights to use when combining
63
+ the outputs. Batch element b is sent to expert e iff gates[b, e] != 0.
64
+ The inputs and outputs are all two-dimensional [batch, depth].
65
+ Caller is responsible for collapsing additional dimensions prior to
66
+ calling this class and reshaping the output to the original shape.
67
+ See common_layers.reshape_like().
68
+ Example use:
69
+ gates: a float32 `Tensor` with shape `[batch_size, num_experts]`
70
+ inputs: a float32 `Tensor` with shape `[batch_size, input_size]`
71
+ experts: a list of length `num_experts` containing sub-networks.
72
+ dispatcher = SparseDispatcher(num_experts, gates)
73
+ expert_inputs = dispatcher.dispatch(inputs)
74
+ expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)]
75
+ outputs = dispatcher.combine(expert_outputs)
76
+ The preceding code sets the output for a particular example b to:
77
+ output[b] = Sum_i(gates[b, i] * experts[i](inputs[b]))
78
+ This class takes advantage of sparsity in the gate matrix by including in the
79
+ `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`.
80
+ """
81
+
82
+ def __init__(self, num_experts, gates):
83
+ """Create a SparseDispatcher."""
84
+
85
+ self._gates = gates
86
+ self._num_experts = num_experts
87
+ # sort experts
88
+ sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) # torch.nonzero: 返回非0坐标,按行、列依次排序
89
+ # drop indices
90
+ _, self._expert_index = sorted_experts.split(1, dim=1)
91
+ # get according batch index for each expert
92
+ self._batch_index = sorted_experts[index_sorted_experts[:, 1],0]
93
+ # calculate num samples that each expert gets
94
+ self._part_sizes = list((gates > 0).sum(0).cpu().numpy())
95
+ # expand gates to match with self._batch_index
96
+ gates_exp = gates[self._batch_index.flatten()]
97
+ self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)
98
+
99
+ def dispatch(self, inp):
100
+ """Create one input Tensor for each expert.
101
+ The `Tensor` for a expert `i` contains the slices of `inp` corresponding
102
+ to the batch elements `b` where `gates[b, i] > 0`.
103
+ Args:
104
+ inp: a `Tensor` of shape "[batch_size, <extra_input_dims>]`
105
+ Returns:
106
+ a list of `num_experts` `Tensor`s with shapes
107
+ `[expert_batch_size_i, <extra_input_dims>]`.
108
+ """
109
+
110
+ # assigns samples to experts whose gate is nonzero
111
+
112
+ # expand according to batch index so we can just split by _part_sizes
113
+ inp_exp = inp[self._batch_index].squeeze(1)
114
+ return torch.split(inp_exp, self._part_sizes, dim=0)
115
+
116
+
117
+ def combine(self, expert_out, multiply_by_gates=True):
118
+ """Sum together the expert output, weighted by the gates.
119
+ The slice corresponding to a particular batch element `b` is computed
120
+ as the sum over all experts `i` of the expert output, weighted by the
121
+ corresponding gate values. If `multiply_by_gates` is set to False, the
122
+ gate values are ignored.
123
+ Args:
124
+ expert_out: a list of `num_experts` `Tensor`s, each with shape
125
+ `[expert_batch_size_i, <extra_output_dims>]`.
126
+ multiply_by_gates: a boolean
127
+ Returns:
128
+ a `Tensor` with shape `[batch_size, <extra_output_dims>]`.
129
+ """
130
+ # apply exp to expert outputs, so we are not longer in log space
131
+
132
+ #stitched = torch.cat(expert_out, 0).exp()
133
+ stitched = torch.cat(expert_out, 0)
134
+
135
+ if multiply_by_gates:
136
+ if len(stitched.shape) == 3:
137
+ stitched = stitched.mul(self._nonzero_gates.unsqueeze(1))
138
+ else:
139
+ stitched = stitched.mul(self._nonzero_gates)
140
+
141
+ if len(stitched.shape) == 3:
142
+ zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), expert_out[-1].size(-1), requires_grad=True, device=stitched.device)
143
+ else:
144
+ zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), requires_grad=True, device=stitched.device)
145
+ # combine samples that have been processed by the same k experts
146
+ combined = zeros.index_add(0, self._batch_index, stitched.float())
147
+ # add eps to all zero values in order to avoid nans when going back to log space
148
+
149
+ #combined[combined == 0] = np.finfo(float).eps
150
+ # back to log space
151
+ #return combined.log()
152
+ return combined
153
+
154
+
155
+ def expert_to_gates(self):
156
+ """Gate values corresponding to the examples in the per-expert `Tensor`s.
157
+ Returns:
158
+ a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32`
159
+ and shapes `[expert_batch_size_i]`
160
+ """
161
+ # split nonzero gates for each expert
162
+ return torch.split(self._nonzero_gates, self._part_sizes, dim=0)
163
+
164
+
165
+
166
+
167
+ class MoE(nn.Module):
168
+
169
+ """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
170
+ Args:
171
+ input_size: integer - size of the input
172
+ output_size: integer - size of the input
173
+ num_experts: an integer - number of experts
174
+ hidden_size: an integer - hidden size of the experts
175
+ noisy_gating: a boolean
176
+ k: an integer - how many experts to use for each batch element
177
+ """
178
+
179
+ def __init__(self,
180
+ noisy_gating = True,
181
+ ds_factor = 8.0,
182
+ num_experts = 4,
183
+ moe_input_size = 768,
184
+ top_k = 2,
185
+ dropout = 0.1,
186
+ gating = 'linear',
187
+ routing = None,
188
+ layer_id = 0
189
+ ):
190
+ super(MoE, self).__init__()
191
+ self.noisy_gating = noisy_gating
192
+ self.num_experts = num_experts
193
+ self.input_size = moe_input_size
194
+ self.k = top_k
195
+ self.layer_id = layer_id
196
+
197
+ # instantiate experts
198
+ #self.experts = nn.ModuleList([MLP(self.input_size, self.output_size, self.hidden_size) for i in range(self.num_experts)])
199
+ self.gating = gating
200
+ self.experts = nn.ModuleList([Adapter(ds_factor, moe_input_size, dropout=dropout) for i in range(self.num_experts)])
201
+ self.routing = routing
202
+ self.infer_expert = None
203
+
204
+
205
+ if self.routing != 'random':
206
+ if gating == 'linear':
207
+ #self.w_gate = nn.Linear(self.input_size, self.num_experts, bias=False)
208
+ self.w_gate = nn.Parameter(torch.zeros(self.input_size, num_experts), requires_grad=True)
209
+ elif gating == 'cosine':
210
+ self.w_gate = CosineTopKGate(self.input_size, self.num_experts)
211
+ self.w_noise = nn.Parameter(torch.zeros(self.input_size, self.num_experts), requires_grad=True)
212
+
213
+ self.softplus = nn.Softplus()
214
+ self.softmax = nn.Softmax(-1)
215
+ self.register_buffer("mean", torch.tensor([0.0]))
216
+ self.register_buffer("std", torch.tensor([1.0]))
217
+
218
+ assert(self.k <= self.num_experts)
219
+
220
+ def cv_squared(self, x):
221
+ """The squared coefficient of variation of a sample.
222
+ Useful as a loss to encourage a positive distribution to be more uniform.
223
+ Epsilons added for numerical stability.
224
+ Returns 0 for an empty Tensor.
225
+ Args:
226
+ x: a `Tensor`.
227
+ Returns:
228
+ a `Scalar`.
229
+ """
230
+ eps = 1e-10
231
+ # if only num_experts = 1
232
+ if x.shape[0] == 1:
233
+ return torch.Tensor([0])
234
+ if len(x.shape) == 2:
235
+ x = x.sum(dim=0)
236
+ return x.float().var() / (x.float().mean()**2 + eps)
237
+
238
+
239
+ def _gates_to_load(self, gates):
240
+ """Compute the true load per expert, given the gates.
241
+ The load is the number of examples for which the corresponding gate is >0.
242
+ Args:
243
+ gates: a `Tensor` of shape [batch_size, n]
244
+ Returns:
245
+ a float32 `Tensor` of shape [n]
246
+ """
247
+ return (gates > 0).sum(0)
248
+
249
+
250
+ def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values):
251
+ """Helper function to NoisyTopKGating.
252
+ Computes the probability that value is in top k, given different random noise.
253
+ This gives us a way of backpropagating from a loss that balances the number
254
+ of times each expert is in the top k experts per example.
255
+ In the case of no noise, pass in None for noise_stddev, and the result will
256
+ not be differentiable.
257
+ Args:
258
+ clean_values: a `Tensor` of shape [batch, n].
259
+ noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
260
+ normally distributed noise with standard deviation noise_stddev.
261
+ noise_stddev: a `Tensor` of shape [batch, n], or None
262
+ noisy_top_values: a `Tensor` of shape [batch, m].
263
+ "values" Output of tf.top_k(noisy_top_values, m). m >= k+1
264
+ Returns:
265
+ a `Tensor` of shape [batch, n].
266
+ """
267
+
268
+ batch = clean_values.size(0)
269
+ m = noisy_top_values.size(1)
270
+ top_values_flat = noisy_top_values.flatten() # (bs x m)
271
+ threshold_positions_if_in = torch.arange(batch) * m + self.k # bs
272
+ threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in.to(top_values_flat.device)), 1)
273
+
274
+ if len(noisy_values.shape) == 3:
275
+ threshold_if_in = threshold_if_in.unsqueeze(1)
276
+
277
+ is_in = torch.gt(noisy_values, threshold_if_in)
278
+ threshold_positions_if_out = threshold_positions_if_in - 1
279
+ threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat,0 , threshold_positions_if_out.to(top_values_flat.device)), 1)
280
+ if len(noisy_values.shape) == 3:
281
+ threshold_if_out = threshold_if_out.unsqueeze(1)
282
+
283
+ # is each value currently in the top k.
284
+
285
+ normal = Normal(self.mean.to(noise_stddev.device), self.std.to(noise_stddev.device))
286
+ prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev)
287
+ prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev)
288
+ prob = torch.where(is_in, prob_if_in, prob_if_out)
289
+ return prob
290
+
291
+
292
+ def random_k_gating(self, features, train):
293
+ if train:
294
+ idx = torch.randint(0, self.num_experts, 1)
295
+ results = self.experts[idx](features)
296
+
297
+ else:
298
+ results = []
299
+ for i in range(self.num_experts):
300
+ tmp = self.num_experts[i](features)
301
+ results.append(tmp)
302
+
303
+ results = torch.stack(results, dim=0).mean(dim=0)
304
+
305
+ return results
306
+
307
+
308
+
309
+ def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2):
310
+ """Noisy top-k gating.
311
+ See paper: https://arxiv.org/abs/1701.06538.
312
+ Args:
313
+ x: input Tensor with shape [batch_size, input_size]
314
+ train: a boolean - we only add noise at training time.
315
+ noise_epsilon: a float
316
+ Returns:
317
+ gates: a Tensor with shape [batch_size, num_experts]
318
+ load: a Tensor with shape [num_experts]
319
+ """
320
+ #clean_logits = self.w_gate(x)
321
+ if self.gating == 'linear':
322
+ clean_logits = x @ self.w_gate
323
+ elif self.gating == 'cosine':
324
+ clean_logits = self.w_gate(x)
325
+
326
+ if self.noisy_gating and train:
327
+ raw_noise_stddev = x @ self.w_noise
328
+ noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon) * train)
329
+ noisy_logits = clean_logits + ( torch.randn_like(clean_logits) * noise_stddev)
330
+ logits = noisy_logits
331
+ else:
332
+ logits = clean_logits
333
+
334
+ # logits (bs, n): 表示选择n中每个expert的概率
335
+
336
+ # 选k个experts,返回相应的下标以及logit
337
+ top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim= -1)
338
+
339
+ top_k_logits = top_logits[:, :self.k] if len(top_logits.shape) == 2 else top_logits[:, :, :self.k]
340
+ top_k_indices = top_indices[:, :self.k] if len(top_indices.shape) == 2 else top_indices[:, :, :self.k]
341
+
342
+ top_k_gates = self.softmax(top_k_logits)
343
+
344
+ zeros = torch.zeros_like(logits, requires_grad=True)
345
+ # 将经过softmax后的weight分配给相应的expert,未选定的expert的weight则为0
346
+ gates = zeros.scatter(-1, top_k_indices, top_k_gates)
347
+
348
+ if self.noisy_gating and self.k < self.num_experts and train:
349
+ load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0)
350
+ else:
351
+ load = self._gates_to_load(gates)
352
+ return gates, load
353
+
354
+
355
+ def forward(self, x, frame_features, train=True, loss_coef=1e-2):
356
+ """Args:
357
+ x: tensor shape [batch_size, input_size]
358
+ train: a boolean scalar.
359
+ loss_coef: a scalar - multiplier on load-balancing losses
360
+ Returns:
361
+ y: a tensor with shape [batch_size, output_size].
362
+ extra_training_loss: a scalar. This should be added into the overall
363
+ training loss of the model. The backpropagation of this loss
364
+ encourages all experts to be approximately equally used across a batch.
365
+ """
366
+
367
+ if self.routing == 'random':
368
+ loss = None
369
+ load = None
370
+ if train:
371
+ gates = torch.zeros(x.shape[0], self.num_experts)
372
+ random_idx = torch.randint(0, self.num_experts, (x.shape[0],))
373
+ gates[torch.arange(x.shape[0]), random_idx] = 1
374
+ gates = gates.to(x.device)
375
+ dispatcher = SparseDispatcher(self.num_experts, gates)
376
+
377
+ expert_inputs = dispatcher.dispatch(frame_features) # 获取每个expert的输入
378
+ gates = dispatcher.expert_to_gates() #
379
+ expert_outputs = [self.experts[i](expert_inputs[i]) for i in range(self.num_experts)]
380
+ y = dispatcher.combine(expert_outputs)
381
+ else:
382
+ if self.infer_expert is None:
383
+ weights = [self.experts[i].state_dict() for i in range(self.num_experts)]
384
+ merge_weights = OrderedDict()
385
+ for idx, it in enumerate(weights):
386
+ for k,v in it.items():
387
+ merge_weights[k] = v / self.num_experts if idx==0 else merge_weights[k] + v / self.num_experts
388
+
389
+ self.infer_expert = deepcopy(self.experts[0])
390
+ self.infer_expert.load_state_dict(merge_weights)
391
+
392
+ y = self.infer_expert(frame_features)
393
+
394
+ return y, loss, load
395
+
396
+ else:
397
+ if len(x.shape) == 1:
398
+ x = x.unsqueeze(0)
399
+
400
+ gates, load = self.noisy_top_k_gating(x, train)
401
+
402
+ # calculate importance loss
403
+ importance = gates.sum(dim=0)
404
+
405
+ # calculate loss
406
+ loss = self.cv_squared(importance) + self.cv_squared(load)
407
+ loss *= loss_coef
408
+
409
+ dispatcher = SparseDispatcher(self.num_experts, gates)
410
+
411
+ expert_inputs = dispatcher.dispatch(frame_features) # 获取每个expert的输入
412
+ gates = dispatcher.expert_to_gates() # 获取
413
+ expert_outputs = [self.experts[i](expert_inputs[i]) for i in range(self.num_experts)]
414
+ y = dispatcher.combine(expert_outputs)
415
+ return y, loss, load
416
+
417
+ class CosineTopKGate(torch.nn.Module):
418
+ def __init__(self, model_dim, num_global_experts, proj_dim=256, init_t=0.5):
419
+ super(CosineTopKGate, self).__init__()
420
+ self.temperature = torch.nn.Parameter(torch.log(torch.full([1], 1.0 / init_t)), requires_grad=True)
421
+ self.cosine_projector = torch.nn.Linear(model_dim, proj_dim)
422
+ self.sim_matrix = torch.nn.Parameter(torch.randn(size=(proj_dim, num_global_experts)), requires_grad=True)
423
+ self.clamp_max = torch.log(torch.tensor(1. / 0.01)).item()
424
+ torch.nn.init.normal_(self.sim_matrix, 0, 0.01)
425
+
426
+ def forward(self, x):
427
+ cosine_projector = self.cosine_projector
428
+ sim_matrix = self.sim_matrix
429
+ logits = torch.matmul(F.normalize(cosine_projector(x), dim=1),
430
+ F.normalize(sim_matrix, dim=0))
431
+ logit_scale = torch.clamp(self.temperature, max=self.clamp_max).exp()
432
+ logits = logits * logit_scale
433
+ return logits
434
+
435
+ '''
436
+ model = MoE()
437
+
438
+ inputs = torch.randn((32, 1, 768))
439
+ frame_features = torch.randn((32,10, 768))
440
+
441
+ model(inputs, frame_features)
442
+ '''
tmp.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import glob
4
+
5
+ result = json.load(open("/home/qinyixin/workspace/TgMoE/Frozenbilm/results/T_MoENet_NEXT-QA.json"))
6
+ video_dir = "/mnt/hdd3/qinyixin/nextqa/video"
7
+
8
+ cols = pd.read_csv("/mnt/hdd3/qinyixin/FrozenBilm/NEXT-QA/val.csv").columns.to_list()
9
+ nextqa = pd.read_csv("/mnt/hdd3/qinyixin/FrozenBilm/NEXT-QA/val.csv").values
10
+ qid_to_vidid = {}
11
+ for it in nextqa:
12
+ choices = [it[9 + idx] for idx in range(5)]
13
+ answer = choices[it[6]]
14
+ question = it[5]
15
+ qid = it[7]
16
+ vidid = str(it[1])
17
+ vid_path = glob.glob(video_dir + "/*/"+ vidid + ".mp4")
18
+
19
+ qid_to_vidid[str(qid)] = {"vid_path": vid_path,
20
+ "choices": str(choices),
21
+ "question": question,
22
+ "answer": answer
23
+ }
24
+
25
+ correct = []
26
+ for k, v in result.items():
27
+ if v['acc']:
28
+ correct.append(qid_to_vidid[k])
29
+
30
+ json.dump(correct, open("demo/T-MoENet_result.json", "w"))
tmp2.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from Infer import Infer
3
+
4
+ device = "cuda"
5
+ handler = Infer(device)
6
+ candidates = ['adjust the tree', 'get away the dust', 'dancing', 'pressed a button to activate', 'presents']
7
+ with torch.no_grad():
8
+ handler.generate("why did the boy clap his hands when he ran to the christmas tree?",
9
+ "/home/qinyixin/workspace/TgMoE/Frozenbilm/demo/videos/4882821564.mp4",
10
+ candidates)
videos/3249402410.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4a6869517220132f2ac016009a8c309464cc76058b475c17977ef641818396c
3
+ size 2414513
videos/4882821564.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49dd433b47f9e3c88c272332d5ae739fec1d9b4d96f5b93b8e648d2b45428b41
3
+ size 9316079
videos/6233408665.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3910609b3807f547bad0cc2375b471b1a409ef8742c723068bce0e1b48606aff
3
+ size 7806177