Skip to content

Commit a174b82

Browse files
fix: independent event loops & streams (#51)
Co-authored-by: Théo Monnom <[email protected]>
1 parent caba2ea commit a174b82

File tree

15 files changed

+328
-366
lines changed

15 files changed

+328
-366
lines changed

examples/basic_room.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,12 @@ def on_track_subscribed(track: livekit.Track,
5959
if track.kind == livekit.TrackKind.KIND_VIDEO:
6060
nonlocal video_stream
6161
video_stream = livekit.VideoStream(track)
62-
63-
@video_stream.on("frame_received")
64-
def on_video_frame(frame: livekit.VideoFrame):
65-
# received a video frame from the track
66-
pass
62+
# video_stream is an async iterator that yields VideoFrame
6763
elif track.kind == livekit.TrackKind.KIND_AUDIO:
6864
print("Subscribed to an Audio Track")
6965
nonlocal audio_stream
7066
audio_stream = livekit.AudioStream(track)
71-
72-
@audio_stream.on('frame_received')
73-
def on_audio_frame(frame: livekit.AudioFrame):
74-
# received an audio frame from the track
75-
pass
67+
# audio_stream is an async iterator that yields AudioFrame
7668

7769
@room.listens_to("track_unsubscribed")
7870
def on_track_unsubscribed(track: livekit.Track,

examples/e2ee.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ async def publish_frames(source: livekit.VideoSource):
3535

3636
source.capture_frame(frame)
3737

38-
hue += framerate/3 # 3s for a full cycle
38+
hue += framerate / 3 # 3s for a full cycle
3939
if hue >= 1.0:
4040
hue = 0.0
4141

examples/face_landmark/face_landmark.py

Lines changed: 39 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
URL = 'ws://localhost:7880'
1414
TOKEN = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE5MDY2MTMyODgsImlzcyI6IkFQSVRzRWZpZFpqclFvWSIsIm5hbWUiOiJuYXRpdmUiLCJuYmYiOjE2NzI2MTMyODgsInN1YiI6Im5hdGl2ZSIsInZpZGVvIjp7InJvb20iOiJ0ZXN0Iiwicm9vbUFkbWluIjp0cnVlLCJyb29tQ3JlYXRlIjp0cnVlLCJyb29tSm9pbiI6dHJ1ZSwicm9vbUxpc3QiOnRydWV9fQ.uSNIangMRu8jZD5mnRYoCHjcsQWCrJXgHCs0aNIgBFY'
1515

16-
frame_queue = Queue()
17-
argb_frame = None
16+
tasks = set()
1817

1918
# You can download a face landmark model file from https://developers.google.com/mediapipe/solutions/vision/face_landmarker#models
2019
model_file = 'face_landmarker.task'
@@ -36,8 +35,7 @@ def draw_landmarks_on_image(rgb_image, detection_result):
3635
face_landmarks_list = detection_result.face_landmarks
3736

3837
# Loop through the detected faces to visualize.
39-
for idx in range(len(face_landmarks_list)):
40-
face_landmarks = face_landmarks_list[idx]
38+
for face_landmarks in face_landmarks_list:
4139

4240
# Draw the face landmarks.
4341
face_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
@@ -68,72 +66,59 @@ def draw_landmarks_on_image(rgb_image, detection_result):
6866
.get_default_face_mesh_iris_connections_style())
6967

7068

71-
async def room() -> None:
72-
room = livekit.Room()
73-
await room.connect(URL, TOKEN)
74-
print("connected to room: " + room.name)
75-
76-
video_stream = None
77-
78-
@room.on("track_subscribed")
79-
def on_track_subscribed(track: livekit.Track,
80-
publication: livekit.RemoteTrackPublication,
81-
participant: livekit.RemoteParticipant):
82-
if track.kind == livekit.TrackKind.KIND_VIDEO:
83-
nonlocal video_stream
84-
video_stream = livekit.VideoStream(track)
85-
86-
@video_stream.on("frame_received")
87-
def on_video_frame(frame: livekit.VideoFrame):
88-
frame_queue.put(frame)
89-
90-
await room.run()
91-
92-
93-
def display_frames() -> None:
69+
async def frame_loop(video_stream: livekit.VideoStream) -> None:
70+
landmarker = FaceLandmarker.create_from_options(options)
71+
argb_frame = None
9472
cv2.namedWindow('livekit_video', cv2.WINDOW_AUTOSIZE)
9573
cv2.startWindowThread()
74+
async for frame in video_stream:
75+
buffer = frame.buffer
9676

97-
global argb_frame
98-
99-
with FaceLandmarker.create_from_options(options) as landmarker:
100-
while True:
101-
frame = frame_queue.get()
102-
buffer = frame.buffer
77+
if argb_frame is None or argb_frame.width != buffer.width or argb_frame.height != buffer.height:
78+
argb_frame = livekit.ArgbFrame(
79+
livekit.VideoFormatType.FORMAT_ABGR, buffer.width, buffer.height)
10380

104-
if argb_frame is None or argb_frame.width != buffer.width or argb_frame.height != buffer.height:
105-
argb_frame = livekit.ArgbFrame(
106-
livekit.VideoFormatType.FORMAT_ABGR, buffer.width, buffer.height)
81+
buffer.to_argb(argb_frame)
10782

108-
buffer.to_argb(argb_frame)
83+
arr = np.ctypeslib.as_array(argb_frame.data)
84+
arr = arr.reshape((argb_frame.height, argb_frame.width, 4))
85+
arr = cv2.cvtColor(arr, cv2.COLOR_RGBA2RGB)
10986

110-
arr = np.ctypeslib.as_array(argb_frame.data)
111-
arr = arr.reshape((argb_frame.height, argb_frame.width, 4))
112-
arr = cv2.cvtColor(arr, cv2.COLOR_RGBA2RGB)
87+
mp_image = mp.Image(
88+
image_format=mp.ImageFormat.SRGB, data=arr)
11389

114-
mp_image = mp.Image(
115-
image_format=mp.ImageFormat.SRGB, data=arr)
90+
detection_result = landmarker.detect_for_video(
91+
mp_image, frame.timestamp_us)
11692

117-
detection_result = landmarker.detect_for_video(
118-
mp_image, frame.timestamp)
93+
draw_landmarks_on_image(arr, detection_result)
11994

120-
draw_landmarks_on_image(arr, detection_result)
95+
arr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
12196

122-
arr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
123-
124-
cv2.imshow('livekit_video', arr)
125-
if cv2.waitKey(1) & 0xFF == ord('q'):
126-
break
97+
cv2.imshow('livekit_video', arr)
98+
if cv2.waitKey(1) & 0xFF == ord('q'):
99+
break
127100

101+
landmarker.close()
128102
cv2.destroyAllWindows()
129103

130104

131105
async def main() -> None:
132-
loop = asyncio.get_event_loop()
133-
future = loop.run_in_executor(None, asyncio.run, room())
106+
room = livekit.Room()
107+
await room.connect(URL, TOKEN)
108+
print("connected to room: " + room.name)
109+
110+
video_stream = None
134111

135-
display_frames()
136-
await future
112+
@room.on("track_subscribed")
113+
def on_track_subscribed(track: livekit.Track, *_):
114+
if track.kind == livekit.TrackKind.KIND_VIDEO:
115+
nonlocal video_stream
116+
video_stream = livekit.VideoStream(track)
117+
task = asyncio.create_task(frame_loop(video_stream))
118+
tasks.add(task)
119+
task.add_done_callback(tasks.remove)
120+
121+
await room.run()
137122

138123
if __name__ == "__main__":
139124
asyncio.run(main())

examples/publish_hue.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ async def publish_frames(source: livekit.VideoSource):
2525
0, livekit.VideoRotation.VIDEO_ROTATION_0, argb_frame.to_i420())
2626

2727
rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
28-
rgb = [(x * 255) for x in rgb] # type: ignore
28+
rgb = [(x * 255) for x in rgb] # type: ignore
2929

3030
argb_color = np.array(rgb + [255], dtype=np.uint8)
3131
arr.flat[::4] = argb_color[0]
@@ -35,7 +35,7 @@ async def publish_frames(source: livekit.VideoSource):
3535

3636
source.capture_frame(frame)
3737

38-
hue += framerate/3 # 3s for a full cycle
38+
hue += framerate / 3 # 3s for a full cycle
3939
if hue >= 1.0:
4040
hue = 0.0
4141

examples/whisper/whisper.py

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -98,56 +98,56 @@ class WhisperFullParams(ctypes.Structure):
9898
whisper.whisper_full_get_segment_text.restype = ctypes.c_char_p
9999
ctx = whisper.whisper_init_from_file(fname_model.encode('utf-8'))
100100

101-
data_30_secs = np.zeros(SAMPLES_30_SECS, dtype=np.float32)
102-
written_samples = 0 # nb. of samples written to data_30_secs for the cur. inference
103101

104-
105-
def on_audio_frame(frame: livekit.AudioFrame):
106-
global data_30_secs, written_samples
107-
108-
# whisper requires 16kHz mono, so resample the data
109-
# also convert the samples from int16 to float32
110-
frame = frame.remix_and_resample(
111-
WHISPER_SAMPLE_RATE, 1)
112-
113-
data = np.array(frame.data, dtype=np.float32) / 32768.0
114-
115-
# write the data inside data_30_secs at written_samples
116-
data_start = SAMPLES_KEEP + written_samples
117-
data_30_secs[data_start:data_start+len(data)] = data
118-
written_samples += len(data)
119-
120-
if written_samples >= SAMPLES_STEP:
121-
params = whisper.whisper_full_default_params(
122-
WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY)
123-
params.print_realtime = False
124-
params.print_progress = False
125-
126-
ctx_ptr = ctypes.c_void_p(ctx)
127-
data_ptr = data_30_secs.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
128-
res = whisper.whisper_full(ctx_ptr,
129-
params,
130-
data_ptr,
131-
written_samples + SAMPLES_KEEP)
132-
133-
if res != 0:
134-
logging.error("error while running inference: %s", res)
135-
return
136-
137-
n_segments = whisper.whisper_full_n_segments(ctx_ptr)
138-
for i in range(n_segments):
139-
t0 = whisper.whisper_full_get_segment_t0(ctx_ptr, i)
140-
t1 = whisper.whisper_full_get_segment_t1(ctx_ptr, i)
141-
txt = whisper.whisper_full_get_segment_text(ctx_ptr, i)
142-
143-
logging.info(
144-
f"{t0/1000.0:.3f} - {t1/1000.0:.3f} : {txt.decode('utf-8')}")
145-
146-
# write old data to the beginning of the buffer (SAMPLES_KEEP)
147-
data_30_secs[:SAMPLES_KEEP] = data_30_secs[data_start +
148-
written_samples - SAMPLES_KEEP:
149-
data_start + written_samples]
150-
written_samples = 0
102+
async def whisper_task(stream: livekit.AudioStream):
103+
data_30_secs = np.zeros(SAMPLES_30_SECS, dtype=np.float32)
104+
written_samples = 0 # nb. of samples written to data_30_secs for the cur. inference
105+
106+
async for frame in stream:
107+
# whisper requires 16kHz mono, so resample the data
108+
# also convert the samples from int16 to float32
109+
frame = frame.remix_and_resample(
110+
WHISPER_SAMPLE_RATE, 1)
111+
112+
data = np.array(frame.data, dtype=np.float32) / 32768.0
113+
114+
# write the data inside data_30_secs at written_samples
115+
data_start = SAMPLES_KEEP + written_samples
116+
data_30_secs[data_start:data_start+len(data)] = data
117+
written_samples += len(data)
118+
119+
if written_samples >= SAMPLES_STEP:
120+
params = whisper.whisper_full_default_params(
121+
WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY)
122+
params.print_realtime = False
123+
params.print_progress = False
124+
125+
ctx_ptr = ctypes.c_void_p(ctx)
126+
data_ptr = data_30_secs.ctypes.data_as(
127+
ctypes.POINTER(ctypes.c_float))
128+
res = whisper.whisper_full(ctx_ptr,
129+
params,
130+
data_ptr,
131+
written_samples + SAMPLES_KEEP)
132+
133+
if res != 0:
134+
logging.error("error while running inference: %s", res)
135+
return
136+
137+
n_segments = whisper.whisper_full_n_segments(ctx_ptr)
138+
for i in range(n_segments):
139+
t0 = whisper.whisper_full_get_segment_t0(ctx_ptr, i)
140+
t1 = whisper.whisper_full_get_segment_t1(ctx_ptr, i)
141+
txt = whisper.whisper_full_get_segment_text(ctx_ptr, i)
142+
143+
logging.info(
144+
f"{t0/1000.0:.3f} - {t1/1000.0:.3f} : {txt.decode('utf-8')}")
145+
146+
# write old data to the beginning of the buffer (SAMPLES_KEEP)
147+
data_30_secs[:SAMPLES_KEEP] = data_30_secs[data_start +
148+
written_samples - SAMPLES_KEEP:
149+
data_start + written_samples]
150+
written_samples = 0
151151

152152

153153
async def main():
@@ -172,7 +172,7 @@ def on_track_subscribed(track: livekit.Track,
172172
logging.info("starting listening to: %s", participant.identity)
173173
nonlocal audio_stream
174174
audio_stream = livekit.AudioStream(track)
175-
audio_stream.add_listener('frame_received', on_audio_frame)
175+
asyncio.create_task(whisper_task(audio_stream))
176176

177177
try:
178178
logging.info("connecting to %s", URL)

0 commit comments

Comments
 (0)