|
import torch |
|
from tqdm import tqdm |
|
import math |
|
|
|
|
|
def outlier_channel_split(W, H, mu, to_size): |
|
old_dim = W.shape[-1] |
|
remaining = to_size - old_dim |
|
|
|
W = torch.cat([W, torch.zeros(W.shape[0], remaining).to(W.device)], dim=-1) |
|
new_H = torch.zeros(to_size, to_size).to(H.device) |
|
new_H[0:H.shape[0], 0:H.shape[1]] = H |
|
H = new_H |
|
mu = torch.cat([mu, torch.zeros(remaining).to(mu.device)], dim=0) |
|
|
|
print('old drange', torch.max(W.flatten()) - torch.min(W.flatten())) |
|
extra_inds = [] |
|
dupe_inds = list(range(old_dim)) |
|
for i in tqdm(range(old_dim, to_size), desc='outlier channel splitting'): |
|
col = torch.argmax(W.abs()).item() % W.shape[-1] |
|
row = math.ceil(torch.argmax(W.abs()).item() // W.shape[-1]) |
|
assert torch.allclose(W[row, col].abs(), torch.max(W.abs().flatten())) |
|
extra_inds.append(col) |
|
dupe_inds.append(dupe_inds[col]) |
|
W[:, col] /= 2 |
|
W[:, i] = W[:, col] |
|
H[i, 0:i] = H[col, 0:i] |
|
H[0:i, i] = H[0:i, col] |
|
H[i, i] = H[col, col] |
|
mu[i] = mu[col] |
|
i += 1 |
|
|
|
print('new drange', torch.max(W.flatten()) - torch.min(W.flatten())) |
|
assert torch.allclose(H.cpu(), H.cpu().T) |
|
return W, H, mu, extra_inds, dupe_inds |
|
|
|
|
|
def fuse_W(W, extra_inds): |
|
for i in range(len(extra_inds)): |
|
W[:, extra_inds[-(i + 1)]] += W[:, -(i + 1)] |
|
return W[:, :W.shape[-1] - len(extra_inds)] |
|
|