-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy path_device.pyx
153 lines (122 loc) · 4.63 KB
/
_device.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
from libc.stdint cimport uintptr_t, int64_t
from nanoarrow_device_c cimport (
ArrowDevice,
ArrowDeviceCpu,
ArrowDeviceResolve
)
from nanoarrow_macros cimport (
NANOARROW_OK,
ARROW_DEVICE_CPU,
ARROW_DEVICE_CUDA,
ARROW_DEVICE_CUDA_HOST,
ARROW_DEVICE_OPENCL,
ARROW_DEVICE_VULKAN,
ARROW_DEVICE_METAL,
ARROW_DEVICE_VPI,
ARROW_DEVICE_ROCM,
ARROW_DEVICE_ROCM_HOST,
ARROW_DEVICE_EXT_DEV,
ARROW_DEVICE_CUDA_MANAGED,
ARROW_DEVICE_ONEAPI,
ARROW_DEVICE_WEBGPU,
ARROW_DEVICE_HEXAGON,
)
from nanoarrow._utils cimport Error
from enum import Enum
from nanoarrow import _repr_utils
class DeviceType(Enum):
"""
An enumerator providing access to the device constant values
defined in the Arrow C Device interface. Unlike the other enum
accessors, this Python Enum is defined in Cython so that we can use
the bulit-in functionality to do better printing of device identifiers
for classes defined in Cython. Unlike the other enums, users don't
typically need to specify these (but would probably like them printed
nicely).
"""
CPU = ARROW_DEVICE_CPU
CUDA = ARROW_DEVICE_CUDA
CUDA_HOST = ARROW_DEVICE_CUDA_HOST
OPENCL = ARROW_DEVICE_OPENCL
VULKAN = ARROW_DEVICE_VULKAN
METAL = ARROW_DEVICE_METAL
VPI = ARROW_DEVICE_VPI
ROCM = ARROW_DEVICE_ROCM
ROCM_HOST = ARROW_DEVICE_ROCM_HOST
EXT_DEV = ARROW_DEVICE_EXT_DEV
CUDA_MANAGED = ARROW_DEVICE_CUDA_MANAGED
ONEAPI = ARROW_DEVICE_ONEAPI
WEBGPU = ARROW_DEVICE_WEBGPU
HEXAGON = ARROW_DEVICE_HEXAGON
cdef class Device:
"""ArrowDevice wrapper
The ArrowDevice structure is a nanoarrow internal struct (i.e.,
not ABI stable) that contains callbacks for device operations
beyond its type and identifier (e.g., copy buffers to or from
a device).
"""
def __cinit__(self, object base, uintptr_t addr):
self._base = base,
self._ptr = <ArrowDevice*>addr
def __eq__(self, other) -> bool:
return (
isinstance(other, Device) and
other.device_type == self.device_type and
other.device_id == self.device_id
)
def __repr__(self):
return _repr_utils.device_repr(self)
@property
def device_type(self):
return DeviceType(self._ptr.device_type)
@property
def device_type_id(self):
return self._ptr.device_type
@property
def device_id(self):
return self._ptr.device_id
@staticmethod
def resolve(device_type, int64_t device_id):
if int(device_type) == ARROW_DEVICE_CPU:
return DEVICE_CPU
cdef ArrowDevice* c_device = ArrowDeviceResolve(device_type, device_id)
if c_device == NULL:
raise ValueError(f"Device not found for type {device_type}/{device_id}")
return Device(None, <uintptr_t>c_device)
# Cache the CPU device
# The CPU device is statically allocated (so base is None)
DEVICE_CPU = Device(None, <uintptr_t>ArrowDeviceCpu())
cdef class CSharedSyncEvent:
def __cinit__(self, Device device, uintptr_t sync_event=0):
self.device = device
self.sync_event = <void*>sync_event
cdef synchronize(self):
if self.sync_event == NULL:
return
cdef Error error = Error()
cdef ArrowDevice* c_device = self.device._ptr
cdef int code = c_device.synchronize_event(c_device, self.sync_event, NULL, &error.c_error)
error.raise_message_not_ok("ArrowDevice::synchronize_event", code)
self.sync_event = NULL
cdef synchronize_stream(self, uintptr_t stream):
cdef Error error = Error()
cdef ArrowDevice* c_device = self.device._ptr
cdef int code = c_device.synchronize_event(c_device, self.sync_event, <void*>stream, &error.c_error)
error.raise_message_not_ok("ArrowDevice::synchronize_event with stream", code)