Skip to content

Commit

Permalink
Merge branch 'master' into l1_axpy
Browse files Browse the repository at this point in the history
  • Loading branch information
kballeda committed Nov 8, 2022
2 parents bdd43a3 + d714a56 commit 0c0882a
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 2 deletions.
46 changes: 46 additions & 0 deletions deps/src/onemkl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,52 @@ extern "C" void onemklCcopy(syclQueue_t device_queue, int64_t n, const float _Co
reinterpret_cast<std::complex<float> *>(y), incy);
}

extern "C" void onemklDamax(syclQueue_t device_queue, int64_t n, const double *x,
int64_t incx, int64_t *result){
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n, x, incx, result);
status.wait();
}
extern "C" void onemklSamax(syclQueue_t device_queue, int64_t n, const float *x,
int64_t incx, int64_t *result){
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n, x, incx, result);
status.wait();
}
extern "C" void onemklZamax(syclQueue_t device_queue, int64_t n, const double _Complex *x,
int64_t incx, int64_t *result){
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n,
reinterpret_cast<const std::complex<double> *>(x), incx, result);
status.wait();
}
extern "C" void onemklCamax(syclQueue_t device_queue, int64_t n, const float _Complex *x,
int64_t incx, int64_t *result){
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n,
reinterpret_cast<const std::complex<float> *>(x), incx, result);
status.wait();
}

extern "C" void onemklDamin(syclQueue_t device_queue, int64_t n, const double *x,
int64_t incx, int64_t *result){
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n, x, incx, result);
status.wait();
}
extern "C" void onemklSamin(syclQueue_t device_queue, int64_t n, const float *x,
int64_t incx, int64_t *result){
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n, x, incx, result);
status.wait();
}
extern "C" void onemklZamin(syclQueue_t device_queue, int64_t n, const double _Complex *x,
int64_t incx, int64_t *result){
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n,
reinterpret_cast<const std::complex<double> *>(x), incx, result);
status.wait();
}
extern "C" void onemklCamin(syclQueue_t device_queue, int64_t n, const float _Complex *x,
int64_t incx, int64_t *result){
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n,
reinterpret_cast<const std::complex<float> *>(x), incx, result);
status.wait();
}

// other

// oneMKL keeps a cache of SYCL queues and tries to destroy them when unloading the library.
Expand Down
18 changes: 18 additions & 0 deletions deps/src/onemkl.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,24 @@ void onemklZcopy(syclQueue_t device_queue, int64_t n, const double _Complex *x,
void onemklCcopy(syclQueue_t device_queue, int64_t n, const float _Complex *x,
int64_t incx, float _Complex *y, int64_t incy);

void onemklDamax(syclQueue_t device_queue, int64_t n, const double *x, int64_t incx,
int64_t *result);
void onemklSamax(syclQueue_t device_queue, int64_t n, const float *x, int64_t incx,
int64_t *result);
void onemklZamax(syclQueue_t device_queue, int64_t n, const double _Complex *x, int64_t incx,
int64_t *result);
void onemklCamax(syclQueue_t device_queue, int64_t n, const float _Complex *x, int64_t incx,
int64_t *result);

void onemklDamin(syclQueue_t device_queue, int64_t n, const double *x, int64_t incx,
int64_t *result);
void onemklSamin(syclQueue_t device_queue, int64_t n, const float *x, int64_t incx,
int64_t *result);
void onemklZamin(syclQueue_t device_queue, int64_t n, const double _Complex *x, int64_t incx,
int64_t *result);
void onemklCamin(syclQueue_t device_queue, int64_t n, const float _Complex *x, int64_t incx,
int64_t *result);

void onemklDestroy();
#ifdef __cplusplus
}
Expand Down
40 changes: 40 additions & 0 deletions lib/mkl/libonemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,43 @@ function onemklCcopy(device_queue, n, x, incx, y, incy)
x::ZePtr{ComplexF32}, incx::Int64,
y::ZePtr{ComplexF32}, incy::Int64)::Cvoid
end

