File size: 2,059 Bytes
a476bbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch


class Masker:
    def __init__(self, tokenizer) -> None:
        self.tokenizer = tokenizer
        self.mask_token_id = self.tokenizer.mask_token_id

    def random_mask(self, input_ids, mask_prob=0.15):
        device = input_ids.device
        mask = (torch.rand(input_ids.shape) < mask_prob).to(device)
        mask = mask & (torch.logical_not(self.tokenizer.is_special_token(input_ids)))
        masked_input_ids = input_ids.clone()
        masked_input_ids[mask] = self.mask_token_id
        return masked_input_ids, mask

    def mask_ptm_tokens(
        self,
        input_ids,
    ):
        device = input_ids.device
        is_ptm_mask = self.tokenizer.is_ptm_token(input_ids).to(device)
        is_ptm_mask = is_ptm_mask & (
            torch.logical_not(self.tokenizer.is_special_token(input_ids))
        )
        masked_input_ids = input_ids.clone()
        masked_input_ids[is_ptm_mask] = self.mask_token_id
        return masked_input_ids, is_ptm_mask

    def random_and_ptm_mask(self, input_ids, mask_prob=0.15):
        device = input_ids.device
        mask = (torch.rand(input_ids.shape) < mask_prob).to(device)
        mask = mask & (torch.logical_not(self.tokenizer.is_special_token(input_ids)))
        is_ptm_mask = self.tokenizer.is_ptm_token(input_ids).to(device)
        is_ptm_mask = is_ptm_mask & (
            torch.logical_not(self.tokenizer.is_special_token(input_ids))
        )
        mask = mask | is_ptm_mask
        masked_input_ids = input_ids.clone()
        masked_input_ids[mask] = self.mask_token_id
        return masked_input_ids, mask

    def random_or_random_and_ptm_mask(
        self, input_ids, ranom_mask_prob=0.15, alternate_prob=0.2
    ):
        """
        alternate between [(1) random mask] and [(2) random mask & ptm mask] by probability alternate_prob
        """
        p = torch.rand(1).item()
        if p < alternate_prob:
            return self.random_mask(input_ids, ranom_mask_prob)
        else:
            return self.random_and_ptm_mask(input_ids, ranom_mask_prob)