Spaces:
Sleeping
Sleeping
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch as th | |
def get_generator(generator, num_samples=0, seed=0): | |
if generator == "dummy": | |
return DummyGenerator() | |
elif generator == "determ": | |
return DeterministicGenerator(num_samples, seed) | |
elif generator == "determ-indiv": | |
return DeterministicIndividualGenerator(num_samples, seed) | |
else: | |
raise NotImplementedError | |
class DummyGenerator: | |
def randn(self, *args, **kwargs): | |
return th.randn(*args, **kwargs) | |
def randint(self, *args, **kwargs): | |
return th.randint(*args, **kwargs) | |
def randn_like(self, *args, **kwargs): | |
return th.randn_like(*args, **kwargs) | |
class DeterministicGenerator: | |
""" | |
RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines | |
Uses a single rng and samples num_samples sized randomness and subsamples the current indices | |
""" | |
def __init__(self, num_samples, seed=0): | |
print("Warning: Distributed not initialised, using single rank") | |
self.rank = 0 | |
self.world_size = 1 | |
self.num_samples = num_samples | |
self.done_samples = 0 | |
self.seed = seed | |
self.rng_cpu = th.Generator() | |
if th.cuda.is_available(): | |
self.rng_cuda = th.Generator(dist_util.dev()) | |
self.set_seed(seed) | |
def get_global_size_and_indices(self, size): | |
global_size = (self.num_samples, *size[1:]) | |
indices = th.arange( | |
self.done_samples + self.rank, | |
self.done_samples + self.world_size * int(size[0]), | |
self.world_size, | |
) | |
indices = th.clamp(indices, 0, self.num_samples - 1) | |
assert ( | |
len(indices) == size[0] | |
), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}" | |
return global_size, indices | |
def get_generator(self, device): | |
return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda | |
def randn(self, *size, dtype=th.float, device="cpu"): | |
global_size, indices = self.get_global_size_and_indices(size) | |
generator = self.get_generator(device) | |
return th.randn(*global_size, generator=generator, dtype=dtype, device=device)[ | |
indices | |
] | |
def randint(self, low, high, size, dtype=th.long, device="cpu"): | |
global_size, indices = self.get_global_size_and_indices(size) | |
generator = self.get_generator(device) | |
return th.randint( | |
low, high, generator=generator, size=global_size, dtype=dtype, device=device | |
)[indices] | |
def randn_like(self, tensor): | |
size, dtype, device = tensor.size(), tensor.dtype, tensor.device | |
return self.randn(*size, dtype=dtype, device=device) | |
def set_done_samples(self, done_samples): | |
self.done_samples = done_samples | |
self.set_seed(self.seed) | |
def get_seed(self): | |
return self.seed | |
def set_seed(self, seed): | |
self.rng_cpu.manual_seed(seed) | |
if th.cuda.is_available(): | |
self.rng_cuda.manual_seed(seed) | |
class DeterministicIndividualGenerator: | |
""" | |
RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines | |
Uses a separate rng for each sample to reduce memoery usage | |
""" | |
def __init__(self, num_samples, seed=0): | |
print("Warning: Distributed not initialised, using single rank") | |
self.rank = 0 | |
self.world_size = 1 | |
self.num_samples = num_samples | |
self.done_samples = 0 | |
self.seed = seed | |
self.rng_cpu = [th.Generator() for _ in range(num_samples)] | |
if th.cuda.is_available(): | |
self.rng_cuda = [th.Generator(dist_util.dev()) for _ in range(num_samples)] | |
self.set_seed(seed) | |
def get_size_and_indices(self, size): | |
indices = th.arange( | |
self.done_samples + self.rank, | |
self.done_samples + self.world_size * int(size[0]), | |
self.world_size, | |
) | |
indices = th.clamp(indices, 0, self.num_samples - 1) | |
assert ( | |
len(indices) == size[0] | |
), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}" | |
return (1, *size[1:]), indices | |
def get_generator(self, device): | |
return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda | |
def randn(self, *size, dtype=th.float, device="cpu"): | |
size, indices = self.get_size_and_indices(size) | |
generator = self.get_generator(device) | |
return th.cat( | |
[ | |
th.randn(*size, generator=generator[i], dtype=dtype, device=device) | |
for i in indices | |
], | |
dim=0, | |
) | |
def randint(self, low, high, size, dtype=th.long, device="cpu"): | |
size, indices = self.get_size_and_indices(size) | |
generator = self.get_generator(device) | |
return th.cat( | |
[ | |
th.randint( | |
low, | |
high, | |
generator=generator[i], | |
size=size, | |
dtype=dtype, | |
device=device, | |
) | |
for i in indices | |
], | |
dim=0, | |
) | |
def randn_like(self, tensor): | |
size, dtype, device = tensor.size(), tensor.dtype, tensor.device | |
return self.randn(*size, dtype=dtype, device=device) | |
def set_done_samples(self, done_samples): | |
self.done_samples = done_samples | |
def get_seed(self): | |
return self.seed | |
def set_seed(self, seed): | |
[ | |
rng_cpu.manual_seed(i + self.num_samples * seed) | |
for i, rng_cpu in enumerate(self.rng_cpu) | |
] | |
if th.cuda.is_available(): | |
[ | |
rng_cuda.manual_seed(i + self.num_samples * seed) | |
for i, rng_cuda in enumerate(self.rng_cuda) | |
] | |