Spaces:
Runtime error
Runtime error
import contextlib | |
import sys | |
import time | |
import torch | |
if sys.version_info >= (3, 7): | |
def profile_time(trace_name, | |
name, | |
enabled=True, | |
stream=None, | |
end_stream=None): | |
"""Print time spent by CPU and GPU. | |
Useful as a temporary context manager to find sweet spots of code | |
suitable for async implementation. | |
""" | |
if (not enabled) or not torch.cuda.is_available(): | |
yield | |
return | |
stream = stream if stream else torch.cuda.current_stream() | |
end_stream = end_stream if end_stream else stream | |
start = torch.cuda.Event(enable_timing=True) | |
end = torch.cuda.Event(enable_timing=True) | |
stream.record_event(start) | |
try: | |
cpu_start = time.monotonic() | |
yield | |
finally: | |
cpu_end = time.monotonic() | |
end_stream.record_event(end) | |
end.synchronize() | |
cpu_time = (cpu_end - cpu_start) * 1000 | |
gpu_time = start.elapsed_time(end) | |
msg = f'{trace_name} {name} cpu_time {cpu_time:.2f} ms ' | |
msg += f'gpu_time {gpu_time:.2f} ms stream {stream}' | |
print(msg, end_stream) | |