File size: 9,284 Bytes
39f3704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import os
import sys
import contextlib
import torch
import intel_extension_for_pytorch as ipex  # pylint: disable=import-error, unused-import
from .hijacks import ipex_hijacks
from .attention import attention_init

# pylint: disable=protected-access, missing-function-docstring, line-too-long


def ipex_init():  # pylint: disable=too-many-statements
    try:
        # Replace cuda with xpu:
        torch.cuda.current_device = torch.xpu.current_device
        torch.cuda.current_stream = torch.xpu.current_stream
        torch.cuda.device = torch.xpu.device
        torch.cuda.device_count = torch.xpu.device_count
        torch.cuda.device_of = torch.xpu.device_of
        torch.cuda.get_device_name = torch.xpu.get_device_name
        torch.cuda.get_device_properties = torch.xpu.get_device_properties
        torch.cuda.init = torch.xpu.init
        torch.cuda.is_available = torch.xpu.is_available
        torch.cuda.is_initialized = torch.xpu.is_initialized
        torch.cuda.is_current_stream_capturing = lambda: False
        torch.cuda.set_device = torch.xpu.set_device
        torch.cuda.stream = torch.xpu.stream
        torch.cuda.synchronize = torch.xpu.synchronize
        torch.cuda.Event = torch.xpu.Event
        torch.cuda.Stream = torch.xpu.Stream
        torch.cuda.FloatTensor = torch.xpu.FloatTensor
        torch.Tensor.cuda = torch.Tensor.xpu
        torch.Tensor.is_cuda = torch.Tensor.is_xpu
        torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
        torch.cuda._initialized = torch.xpu.lazy_init._initialized
        torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
        torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
        torch.cuda._tls = torch.xpu.lazy_init._tls
        torch.cuda.threading = torch.xpu.lazy_init.threading
        torch.cuda.traceback = torch.xpu.lazy_init.traceback
        torch.cuda.Optional = torch.xpu.Optional
        torch.cuda.__cached__ = torch.xpu.__cached__
        torch.cuda.__loader__ = torch.xpu.__loader__
        torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
        torch.cuda.Tuple = torch.xpu.Tuple
        torch.cuda.streams = torch.xpu.streams
        torch.cuda._lazy_new = torch.xpu._lazy_new
        torch.cuda.FloatStorage = torch.xpu.FloatStorage
        torch.cuda.Any = torch.xpu.Any
        torch.cuda.__doc__ = torch.xpu.__doc__
        torch.cuda.default_generators = torch.xpu.default_generators
        torch.cuda.HalfTensor = torch.xpu.HalfTensor
        torch.cuda._get_device_index = torch.xpu._get_device_index
        torch.cuda.__path__ = torch.xpu.__path__
        torch.cuda.Device = torch.xpu.Device
        torch.cuda.IntTensor = torch.xpu.IntTensor
        torch.cuda.ByteStorage = torch.xpu.ByteStorage
        torch.cuda.set_stream = torch.xpu.set_stream
        torch.cuda.BoolStorage = torch.xpu.BoolStorage
        torch.cuda.os = torch.xpu.os
        torch.cuda.torch = torch.xpu.torch
        torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
        torch.cuda.Union = torch.xpu.Union
        torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
        torch.cuda.ShortTensor = torch.xpu.ShortTensor
        torch.cuda.LongTensor = torch.xpu.LongTensor
        torch.cuda.IntStorage = torch.xpu.IntStorage
        torch.cuda.LongStorage = torch.xpu.LongStorage
        torch.cuda.__annotations__ = torch.xpu.__annotations__
        torch.cuda.__package__ = torch.xpu.__package__
        torch.cuda.__builtins__ = torch.xpu.__builtins__
        torch.cuda.CharTensor = torch.xpu.CharTensor
        torch.cuda.List = torch.xpu.List
        torch.cuda._lazy_init = torch.xpu._lazy_init
        torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
        torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
        torch.cuda.ByteTensor = torch.xpu.ByteTensor
        torch.cuda.StreamContext = torch.xpu.StreamContext
        torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
        torch.cuda.ShortStorage = torch.xpu.ShortStorage
        torch.cuda._lazy_call = torch.xpu._lazy_call
        torch.cuda.HalfStorage = torch.xpu.HalfStorage
        torch.cuda.random = torch.xpu.random
        torch.cuda._device = torch.xpu._device
        torch.cuda.classproperty = torch.xpu.classproperty
        torch.cuda.__name__ = torch.xpu.__name__
        torch.cuda._device_t = torch.xpu._device_t
        torch.cuda.warnings = torch.xpu.warnings
        torch.cuda.__spec__ = torch.xpu.__spec__
        torch.cuda.BoolTensor = torch.xpu.BoolTensor
        torch.cuda.CharStorage = torch.xpu.CharStorage
        torch.cuda.__file__ = torch.xpu.__file__
        torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
        # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing

        # Memory:
        torch.cuda.memory = torch.xpu.memory
        if "linux" in sys.platform and "WSL2" in os.popen("uname -a").read():
            torch.xpu.empty_cache = lambda: None
        torch.cuda.empty_cache = torch.xpu.empty_cache
        torch.cuda.memory_stats = torch.xpu.memory_stats
        torch.cuda.memory_summary = torch.xpu.memory_summary
        torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
        torch.cuda.memory_allocated = torch.xpu.memory_allocated
        torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
        torch.cuda.memory_reserved = torch.xpu.memory_reserved
        torch.cuda.memory_cached = torch.xpu.memory_reserved
        torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
        torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
        torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
        torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
        torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
        torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
        torch.cuda.reset_accumulated_memory_stats = (
            torch.xpu.reset_accumulated_memory_stats
        )

        # RNG:
        torch.cuda.get_rng_state = torch.xpu.get_rng_state
        torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
        torch.cuda.set_rng_state = torch.xpu.set_rng_state
        torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
        torch.cuda.manual_seed = torch.xpu.manual_seed
        torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
        torch.cuda.seed = torch.xpu.seed
        torch.cuda.seed_all = torch.xpu.seed_all
        torch.cuda.initial_seed = torch.xpu.initial_seed

        # AMP:
        torch.cuda.amp = torch.xpu.amp
        if not hasattr(torch.cuda.amp, "common"):
            torch.cuda.amp.common = contextlib.nullcontext()
        torch.cuda.amp.common.amp_definitely_not_available = lambda: False
        try:
            torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
        except Exception:  # pylint: disable=broad-exception-caught
            try:
                from .gradscaler import (
                    gradscaler_init,
                )  # pylint: disable=import-outside-toplevel, import-error

                gradscaler_init()
                torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
            except Exception:  # pylint: disable=broad-exception-caught
                torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler

        # C
        torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
        ipex._C._DeviceProperties.major = 2023
        ipex._C._DeviceProperties.minor = 2

        # Fix functions with ipex:
        torch.cuda.mem_get_info = lambda device=None: [
            (
                torch.xpu.get_device_properties(device).total_memory
                - torch.xpu.memory_allocated(device)
            ),
            torch.xpu.get_device_properties(device).total_memory,
        ]
        torch._utils._get_available_device_type = lambda: "xpu"
        torch.has_cuda = True
        torch.cuda.has_half = True
        torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
        torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
        torch.version.cuda = "11.7"
        torch.cuda.get_device_capability = lambda *args, **kwargs: [11, 7]
        torch.cuda.get_device_properties.major = 11
        torch.cuda.get_device_properties.minor = 7
        torch.cuda.ipc_collect = lambda *args, **kwargs: None
        torch.cuda.utilization = lambda *args, **kwargs: 0
        if hasattr(torch.xpu, "getDeviceIdListForCard"):
            torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
            torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard
        else:
            torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card
            torch.cuda.get_device_id_list_per_card = (
                torch.xpu.get_device_id_list_per_card
            )

        ipex_hijacks()
        attention_init()
        try:
            from .diffusers import ipex_diffusers

            ipex_diffusers()
        except Exception:  # pylint: disable=broad-exception-caught
            pass
    except Exception as e:
        return False, e
    return True, None