-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paththready.lua
393 lines (327 loc) · 11.1 KB
/
thready.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
--- "Thread" (coroutine) handling system that allows different systems to run their main loops in separate threads.
local expect = require "cc.expect".expect
local logging = require "logging"
local thread_context = logging.create_context("Thready")
---@class thread_data
---@field thread thread The coroutine thread.
---@field id integer The ID of the thread.
---@field set_name string The name of the set that owns the thread.
---@field event_filter string|nil The event filter for the thread.
---@field status thread_status The status of the thread.
---@field alive boolean Whether the thread is alive or not.
---@field init_args table The arguments passed to the thread function.
---@class listener_data
---@field event string The event to listen for.
---@field callback fun(event:string, ...:any) The callback to run when the event is received.
---@field id integer The ID of the listener.
---@field set_name string The name of the set that owns the listener.
---@alias thread_status "running"|"suspended"|"dead"|"new"
---@class thready
---@field coroutines table<string, thread_data[]> A table of coroutines for each set.
---@field listeners table<string, listener_data[]> A table of listeners for each set/event.
---@field stop_on_error boolean Whether to stop the entire system on error.
---@field kill_set_on_error boolean Whether to kill all threads for a set on error.
---@field running boolean Whether the system is currently running.
local thready = {
coroutines = {},
listeners = {},
stop_on_error = false,
kill_set_on_error = true,
running = false
}
local used_ids = {}
--- Generate a unique ID for a thread.
local function gen_unique_id()
local id = math.random(1, 2^31 - 1)
while used_ids[id] do
id = math.random(1, 2^31 - 1)
end
used_ids[id] = true
return id
end
--- Check if a coroutine has errored and handle it.
--- @param coro_data thread_data The coroutine data to check.
---@return boolean kill_all Whether to kill all threads for the set.
local function check_errored(coro_data)
if not coro_data.alive then
if thready.stop_on_error then
thready.running = false
error(("%s thread %d errored: %s"):format(coro_data.set_name, coro_data.id, coro_data.event_filter), 0)
end
if thready.kill_set_on_error then
thread_context.error(("%s thread %d errored: %s"):format(coro_data.set_name, coro_data.id, coro_data.event_filter))
return true
else
thread_context.warn(("Ignoring error in %s thread %d: %s"):format(coro_data.set_name, coro_data.id, coro_data.event_filter))
end
end
return false
end
local function update_status(coro_data)
coro_data.status = coroutine.status(coro_data.thread)
end
local function remove_thread(set_name, id)
local coros = thready.coroutines[set_name]
if not coros then return end
for i = 1, #coros do
if coros[i].id == id then
table.remove(coros, i)
used_ids[id] = nil
break
end
end
end
--- Run a single step of all coroutines in the system given an event.
---@param event_name string The name of the event to run.
---@param ... any The arguments of the event.
local function run(event_name, ...)
local to_remove = {}
local to_kill = {}
for set_name, coros in pairs(thready.coroutines) do
for i = 1, #coros do
local coro = coros[i]
if coro.status == "suspended" then
-- Resume the coroutine, but ONLY if:
-- 1. The event filter is nil (no filter)
-- 2. The event filter is not nil, but the event matches the filter.
-- 3. The event is a `terminate` event.
if not coro.event_filter or coro.event_filter == event_name or event_name == "terminate" then
coro.alive, coro.event_filter = coroutine.resume(coro.thread, event_name, ...)
end
-- Check if errored
if check_errored(coro) then
to_kill[set_name] = true
break -- stop executing this set's coroutines
end
update_status(coro)
-- If the coroutine is dead, mark it for removal.
if coro.status == "dead" then
to_remove[coro.id] = set_name
end
elseif coro.status == "new" then
-- Initialize by running once.
coro.alive, coro.event_filter = coroutine.resume(coro.thread, table.unpack(coro.init_args, 1, coro.init_args.n))
-- Check if errored
if check_errored(coro) then
to_kill[set_name] = true
break -- stop executing this set's coroutines
end
update_status(coro)
-- If the coroutine is dead, mark it for removal.
if coro.status == "dead" then
to_remove[coro.id] = set_name
end
elseif coro.status == "dead" then
-- Remove dead coroutines.
to_remove[coro.id] = set_name
-- Check if errored
if check_errored(coro) then
to_kill[set_name] = true
break -- stop executing this set's coroutines
end
end
end
end
-- Remove dead coroutines.
for id, set_name in pairs(to_remove) do
remove_thread(set_name, id)
end
-- Remove all threads for sets that errored.
for set_name in pairs(to_kill) do
thready.kill_all(set_name)
end
end
--- Run the main loop of the thread system. Recommend using this with `parallel`.
---@see thready.parallelAny
---@see thready.parallelAll
function thready.main_loop()
thready.running = true
thread_context.debug("Thready started.")
while thready.running do
local event_data = table.pack(os.pullEvent())
-- spawn listeners
for _, set in pairs(thready.listeners) do
for _, listener in ipairs(set) do
if listener.event == event_data[1] then
thready.spawn(listener.set_name, listener.callback, table.unpack(event_data, 1, event_data.n))
end
end
end
run(table.unpack(event_data, 1, event_data.n))
end
thread_context.debug("Thready stopped.")
end
--- Start the thread system in parallel with other functions. This is a shorthand to `parallel.waitForAny(thready.main_loop, ...)`.
--- ## Usage
--- ```lua
--- thready.parallelAny(
--- your_main_loop,
--- other_main_loop,
--- ...
--- )
--- ```
---@param ... function The main loop(s) of the program.
function thready.parallelAny(...)
if thready.running then
error("Thread system is already running.", 2)
end
local args = {...}
for i, fun in ipairs(args) do
expect(i, fun, "function")
end
parallel.waitForAny(thready.main_loop, ...)
end
--- Start the thread system in parallel with other functions. This is a shorthand to `parallel.waitForAll(thready.main_loop, ...)`.
--- ## Usage
--- ```lua
--- thready.parallelAll(
--- your_main_loop,
--- other_main_loop,
--- ...
--- )
--- ```
---@param ... function The main loop(s) of the program.
function thready.parallelAll(...)
if thready.running then
error("Thread system is already running.", 2)
end
local args = {...}
for i, fun in ipairs(args) do
expect(i, fun, "function")
end
parallel.waitForAll(thready.main_loop, ...)
end
--- Spawn a new thread for a given set.
---@param set_name string The name of the set to spawn the thread in.
---@param thread_fun fun() The function to run in the thread.
---@param ... any The arguments to pass to the thread function.
---@return integer id The ID of the spawned thread.
function thready.spawn(set_name, thread_fun, ...)
expect(1, set_name, "string")
expect(2, thread_fun, "function")
--
local id = gen_unique_id()
local thread = coroutine.create(thread_fun)
---@type thread_data
local coro_data = {
thread = thread,
id = id,
set_name = set_name,
event_filter = nil,
status = "new",
alive = true,
init_args = table.pack(...)
}
if not thready.coroutines[set_name] then
thready.coroutines[set_name] = {}
end
thread_context.debug(("Spawning thread id %d in set %s."):format(id, set_name))
table.insert(thready.coroutines[set_name], coro_data)
os.queueEvent("thready_spawn") -- resume the main loop
return id
end
--- Add a listener for a given event.
---@param event string The event to listen for.
---@param callback fun(event:string, ...:any) The callback to run when the event is received.
---@return integer id The ID of the listener.
function thready.listen(set_name, event, callback)
expect(1, event, "string")
expect(2, callback, "function")
--
local id = gen_unique_id()
used_ids[id] = true
if not thready.listeners[set_name] then
thready.listeners[set_name] = {}
end
table.insert(thready.listeners[set_name], {
event = event,
set_name = set_name,
callback = callback,
id = id
})
thread_context.debug(("Listening for event %s in set %s with listener id %d."):format(event, set_name, id))
return id
end
--- Remove a listener by its ID. This will not stop any currently running listeners.
---@param id integer The ID of the listener to remove.
function thready.remove_listener(id)
expect(1, id, "number")
--
for _, set in pairs(thready.listeners) do
for i = 1, #set do
if set[i].id == id then
table.remove(set, i)
used_ids[id] = nil
thread_context.debug(("Removed listener id %d."):format(id))
return
end
end
end
thread_context.warn(("Attempted to remove listener id %d, but it does not exist."):format(id))
end
--- Get information about a thread.
---@param id integer The ID of the thread to get information about.
---@return thread_data|nil data The data of the thread, or nil if the thread does not exist.
function thready.get_thread(id)
expect(1, id, "number")
--
for _, coros in pairs(thready.coroutines) do
for _, coro in ipairs(coros) do
if coro.id == id then
return coro
end
end
end
return nil
end
--- Check if a thread is alive.
---@param id integer The ID of the thread to check.
---@return boolean alive Whether the thread is alive.
function thready.is_alive(id)
expect(1, id, "number")
--
local coro = thready.get_thread(id)
return coro and coro.status ~= "dead" and coro.alive or false
end
--- Kill a thread.
---@param id integer The ID of the thread to kill.
function thready.kill(id)
expect(1, id, "number")
--
for set_name, coros in pairs(thready.coroutines) do
for i = 1, #coros do
if coros[i].id == id then
remove_thread(set_name, id)
thread_context.debug(("Killed thread id %d in set %s."):format(id, set_name))
return
end
end
end
thread_context.warn(("Attempted to kill thread id %d, but it does not exist."):format(id))
end
--- Kill all threads for a given set.
---@param set_name string The name of the set to kill all threads for.
function thready.kill_all(set_name)
expect(1, set_name, "string")
--
thread_context.debug(("Killing all threads and stopping listeners in set %s."):format(set_name))
local coros = thready.coroutines[set_name]
-- If no coroutines exist, return.
if not coros then
return
end
-- Clear the used IDs
for i = 1, #coros do
used_ids[coros[i].id] = nil
end
-- Remove the coroutines
thready.coroutines[set_name] = nil
-- Remove the listeners
thready.listeners[set_name] = nil
end
--- Clear the entire thread system.
function thready.clear()
thready.coroutines = {}
used_ids = {}
end
return thready