Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from torch.nn.parallel._functions import _get_stream | |
def scatter(input, devices, streams=None): | |
"""Scatters tensor across multiple GPUs.""" | |
if streams is None: | |
streams = [None] * len(devices) | |
if isinstance(input, list): | |
chunk_size = (len(input) - 1) // len(devices) + 1 | |
outputs = [ | |
scatter(input[i], [devices[i // chunk_size]], | |
[streams[i // chunk_size]]) for i in range(len(input)) | |
] | |
return outputs | |
elif isinstance(input, torch.Tensor): | |
output = input.contiguous() | |
# TODO: copy to a pinned buffer first (if copying from CPU) | |
stream = streams[0] if output.numel() > 0 else None | |
if devices != [-1]: | |
with torch.cuda.device(devices[0]), torch.cuda.stream(stream): | |
output = output.cuda(devices[0], non_blocking=True) | |
else: | |
# unsqueeze the first dimension thus the tensor's shape is the | |
# same as those scattered with GPU. | |
output = output.unsqueeze(0) | |
return output | |
else: | |
raise Exception(f'Unknown type {type(input)}.') | |
def synchronize_stream(output, devices, streams): | |
if isinstance(output, list): | |
chunk_size = len(output) // len(devices) | |
for i in range(len(devices)): | |
for j in range(chunk_size): | |
synchronize_stream(output[i * chunk_size + j], [devices[i]], | |
[streams[i]]) | |
elif isinstance(output, torch.Tensor): | |
if output.numel() != 0: | |
with torch.cuda.device(devices[0]): | |
main_stream = torch.cuda.current_stream() | |
main_stream.wait_stream(streams[0]) | |
output.record_stream(main_stream) | |
else: | |
raise Exception(f'Unknown type {type(output)}.') | |
def get_input_device(input): | |
if isinstance(input, list): | |
for item in input: | |
input_device = get_input_device(item) | |
if input_device != -1: | |
return input_device | |
return -1 | |
elif isinstance(input, torch.Tensor): | |
return input.get_device() if input.is_cuda else -1 | |
else: | |
raise Exception(f'Unknown type {type(input)}.') | |
class Scatter: | |
def forward(target_gpus, input): | |
input_device = get_input_device(input) | |
streams = None | |
if input_device == -1 and target_gpus != [-1]: | |
# Perform CPU to GPU copies in a background stream | |
streams = [_get_stream(device) for device in target_gpus] | |
outputs = scatter(input, target_gpus, streams) | |
# Synchronize with the copy stream | |
if streams is not None: | |
synchronize_stream(outputs, target_gpus, streams) | |
return tuple(outputs) | |