-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.t
296 lines (259 loc) · 6.23 KB
/
util.t
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
local C = terralib.includecstring [[
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifndef _WIN32
#include <sys/time.h>
double __currentTimeInSeconds() {
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec + tv.tv_usec / 1000000.0;
}
#else
#include <time.h>
double __currentTimeInSeconds() {
return time(NULL);
}
#endif
]]
local U = {}
-- Cross platform
terra U.currentTimeInSeconds()
return C.__currentTimeInSeconds()
end
function U.copytable(tab)
local ret = {}
for k,v in pairs(tab) do
ret[k] = v
end
return ret
end
function U.concattables(...)
local tab1 = (select(1,...))
local t = U.copytable(tab1)
for i=2,select("#",...) do
local tab2 = (select(i,...))
for _,e in ipairs(tab2) do
table.insert(t, e)
end
end
return t
end
function U.joinTables(...)
local tab1 = (select(1,...))
local t = U.copytable(tab1)
for i=2,select("#",...) do
local tab2 = (select(i,...))
for k,v in pairs(tab2) do
t[k] = v
end
end
return t
end
function U.index(tbl, indices)
local ret = {}
for i,index in ipairs(indices) do
table.insert(ret, tbl[index])
end
return ret
end
function U.inline(terrafn)
local defs = terrafn:getdefinitions()
for i,d in ipairs(defs) do
d:setinlined(true)
end
return terrafn
end
function U.istype(typ) return macro(function(x) return x:gettype() == typ end) end
function U.assertIsType(type, msg)
return macro(function(x)
U.luaAssertWithTrace(x:gettype() == type, msg)
return quote end
end)
end
U.getTypeAsString = macro(function(x) return tostring(x:gettype()) end)
U.printType = macro(function(x) print(x:gettype()); return quote end end)
function U.wait(procstr)
return io.popen(procstr):read("*a")
end
U.systemf = macro(function(formatstr, ...)
local args = {...}
return quote
var buf : int8[1024] -- Should be long enough, yeah?
C.sprintf(buf, formatstr, [args])
C.system(buf)
end
end)
function string:split(sep)
local sep, fields = sep or " ", {}
local pattern = string.format("([^%s]+)", sep)
self:gsub(pattern, function(c) fields[#fields+1] = c end)
return fields
end
function U.foreach(iterator, codeblock)
return quote
while not [iterator]:done() do
[codeblock]
[iterator]:next()
end
end
end
U.swap = macro(function(a, b)
return quote
var tmp = a
a = b
b = tmp
end
end)
function U.openModule(ns)
for n,v in pairs(ns) do
rawset(_G, n, v)
end
end
function U.stringify(...)
local str = ""
for i=1,select("#", ...) do
local t = (select(i, ...))
local typ = type(t)
if typ ~= "table" and typ ~= "function" then
str = string.format("%s%s,", str, tostring(t))
else
-- Use the raw tostring metamethod to get the
-- memory address of this table/function
local tostr = nil
if typ == "table" then
tostr = t.__tostring
if getmetatable(t) then getmetatable(t).__tostring = nil end
end
local mystr = tostring(t):gsub(string.format("%s: ", typ), "")
if typ == "table" then
if getmetatable(t) then getmetatable(t).__tostring = tostr end
end
str = string.format("%s%s,", str, mystr)
end
end
return str
end
function U.osName()
return U.wait("uname")
end
function U.isPosix()
local uname = U.wait("uname")
return (uname == "Darwin" or uname == "Linux")
end
-- Use 'code' to generate code if 'flag' is true.
-- (May be a quote or a lua function that returns a quote)
-- Otherwise, return an empty quote
function U.optionally(flag, code, ...)
if flag then
if type(code) == "function" then
return code(...)
else
return code
end
else return quote end end
end
-- Cross platform
U.fatalError = macro(function(...)
local args = {...}
return quote
C.printf("[Fatal Error] ")
C.printf([args])
-- Traceback only supported on POSIX systems
[U.isPosix() and quote terralib.traceback(nil) end or quote end]
C.exit(1)
end
end)
U.assert = macro(function(condition, ...)
local args = {...}
return quote
if not condition then
C.printf("[Assertion Failed] ")
[U.optionally(#args > 0, function() return quote
C.printf([args])
end end)]
-- Traceback only supported on POSIX systems
[U.isPosix() and quote terralib.traceback(nil) end or quote end]
C.exit(1)
end
end
end)
function U.luaAssertWithTrace(condition, msg)
if not condition then
print(debug.traceback())
assert(condition, msg)
end
end
function U.findDefWithParamTypes(terrafn, paramTypes)
for _,d in ipairs(terrafn:getdefinitions()) do
local ptypes = d:gettype().parameters
if #ptypes == #paramTypes then
local typesMatch = true
for i=1,#ptypes do
if ptypes[i] ~= paramTypes[i] then
typesMatch = false
break
end
end
if typesMatch then
return d
end
end
end
-- Couldn't find a matching definition
return nil
end
-- Wrap a function with another that accepts a table of
-- named arguments. For arguments not present in the table,
-- fetch the default value from argdefs.
-- argdefs is specified as a list of {name, default} tuples.
function U.fnWithDefaultArgs(fn, argdefs)
return function(args)
args = args or {}
local arglist = {}
for _,argdef in ipairs(argdefs) do
local argname = argdef[1]
local argdefault = argdef[2]
local argval = args[argname]
if argval == nil then argval = argdefault end
table.insert(arglist, argval)
end
return fn(unpack(arglist))
end
end
function U.includec_path(filename)
local cpath = os.getenv("C_INCLUDE_PATH") or "."
return terralib.includec(filename, "-I", cpath)
end
function U.includecstring_path(str)
local cpath = os.getenv("C_INCLUDE_PATH") or "."
return terralib.includecstring(str, "-I", cpath)
end
-- Import all entries of table into the environment of the
-- calling function
function U.importAll(table)
local env = getfenv(2)
local newenv = {}
setmetatable(newenv, {__index = env})
for k,v in pairs(table) do
rawset(newenv, k, v)
end
setfenv(2, newenv)
end
-- Import some entries of a table into the environment of
-- the calling function
function U.importEntries(table, ...)
local names = {...}
local env = getfenv(2)
local newenv = {}
setmetatable(newenv, {__index = env})
for _,n in ipairs(names) do
if table[n] then
rawset(newenv, n, table[n])
else
error(string.format("import - table does not have an entry named '%s'", n))
end
end
setfenv(2, newenv)
end
return U