File size: 4,904 Bytes
6eca12e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import numpy as np
import os, sys
import time

class LyraChecker:
    def __init__(self, dir_data, tol):
        self.dir_data = dir_data
        self.tol = tol
    
    def cmp(self, fpath1, fpath2="", tol=0):
        tolbk = self.tol
        if tol != 0:
            self.tol = tol
        if fpath2 == "":
            fpath2 = fpath1
            fpath1 += "_1"
            fpath2 += "_2"
        v1 = self.get_npy(fpath1) #np.load(os.path.join(self.dir_data, fpath1))
        v2 = self.get_npy(fpath2) #np.load(os.path.join(self.dir_data, fpath2))
        name = fpath1
        if ".npy" in fpath1:
            name = ".".join(os.path.basename(fpath1).split(".")[:-1])
        self._cmp_inner(v1, v2, name)
        self.tol = tolbk

    def _cmp_inner(self, v1, v2, name):
        print(v1.shape, v2.shape)
        if v1.shape != v2.shape:
            if v1.shape[1] == v2.shape[1]:
                v2 = v2.reshape([v2.shape[0], v2.shape[1], -1])
            else:
                v2 = torch.tensor(v2).permute(0, 3, 1, 2).numpy()
            print(v1.shape, v2.shape)
        self._check_data(name, v1, v2)
        print(np.size(v1))

    def _check_data(self, stage, x_out, x_gt):
        print(f"========== {stage} =============")
        print(x_out.shape, x_gt.shape)
        if np.allclose(x_gt, x_out, atol=self.tol):
            print(f"[OK] At {stage}, tol: {self.tol}")
        else:
            diff_cnt = np.count_nonzero(np.abs(x_gt - x_out)>self.tol)
            print(f"[FAIL]At {stage}, not aligned. tol: {self.tol}")
            print("    [INFO]Max diff: ", np.max(np.abs(x_gt - x_out)))
            print("    [INFO]Diff count: ", diff_cnt, ", ratio: ", round(diff_cnt/np.size(x_out), 2))
        print(f">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")


    def cmp_query(self, fpath1, fpath2):
        v1 = np.load(os.path.join(self.dir_data, fpath1))
        vk = np.load(os.path.join(self.dir_data, fpath1).replace("query", "key"))
        vv = np.load(os.path.join(self.dir_data, fpath1).replace("query", "value"))

        v2 = np.load(os.path.join(self.dir_data, fpath2))
        # print(v1.shape, v2.shape)
        q2 = v2[:,:,0,:,:].transpose([0,2,1,3])
        # print(v1.shape, q2.shape)
        self.check_data("query", v1, q2)
        # print(vk.shape, v2.shape)
        k2 = v2[:,:,1,:,:].transpose([0,2,1,3])
        self.check_data("key", vk, k2)
        vv2 = v2[:,:,2,:,:].transpose([0,2,1,3])
        # print(vv.shape, vv2.shape)
        self.check_data("value", vv, vv2)

    def _get_data_fpath(self, fname):
        fpath = os.path.join(self.dir_data, fname)
        if not fpath.endswith(".npy"):
            fpath += ".npy"
        return fpath

    def get_npy(self, fname):
        fpath = self._get_data_fpath(fname)
        return np.load(fpath)

        


class MkDataHelper:
    def __init__(self, data_dir="/data/home/kiokaxiao/data"):
        self.data_dir = data_dir

    def mkdata(self, subdir, name, shape, dtype=torch.float16):
        outdir = os.path.join(self.data_dir, subdir)
        os.makedirs(outdir, exist_ok=True)
        fpath = os.path.join(outdir, name+".npy")
        data = torch.randn(shape, dtype=torch.float16)
        np.save(fpath, data.to(dtype).numpy())
        return data

    def gen_out_with_func(self, func, inputs):
        output = func(inputs)
        return output

    def savedata(self, subdir, name, data):
        outdir = os.path.join(self.data_dir, subdir)
        os.makedirs(outdir, exist_ok=True)
        fpath = os.path.join(outdir, name+".npy")
        np.save(fpath, data.cpu().numpy())


class TorchSaver:
    def __init__(self, data_dir):
        self.data_dir = data_dir
        os.makedirs(self.data_dir, exist_ok=True)
        self.is_save = True

    def save_v(self, name, v):
        if not self.is_save:
            return
        fpath = os.path.join(self.data_dir, name+"_1.npy")
        np.save(fpath, v.detach().cpu().numpy())

    def save_v2(self, name, v):
        if not self.is_save:
            return
        fpath = os.path.join(self.data_dir, name+"_1.npy")
        np.save(fpath, v.detach().cpu().numpy())

def timer_annoc(funct):
    def inner(*args,**kwargs):
        start = time.perf_counter()
        res = funct(*args,**kwargs)
        torch.cuda.synchronize()
        end = time.perf_counter()
        print("torch cost: ", end-start)
        return res
    return inner

def get_mem_use():
    f = os.popen("nvidia-smi | grep MiB" )
    line = f.read().strip()
    while "  " in line:
        line = line.replace("  ", " ")
    memuse = line.split(" ")[8]
    return memuse

if __name__ == "__main__":
    dir_data = sys.argv[1]
    fname_v1 = sys.argv[2]
    fname_v2 = sys.argv[3]
    tol = 0.01
    if len(sys.argv) > 4:
        tol = float(sys.argv[4])
    checker = LyraChecker(dir_data, tol)
    checker.cmp(fname_v1, fname_v2)