WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit 566bd19

Browse files
committed
oneAPI-aware MPI
1 parent f6012ff commit 566bd19

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,19 @@ Requires = "~0.5, 1.0"
3434
Serialization = "1"
3535
Sockets = "1"
3636
julia = "1.6"
37+
oneAPI = "2.1"
3738

3839
[extensions]
3940
AMDGPUExt = "AMDGPU"
4041
CUDAExt = "CUDA"
42+
OneAPIExt = "oneAPI"
4143

4244
[extras]
4345
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
4446
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
47+
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
4548

4649
[weakdeps]
4750
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
4851
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
52+
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

ext/OneAPIExt.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
module OneAPIExt
2+
3+
import MPI
4+
isdefined(Base, :get_extension) ? (import oneAPI) : (import ..oneAPI)
5+
import MPI: MPIPtr, Buffer, Datatype
6+
7+
function Base.cconvert(::Type{MPIPtr}, A::oneAPI.oneArray{T}) where T
8+
A
9+
end
10+
11+
function Base.unsafe_convert(::Type{MPIPtr}, X::oneAPI.oneArray{T}) where T
12+
reinterpret(MPIPtr, Base.unsafe_convert(oneAPI.ZePtr{T}, X))
13+
end
14+
15+
# only need to define this for strided arrays: all others can be handled by generic machinery
16+
function Base.unsafe_convert(::Type{MPIPtr}, V::SubArray{T,N,P,I,true}) where {T,N,P<:oneAPI.oneArray,I}
17+
X = parent(V)
18+
pX = Base.unsafe_convert(oneAPI.ZePtr{T}, X)
19+
pV = pX + ((V.offset1 + V.stride1) - first(LinearIndices(X)))*sizeof(T)
20+
return reinterpret(MPIPtr, pV)
21+
end
22+
23+
function Buffer(arr::oneAPI.oneArray)
24+
Buffer(arr, Cint(length(arr)), Datatype(eltype(arr)))
25+
end
26+
27+
end # OneAPIExt

0 commit comments

Comments
 (0)