File size: 4,077 Bytes
58f667f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import asyncio
import contextlib
import logging
import os
import time
from typing import List

import torch

logger = logging.getLogger(__name__)

DEBUG_COMPLETED_TIME = bool(os.environ.get('DEBUG_COMPLETED_TIME', False))


@contextlib.asynccontextmanager
async def completed(trace_name='',
                    name='',
                    sleep_interval=0.05,
                    streams: List[torch.cuda.Stream] = None):
    """Async context manager that waits for work to complete on given CUDA
    streams."""
    if not torch.cuda.is_available():
        yield
        return

    stream_before_context_switch = torch.cuda.current_stream()
    if not streams:
        streams = [stream_before_context_switch]
    else:
        streams = [s if s else stream_before_context_switch for s in streams]

    end_events = [
        torch.cuda.Event(enable_timing=DEBUG_COMPLETED_TIME) for _ in streams
    ]

    if DEBUG_COMPLETED_TIME:
        start = torch.cuda.Event(enable_timing=True)
        stream_before_context_switch.record_event(start)

        cpu_start = time.monotonic()
    logger.debug('%s %s starting, streams: %s', trace_name, name, streams)
    grad_enabled_before = torch.is_grad_enabled()
    try:
        yield
    finally:
        current_stream = torch.cuda.current_stream()
        assert current_stream == stream_before_context_switch

        if DEBUG_COMPLETED_TIME:
            cpu_end = time.monotonic()
        for i, stream in enumerate(streams):
            event = end_events[i]
            stream.record_event(event)

        grad_enabled_after = torch.is_grad_enabled()

        # observed change of torch.is_grad_enabled() during concurrent run of
        # async_test_bboxes code
        assert (grad_enabled_before == grad_enabled_after
                ), 'Unexpected is_grad_enabled() value change'

        are_done = [e.query() for e in end_events]
        logger.debug('%s %s completed: %s streams: %s', trace_name, name,
                     are_done, streams)
        with torch.cuda.stream(stream_before_context_switch):
            while not all(are_done):
                await asyncio.sleep(sleep_interval)
                are_done = [e.query() for e in end_events]
                logger.debug(
                    '%s %s completed: %s streams: %s',
                    trace_name,
                    name,
                    are_done,
                    streams,
                )

        current_stream = torch.cuda.current_stream()
        assert current_stream == stream_before_context_switch

        if DEBUG_COMPLETED_TIME:
            cpu_time = (cpu_end - cpu_start) * 1000
            stream_times_ms = ''
            for i, stream in enumerate(streams):
                elapsed_time = start.elapsed_time(end_events[i])
                stream_times_ms += f' {stream} {elapsed_time:.2f} ms'
            logger.info('%s %s %.2f ms %s', trace_name, name, cpu_time,
                        stream_times_ms)


@contextlib.asynccontextmanager
async def concurrent(streamqueue: asyncio.Queue,
                     trace_name='concurrent',
                     name='stream'):
    """Run code concurrently in different streams.

    :param streamqueue: asyncio.Queue instance.

    Queue tasks define the pool of streams used for concurrent execution.
    """
    if not torch.cuda.is_available():
        yield
        return

    initial_stream = torch.cuda.current_stream()

    with torch.cuda.stream(initial_stream):
        stream = await streamqueue.get()
        assert isinstance(stream, torch.cuda.Stream)

        try:
            with torch.cuda.stream(stream):
                logger.debug('%s %s is starting, stream: %s', trace_name, name,
                             stream)
                yield
                current = torch.cuda.current_stream()
                assert current == stream
                logger.debug('%s %s has finished, stream: %s', trace_name,
                             name, stream)
        finally:
            streamqueue.task_done()
            streamqueue.put_nowait(stream)