From d585d15c2be2556b8d070c466fa2f12d959f8431 Mon Sep 17 00:00:00 2001 From: texhnolyze Date: Thu, 28 Nov 2024 17:59:03 +0100 Subject: [PATCH] feat(resampling): import data resampled with different converters by data type and different resampling strategies: - game_states are converted with original rate - images are converted with a max sample rate, meaning they can have a lesser rate, but not more while not having consitent time deltas - joint_state and joint_command are resampled in sync with an interpolation using the last available value at any sample point The initial db data point at `t_0`, having a relative timestamp of `0.0` is set only when all synced data becomes available (joint_state, joint_command, later also imu). The last available game_state will also be saved at this time point `t_0`. Before this no data will be saved to the database. --- ddlitlab2024/dataset/converters/converter.py | 36 ++++ .../converters/game_state_converter.py | 60 ++++++ .../dataset/converters/image_converter.py | 61 ++++++ .../converters/synced_data_converter.py | 61 ++++++ .../dataset/imports/model_importer.py | 25 ++- .../dataset/imports/strategies/bitbots.py | 201 ++++++------------ .../dataset/resampling/max_rate_resampler.py | 46 ++++ .../resampling/original_rate_resampler.py | 6 + .../previous_interpolation_resampler.py | 54 +++++ 9 files changed, 418 insertions(+), 132 deletions(-) create mode 100644 ddlitlab2024/dataset/converters/converter.py create mode 100644 ddlitlab2024/dataset/converters/game_state_converter.py create mode 100644 ddlitlab2024/dataset/converters/image_converter.py create mode 100644 ddlitlab2024/dataset/converters/synced_data_converter.py create mode 100644 ddlitlab2024/dataset/resampling/max_rate_resampler.py create mode 100644 ddlitlab2024/dataset/resampling/original_rate_resampler.py create mode 100644 ddlitlab2024/dataset/resampling/previous_interpolation_resampler.py diff --git a/ddlitlab2024/dataset/converters/converter.py b/ddlitlab2024/dataset/converters/converter.py new file mode 100644 index 0000000..a9b3ab1 --- /dev/null +++ b/ddlitlab2024/dataset/converters/converter.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod + +from ddlitlab2024.dataset.imports.model_importer import InputData, ModelData +from ddlitlab2024.dataset.models import Recording + + +class Converter(ABC): + def __init__(self, resampler) -> None: + self.resampler = resampler + + @abstractmethod + def populate_recording_metadata(self, data: InputData, recording: Recording): + """ + Different converters of specific data/topics might need to extract + information about a recording in general and update its metadata + e.g. from a bitbots /gamestate message we extract the team's color + + Args: + data: The input data to extract metadata from (e.g. a gamestate message) + recording: The recording db model to update + """ + pass + + @abstractmethod + def convert_to_model(self, data: InputData, relative_timestamp: float, recording: Recording) -> ModelData: + """_summary_ + + Args: + data (InputData): The input data to convert to a model (e.g. a gamestate ros message) + relative_timestamp (float): The timestamp of the data relative to the start of the recording + recording (Recording): The recording db model the created model will be associated with + + Returns: + ModelData: Dataclass containing list of models to be created from the data (fields can be empty) + """ + pass diff --git a/ddlitlab2024/dataset/converters/game_state_converter.py b/ddlitlab2024/dataset/converters/game_state_converter.py new file mode 100644 index 0000000..d49775e --- /dev/null +++ b/ddlitlab2024/dataset/converters/game_state_converter.py @@ -0,0 +1,60 @@ +from enum import Enum + +from ddlitlab2024.dataset import logger +from ddlitlab2024.dataset.converters.converter import Converter +from ddlitlab2024.dataset.imports.model_importer import InputData, ModelData +from ddlitlab2024.dataset.models import GameState, Recording, RobotState, TeamColor +from ddlitlab2024.dataset.resampling.original_rate_resampler import OriginalRateResampler + + +class GameStateMessage(Enum): + INITIAL = 0 + READY = 1 + SET = 2 + PLAYING = 3 + FINISHED = 4 + + +class GameStateConverter(Converter): + def __init__(self, resampler: OriginalRateResampler) -> None: + self.resampler = resampler + + def populate_recording_metadata(self, data, recording: Recording): + team_color = TeamColor.BLUE if data.game_state.team_color == 0 else TeamColor.RED + if recording.team_color is None: + recording.team_color = team_color + + team_color_changed = recording.team_color != team_color + + if team_color_changed: + logger.warning("The team color changed, during one recording! This will be ignored.") + + def convert_to_model(self, data: InputData, relative_timestamp: float, recording: Recording) -> ModelData: + models = ModelData() + + for sample in self.resampler.resample(data, relative_timestamp): + if not sample.was_sampled_already: + models.game_states.append(self._create_game_state(sample.data.game_state, sample.timestamp, recording)) + + return models + + def _create_game_state(self, msg, sampling_timestamp: float, recording: Recording) -> GameState: + return GameState(stamp=sampling_timestamp, recording=recording, state=self._robot_state_from_msg(msg)) + + def _robot_state_from_msg(self, msg) -> RobotState: + if msg.penalized: + return RobotState.STOPPED + + match msg.game_state: + case GameStateMessage.INITIAL: + return RobotState.STOPPED + case GameStateMessage.READY: + return RobotState.POSITIONING + case GameStateMessage.SET: + return RobotState.STOPPED + case GameStateMessage.PLAYING: + return RobotState.PLAYING + case GameStateMessage.FINISHED: + return RobotState.STOPPED + case _: + return RobotState.UNKNOWN diff --git a/ddlitlab2024/dataset/converters/image_converter.py b/ddlitlab2024/dataset/converters/image_converter.py new file mode 100644 index 0000000..23901c0 --- /dev/null +++ b/ddlitlab2024/dataset/converters/image_converter.py @@ -0,0 +1,61 @@ +import cv2 +import numpy as np + +from ddlitlab2024.dataset import logger +from ddlitlab2024.dataset.converters.converter import Converter +from ddlitlab2024.dataset.imports.model_importer import InputData, ModelData +from ddlitlab2024.dataset.models import DEFAULT_IMG_SIZE, Image, Recording +from ddlitlab2024.dataset.resampling.max_rate_resampler import MaxRateResampler + + +class ImageConverter(Converter): + def __init__(self, resampler: MaxRateResampler) -> None: + self.resampler = resampler + + def populate_recording_metadata(self, data: InputData, recording: Recording): + img_scaling = (DEFAULT_IMG_SIZE[0] / data.image.width, DEFAULT_IMG_SIZE[1] / data.image.height) + if recording.img_width_scaling == 0.0: + recording.img_width_scaling = img_scaling[0] + if recording.img_height_scaling == 0.0: + recording.img_height_scaling = img_scaling[1] + + img_scaling_changed = ( + recording.img_width_scaling != img_scaling[0] or recording.img_height_scaling != img_scaling[1] + ) + + if img_scaling_changed: + logger.error( + "The image sizes changed, during one recording! All images of a recording must have the same size." + ) + + def convert_to_model(self, data: InputData, relative_timestamp: float, recording: Recording) -> ModelData: + models = ModelData() + + for sample in self.resampler.resample(data, relative_timestamp): + if not sample.was_sampled_already: + models.images.append(self._create_image(sample.data.image, sample.timestamp, recording)) + + return models + + def _create_image(self, data, sampling_timestamp: float, recording: Recording) -> Image: + img_array = np.frombuffer(data.data, np.uint8).reshape((data.height, data.width, 3)) + + will_img_be_upscaled = recording.img_width_scaling > 1.0 or recording.img_height_scaling > 1.0 + interpolation = cv2.INTER_AREA + if will_img_be_upscaled: + interpolation = cv2.INTER_CUBIC + + resized_img = cv2.resize(img_array, (recording.img_width, recording.img_height), interpolation=interpolation) + match data.encoding: + case "rgb8": + resized_rgb_img = resized_img + case "bgr8": + resized_rgb_img = cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB) + case _: + raise AssertionError(f"Unsupported image encoding: {data.encoding}") + + return Image( + stamp=sampling_timestamp, + recording=recording, + image=resized_rgb_img, + ) diff --git a/ddlitlab2024/dataset/converters/synced_data_converter.py b/ddlitlab2024/dataset/converters/synced_data_converter.py new file mode 100644 index 0000000..e937b84 --- /dev/null +++ b/ddlitlab2024/dataset/converters/synced_data_converter.py @@ -0,0 +1,61 @@ +from ddlitlab2024.dataset.converters.converter import Converter +from ddlitlab2024.dataset.imports.model_importer import InputData, ModelData +from ddlitlab2024.dataset.models import JointCommands, JointStates, Recording +from ddlitlab2024.dataset.resampling.previous_interpolation_resampler import PreviousInterpolationResampler +from ddlitlab2024.utils.utils import camelcase_to_snakecase, shift_radian_to_positive_range + + +def joints_dict_from_msg_data(joints_data: list[tuple[str, float]]) -> dict[str, float]: + joints_dict = {} + + for name, position in joints_data: + key = camelcase_to_snakecase(name) + value = shift_radian_to_positive_range(position) + joints_dict[key] = value + + return joints_dict + + +class SyncedDataConverter(Converter): + def __init__(self, resampler: PreviousInterpolationResampler) -> None: + self.resampler = resampler + + def populate_recording_metadata(self, data: InputData, recording: Recording): + pass + + def convert_to_model(self, data: InputData, relative_timestamp: float, recording: Recording) -> ModelData: + assert data.joint_state is not None, "joint_state are required in synced resampling data" + assert data.joint_command is not None, "joint_command are required in synced resampling data" + + models = ModelData() + + for sample in self.resampler.resample(data, relative_timestamp): + if not sample.was_sampled_already: + models.joint_states.append( + self._create_joint_states(sample.data.joint_state, sample.timestamp, recording) + ) + models.joint_commands.append( + self._create_joint_commands(sample.data.joint_command, sample.timestamp, recording) + ) + + return models + + def _create_joint_states(self, msg, sampling_timestamp: float, recording: Recording) -> JointStates: + if msg is None: + return JointStates(stamp=sampling_timestamp, recording=recording) + else: + joint_states_data = list(zip(msg.name, msg.position)) + + return JointStates( + stamp=sampling_timestamp, recording=recording, **joints_dict_from_msg_data(joint_states_data) + ) + + def _create_joint_commands(self, msg, sampling_timestamp: float, recording: Recording) -> JointCommands: + if msg is None: + return JointCommands(stamp=sampling_timestamp, recording=recording) + else: + joint_commands_data = list(zip(msg.joint_names, msg.positions)) + + return JointCommands( + stamp=sampling_timestamp, recording=recording, **joints_dict_from_msg_data(joint_commands_data) + ) diff --git a/ddlitlab2024/dataset/imports/model_importer.py b/ddlitlab2024/dataset/imports/model_importer.py index 03920d9..50fa83a 100644 --- a/ddlitlab2024/dataset/imports/model_importer.py +++ b/ddlitlab2024/dataset/imports/model_importer.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from pathlib import Path +from typing import Any from ddlitlab2024.dataset.db import Database from ddlitlab2024.dataset.models import GameState, Image, JointCommands, JointStates, Recording @@ -15,6 +16,21 @@ class ImportMetadata: simulated: bool +@dataclass +class Sample[T]: + data: T + timestamp: float + was_sampled_already: bool = False + + +@dataclass +class InputData: + image: Any = None + game_state: Any = None + joint_state: Any = None + joint_command: Any = None + + @dataclass class ModelData: recording: Recording | None = None @@ -24,7 +40,14 @@ class ModelData: images: list[Image] = field(default_factory=list) def model_instances(self): - return [self.recording] + self.game_states + self.joint_states + self.joint_commands + return [self.recording] + self.game_states + self.joint_states + self.joint_commands + self.images + + def merge(self, other: "ModelData") -> "ModelData": + self.game_states.extend(other.game_states) + self.joint_states.extend(other.joint_states) + self.joint_commands.extend(other.joint_commands) + self.images.extend(other.images) + return self class ImportStrategy(ABC): diff --git a/ddlitlab2024/dataset/imports/strategies/bitbots.py b/ddlitlab2024/dataset/imports/strategies/bitbots.py index 28c1461..457979b 100644 --- a/ddlitlab2024/dataset/imports/strategies/bitbots.py +++ b/ddlitlab2024/dataset/imports/strategies/bitbots.py @@ -1,43 +1,30 @@ from contextlib import contextmanager from datetime import datetime -from enum import Enum from pathlib import Path -import cv2 -import numpy as np from mcap.reader import make_reader from mcap.summary import Summary from mcap_ros2.decoder import DecoderFactory from ddlitlab2024.dataset import logger -from ddlitlab2024.dataset.imports.model_importer import ImportMetadata, ImportStrategy, ModelData +from ddlitlab2024.dataset.converters.converter import Converter +from ddlitlab2024.dataset.converters.game_state_converter import GameStateConverter +from ddlitlab2024.dataset.converters.image_converter import ImageConverter +from ddlitlab2024.dataset.converters.synced_data_converter import SyncedDataConverter +from ddlitlab2024.dataset.imports.model_importer import ImportMetadata, ImportStrategy, InputData, ModelData from ddlitlab2024.dataset.models import ( DEFAULT_IMG_SIZE, - GameState, - Image, - JointCommands, - JointStates, Recording, - RobotState, - TeamColor, ) -from ddlitlab2024.utils.utils import camelcase_to_snakecase, shift_radian_to_positive_range +from ddlitlab2024.dataset.resampling.max_rate_resampler import MaxRateResampler +from ddlitlab2024.dataset.resampling.original_rate_resampler import OriginalRateResampler +from ddlitlab2024.dataset.resampling.previous_interpolation_resampler import PreviousInterpolationResampler DATETIME_FORMAT = "%d.%m-%Y %H:%M:%S" - - -class GameStateMessage(Enum): - INITIAL = 0 - READY = 1 - SET = 2 - PLAYING = 3 - FINISHED = 4 - - USED_TOPICS = [ "/DynamixelController/command", - "/camera/camera_info", "/camera/image_proc", + "/camera/image_to_record", "/gamestate", "/imu/data", "/joint_states", @@ -45,91 +32,88 @@ class GameStateMessage(Enum): ] +RESAMPLE_RATE_HZ = 20 +IMAGE_MAX_RESAMPLE_RATE_HZ = 10 + + class BitBotsImportStrategy(ImportStrategy): def __init__(self, metadata: ImportMetadata): self.metadata = metadata + self.image_converter = ImageConverter(MaxRateResampler(IMAGE_MAX_RESAMPLE_RATE_HZ)) + self.game_state_converter = GameStateConverter(OriginalRateResampler()) + self.synced_data_converter = SyncedDataConverter(PreviousInterpolationResampler(RESAMPLE_RATE_HZ)) + + self.model_data = ModelData() + def convert_to_model_data(self, file_path: Path) -> ModelData: with self._mcap_reader(file_path) as reader: summary: Summary | None = reader.get_summary() if summary is None: logger.error("No summary found in the MCAP file, skipping processing.") - return ModelData() + return self.model_data + last_messages_by_topic = InputData() first_used_msg_time = None - model_data = ModelData(recording=self.create_recording(summary, file_path)) - assert model_data.recording is not None, "Recording is not set" - self._log_debug_info(summary, model_data.recording) + self.model_data.recording = self._create_recording(summary, file_path) + + self._log_debug_info(summary, self.model_data.recording) for _, channel, message, ros_msg in reader.iter_decoded_messages(topics=USED_TOPICS): - first_used_msg_time = first_used_msg_time or message.publish_time - relative_timestamp = (message.publish_time - first_used_msg_time) / 1e9 + converter: Converter | None = None match channel.topic: case "/gamestate": - team_color = TeamColor.BLUE if ros_msg.team_color == 0 else TeamColor.RED - if model_data.recording.team_color is None: - model_data.recording.team_color = team_color + last_messages_by_topic.game_state = ros_msg + converter = self.game_state_converter + case "/camera/image_proc" | "/camera/image_to_record": + last_messages_by_topic.image = ros_msg + converter = self.image_converter + case "/joint_states": + last_messages_by_topic.joint_state = ros_msg + converter = self.synced_data_converter + case "/DynamixelController/command": + last_messages_by_topic.joint_command = ros_msg + converter = self.synced_data_converter - team_color_changed = model_data.recording.team_color != team_color + if self._is_all_synced_data_available(last_messages_by_topic): + if first_used_msg_time is None: + first_used_msg_time = message.publish_time + self._initial_conversion(last_messages_by_topic) + else: + relative_msg_timestamp = (message.publish_time - first_used_msg_time) / 1e9 + if converter: + self._create_models(converter, last_messages_by_topic, relative_msg_timestamp) - if team_color_changed: - logger.warning("The team color changed, during one recording! This will be ignored.") + return self.model_data - model_data.game_states.append( - self.create_gamestate(ros_msg, relative_timestamp, model_data.recording) - ) - case "/joint_states": - model_data.joint_states.append( - self.create_joint_states(ros_msg, relative_timestamp, model_data.recording) - ) - case "/DynamixelController/command": - model_data.joint_commands.append( - self.create_joint_commands(ros_msg, relative_timestamp, model_data.recording) - ) - case "/camera/image_proc" | "/camera/image_raw": - img_scaling = (DEFAULT_IMG_SIZE[0] / ros_msg.width, DEFAULT_IMG_SIZE[1] / ros_msg.height) - if model_data.recording.img_width_scaling == 0.0: - model_data.recording.img_width_scaling = img_scaling[0] - if model_data.recording.img_height_scaling == 0.0: - model_data.recording.img_height_scaling = img_scaling[1] - - img_scaling_changed = ( - model_data.recording.img_width_scaling != img_scaling[0] - or model_data.recording.img_height_scaling != img_scaling[1] - ) - - if img_scaling_changed: - logger.error( - "The image sizes changed, during one recording! " - + "All images of a recording must have the same size." - ) - - model_data.images.append(self.create_image(ros_msg, relative_timestamp, model_data.recording)) - - return model_data - - def create_image(self, msg, relative_timestamp: float, recording: Recording) -> Image: - img_array = np.frombuffer(msg.data, np.uint8).reshape((msg.height, msg.width, 3)) - - will_img_be_upscaled = recording.img_width_scaling > 1.0 or recording.img_height_scaling > 1.0 - interpolation = cv2.INTER_AREA - if will_img_be_upscaled: - interpolation = cv2.INTER_CUBIC - - resized_img = cv2.resize(img_array, (recording.img_width, recording.img_height), interpolation=interpolation) - resized_rgb_img = cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB) - - return Image( - stamp=relative_timestamp, - recording=recording, - image=resized_rgb_img, - ) + def _initial_conversion(self, data: InputData): + assert self._is_all_synced_data_available(data), "All synced data must be available to create initial models" + + first_timestamp = 0.0 - def create_recording(self, summary: Summary, mcap_file_path: Path) -> Recording: - start_timestamp, end_timestamp = self.extract_timeframe(summary) + if data.game_state: + self._create_models(self.game_state_converter, data, first_timestamp) + + self._create_models(self.synced_data_converter, data, first_timestamp) + + def _create_models(self, converter: Converter, data: InputData, relative_timestamp: float) -> ModelData: + assert self.model_data.recording is not None, "Recording must be defined to create child models" + + converter.populate_recording_metadata(data, self.model_data.recording) + model_data = converter.convert_to_model(data, relative_timestamp, self.model_data.recording) + if model_data: + self.model_data = self.model_data.merge(model_data) + + return self.model_data + + def _is_all_synced_data_available(self, data: InputData) -> bool: + return data.joint_command is not None and data.joint_state is not None + + def _create_recording(self, summary: Summary, mcap_file_path: Path) -> Recording: + start_timestamp, end_timestamp = self._extract_timeframe(summary) return Recording( allow_public=self.metadata.allow_public, @@ -147,52 +131,7 @@ def create_recording(self, summary: Summary, mcap_file_path: Path) -> Recording: img_height_scaling=0.0, ) - def create_gamestate(self, msg, relative_timestamp: float, recording: Recording) -> GameState: - return GameState(stamp=relative_timestamp, recording=recording, state=self.robot_state_from_msg(msg)) - - def robot_state_from_msg(self, msg) -> RobotState: - if msg.penalized: - return RobotState.STOPPED - - match msg.game_state: - case GameStateMessage.INITIAL: - return RobotState.STOPPED - case GameStateMessage.READY: - return RobotState.POSITIONING - case GameStateMessage.SET: - return RobotState.STOPPED - case GameStateMessage.PLAYING: - return RobotState.PLAYING - case GameStateMessage.FINISHED: - return RobotState.STOPPED - case _: - return RobotState.UNKNOWN - - def create_joint_states(self, msg, relative_timestamp: float, recording: Recording) -> JointStates: - joint_states_data = list(zip(msg.name, msg.position)) - - return JointStates( - stamp=relative_timestamp, recording=recording, **self._joints_dict_from_msg_data(joint_states_data) - ) - - def create_joint_commands(self, msg, relative_timestamp: float, recording: Recording) -> JointStates: - joint_commands_data = list(zip(msg.joint_names, msg.positions)) - - return JointCommands( - stamp=relative_timestamp, recording=recording, **self._joints_dict_from_msg_data(joint_commands_data) - ) - - def _joints_dict_from_msg_data(self, joints_data: list[tuple[str, float]]) -> dict[str, float]: - joints_dict = {} - - for name, position in joints_data: - key = camelcase_to_snakecase(name) - value = shift_radian_to_positive_range(position) - joints_dict[key] = value - - return joints_dict - - def extract_timeframe(self, summary: Summary) -> tuple[int, int]: + def _extract_timeframe(self, summary: Summary) -> tuple[int, int]: first_msg_start_time = None last_msg_end_time = None diff --git a/ddlitlab2024/dataset/resampling/max_rate_resampler.py b/ddlitlab2024/dataset/resampling/max_rate_resampler.py new file mode 100644 index 0000000..dec9afb --- /dev/null +++ b/ddlitlab2024/dataset/resampling/max_rate_resampler.py @@ -0,0 +1,46 @@ +from ddlitlab2024.dataset.imports.model_importer import InputData, Sample + + +class MaxRateResampler: + def __init__(self, max_sample_rate_hz: int): + self.max_sample_rate_hz = max_sample_rate_hz + self.sampling_step_in_seconds = 1 / max_sample_rate_hz + + self.last_sampled_data = None + self.last_sampled_timestamp = None + self.last_sample_step_timestamp = None + + def resample(self, data: InputData, relative_timestamp: float) -> list[Sample[InputData]]: + if self.last_sample_step_timestamp is None: + return [self._initial_sample(data, relative_timestamp)] + else: + return self._samples_until(data, relative_timestamp) + + def _initial_sample(self, data: InputData, relative_timestamp: float) -> Sample[InputData]: + self.last_sampled_data = data + self.last_sampled_timestamp = relative_timestamp + self.last_sample_step_timestamp = relative_timestamp + + return Sample(data=self.last_sampled_data, timestamp=self.last_sampled_timestamp) + + def _samples_until(self, data: InputData, relative_timestamp: float) -> list[Sample[InputData]]: + assert ( + self.last_sampled_data is not None + and self.last_sampled_timestamp is not None + and self.last_sample_step_timestamp + ), "There must have been an initial sample" + + if self.is_timestamp_after_next_sampling_step(relative_timestamp): + self.last_sampled_data = data + self.last_sampled_timestamp = relative_timestamp + self.last_sample_step_timestamp = self.last_sample_step_timestamp + self.sampling_step_in_seconds + return [Sample(data=self.last_sampled_data, timestamp=self.last_sampled_timestamp)] + + return [Sample(data=self.last_sampled_data, timestamp=self.last_sampled_timestamp, was_sampled_already=True)] + + def is_timestamp_after_next_sampling_step(self, relative_timestamp: float) -> bool: + if self.last_sample_step_timestamp is None: + # There was no previous sample, so it is time to sample + return True + + return relative_timestamp - self.last_sample_step_timestamp >= self.sampling_step_in_seconds diff --git a/ddlitlab2024/dataset/resampling/original_rate_resampler.py b/ddlitlab2024/dataset/resampling/original_rate_resampler.py new file mode 100644 index 0000000..7f56a26 --- /dev/null +++ b/ddlitlab2024/dataset/resampling/original_rate_resampler.py @@ -0,0 +1,6 @@ +from ddlitlab2024.dataset.imports.model_importer import InputData, Sample + + +class OriginalRateResampler: + def resample(self, data: InputData, relative_timestamp: float) -> list[Sample[InputData]]: + return [Sample(data=data, timestamp=relative_timestamp)] diff --git a/ddlitlab2024/dataset/resampling/previous_interpolation_resampler.py b/ddlitlab2024/dataset/resampling/previous_interpolation_resampler.py new file mode 100644 index 0000000..6cdf556 --- /dev/null +++ b/ddlitlab2024/dataset/resampling/previous_interpolation_resampler.py @@ -0,0 +1,54 @@ +from ddlitlab2024.dataset.imports.model_importer import InputData, Sample + + +class PreviousInterpolationResampler: + def __init__(self, sample_rate_hz: int): + self.sample_rate_hz = sample_rate_hz + self.sampling_step_in_seconds = 1 / sample_rate_hz + + self.last_received_data = None + self.last_sampled_data = None + self.last_sample_step_timestamp = None + + def resample(self, data: InputData, relative_timestamp: float) -> list[Sample[InputData]]: + if self.last_sample_step_timestamp is None: + return [self._initial_sample(data, relative_timestamp)] + else: + return self._samples_until(data, relative_timestamp) + + def _initial_sample(self, data: InputData, relative_timestamp: float) -> Sample[InputData]: + self.last_received_data = data + self.last_sampled_data = data + self.last_sample_step_timestamp = relative_timestamp + + return Sample(data=self.last_sampled_data, timestamp=self.last_sample_step_timestamp) + + def _samples_until(self, data: InputData, relative_timestamp: float) -> list[Sample[InputData]]: + assert ( + self.last_received_data is not None + and self.last_sampled_data is not None + and self.last_sample_step_timestamp is not None + ), "There must have been an initial sample" + + samples = [] + num_samples = self._num_passed_sampling_steps(relative_timestamp) + + for _ in range(num_samples): + self.last_sampled_data = self.last_received_data + self.last_sample_step_timestamp = self.last_sample_step_timestamp + self.sampling_step_in_seconds + samples.append(Sample(data=self.last_sampled_data, timestamp=self.last_sample_step_timestamp)) + + if num_samples > 0: + return samples + else: + self.last_received_data = data + return [ + Sample(data=self.last_sampled_data, timestamp=self.last_sample_step_timestamp, was_sampled_already=True) + ] + + def _num_passed_sampling_steps(self, relative_timestamp: float) -> int: + if self.last_sample_step_timestamp is None: + # There was no previous sample, so it is time to sample once + return 1 + + return int((relative_timestamp - self.last_sample_step_timestamp) / self.sampling_step_in_seconds)