-
Notifications
You must be signed in to change notification settings - Fork 1
/
blas.t
101 lines (83 loc) · 2.7 KB
/
blas.t
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
-- SPDX-FileCopyrightText: 2024 René Hiemstra <[email protected]>
-- SPDX-FileCopyrightText: 2024 Torsten Keßler <[email protected]>
--
-- SPDX-License-Identifier: MIT
local uname = io.popen("uname", "r"):read("*a")
local C
if uname == "Darwin\n" then
C = terralib.includecstring([[
#include <cblas.h>
]])
terralib.linklibrary("libcblas.dylib")
elseif uname == "Linux\n" then
C = terralib.includecstring[[
#include <openblas/cblas.h>
]]
terralib.linklibrary("libopenblas.so")
else
error("Not implemented for this OS.")
end
local complex = require("complex")
local complexFloat = complex.complex(float)
local complexDouble = complex.complex(double)
local wrapper = require("wrapper")
local S = {}
S.RowMajor = C.CblasRowMajor
S.ColMajor = C.CblasColMajor
S.NoTrans = C.CblasNoTrans
S.Trans = C.CblasTrans
S.ConjTrans = C.CblasConjTrans
S.Upper = C.CblasUpper
S.Lower = C.CblasLower
S.Left = C.CblasLeft
S.Right = C.CblasRight
S.Unit = C.CblasUnit
S.NonUnit = C.CblasNonUnit
-- All tables follow this ordering
local type = {
float, double, complexFloat, complexDouble
}
local function blas_name(pre, name)
return string.format("cblas_%s%s", pre, name)
end
local function default_blas(C, name, alt_name)
alt_name = alt_name or name
local prefix = terralib.newlist{"s", "d", "c", "z"}
return prefix:map(function(pre)
local c_name = blas_name(pre, name)
if not rawget(C, c_name) then
c_name = blas_name(pre, alt_name)
end
return C[c_name]
end)
end
local blas = {
-- BLAS level 1
{"swap", default_blas(C, "swap")},
{"scal", default_blas(C, "scal")},
{"copy", default_blas(C, "copy")},
{"axpy", default_blas(C, "axpy")},
{"dot", default_blas(C, "dot", "dotc_sub")},
{"nrm2", {C.cblas_snrm2, C.cblas_dnrm2, C.cblas_scnrm2, C.cblas_dznrm2}},
{"asum", {C.cblas_sasum, C.cblas_dasum, C.cblas_scasum, C.cblas_dzasum}},
{"iamax", {C.cblas_isamax, C.cblas_idamax, C.cblas_icamax, C.cblas_izamax}},
-- BLAS level 2
{"gemv", default_blas(C, "gemv")},
{"trsv", default_blas(C, "trsv")},
{"trmv", default_blas(C, "trmv")},
-- BLAS level 3
{"gemm", default_blas(C, "gemm")},
{"trsm", default_blas(C, "trsm")},
}
for _, func in pairs(blas) do
local name = func[1]
local c_func = func[2]
S[name] = terralib.overloadedfunction(name)
for i = 1, 4 do
-- Use float implementation as reference for function signature
S[name]:adddefinition(
wrapper.generate_wrapper(type[i], c_func[i], type[1], c_func[1])
)
end
end
return S