Skip to content

Commit

Permalink
feat: video
Browse files Browse the repository at this point in the history
* replace multiprocessing with pyav

* add config.RECORD_WINDOW_DATA

* video_write_q

* fix get_timestamp; extract_frames_to_pil_images with pyav

* add video.py; ActionEvent.original_timestamp

* use global SCT in get_monitor_dims

* fix tests

* fix window._windows.get_active_window_state (missing type)

* add tests/openadapt/test_video.py

* flake8

* black

* poetry lock
  • Loading branch information
abrichr authored Feb 29, 2024
1 parent 7759996 commit 20e08b8
Show file tree
Hide file tree
Showing 17 changed files with 2,256 additions and 1,642 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ src

# MacOS file
.DS_Store

*.pyc
*.pt
5 changes: 4 additions & 1 deletion openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"OPENAI_API_KEY": "<set your api key in .env>",
# "OPENAI_MODEL_NAME": "gpt-4",
"OPENAI_MODEL_NAME": "gpt-3.5-turbo",
"RECORD_WINDOW_DATA": False,
# may incur significant performance penalty
"RECORD_READ_ACTIVE_ELEMENT_STATE": False,
# TODO: remove?
Expand Down Expand Up @@ -99,7 +100,6 @@
"key_vk",
"children",
],
"PLOT_PERFORMANCE": True,
# VISUALIZATION CONFIGURATIONS
"VISUALIZE_DARK_MODE": False,
"VISUALIZE_RUN_NATIVELY": True,
Expand All @@ -111,6 +111,9 @@
"SAVE_SCREENSHOT_DIFF": False,
"SPACY_MODEL_NAME": "en_core_web_trf",
"PRIVATE_AI_API_KEY": "<set your api key in .env>",
"RECORD_VIDEO": False,
"RECORD_IMAGES": True,
"VIDEO_PIXEL_FORMAT": "rgb24",
}

