generated from bit-bots/bitbots_template_repository
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(resampling): import data resampled
with different converters by data type and different resampling strategies: - game_states are converted with original rate - images are converted with a max sample rate, meaning they can have a lesser rate, but not more while not having consitent time deltas - joint_state and joint_command are resampled in sync with an interpolation using the last available value at any sample point The initial db data point at `t_0`, having a relative timestamp of `0.0` is set only when all synced data becomes available (joint_state, joint_command, later also imu). The last available game_state will also be saved at this time point `t_0`. Before this no data will be saved to the database.
- Loading branch information
1 parent
404f90d
commit d585d15
Showing
9 changed files
with
418 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
from ddlitlab2024.dataset.imports.model_importer import InputData, ModelData | ||
from ddlitlab2024.dataset.models import Recording | ||
|
||
|
||
class Converter(ABC): | ||
def __init__(self, resampler) -> None: | ||
self.resampler = resampler | ||
|
||
@abstractmethod | ||
def populate_recording_metadata(self, data: InputData, recording: Recording): | ||
""" | ||
Different converters of specific data/topics might need to extract | ||
information about a recording in general and update its metadata | ||
e.g. from a bitbots /gamestate message we extract the team's color | ||
Args: | ||
data: The input data to extract metadata from (e.g. a gamestate message) | ||
recording: The recording db model to update | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def convert_to_model(self, data: InputData, relative_timestamp: float, recording: Recording) -> ModelData: | ||
"""_summary_ | ||
Args: | ||
data (InputData): The input data to convert to a model (e.g. a gamestate ros message) | ||
relative_timestamp (float): The timestamp of the data relative to the start of the recording | ||
recording (Recording): The recording db model the created model will be associated with | ||
Returns: | ||
ModelData: Dataclass containing list of models to be created from the data (fields can be empty) | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from enum import Enum | ||
|
||
from ddlitlab2024.dataset import logger | ||
from ddlitlab2024.dataset.converters.converter import Converter | ||
from ddlitlab2024.dataset.imports.model_importer import InputData, ModelData | ||
from ddlitlab2024.dataset.models import GameState, Recording, RobotState, TeamColor | ||
from ddlitlab2024.dataset.resampling.original_rate_resampler import OriginalRateResampler | ||
|
||
|
||
class GameStateMessage(Enum): | ||
INITIAL = 0 | ||
READY = 1 | ||
SET = 2 | ||
PLAYING = 3 | ||
FINISHED = 4 | ||
|
||
|
||
class GameStateConverter(Converter): | ||
def __init__(self, resampler: OriginalRateResampler) -> None: | ||
self.resampler = resampler | ||
|
||
def populate_recording_metadata(self, data, recording: Recording): | ||
team_color = TeamColor.BLUE if data.game_state.team_color == 0 else TeamColor.RED | ||
if recording.team_color is None: | ||
recording.team_color = team_color | ||
|
||
team_color_changed = recording.team_color != team_color | ||
|
||
if team_color_changed: | ||
logger.warning("The team color changed, during one recording! This will be ignored.") | ||
|
||
def convert_to_model(self, data: InputData, relative_timestamp: float, recording: Recording) -> ModelData: | ||
models = ModelData() | ||
|
||
for sample in self.resampler.resample(data, relative_timestamp): | ||
if not sample.was_sampled_already: | ||
models.game_states.append(self._create_game_state(sample.data.game_state, sample.timestamp, recording)) | ||
|
||
return models | ||
|
||
def _create_game_state(self, msg, sampling_timestamp: float, recording: Recording) -> GameState: | ||
return GameState(stamp=sampling_timestamp, recording=recording, state=self._robot_state_from_msg(msg)) | ||
|
||
def _robot_state_from_msg(self, msg) -> RobotState: | ||
if msg.penalized: | ||
return RobotState.STOPPED | ||
|
||
match msg.game_state: | ||
case GameStateMessage.INITIAL: | ||
return RobotState.STOPPED | ||
case GameStateMessage.READY: | ||
return RobotState.POSITIONING | ||
case GameStateMessage.SET: | ||
return RobotState.STOPPED | ||
case GameStateMessage.PLAYING: | ||
return RobotState.PLAYING | ||
case GameStateMessage.FINISHED: | ||
return RobotState.STOPPED | ||
case _: | ||
return RobotState.UNKNOWN |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import cv2 | ||
import numpy as np | ||
|
||
from ddlitlab2024.dataset import logger | ||
from ddlitlab2024.dataset.converters.converter import Converter | ||
from ddlitlab2024.dataset.imports.model_importer import InputData, ModelData | ||
from ddlitlab2024.dataset.models import DEFAULT_IMG_SIZE, Image, Recording | ||
from ddlitlab2024.dataset.resampling.max_rate_resampler import MaxRateResampler | ||
|
||
|
||
class ImageConverter(Converter): | ||
def __init__(self, resampler: MaxRateResampler) -> None: | ||
self.resampler = resampler | ||
|
||
def populate_recording_metadata(self, data: InputData, recording: Recording): | ||
img_scaling = (DEFAULT_IMG_SIZE[0] / data.image.width, DEFAULT_IMG_SIZE[1] / data.image.height) | ||
if recording.img_width_scaling == 0.0: | ||
recording.img_width_scaling = img_scaling[0] | ||
if recording.img_height_scaling == 0.0: | ||
recording.img_height_scaling = img_scaling[1] | ||
|
||
img_scaling_changed = ( | ||
recording.img_width_scaling != img_scaling[0] or recording.img_height_scaling != img_scaling[1] | ||
) | ||
|
||
if img_scaling_changed: | ||
logger.error( | ||
"The image sizes changed, during one recording! All images of a recording must have the same size." | ||
) | ||
|
||
def convert_to_model(self, data: InputData, relative_timestamp: float, recording: Recording) -> ModelData: | ||
models = ModelData() | ||
|
||
for sample in self.resampler.resample(data, relative_timestamp): | ||
if not sample.was_sampled_already: | ||
models.images.append(self._create_image(sample.data.image, sample.timestamp, recording)) | ||
|
||
return models | ||
|
||
def _create_image(self, data, sampling_timestamp: float, recording: Recording) -> Image: | ||
img_array = np.frombuffer(data.data, np.uint8).reshape((data.height, data.width, 3)) | ||
|
||
will_img_be_upscaled = recording.img_width_scaling > 1.0 or recording.img_height_scaling > 1.0 | ||
interpolation = cv2.INTER_AREA | ||
if will_img_be_upscaled: | ||
interpolation = cv2.INTER_CUBIC | ||
|
||
resized_img = cv2.resize(img_array, (recording.img_width, recording.img_height), interpolation=interpolation) | ||
match data.encoding: | ||
case "rgb8": | ||
resized_rgb_img = resized_img | ||
case "bgr8": | ||
resized_rgb_img = cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB) | ||
case _: | ||
raise AssertionError(f"Unsupported image encoding: {data.encoding}") | ||
|
||
return Image( | ||
stamp=sampling_timestamp, | ||
recording=recording, | ||
image=resized_rgb_img, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from ddlitlab2024.dataset.converters.converter import Converter | ||
from ddlitlab2024.dataset.imports.model_importer import InputData, ModelData | ||
from ddlitlab2024.dataset.models import JointCommands, JointStates, Recording | ||
from ddlitlab2024.dataset.resampling.previous_interpolation_resampler import PreviousInterpolationResampler | ||
from ddlitlab2024.utils.utils import camelcase_to_snakecase, shift_radian_to_positive_range | ||
|
||
|
||
def joints_dict_from_msg_data(joints_data: list[tuple[str, float]]) -> dict[str, float]: | ||
joints_dict = {} | ||
|
||
for name, position in joints_data: | ||
key = camelcase_to_snakecase(name) | ||
value = shift_radian_to_positive_range(position) | ||
joints_dict[key] = value | ||
|
||
return joints_dict | ||
|
||
|
||
class SyncedDataConverter(Converter): | ||
def __init__(self, resampler: PreviousInterpolationResampler) -> None: | ||
self.resampler = resampler | ||
|
||
def populate_recording_metadata(self, data: InputData, recording: Recording): | ||
pass | ||
|
||
def convert_to_model(self, data: InputData, relative_timestamp: float, recording: Recording) -> ModelData: | ||
assert data.joint_state is not None, "joint_state are required in synced resampling data" | ||
assert data.joint_command is not None, "joint_command are required in synced resampling data" | ||
|
||
models = ModelData() | ||
|
||
for sample in self.resampler.resample(data, relative_timestamp): | ||
if not sample.was_sampled_already: | ||
models.joint_states.append( | ||
self._create_joint_states(sample.data.joint_state, sample.timestamp, recording) | ||
) | ||
models.joint_commands.append( | ||
self._create_joint_commands(sample.data.joint_command, sample.timestamp, recording) | ||
) | ||
|
||
return models | ||
|
||
def _create_joint_states(self, msg, sampling_timestamp: float, recording: Recording) -> JointStates: | ||
if msg is None: | ||
return JointStates(stamp=sampling_timestamp, recording=recording) | ||
else: | ||
joint_states_data = list(zip(msg.name, msg.position)) | ||
|
||
return JointStates( | ||
stamp=sampling_timestamp, recording=recording, **joints_dict_from_msg_data(joint_states_data) | ||
) | ||
|
||
def _create_joint_commands(self, msg, sampling_timestamp: float, recording: Recording) -> JointCommands: | ||
if msg is None: | ||
return JointCommands(stamp=sampling_timestamp, recording=recording) | ||
else: | ||
joint_commands_data = list(zip(msg.joint_names, msg.positions)) | ||
|
||
return JointCommands( | ||
stamp=sampling_timestamp, recording=recording, **joints_dict_from_msg_data(joint_commands_data) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.