function onemklSamax(device_queue, n, x, incx, result)
@ccall liboneapi_support.onemklSamax(device_queue::syclQueue_t, n::Int64,
x::ZePtr{Cfloat}, incx::Int64, result::ZePtr{Int64})::Cvoid
end

function onemklDamax(device_queue, n, x, incx, result)
@ccall liboneapi_support.onemklDamax(device_queue::syclQueue_t, n::Int64,
x::ZePtr{Cdouble}, incx::Int64, result::ZePtr{Int64})::Cvoid
end

function onemklCamax(device_queue, n, x, incx, result)
@ccall liboneapi_support.onemklCamax(device_queue::syclQueue_t, n::Int64,
x::ZePtr{ComplexF32}, incx::Int64,result::ZePtr{Int64})::Cvoid
end

function onemklZamax(device_queue, n, x, incx, result)
@ccall liboneapi_support.onemklZamax(device_queue::syclQueue_t, n::Int64,
x::ZePtr{ComplexF64}, incx::Int64, result::ZePtr{Int64})::Cvoid
end

function onemklSamin(device_queue, n, x, incx, result)
@ccall liboneapi_support.onemklSamin(device_queue::syclQueue_t, n::Int64,
x::ZePtr{Cfloat}, incx::Int64, result::ZePtr{Int64})::Cvoid
end

function onemklDamin(device_queue, n, x, incx, result)
@ccall liboneapi_support.onemklDamin(device_queue::syclQueue_t, n::Int64,
x::ZePtr{Cdouble}, incx::Int64, result::ZePtr{Int64})::Cvoid
end

function onemklCamin(device_queue, n, x, incx, result)
@ccall liboneapi_support.onemklCamin(device_queue::syclQueue_t, n::Int64,
x::ZePtr{ComplexF32}, incx::Int64,result::ZePtr{Int64})::Cvoid
end

function onemklZamin(device_queue, n, x, incx, result)
@ccall liboneapi_support.onemklZamin(device_queue::syclQueue_t, n::Int64,
x::ZePtr{ComplexF64}, incx::Int64, result::ZePtr{Int64})::Cvoid
end
33 changes: 33 additions & 0 deletions lib/mkl/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,39 @@ for (fname, elty) in
end
end

## iamax
for (fname, elty) in
((:onemklDamax,:Float64),
(:onemklSamax,:Float32),
(:onemklZamax,:ComplexF64),
(:onemklCamax,:ComplexF32))
@eval begin
function iamax(x::oneStridedArray{$elty})
n = length(x)
queue = global_queue(context(x), device(x))
result = oneArray{Int64}([0]);
$fname(sycl_queue(queue), n, x, stride(x, 1), result)
return Array(result)[1]+1
end
end
end

## iamin
for (fname, elty) in
((:onemklDamin,:Float64),
(:onemklSamin,:Float32),
(:onemklZamin,:ComplexF64),
(:onemklCamin,:ComplexF32))
@eval begin
function iamin(x::StridedArray{$elty})
n = length(x)
result = oneArray{Int64}([0]);
queue = global_queue(context(x), device(x))
$fname(sycl_queue(queue),n, x, stride(x, 1), result)
return Array(result)[1]+1
end
end
end

# level 3

Expand Down
13 changes: 11 additions & 2 deletions test/onemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ using oneAPI.oneMKL
using LinearAlgebra

m = 20
n = 35
k = 13

############################################################################################
@testset "level 1" begin
Expand All @@ -22,5 +20,16 @@ k = 13
alpha = rand(T,1)
@test testf(axpy!, alpha[1], rand(T,m), rand(T,m))
end

A = oneArray(rand(T, m))
B = oneArray{T}(undef, m)
oneMKL.copy!(m,A,B)
@test Array(A) == Array(B)

# testing oneMKL max and min
a = convert.(T, [1.0, 2.0, -0.8, 5.0, 3.0])
ca = oneArray(a)
@test BLAS.iamax(a) == oneMKL.iamax(ca)
@test oneMKL.iamin(ca) == 3
end # level 1 testset
end

0 comments on commit 0c0882a

Please sign in to comment.