# each string in STOP_STRS should only contain strings
Expand Down
60 changes: 45 additions & 15 deletions openadapt/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ def _insert(


def insert_action_event(
recording_timestamp: int, event_timestamp: int, event_data: dict[str, Any]
recording_timestamp: float, event_timestamp: int, event_data: dict[str, Any]
) -> None:
"""Insert an action event into the database.
Args:
recording_timestamp (int): The timestamp of the recording.
recording_timestamp (float): The timestamp of the recording.
event_timestamp (int): The timestamp of the event.
event_data (dict): The data of the event.
"""
Expand All @@ -88,12 +88,12 @@ def insert_action_event(


def insert_screenshot(
recording_timestamp: int, event_timestamp: int, event_data: dict[str, Any]
recording_timestamp: float, event_timestamp: int, event_data: dict[str, Any]
) -> None:
"""Insert a screenshot into the database.
Args:
recording_timestamp (int): The timestamp of the recording.
recording_timestamp (float): The timestamp of the recording.
event_timestamp (int): The timestamp of the event.
event_data (dict): The data of the event.
"""
Expand All @@ -106,14 +106,14 @@ def insert_screenshot(


def insert_window_event(
recording_timestamp: int,
recording_timestamp: float,
event_timestamp: int,
event_data: dict[str, Any],
) -> None:
"""Insert a window event into the database.
Args:
recording_timestamp (int): The timestamp of the recording.
recording_timestamp (float): The timestamp of the recording.
event_timestamp (int): The timestamp of the event.
event_data (dict): The data of the event.
"""
Expand All @@ -126,15 +126,15 @@ def insert_window_event(


def insert_perf_stat(
recording_timestamp: int,
recording_timestamp: float,
event_type: str,
start_time: float,
end_time: float,
) -> None:
"""Insert an event performance stat into the database.
Args:
recording_timestamp (int): The timestamp of the recording.
recording_timestamp (float): The timestamp of the recording.
event_type (str): The type of the event.
start_time (float): The start time of the event.
end_time (float): The end time of the event.
Expand All @@ -148,11 +148,11 @@ def insert_perf_stat(
_insert(event_perf_stat, PerformanceStat, performance_stats)


def get_perf_stats(recording_timestamp: int) -> list[PerformanceStat]:
def get_perf_stats(recording_timestamp: float) -> list[PerformanceStat]:
"""Get performance stats for a given recording.
Args:
recording_timestamp (int): The timestamp of the recording.
recording_timestamp (float): The timestamp of the recording.
Returns:
list[PerformanceStat]: A list of performance stats for the recording.
Expand All @@ -166,7 +166,7 @@ def get_perf_stats(recording_timestamp: int) -> list[PerformanceStat]:


def insert_memory_stat(
recording_timestamp: int, memory_usage_bytes: int, timestamp: int
recording_timestamp: float, memory_usage_bytes: int, timestamp: int
) -> None:
"""Insert memory stat into db."""
memory_stat = {
Expand All @@ -177,7 +177,7 @@ def insert_memory_stat(
_insert(memory_stat, MemoryStat, memory_stats)


def get_memory_stats(recording_timestamp: int) -> None:
def get_memory_stats(recording_timestamp: float) -> None:
"""Return memory stats for a given recording."""
return (
db.query(MemoryStat)
Expand All @@ -196,7 +196,7 @@ def insert_recording(recording_data: Recording) -> Recording:
return db_obj


def delete_recording(recording_timestamp: int) -> None:
def delete_recording(recording_timestamp: float) -> None:
"""Remove the recording from the db."""
db.query(Recording).filter(Recording.timestamp == recording_timestamp).delete()
db.commit()
Expand Down Expand Up @@ -241,12 +241,12 @@ def get_recording(timestamp: int) -> Recording:
return db.query(Recording).filter(Recording.timestamp == timestamp).first()


def _get(table: BaseModel, recording_timestamp: int) -> list[BaseModel]:
def _get(table: BaseModel, recording_timestamp: float) -> list[BaseModel]:
"""Retrieve records from the database table based on the recording timestamp.
Args:
table (BaseModel): The database table to query.
recording_timestamp (int): The recording timestamp to filter the records.
recording_timestamp (float): The recording timestamp to filter the records.
Returns:
list[BaseModel]: A list of records retrieved from the database table,
Expand Down Expand Up @@ -420,3 +420,33 @@ def new_session() -> None:
if db:
db.close()
db = Session()


def update_video_start_time(
recording_timestamp: float, video_start_time: float
) -> None:
"""Update the video start time of a specific recording.
Args:
recording_timestamp (float): The timestamp of the recording to update.
video_start_time (float): The new video start time to set.
"""
# Find the recording by its timestamp
recording = (
db.query(Recording).filter(Recording.timestamp == recording_timestamp).first()
)

if not recording:
logger.error(f"No recording found with timestamp {recording_timestamp}.")
return

# Update the video start time
recording.video_start_time = video_start_time

# Commit the changes to the database
db.commit()

logger.info(
f"Updated video start time for recording {recording_timestamp} to"
f" {video_start_time}."
)
51 changes: 42 additions & 9 deletions openadapt/deprecated/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
from bokeh.models.widgets import Div
from loguru import logger
from tqdm import tqdm
import fire

from openadapt import config
from openadapt import config, video
from openadapt.db.crud import get_latest_recording
from openadapt.events import get_events
from openadapt.models import Recording
from openadapt.privacy.providers.presidio import PresidioScrubbingProvider
from openadapt.utils import (
EMPTY,
compute_diff,
configure_logging,
display_event,
evenly_spaced,
Expand Down Expand Up @@ -184,11 +186,17 @@ def dict2html(


@logger.catch
def main(recording: Recording = None) -> bool:
def main(
recording: Recording = None,
diff_video: bool = False,
cleanup: bool = True,
) -> bool:
"""Visualize a recording.
Args:
recording (Recording, optional): The recording to visualize.
diff_video (bool): Whether to diff Screenshots against video frames.
cleanup (bool): Whether to remove the HTML file after it is displayed.
Returns:
bool: True if visualization was successful, None otherwise.
Expand All @@ -199,7 +207,8 @@ def main(recording: Recording = None) -> bool:
recording = get_latest_recording()
if SCRUB:
scrub.scrub_text(recording.task_description)
logger.debug(f"{recording=}")
logger.info(f"{recording=}")
logger.info(f"{diff_video=}")

meta = {}
action_events = get_events(recording, process=PROCESS_EVENTS, meta=meta)
Expand Down Expand Up @@ -233,6 +242,14 @@ def main(recording: Recording = None) -> bool:
]
logger.info(f"{len(action_events)=}")

if diff_video:
video_file_name = video.get_video_file_name(recording.timestamp)
timestamps = [
action_event.screenshot.timestamp - recording.video_start_time
for action_event in action_events
]
frames = video.extract_frames(video_file_name, timestamps)

num_events = (
min(MAX_EVENTS, len(action_events))
if MAX_EVENTS is not None
Expand All @@ -248,9 +265,24 @@ def main(recording: Recording = None) -> bool:
for idx, action_event in enumerate(action_events):
if idx == MAX_EVENTS:
break
image = display_event(action_event)
diff = display_event(action_event, diff=True)
mask = action_event.screenshot.diff_mask

try:
image = display_event(action_event)
except TypeError as exc:
# https://github.com/moses-palmer/pynput/issues/481
logger.warning(exc)
continue

if diff_video:
frame_image = frames[idx]
diff_image = compute_diff(frame_image, action_event.screenshot.image)

# TODO: rename
diff = frame_image
mask = diff_image
else:
diff = display_event(action_event, diff=True)
mask = action_event.screenshot.diff_mask

if SCRUB:
image = scrub.scrub_image(image)
Expand Down Expand Up @@ -323,14 +355,15 @@ def main(recording: Recording = None) -> bool:
)
)

def cleanup() -> None:
def _cleanup() -> None:
os.remove(fname_out)
removed = not os.path.exists(fname_out)
logger.info(f"{removed=}")

Timer(1, cleanup).start()
if cleanup:
Timer(1, _cleanup).start()
return True


if __name__ == "__main__":
main()
fire.Fire(main)
13 changes: 13 additions & 0 deletions openadapt/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,18 @@ def get_merged_events(
)


def remove_invalid_keyboard_events(
events: list[models.ActionEvent],
) -> list[models.ActionEvent]:
"""Remove invalid keyboard events."""
return [
event
for event in events
# https://github.com/moses-palmer/pynput/issues/481
if not str(event.key) == "<0>"
]


def merge_consecutive_keyboard_events(
events: list[models.ActionEvent],
group_named_keys: bool = KEYBOARD_EVENTS_MERGE_GROUP_NAMED_KEYS,
Expand Down Expand Up @@ -717,6 +729,7 @@ def process_events(
f"{num_total=}"
)
process_fns = [
remove_invalid_keyboard_events,
merge_consecutive_keyboard_events,
merge_consecutive_mouse_move_events,
merge_consecutive_mouse_scroll_events,
Expand Down
1 change: 1 addition & 0 deletions openadapt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class Recording(db.Base):
double_click_distance_pixels = sa.Column(sa.Numeric(asdecimal=False))
platform = sa.Column(sa.String)
task_description = sa.Column(sa.String)
video_start_time = sa.Column(ForceFloat)

action_events = sa.orm.relationship(
"ActionEvent",
Expand Down
Loading

0 comments on commit 20e08b8

Please sign in to comment.