diff --git a/.travis.yml b/.travis.yml index 607c72b520..981e8727f8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,17 +4,31 @@ # There's a docker image pupillabs/pupil-docker-ubuntu:latest which contains all # required dependencies for the test-suite to run. -os: minimal -services: docker +jobs: + include: + - name: pytest + os: minimal + services: docker + before_install: + - docker pull pupillabs/pupil-docker-ubuntu:latest + - chmod +x ./.travis/*.sh + script: + - > + docker run --rm + -v `pwd`:/repo + -w /repo + pupillabs/pupil-docker-ubuntu:latest + /bin/bash /repo/.travis/run_tests.sh -before_install: - - docker pull pupillabs/pupil-docker-ubuntu:latest - - chmod +x ./.travis/*.sh - -script: - - > - docker run --rm - -v `pwd`:/repo - -w /repo - pupillabs/pupil-docker-ubuntu:latest - /bin/bash /repo/.travis/run_tests.sh + - name: black formatting check + language: python + before_script: + - pip install -U pip + - pip install black + script: + - > + black . --check --exclude pupil_src/tests || ( + echo -e "\033[0;31m PLEASE RUN THE BLACK FORMATTER ON YOUR CODE: \033[0m" && + echo "See https://github.com/psf/black for details" && + false + ) diff --git a/pupil_src/launchables/service.py b/pupil_src/launchables/service.py index 81cc2ed77b..6666c21c51 100644 --- a/pupil_src/launchables/service.py +++ b/pupil_src/launchables/service.py @@ -193,7 +193,6 @@ def get_dt(): "min_calibration_confidence", 0.8 ) - audio.set_audio_mode( session_settings.get("audio_mode", audio.get_default_audio_mode()) ) @@ -226,7 +225,12 @@ def get_dt(): def handle_notifications(n): subject = n["subject"] if subject == "start_plugin": - g_pool.plugins.add(plugin_by_name[n["name"]], args=n.get("args", {})) + try: + g_pool.plugins.add( + plugin_by_name[n["name"]], args=n.get("args", {}) + ) + except KeyError as err: + logger.error(f"Attempt to load unknown plugin: {err}") elif subject == "service_process.should_stop": g_pool.service_should_run = False elif subject.startswith("meta.should_doc"): @@ -272,7 +276,6 @@ def handle_notifications(n): gaze_pub.send(gaze_datum) events["gaze"].append(gaze_datum) - for plugin in g_pool.plugins: plugin.recent_events(events=events) diff --git a/pupil_src/shared_modules/accuracy_visualizer.py b/pupil_src/shared_modules/accuracy_visualizer.py index 2c6015f810..e1f305119a 100644 --- a/pupil_src/shared_modules/accuracy_visualizer.py +++ b/pupil_src/shared_modules/accuracy_visualizer.py @@ -26,7 +26,7 @@ ) from plugin import Plugin -from gaze_mapping import registered_gazer_classes_by_class_name +from gaze_mapping import gazer_classes_by_class_name, registered_gazer_classes from gaze_mapping.notifications import ( CalibrationSetupNotification, CalibrationResultNotification, @@ -80,7 +80,9 @@ def clear(self): self.__gazer_class = None self.__gazer_params = None - def update(self, gazer_class_name: str, gazer_params=..., pupil_list=..., ref_list=...): + def update( + self, gazer_class_name: str, gazer_params=..., pupil_list=..., ref_list=... + ): if ( self.gazer_class_name is not None and self.gazer_class_name != gazer_class_name @@ -107,7 +109,7 @@ def __gazer_class_from_name(gazer_class_name: str) -> T.Optional[T.Any]: logger.info("Accuracy visualization is disabled for HMD calibration") return None - gazers_by_name = registered_gazer_classes_by_class_name() + gazers_by_name = gazer_classes_by_class_name(registered_gazer_classes()) try: gazer_cls = gazers_by_name[gazer_class_name] diff --git a/pupil_src/shared_modules/calibration_choreography/base_plugin.py b/pupil_src/shared_modules/calibration_choreography/base_plugin.py index 82654d0498..b43f183a97 100644 --- a/pupil_src/shared_modules/calibration_choreography/base_plugin.py +++ b/pupil_src/shared_modules/calibration_choreography/base_plugin.py @@ -315,8 +315,8 @@ def on_choreography_successfull( calib_data = {"ref_list": ref_list, "pupil_list": pupil_list} self._start_plugin(self.selected_gazer_class, calib_data=calib_data) elif mode == ChoreographyMode.VALIDATION: + assert self.g_pool.active_gaze_mapping_plugin is not None gazer_class = self.g_pool.active_gaze_mapping_plugin.__class__ - assert gazer_class == self.selected_gazer_class gazer_params = self.g_pool.active_gaze_mapping_plugin.get_params() self._start_plugin("Accuracy_Visualizer") @@ -448,8 +448,7 @@ def update_ui(self): ) if self.shows_action_buttons: self.__ui_button_validation.read_only = ( - self.selected_gazer_class - is not self.g_pool.active_gaze_mapping_plugin.__class__ + self.g_pool.active_gaze_mapping_plugin is None ) def deinit_ui(self): diff --git a/pupil_src/shared_modules/calibration_choreography/screen_marker_plugin.py b/pupil_src/shared_modules/calibration_choreography/screen_marker_plugin.py index 330e709ae8..c59b12e163 100644 --- a/pupil_src/shared_modules/calibration_choreography/screen_marker_plugin.py +++ b/pupil_src/shared_modules/calibration_choreography/screen_marker_plugin.py @@ -67,6 +67,14 @@ def selection_label(cls) -> str: def selection_order(cls) -> float: return 1.0 + @staticmethod + def get_list_of_markers_to_show(mode: ChoreographyMode) -> list: + if ChoreographyMode.CALIBRATION == mode: + return [(0.5, 0.5), (0.0, 1.0), (1.0, 1.0), (1.0, 0.0), (0.0, 0.0)] + if ChoreographyMode.VALIDATION == mode: + return [(0.5, 1.0), (1.0, 0.5), (0.5, 0.0), (0.0, 0.5)] + raise ValueError(f"Unknown mode {mode}") + def __init__( self, g_pool, @@ -269,7 +277,7 @@ def _perform_start(self): ) return - self.__current_list_of_markers_to_show = self.__get_list_of_markers_to_show( + self.__current_list_of_markers_to_show = self.get_list_of_markers_to_show( mode=self.current_mode, ) self.__currently_shown_marker_position = None @@ -289,14 +297,6 @@ def _perform_stop(self): ### Private - @staticmethod - def __get_list_of_markers_to_show(mode: ChoreographyMode,) -> list: - if ChoreographyMode.CALIBRATION == mode: - return [(0.5, 0.5), (0.0, 1.0), (1.0, 1.0), (1.0, 0.0), (0.0, 0.0)] - if ChoreographyMode.VALIDATION == mode: - return [(0.5, 1.0), (1.0, 0.5), (0.5, 0.0), (0.0, 0.5)] - raise ValueError(f"Unknown mode {mode}") - def _on_window_did_close(self): self._signal_should_stop(mode=self.current_mode) diff --git a/pupil_src/shared_modules/csv_utils.py b/pupil_src/shared_modules/csv_utils.py index cbd888896b..b80a803edb 100644 --- a/pupil_src/shared_modules/csv_utils.py +++ b/pupil_src/shared_modules/csv_utils.py @@ -14,15 +14,16 @@ import typing as t -CSV_EXPORT_RAW_TYPE = t.TypeVar('CSV_EXPORT_RAW_TYPE') +CSV_EXPORT_RAW_TYPE = t.TypeVar("CSV_EXPORT_RAW_TYPE") CSV_EXPORT_LABEL_TYPE = t.AnyStr CSV_EXPORT_VALUE_TYPE = t.Any CSV_EXPORT_VALUE_GETTER_TYPE = t.Callable[[CSV_EXPORT_RAW_TYPE], CSV_EXPORT_VALUE_TYPE] -CSV_EXPORT_SCHEMA_TYPE = t.List[t.Tuple[CSV_EXPORT_LABEL_TYPE, CSV_EXPORT_VALUE_GETTER_TYPE]] +CSV_EXPORT_SCHEMA_TYPE = t.List[ + t.Tuple[CSV_EXPORT_LABEL_TYPE, CSV_EXPORT_VALUE_GETTER_TYPE] +] class CSV_Exporter(abc.ABC, t.Generic[CSV_EXPORT_RAW_TYPE]): - @classmethod @abc.abstractmethod def csv_export_schema(cls) -> CSV_EXPORT_SCHEMA_TYPE: @@ -33,10 +34,17 @@ def csv_export_labels(cls) -> t.Iterable[CSV_EXPORT_LABEL_TYPE]: return tuple(label for label, _ in cls.csv_export_schema()) @classmethod - def csv_export_values(cls, raw_value: CSV_EXPORT_RAW_TYPE) -> t.Iterable[CSV_EXPORT_VALUE_TYPE]: + def csv_export_values( + cls, raw_value: CSV_EXPORT_RAW_TYPE + ) -> t.Iterable[CSV_EXPORT_VALUE_TYPE]: return tuple(getter(raw_value) for _, getter in cls.csv_export_schema()) - def csv_export(self, raw_values: t.Iterable[CSV_EXPORT_RAW_TYPE], export_dir: str, export_name: str) -> str: + def csv_export( + self, + raw_values: t.Iterable[CSV_EXPORT_RAW_TYPE], + export_dir: str, + export_name: str, + ) -> str: export_path = os.path.abspath(os.path.join(export_dir, export_name)) @@ -63,7 +71,9 @@ def read_key_value_file(csvfile): if "key" not in first_line or "value" not in first_line: csvfile.seek(0) # Seek to start if first_line is not an header dialect = csv.Sniffer().sniff(first_line, delimiters=",\t") - reader = csv.reader(csvfile, dialect, quoting=csv.QUOTE_NONE, escapechar='\\') # create reader + reader = csv.reader( + csvfile, dialect, quoting=csv.QUOTE_NONE, escapechar="\\" + ) # create reader for row in reader: kvstore[row[0]] = row[1] return kvstore @@ -80,7 +90,7 @@ def write_key_value_file(csvfile, dictionary, append=False): Returns: None: No return """ - writer = csv.writer(csvfile, delimiter=",", quoting=csv.QUOTE_NONE, escapechar='\\') + writer = csv.writer(csvfile, delimiter=",", quoting=csv.QUOTE_NONE, escapechar="\\") if not append: writer.writerow(["key", "value"]) for key, val in dictionary.items(): diff --git a/pupil_src/shared_modules/file_methods.py b/pupil_src/shared_modules/file_methods.py index e5b41c1399..c916ec7060 100644 --- a/pupil_src/shared_modules/file_methods.py +++ b/pupil_src/shared_modules/file_methods.py @@ -353,17 +353,13 @@ def _deep_copy_serialized_dict(self): return Serialized_Dict(python_dict=dict_copy) def _deep_copy_dict(self): - def unpacking_ext_hook(self, code, data): if code == self.MSGPACK_EXT_CODE: return type(self)(msgpack_bytes=data)._deep_copy_dict() return msgpack.ExtType(code, data) return msgpack.unpackb( - self._ser_data, - raw=False, - use_list=False, - ext_hook=unpacking_ext_hook, + self._ser_data, raw=False, use_list=False, ext_hook=unpacking_ext_hook, ) diff --git a/pupil_src/shared_modules/gaze_mapping/__init__.py b/pupil_src/shared_modules/gaze_mapping/__init__.py index 5043b2f2f8..e9926e7587 100644 --- a/pupil_src/shared_modules/gaze_mapping/__init__.py +++ b/pupil_src/shared_modules/gaze_mapping/__init__.py @@ -19,12 +19,18 @@ def registered_gazer_classes() -> list: return gazer_base.GazerBase.registered_gazer_classes() -def registered_gazer_labels_by_class_names() -> dict: - return {cls.__name__: cls.label for cls in registered_gazer_classes()} +def user_selectable_gazer_classes() -> list: + gazers = registered_gazer_classes() + gazers = filter(lambda g: g is not GazerHMD3D, gazers) + return list(gazers) -def registered_gazer_classes_by_class_name() -> dict: - return {cls.__name__: cls for cls in registered_gazer_classes()} +def gazer_labels_by_class_names(gazers: list) -> dict: + return {cls.__name__: cls.label for cls in gazers} + + +def gazer_classes_by_class_name(gazers: list) -> dict: + return {cls.__name__: cls for cls in gazers} default_gazer_class = Gazer3D diff --git a/pupil_src/shared_modules/gaze_mapping/gazer_3d/gazer_hmd.py b/pupil_src/shared_modules/gaze_mapping/gazer_3d/gazer_hmd.py index f232bcc956..bd02e730b9 100644 --- a/pupil_src/shared_modules/gaze_mapping/gazer_3d/gazer_hmd.py +++ b/pupil_src/shared_modules/gaze_mapping/gazer_3d/gazer_hmd.py @@ -23,6 +23,7 @@ from gaze_mapping.gazer_base import ( GazerBase, Model, + CalibrationError, NotEnoughDataError, FitDidNotConvergeError, ) @@ -46,6 +47,13 @@ logger = logging.getLogger(__name__) +class MissingEyeTranslationsError(CalibrationError): + message = ( + "GazerHMD3D can only be calibrated if it is " + "initialised with valid eye translations." + ) + + class ModelHMD3D_Binocular(Model3D_Binocular): def __init__(self, *, intrinsics, eye_translations): self.intrinsics = intrinsics @@ -53,6 +61,8 @@ def __init__(self, *, intrinsics, eye_translations): self._is_fitted = False def _fit(self, X, Y): + if self.eye_translations is None: + raise MissingEyeTranslationsError() assert X.shape[1] == _BINOCULAR_FEATURE_COUNT, X unprojected_ref_points = Y @@ -109,7 +119,7 @@ class GazerHMD3D(Gazer3D): def _gazer_description_text(cls) -> str: return "Gaze mapping built specifically for HMD-Eyes." - def __init__(self, g_pool, *, eye_translations, calib_data=None, params=None): + def __init__(self, g_pool, *, eye_translations=None, calib_data=None, params=None): self.__eye_translations = eye_translations super().__init__(g_pool, calib_data=calib_data, params=params) diff --git a/pupil_src/shared_modules/gaze_mapping/gazer_base.py b/pupil_src/shared_modules/gaze_mapping/gazer_base.py index a4de243589..0abca2189d 100644 --- a/pupil_src/shared_modules/gaze_mapping/gazer_base.py +++ b/pupil_src/shared_modules/gaze_mapping/gazer_base.py @@ -44,6 +44,14 @@ class NotEnoughDataError(CalibrationError): message = "Not sufficient data available." +class NotEnoughPupilDataError(NotEnoughDataError): + message = "Not sufficient pupil data available." + + +class NotEnoughReferenceDataError(NotEnoughDataError): + message = "Not sufficient reference data available." + + class FitDidNotConvergeError(CalibrationError): message = "Model fit did not converge." @@ -194,18 +202,25 @@ def __init__( self._announce_calibration_setup(calib_data=calib_data) try: self.fit_on_calib_data(calib_data) - except CalibrationError: + except CalibrationError as err: if raise_calibration_error: raise # Let offline calibration handle this one! logger.error("Calibration Failed!") self.alive = False - self._announce_calibration_failure(reason=CalibrationError.__name__) + self._announce_calibration_failure(reason=err.message) except Exception as err: import traceback - self._announce_calibration_failure(reason=err.__class__.__name__) logger.debug(traceback.format_exc()) - raise CalibrationError() from err + if raise_calibration_error: + raise CalibrationError() from err # Let offline calibration handle this one! + logger.error("Calibration Failed!") + self.alive = False + try: + reason = err.args[0] + except (AttributeError, IndexError): + reason = err.__class__.__name__ + self._announce_calibration_failure(reason=reason) else: self._announce_calibration_success() self._announce_calibration_result(params=self.get_params()) @@ -214,8 +229,9 @@ def __init__( else: raise ValueError("Requires either `calib_data` or `params`") - # used by pupil_data_relay for gaze mapping - g_pool.active_gaze_mapping_plugin = self + if self.alive: + # Used by pupil_data_relay for gaze mapping. + g_pool.active_gaze_mapping_plugin = self def get_init_dict(self): return {"params": self.get_params()} @@ -257,7 +273,9 @@ def fit_on_calib_data(self, calib_data): pupil_data, self.g_pool.min_calibration_confidence ) if not pupil_data: - raise NotEnoughDataError + raise NotEnoughPupilDataError + if not ref_data: + raise NotEnoughReferenceDataError # match pupil to reference data (left, right, and binocular) matches = self.match_pupil_to_ref(pupil_data, ref_data) if matches.binocular[0]: diff --git a/pupil_src/shared_modules/gaze_mapping/utils.py b/pupil_src/shared_modules/gaze_mapping/utils.py index df8139ab34..56dbeb657c 100644 --- a/pupil_src/shared_modules/gaze_mapping/utils.py +++ b/pupil_src/shared_modules/gaze_mapping/utils.py @@ -32,8 +32,8 @@ def _filter_pupil_list_by_confidence(pupil_list, threshold): def _match_data_batch(pupil_list, ref_list): - assert pupil_list - assert ref_list + assert pupil_list, "No pupil data to match" + assert ref_list, "No reference data to match" pupil0 = [p for p in pupil_list if p["id"] == 0] pupil1 = [p for p in pupil_list if p["id"] == 1] @@ -45,9 +45,9 @@ def _match_data_batch(pupil_list, ref_list): num_mono_right = len(matched_pupil0_data[0]) num_mono_left = len(matched_pupil1_data[0]) - logger.info(f"Collected {num_bino} binocular references.") - logger.info(f"Collected {num_mono_right} right eye monocular references.") - logger.info(f"Collected {num_mono_left} left eye monocular references.") + logger.debug(f"Collected {num_bino} binocular references.") + logger.debug(f"Collected {num_mono_right} right eye monocular references.") + logger.debug(f"Collected {num_mono_left} left eye monocular references.") return ( matched_binocular_data, diff --git a/pupil_src/shared_modules/gaze_producer/controller/calculate_all_controller.py b/pupil_src/shared_modules/gaze_producer/controller/calculate_all_controller.py index 8545db742e..d9aac8e34e 100644 --- a/pupil_src/shared_modules/gaze_producer/controller/calculate_all_controller.py +++ b/pupil_src/shared_modules/gaze_producer/controller/calculate_all_controller.py @@ -29,25 +29,22 @@ def __init__( def calculate_all(self): """ - (Re)Calculate all calibrations and gaze mappings with their respective - current settings. If there are no reference locations in the storage, - first the current reference detector is run. + Detect reference locations if none available. Then (re)calculate all + calibrations and gaze mappers. """ - if self.does_detect_references: + if self._reference_location_storage.is_empty: task = self._reference_detection_controller.start_detection() task.add_observer("on_completed", self._on_reference_detection_completed) else: self._calculate_all_calibrations() - @property - def does_detect_references(self): + def calculate_all_if_references_available(self): """ - True if the controller would first detect reference locations in calculate_all() + (Re)calculate all calibrations and gaze mappers if reference locations are + available. """ - at_least_one_reference_location = any( - True for _ in self._reference_location_storage - ) - return not at_least_one_reference_location + if not self._reference_location_storage.is_empty: + self._calculate_all_calibrations() def _on_reference_detection_completed(self, _): self._calculate_all_calibrations() diff --git a/pupil_src/shared_modules/gaze_producer/controller/calibration_controller.py b/pupil_src/shared_modules/gaze_producer/controller/calibration_controller.py index 57349a678d..a2f6f951b6 100644 --- a/pupil_src/shared_modules/gaze_producer/controller/calibration_controller.py +++ b/pupil_src/shared_modules/gaze_producer/controller/calibration_controller.py @@ -51,7 +51,9 @@ def on_calibration_completed(status_and_result): ) task.add_observer("on_completed", on_calibration_completed) task.add_observer("on_exception", tasklib.raise_exception) - self._task_manager.add_task(task) + self._task_manager.add_task( + task, identifier=f"{calibration.unique_id}-calibration" + ) return task def on_calibration_computed(self, calibration): diff --git a/pupil_src/shared_modules/gaze_producer/controller/gaze_mapper_controller.py b/pupil_src/shared_modules/gaze_producer/controller/gaze_mapper_controller.py index 3482883974..cbed9189d1 100644 --- a/pupil_src/shared_modules/gaze_producer/controller/gaze_mapper_controller.py +++ b/pupil_src/shared_modules/gaze_producer/controller/gaze_mapper_controller.py @@ -53,15 +53,15 @@ def calculate(self, gaze_mapper): if calibration is None: self._abort_calculation( gaze_mapper, - "The calibration was not found for the gaze mapper '{}', " - "please select a different calibration!".format(gaze_mapper.name), + "The calibration was not found for the gaze mapper " + f"'{gaze_mapper.name}', please select a different calibration!", ) return None if calibration.params is None: self._abort_calculation( gaze_mapper, - "You first need to calculate calibration '{}' before calculating the " - "mapper '{}'".format(calibration.name, gaze_mapper.name), + f"You first need to calculate calibration '{calibration.name}' before " + f"calculating the mapper '{gaze_mapper.name}'", ) return None try: @@ -69,8 +69,8 @@ def calculate(self, gaze_mapper): except worker.map_gaze.NotEnoughPupilData: self._abort_calculation(gaze_mapper, "There is no pupil data to be mapped!") return None - self._task_manager.add_task(task) - logger.info("Start gaze mapping for '{}'".format(gaze_mapper.name)) + self._task_manager.add_task(task, identifier=f"{gaze_mapper.unique_id}-mapping") + logger.info(f"Start gaze mapping for '{gaze_mapper.name}'") def _abort_calculation(self, gaze_mapper, error_message): logger.error(error_message) @@ -92,7 +92,7 @@ def _create_mapping_task(self, gaze_mapper, calibration): task = worker.map_gaze.create_task(gaze_mapper, calibration) def on_yield_gaze(mapped_gaze_ts_and_data): - gaze_mapper.status = "Mapping {:.0f}% complete".format(task.progress * 100) + gaze_mapper.status = f"Mapping {task.progress * 100:.0f}% complete" for timestamp, gaze_datum in mapped_gaze_ts_and_data: gaze_mapper.gaze.append(gaze_datum) gaze_mapper.gaze_ts.append(timestamp) @@ -153,24 +153,28 @@ def on_gaze_mapping_calculated(self, gaze_mapper): def validate_gaze_mapper(self, gaze_mapper): def validation_completed(accuracy_and_precision): accuracy, precision = accuracy_and_precision - gaze_mapper.accuracy_result = "{:.1f}° from {} / {} samples".format( - accuracy.result, accuracy.num_used, accuracy.num_total + gaze_mapper.accuracy_result = ( + f"{accuracy.result:.1f}° from {accuracy.num_used} / " + f"{accuracy.num_total} samples" ) - gaze_mapper.precision_result = "{:.1f}° from {} / {} samples".format( - precision.result, precision.num_used, precision.num_total + gaze_mapper.precision_result = ( + f"{precision.result:.1f}° from {precision.num_used} / " + f"{precision.num_total} samples" ) calibration = self.get_valid_calibration_or_none(gaze_mapper) if calibration is None: logger.error( - f"Could not validate gaze mapper {gaze_mapper.name}; Calibration was not found, please select a different calibration." + f"Could not validate gaze mapper {gaze_mapper.name};" + " Calibration was not found, please select a different calibration." ) return if calibration.params is None: logger.error( - f"Could not validate gaze mapper {gaze_mapper.name}; You first need to calculate calibration '{calibration.name}'" + f"Could not validate gaze mapper {gaze_mapper.name};" + " You first need to calculate calibration '{calibration.name}'" ) return @@ -179,7 +183,9 @@ def validation_completed(accuracy_and_precision): ) task.add_observer("on_completed", validation_completed) task.add_observer("on_exception", tasklib.raise_exception) - self._task_manager.add_task(task) + self._task_manager.add_task( + task, identifier=f"{gaze_mapper.unique_id}-validating" + ) def get_valid_calibration_or_none(self, gaze_mapper): return self._calibration_storage.get_or_none(gaze_mapper.calibration_unique_id) diff --git a/pupil_src/shared_modules/gaze_producer/controller/reference_location_controllers.py b/pupil_src/shared_modules/gaze_producer/controller/reference_location_controllers.py index fd37a94539..19e141a326 100644 --- a/pupil_src/shared_modules/gaze_producer/controller/reference_location_controllers.py +++ b/pupil_src/shared_modules/gaze_producer/controller/reference_location_controllers.py @@ -41,7 +41,9 @@ def on_detection_completed(_): ) self._detection_task.add_observer("on_yield", on_detection_yields) self._detection_task.add_observer("on_completed", on_detection_completed) - self._task_manager.add_task(self._detection_task) + self._task_manager.add_task( + self._detection_task, identifier="reference_detection" + ) return self._detection_task def on_detection_started(self, detection_task): diff --git a/pupil_src/shared_modules/gaze_producer/gaze_from_offline_calibration.py b/pupil_src/shared_modules/gaze_producer/gaze_from_offline_calibration.py index 6b925f3247..731d0aa173 100644 --- a/pupil_src/shared_modules/gaze_producer/gaze_from_offline_calibration.py +++ b/pupil_src/shared_modules/gaze_producer/gaze_from_offline_calibration.py @@ -12,10 +12,9 @@ from gaze_producer import controller, model from gaze_producer import ui as plugin_ui from gaze_producer.gaze_producer_base import GazeProducerBase -from observable import Observable from plugin_timeline import PluginTimeline from pupil_recording import PupilRecording -from tasklib.manager import PluginTaskManager +from tasklib.manager import UniqueTaskManager # IMPORTANT: GazeProducerBase needs to be THE LAST in the list of bases, otherwise @@ -37,7 +36,7 @@ def __init__(self, g_pool): self.inject_plugin_dependencies() - self._task_manager = PluginTaskManager(plugin=self) + self._task_manager = UniqueTaskManager(plugin=self) self._recording_uuid = PupilRecording(g_pool.rec_dir).meta_info.recording_uuid @@ -50,26 +49,31 @@ def __init__(self, g_pool): "pupil_positions", g_pool.rec_dir, plugin=self ) self._pupil_changed_listener.add_observer( - "on_data_changed", self._calculate_all_controller.calculate_all + "on_data_changed", + self._calculate_all_controller.calculate_all_if_references_available, ) def _setup_storages(self): self._reference_location_storage = model.ReferenceLocationStorage( - self.g_pool.rec_dir, plugin=self + self.g_pool.rec_dir ) self._calibration_storage = model.CalibrationStorage( rec_dir=self.g_pool.rec_dir, - plugin=self, get_recording_index_range=self._recording_index_range, recording_uuid=self._recording_uuid, ) self._gaze_mapper_storage = model.GazeMapperStorage( self._calibration_storage, rec_dir=self.g_pool.rec_dir, - plugin=self, get_recording_index_range=self._recording_index_range, ) + def cleanup(self): + super().cleanup() + self._reference_location_storage.save_to_disk() + self._calibration_storage.save_to_disk() + self._gaze_mapper_storage.save_to_disk() + def _setup_controllers(self): self._reference_detection_controller = controller.ReferenceDetectionController( self._task_manager, self._reference_location_storage @@ -203,8 +207,9 @@ def _current_trim_mark_range(self): def _index_range_as_str(self, index_range): from_index, to_index = index_range - return "{} - {}".format( - self._index_time_as_str(from_index), self._index_time_as_str(to_index) + return ( + f"{self._index_time_as_str(from_index)} - " + f"{self._index_time_as_str(to_index)}" ) def _index_time_as_str(self, index): @@ -213,4 +218,4 @@ def _index_time_as_str(self, index): time = ts - min_ts minutes = abs(time // 60) # abs because it's sometimes -0 seconds = round(time % 60) - return "{:02.0f}:{:02.0f}".format(minutes, seconds) + return f"{minutes:02.0f}:{seconds:02.0f}" diff --git a/pupil_src/shared_modules/gaze_producer/model/calibration.py b/pupil_src/shared_modules/gaze_producer/model/calibration.py index 8f3e6f1df9..76c74a833b 100644 --- a/pupil_src/shared_modules/gaze_producer/model/calibration.py +++ b/pupil_src/shared_modules/gaze_producer/model/calibration.py @@ -51,14 +51,6 @@ def __init__( self.__is_offline_calibration = is_offline_calibration self.__calib_params = calib_params - # Assert all properties are consistent - try: - self.__assert_property_consistency() - except ValueError: - raise - except Exception as err: - raise ValueError(str(err)) - @property def is_offline_calibration(self) -> bool: return self.__is_offline_calibration @@ -79,12 +71,6 @@ def update( self.__is_offline_calibration = is_offline_calibration if calib_params is not ...: self.__calib_params = calib_params - try: - self.__assert_property_consistency() - except ValueError: - raise - except Exception as err: - raise ValueError(str(err)) @staticmethod def from_dict(dict_: dict) -> "Calibration": @@ -98,7 +84,6 @@ def from_dict(dict_: dict) -> "Calibration": @property def as_dict(self) -> dict: - self.__assert_property_consistency() # sanity check dict_ = {k: v(self) for (k, v) in self.__schema} return dict_ @@ -128,12 +113,3 @@ def as_tuple(self): ("is_offline_calibration", lambda self: self.__is_offline_calibration), ("calib_params", lambda self: self.__calib_params), ) - - def __assert_property_consistency(self): - if self.__is_offline_calibration: - pass - else: - if self.__calib_params is not None: - raise ValueError( - f"Unexpected calib_params argument for pre-recorded calibration" - ) diff --git a/pupil_src/shared_modules/gaze_producer/model/calibration_storage.py b/pupil_src/shared_modules/gaze_producer/model/calibration_storage.py index c1a03d5f46..04314f0916 100644 --- a/pupil_src/shared_modules/gaze_producer/model/calibration_storage.py +++ b/pupil_src/shared_modules/gaze_producer/model/calibration_storage.py @@ -20,7 +20,12 @@ from storage import Storage from gaze_producer import model from observable import Observable -from gaze_mapping import default_gazer_class, registered_gazer_labels_by_class_names +from gaze_mapping import ( + default_gazer_class, + registered_gazer_classes, + user_selectable_gazer_classes, + gazer_classes_by_class_name, +) from gaze_mapping.notifications import ( CalibrationSetupNotification, CalibrationResultNotification, @@ -33,8 +38,7 @@ class CalibrationStorage(Storage, Observable): _calibration_suffix = "plcal" - def __init__(self, rec_dir, plugin, get_recording_index_range, recording_uuid): - super().__init__(plugin) + def __init__(self, rec_dir, get_recording_index_range, recording_uuid): self._rec_dir = rec_dir self._get_recording_index_range = get_recording_index_range self._recording_uuid = str(recording_uuid) @@ -59,7 +63,7 @@ def create_default_calibration(self): ) def __create_prerecorded_calibration( - self, result_notification: CalibrationResultNotification + self, number: int, result_notification: CalibrationResultNotification ): timestamp = result_notification.timestamp @@ -68,7 +72,11 @@ def __create_prerecorded_calibration( # the easiest datum that differs between calibrations but is the same # for every start unique_id = model.Calibration.create_unique_id_from_string(str(timestamp)) - name = make_unique.by_number_at_end("Recorded Calibration", self.item_names) + + name = "Recorded Calibration" + if number > 1: + name += f" {number}" + return model.Calibration( unique_id=unique_id, name=name, @@ -76,7 +84,7 @@ def __create_prerecorded_calibration( gazer_class_name=result_notification.gazer_class_name, frame_index_range=self._get_recording_index_range(), minimum_confidence=0.8, - is_offline_calibration=True, + is_offline_calibration=False, status="Not calculated yet", calib_params=result_notification.params, ) @@ -89,14 +97,33 @@ def duplicate_calibration(self, calibration): new_calibration.unique_id = model.Calibration.create_new_unique_id() return new_calibration - def add(self, calibration): - if any(c.unique_id == calibration.unique_id for c in self._calibrations): + def add(self, calibration, overwrite=False): + for calib in self._calibrations.copy(): + if calib.unique_id != calibration.unique_id: + continue + if overwrite: + self._calibrations.remove(calib) + break logger.warning( f"Did not add calibration {calibration.name} ({calibration.unique_id})" " because it is already in the storage. Currently in storage:\n" + "\n".join(f"- {c.name} ({c.unique_id})" for c in self._calibrations) ) return + is_calib_editable = ( + self._from_same_recording(calibration) + and calibration.is_offline_calibration + ) + if is_calib_editable: + available_gazer_classes = user_selectable_gazer_classes() + else: + available_gazer_classes = registered_gazer_classes() + gazer_class_names = gazer_classes_by_class_name(available_gazer_classes).keys() + if calibration.gazer_class_name not in gazer_class_names: + logger.warning( + f"Did not add calibration {calibration.name} ({calibration.unique_id}) because gaze mapping method ({calibration.gazer_class_name}) is not available." + ) + return self._calibrations.append(calibration) self._calibrations.sort(key=lambda c: c.name) @@ -161,6 +188,7 @@ def _load_calibration_from_file(self, file_name): def _load_recorded_calibrations(self): notifications = fm.load_pldata_file(self._rec_dir, "notify") + counter = 1 for topic, data in zip(notifications.topics, notifications.data): if topic.startswith("notify."): # Remove "notify." prefix @@ -183,9 +211,10 @@ def _load_recorded_calibrations(self): logger.debug(str(err)) continue calibration = self.__create_prerecorded_calibration( - result_notification=note + number=counter, result_notification=note ) - self.add(calibration) + self.add(calibration, overwrite=True) + counter += 1 def save_to_disk(self): os.makedirs(self._calibration_folder, exist_ok=True) diff --git a/pupil_src/shared_modules/gaze_producer/model/gaze_mapper_storage.py b/pupil_src/shared_modules/gaze_producer/model/gaze_mapper_storage.py index 49df0809ab..4469dc180b 100644 --- a/pupil_src/shared_modules/gaze_producer/model/gaze_mapper_storage.py +++ b/pupil_src/shared_modules/gaze_producer/model/gaze_mapper_storage.py @@ -23,8 +23,8 @@ class GazeMapperStorage(SingleFileStorage, Observable): - def __init__(self, calibration_storage, rec_dir, plugin, get_recording_index_range): - super().__init__(rec_dir, plugin) + def __init__(self, calibration_storage, rec_dir, get_recording_index_range): + super().__init__(rec_dir) self._calibration_storage = calibration_storage self._get_recording_index_range = get_recording_index_range self._gaze_mappers = [] diff --git a/pupil_src/shared_modules/gaze_producer/model/legacy/calibration_storage_updater.py b/pupil_src/shared_modules/gaze_producer/model/legacy/calibration_storage_updater.py index c353808b91..6c7388322f 100644 --- a/pupil_src/shared_modules/gaze_producer/model/legacy/calibration_storage_updater.py +++ b/pupil_src/shared_modules/gaze_producer/model/legacy/calibration_storage_updater.py @@ -33,7 +33,7 @@ def update_offline_calibrations_to_latest_version(cls, rec_dir): if not calib_dir.is_dir(): return # TODO: Raise exception - "calibrations" must be a directory - for calib_path in sorted(calib_dir.glob("*.plcal")): + for calib_path in sorted(calib_dir.glob("[!.]*.plcal")): calib_dict = fm.load_object(calib_path) version = calib_dict.get("version", None) data = calib_dict.get("data", None) diff --git a/pupil_src/shared_modules/gaze_producer/model/reference_location_storage.py b/pupil_src/shared_modules/gaze_producer/model/reference_location_storage.py index fa39caf984..79531781ec 100644 --- a/pupil_src/shared_modules/gaze_producer/model/reference_location_storage.py +++ b/pupil_src/shared_modules/gaze_producer/model/reference_location_storage.py @@ -19,8 +19,8 @@ class ReferenceLocationStorage(SingleFileStorage, Observable): - def __init__(self, rec_dir, plugin): - super().__init__(rec_dir, plugin) + def __init__(self, rec_dir): + super().__init__(rec_dir) self._reference_locations = {} self._load_from_disk() diff --git a/pupil_src/shared_modules/gaze_producer/ui/calibration_menu.py b/pupil_src/shared_modules/gaze_producer/ui/calibration_menu.py index 07f59a19b2..8c6354ab5e 100644 --- a/pupil_src/shared_modules/gaze_producer/ui/calibration_menu.py +++ b/pupil_src/shared_modules/gaze_producer/ui/calibration_menu.py @@ -13,7 +13,11 @@ from pyglui import ui from gaze_producer import ui as plugin_ui -from gaze_mapping import registered_gazer_labels_by_class_names +from gaze_mapping import ( + gazer_labels_by_class_names, + registered_gazer_classes, + user_selectable_gazer_classes, +) logger = logging.getLogger(__name__) @@ -85,12 +89,15 @@ def _create_range_selector(self, calibration): ) def _create_mapping_method_selector(self, calibration): + gazers = user_selectable_gazer_classes() + gazers_map = gazer_labels_by_class_names(gazers) + return ui.Selector( "gazer_class_name", calibration, label="Gaze Mapping", - labels=list(registered_gazer_labels_by_class_names().values()), - selection=list(registered_gazer_labels_by_class_names().keys()), + labels=list(gazers_map.values()), + selection=list(gazers_map.keys()), ) def _create_min_confidence_slider(self, calibration): @@ -120,8 +127,10 @@ def _render_ui_calibration_from_other_recording(self, calibration, menu): ) def _info_text_for_calibration_from_other_recording(self, calibration): + gazers = registered_gazer_classes() gazer_class_name = calibration.gazer_class_name - gazer_label = registered_gazer_labels_by_class_names()[gazer_class_name] + gazer_label = gazer_labels_by_class_names(gazers)[gazer_class_name] + if calibration.params: return ( f"This {gazer_label} calibration was copied from another recording. " @@ -139,18 +148,39 @@ def _render_ui_online_calibration(self, calibration, menu): menu.append(ui.Info_Text(self._info_text_for_online_calibration(calibration))) def _info_text_for_online_calibration(self, calibration): + gazers = registered_gazer_classes() gazer_class_name = calibration.gazer_class_name - gazer_label = registered_gazer_labels_by_class_names()[gazer_class_name] + gazer_label = gazer_labels_by_class_names(gazers)[gazer_class_name] + return ( f"This {gazer_label} calibration was created before or during the " "recording. It is ready to be used in gaze mappers." ) def _on_click_duplicate_button(self): - if self._calibration_controller.is_from_same_recording(self.current_item): - super()._on_click_duplicate_button() - else: + if not self._calibration_controller.is_from_same_recording(self.current_item): logger.error("Cannot duplicate calibrations from other recordings!") + return + + if not self.current_item.is_offline_calibration: + logger.error("Cannot duplicate pre-recorded calibrations!") + return + + super()._on_click_duplicate_button() + + def _on_click_delete(self): + if self.current_item is None: + return + + if not self._calibration_controller.is_from_same_recording(self.current_item): + logger.error("Cannot delete calibrations from other recordings!") + return + + if not self.current_item.is_offline_calibration: + logger.error("Cannot delete pre-recorded calibrations!") + return + + super()._on_click_delete() def _on_name_change(self, new_name): self._calibration_storage.rename(self.current_item, new_name) diff --git a/pupil_src/shared_modules/gaze_producer/ui/on_top_menu.py b/pupil_src/shared_modules/gaze_producer/ui/on_top_menu.py index 2cdafc7329..0ab0479f9d 100644 --- a/pupil_src/shared_modules/gaze_producer/ui/on_top_menu.py +++ b/pupil_src/shared_modules/gaze_producer/ui/on_top_menu.py @@ -19,6 +19,7 @@ def __init__(self, calculate_all_controller, reference_location_storage): self._calculate_all_button = None self._calculate_all_controller = calculate_all_controller + self._reference_location_storage = reference_location_storage reference_location_storage.add_observer( "add", self._on_reference_storage_changed @@ -56,7 +57,7 @@ def _create_calculate_all_button(self): @property def _calculate_all_button_label(self): - if self._calculate_all_controller.does_detect_references: + if self._reference_location_storage.is_empty: return "Detect References, Calculate All Calibrations and Mappings" else: return "Calculate All Calibrations and Mappings" diff --git a/pupil_src/shared_modules/gaze_producer/ui/storage_edit_menu.py b/pupil_src/shared_modules/gaze_producer/ui/storage_edit_menu.py index 09eb7e129a..12b8e3284e 100644 --- a/pupil_src/shared_modules/gaze_producer/ui/storage_edit_menu.py +++ b/pupil_src/shared_modules/gaze_producer/ui/storage_edit_menu.py @@ -96,6 +96,8 @@ def _on_click_duplicate_button(self): self.render() def _on_click_delete(self): + if self.current_item is None: + return current_index = self.items.index(self.current_item) self._storage.delete(self.current_item) current_index = min(current_index, len(self.items) - 1) diff --git a/pupil_src/shared_modules/gaze_producer/worker/create_calibration.py b/pupil_src/shared_modules/gaze_producer/worker/create_calibration.py index df85b88986..2b3cf1c059 100644 --- a/pupil_src/shared_modules/gaze_producer/worker/create_calibration.py +++ b/pupil_src/shared_modules/gaze_producer/worker/create_calibration.py @@ -15,7 +15,11 @@ import player_methods as pm import tasklib.background from gaze_producer import model -from gaze_mapping import registered_gazer_classes_by_class_name, CalibrationError +from gaze_mapping import ( + gazer_classes_by_class_name, + registered_gazer_classes, + CalibrationError, +) from methods import normalize from .fake_gpool import FakeGPool @@ -42,6 +46,7 @@ def create_task(calibration, all_reference_locations): ] fake_gpool = FakeGPool.from_g_pool(g_pool) + fake_gpool.min_calibration_confidence = calibration.minimum_confidence args = ( fake_gpool, @@ -49,7 +54,7 @@ def create_task(calibration, all_reference_locations): ref_dicts_in_calib_range, pupil_pos_in_calib_range, ) - name = "Create calibration {}".format(calibration.name) + name = f"Create calibration {calibration.name}" return tasklib.background.create(name, _create_calibration, args=args) @@ -67,7 +72,7 @@ def _create_calibration( # This is needed to support user-provided gazers fake_gpool.import_runtime_plugins() - gazers_by_name = registered_gazer_classes_by_class_name() + gazers_by_name = gazer_classes_by_class_name(registered_gazer_classes()) try: gazer_class = gazers_by_name[gazer_class_name] diff --git a/pupil_src/shared_modules/gaze_producer/worker/detect_circle_markers.py b/pupil_src/shared_modules/gaze_producer/worker/detect_circle_markers.py index 633bd71346..5d74d5d014 100644 --- a/pupil_src/shared_modules/gaze_producer/worker/detect_circle_markers.py +++ b/pupil_src/shared_modules/gaze_producer/worker/detect_circle_markers.py @@ -95,12 +95,10 @@ def _receive_detections(self): return elif topic == "exception": logger.warning( - "Calibration marker detection raised exception:\n{}".format( - msg["reason"] - ) + f"Calibration marker detection raised exception:\n{msg['reason']}" ) logger.info("Marker detection was interrupted") - logger.debug("Reason: {}".format(msg.get("reason", "n/a"))) + logger.debug(f"Reason: {msg.get('reason', 'n/a')}") self.on_canceled_or_killed() return diff --git a/pupil_src/shared_modules/gaze_producer/worker/map_gaze.py b/pupil_src/shared_modules/gaze_producer/worker/map_gaze.py index 88c0dd8c9e..2c17309389 100644 --- a/pupil_src/shared_modules/gaze_producer/worker/map_gaze.py +++ b/pupil_src/shared_modules/gaze_producer/worker/map_gaze.py @@ -11,7 +11,7 @@ import file_methods as fm import player_methods as pm import tasklib -from gaze_mapping import registered_gazer_classes_by_class_name +from gaze_mapping import gazer_classes_by_class_name, registered_gazer_classes from .fake_gpool import FakeGPool @@ -44,7 +44,7 @@ def create_task(gaze_mapper, calibration): gaze_mapper.manual_correction_x, gaze_mapper.manual_correction_y, ) - name = "Create gaze mapper {}".format(gaze_mapper.name) + name = f"Create gaze mapper {gaze_mapper.name}" return tasklib.background.create( name, _map_gaze, args=args, pass_shared_memory=True, ) @@ -60,7 +60,7 @@ def _map_gaze( shared_memory, ): fake_gpool.import_runtime_plugins() - gazers_by_name = registered_gazer_classes_by_class_name() + gazers_by_name = gazer_classes_by_class_name(registered_gazer_classes()) gazer_cls = gazers_by_name[gazer_class_name] gazer = gazer_cls(fake_gpool, params=gazer_params) diff --git a/pupil_src/shared_modules/gaze_producer/worker/validate_gaze.py b/pupil_src/shared_modules/gaze_producer/worker/validate_gaze.py index 132aed7ccd..7d650ca301 100644 --- a/pupil_src/shared_modules/gaze_producer/worker/validate_gaze.py +++ b/pupil_src/shared_modules/gaze_producer/worker/validate_gaze.py @@ -13,7 +13,7 @@ from methods import normalize import player_methods as pm -from gaze_mapping import registered_gazer_classes_by_class_name +from gaze_mapping import gazer_classes_by_class_name, registered_gazer_classes from .fake_gpool import FakeGPool @@ -48,7 +48,7 @@ def create_bg_task(gaze_mapper, calibration, reference_location_storage): ) return tasklib.background.create( - "validate gaze mapper '{}'".format(gaze_mapper.name), validate, args=args, + f"validate gaze mapper '{gaze_mapper.name}'", validate, args=args, ) @@ -61,7 +61,7 @@ def validate( refs_in_validation_range, ): g_pool.import_runtime_plugins() - gazers_by_name = registered_gazer_classes_by_class_name() + gazers_by_name = gazer_classes_by_class_name(registered_gazer_classes()) gazer_class = gazers_by_name[gazer_class_name] pupil_list = pupils_in_validation_range diff --git a/pupil_src/shared_modules/gl_utils/utils.py b/pupil_src/shared_modules/gl_utils/utils.py index 96450e62fe..f761a276c9 100644 --- a/pupil_src/shared_modules/gl_utils/utils.py +++ b/pupil_src/shared_modules/gl_utils/utils.py @@ -188,6 +188,12 @@ class Coord_System(object): def __init__(self, left, right, bottom, top): super(Coord_System, self).__init__() + if left == right: + left -= 1 + right += 1 + if top == bottom: + top -= 1 + bottom += 1 self.bounds = left, right, bottom, top def __enter__(self): diff --git a/pupil_src/shared_modules/gprof2dot.py b/pupil_src/shared_modules/gprof2dot.py index 686854dee9..6d8f676e62 100644 --- a/pupil_src/shared_modules/gprof2dot.py +++ b/pupil_src/shared_modules/gprof2dot.py @@ -2402,9 +2402,14 @@ def parse_samples(self): self.consume() while not self.lookahead().startswith("CPU"): - rank, percent_self, percent_accum, count, traceid, method = ( - self.lookahead().split() - ) + ( + rank, + percent_self, + percent_accum, + count, + traceid, + method, + ) = self.lookahead().split() self.samples[int(traceid)] = (int(count), method) self.consume() diff --git a/pupil_src/shared_modules/head_pose_tracker/ui/offline_head_pose_tracker_timeline.py b/pupil_src/shared_modules/head_pose_tracker/ui/offline_head_pose_tracker_timeline.py index 72b4d974d3..26f055225d 100644 --- a/pupil_src/shared_modules/head_pose_tracker/ui/offline_head_pose_tracker_timeline.py +++ b/pupil_src/shared_modules/head_pose_tracker/ui/offline_head_pose_tracker_timeline.py @@ -95,9 +95,10 @@ def _create_progress_indication(self): return RangeElementFrameIdx() def _on_detection_started(self): - self._frame_start, frame_end = ( - self._general_settings.detection_frame_index_range - ) + ( + self._frame_start, + frame_end, + ) = self._general_settings.detection_frame_index_range self._frame_count = frame_end - self._frame_start + 1 def _on_storage_changed(self): @@ -169,9 +170,10 @@ def _on_localization_reset(self): self.render_parent_timeline() def _on_localization_started(self): - self._frame_start, frame_end = ( - self._general_settings.localization_frame_index_range - ) + ( + self._frame_start, + frame_end, + ) = self._general_settings.localization_frame_index_range self._frame_count = frame_end - self._frame_start + 1 def _on_storage_changed(self): diff --git a/pupil_src/shared_modules/hololens_relay.py b/pupil_src/shared_modules/hololens_relay.py index 9f178c76b7..20d06a19dc 100644 --- a/pupil_src/shared_modules/hololens_relay.py +++ b/pupil_src/shared_modules/hololens_relay.py @@ -391,9 +391,7 @@ def on_recv(self, socket, ipc_pub): calib_method = "HMD_Calibration_3D" ipc_pub.notify({"subject": "start_plugin", "name": calib_method}) - ipc_pub.notify( - {"subject": "set_pupil_detection_enabled", "value": True} - ) + ipc_pub.notify({"subject": "set_pupil_detection_enabled", "value": True}) ipc_pub.notify( { "subject": "eye_process.should_start.{}".format(0), diff --git a/pupil_src/shared_modules/network_api/controller/frame_publisher_controller.py b/pupil_src/shared_modules/network_api/controller/frame_publisher_controller.py index fb9bd8d5a7..b54b4ffa27 100644 --- a/pupil_src/shared_modules/network_api/controller/frame_publisher_controller.py +++ b/pupil_src/shared_modules/network_api/controller/frame_publisher_controller.py @@ -20,11 +20,8 @@ class FramePublisherController(Observable): - def on_frame_publisher_did_start(self, format: FrameFormat): - logger.debug(f"on_frame_publisher_did_start({format})") - - def on_frame_publisher_did_stop(self): - logger.debug(f"on_frame_publisher_did_stop") + def on_format_changed(self): + logger.debug(f"on_format_changed({self.__frame_format})") def __init__(self, format="jpeg", **kwargs): self.__frame_format = FrameFormat(format) @@ -33,9 +30,6 @@ def __init__(self, format="jpeg", **kwargs): def get_init_dict(self): return {"format": self.__frame_format.value} - def cleanup(self): - self.on_frame_publisher_did_stop() - @property def frame_format(self): return self.__frame_format @@ -43,7 +37,7 @@ def frame_format(self): @frame_format.setter def frame_format(self, value): self.__frame_format = FrameFormat(value) - self.on_frame_publisher_did_start(format=self.__frame_format) + self.on_format_changed() def create_world_frame_dicts_from_frame(self, frame) -> T.List[dict]: if not frame: diff --git a/pupil_src/shared_modules/network_api/network_api_plugin.py b/pupil_src/shared_modules/network_api/network_api_plugin.py index 488d9aeadc..2047dbd93e 100644 --- a/pupil_src/shared_modules/network_api/network_api_plugin.py +++ b/pupil_src/shared_modules/network_api/network_api_plugin.py @@ -11,9 +11,7 @@ import logging from plugin import Plugin -from pyglui import ui -from .model import FrameFormat from .controller import FramePublisherController from .controller import PupilRemoteController from .ui import FramePublisherMenu @@ -39,12 +37,12 @@ def __init__(self, g_pool, **kwargs): # Frame Publisher setup self.__frame_publisher = FramePublisherController(**kwargs) self.__frame_publisher.add_observer( - "on_frame_publisher_did_start", self.on_frame_publisher_did_start - ) - self.__frame_publisher.add_observer( - "on_frame_publisher_did_stop", self.on_frame_publisher_did_stop + "on_format_changed", self.frame_publisher_announce_current_format ) + # Let existing eye-processes know about current frame publishing format + self.frame_publisher_announce_current_format() + # Pupil Remote setup self.__pupil_remote = PupilRemoteController(g_pool, **kwargs) self.__pupil_remote.add_observer( @@ -65,7 +63,7 @@ def get_init_dict(self): } def cleanup(self): - self.__frame_publisher.cleanup() + self.frame_publisher_announce_stop() self.__frame_publisher = None self.__pupil_remote.cleanup() self.__pupil_remote = None @@ -105,16 +103,21 @@ def on_notify(self, notification): Any other notification received though the reqrepl port. """ if notification["subject"].startswith("eye_process.started"): - # trigger notification - self.__frame_publisher.frame_format = self.__frame_publisher.frame_format + # Let newly started eye-processes know about current frame publishing format + self.frame_publisher_announce_current_format() elif notification["subject"] == "frame_publishing.set_format": # update format and trigger notification self.__frame_publisher.frame_format = notification["format"] - def on_frame_publisher_did_start(self, format: FrameFormat): - self.notify_all({"subject": "frame_publishing.started", "format": format.value}) + def frame_publisher_announce_current_format(self, *_): + self.notify_all( + { + "subject": "frame_publishing.started", + "format": self.__frame_publisher.frame_format.value, + } + ) - def on_frame_publisher_did_stop(self): + def frame_publisher_announce_stop(self): self.notify_all({"subject": "frame_publishing.stopped"}) def on_pupil_remote_server_did_start(self, address: str): diff --git a/pupil_src/shared_modules/plugin.py b/pupil_src/shared_modules/plugin.py index b96b77c297..0d2746e737 100644 --- a/pupil_src/shared_modules/plugin.py +++ b/pupil_src/shared_modules/plugin.py @@ -493,7 +493,7 @@ def import_runtime_plugins(plugin_dir): and issubclass(member, Plugin) and member.__name__ != "Plugin" ): - logger.info("Added: {}".format(member)) + logger.debug("Added: {}".format(member)) runtime_plugins.append(member) except Exception as e: logger.warning("Failed to load '{}'. Reason: '{}' ".format(d, e)) diff --git a/pupil_src/shared_modules/pupil_detector_plugins/visualizer_2d.py b/pupil_src/shared_modules/pupil_detector_plugins/visualizer_2d.py index a7b50ee70b..4252e87a0c 100644 --- a/pupil_src/shared_modules/pupil_detector_plugins/visualizer_2d.py +++ b/pupil_src/shared_modules/pupil_detector_plugins/visualizer_2d.py @@ -29,11 +29,15 @@ def draw_ellipse( arcEnd=360, delta=8, ) - except ValueError: - # Happens when converting 'nan' to int - # TODO: Investigate why results are sometimes 'nan' - logger.debug(f"WARN: trying to draw ellipse with 'NaN' data: {ellipse}") - return + except Exception as e: + # Known issues: + # - There are reports of negative eye_ball axes when drawing the 3D eyeball + # outline, which will raise cv2.error. TODO: Investigate cause in detectors. + logger.debug( + "Error drawing ellipse! Skipping...\n" + f"ellipse: {ellipse}\n" + f"{type(e)}: {e}" + ) draw_polyline(pts, thickness, RGBA(*rgba)) if draw_center: @@ -43,6 +47,12 @@ def draw_ellipse( def draw_eyeball_outline(pupil_detection_result_3d): + if pupil_detection_result_3d["model_confidence"] <= 0.0: + # NOTE: if 'model_confidence' == 0, some values of the 'projected_sphere' might + # be 'nan', which will cause cv2.ellipse to crash. + # TODO: Fix in detectors. + return + draw_ellipse( ellipse=pupil_detection_result_3d["projected_sphere"], rgba=(0, 0.9, 0.1, pupil_detection_result_3d["model_confidence"]), diff --git a/pupil_src/shared_modules/pupil_recording/update/__init__.py b/pupil_src/shared_modules/pupil_recording/update/__init__.py index 831d6e05ce..d663ab3ca2 100644 --- a/pupil_src/shared_modules/pupil_recording/update/__init__.py +++ b/pupil_src/shared_modules/pupil_recording/update/__init__.py @@ -80,6 +80,7 @@ def update_recording(rec_dir: str): # Update offline calibrations to latest calibration model version from gaze_producer.model.legacy import update_offline_calibrations_to_latest_version + update_offline_calibrations_to_latest_version(rec_dir) diff --git a/pupil_src/shared_modules/recorder.py b/pupil_src/shared_modules/recorder.py index 527c27c508..bd85813abe 100644 --- a/pupil_src/shared_modules/recorder.py +++ b/pupil_src/shared_modules/recorder.py @@ -31,7 +31,10 @@ from pupil_recording.info import Version from pupil_recording.info import RecordingInfoFile -from gaze_mapping.notifications import CalibrationSetupNotification, CalibrationResultNotification +from gaze_mapping.notifications import ( + CalibrationSetupNotification, + CalibrationResultNotification, +) # from scipy.interpolate import UnivariateSpline from plugin import System_Plugin_Base @@ -350,7 +353,10 @@ def start(self): else: self.writer = MPEG_Writer(self.video_path, start_time_synced) - calibration_data_notification_classes = [CalibrationSetupNotification, CalibrationResultNotification] + calibration_data_notification_classes = [ + CalibrationSetupNotification, + CalibrationResultNotification, + ] writer = PLData_Writer(self.rec_path, "notify") for note_class in calibration_data_notification_classes: diff --git a/pupil_src/shared_modules/remote_recorder.py b/pupil_src/shared_modules/remote_recorder.py index d6fa0f2fd3..c1f3a18af2 100644 --- a/pupil_src/shared_modules/remote_recorder.py +++ b/pupil_src/shared_modules/remote_recorder.py @@ -74,8 +74,7 @@ def __init__(self, num_states_changed_callback): self._attached_rec_states = {} self._network = ndsi.Network( - formats={ndsi.DataFormat.V3}, - callbacks=(self.on_event,) + formats={ndsi.DataFormat.V3}, callbacks=(self.on_event,) ) self._network.start() diff --git a/pupil_src/shared_modules/stdlib_utils.py b/pupil_src/shared_modules/stdlib_utils.py index 16aed138bd..169f619b90 100644 --- a/pupil_src/shared_modules/stdlib_utils.py +++ b/pupil_src/shared_modules/stdlib_utils.py @@ -26,9 +26,13 @@ class sliceable_deque(collections.deque): """ deque subclass with support for slicing. """ + def __getitem__(self, index): if isinstance(index, slice): - return type(self)(itertools.islice(self, index.start, index.stop, index.step), maxlen=self.maxlen) + return type(self)( + itertools.islice(self, index.start, index.stop, index.step), + maxlen=self.maxlen, + ) return collections.deque.__getitem__(self, index) @@ -39,8 +43,12 @@ def __getitem__(self, index): class unique(collections.abc.Iterable): - - def __init__(self, it: typing.Iterable, key: Unique_Key_Getter=..., select: Unique_Select=...): + def __init__( + self, + it: typing.Iterable, + key: Unique_Key_Getter = ..., + select: Unique_Select = ..., + ): self._it = list(it) self._key = key if key is not ... else lambda elem: elem self._select = select if select is not ... else lambda x, y: x diff --git a/pupil_src/shared_modules/storage.py b/pupil_src/shared_modules/storage.py index d76691f08b..d3cc575d7b 100644 --- a/pupil_src/shared_modules/storage.py +++ b/pupil_src/shared_modules/storage.py @@ -60,12 +60,17 @@ def create_unique_id_from_string(string): class Storage(abc.ABC): - def __init__(self, plugin): - plugin.add_observer("cleanup", self._on_cleanup) - def __iter__(self): return iter(self.items) + @property + def is_empty(self): + try: + next(iter(self)) + return False + except StopIteration: + return True + @abc.abstractmethod def add(self, item): pass @@ -111,9 +116,6 @@ def _load_data_from_file(self, filepath): return None return dict_representation.get("data", None) - def _on_cleanup(self): - self.save_to_disk() - @staticmethod def get_valid_filename(file_name): """ @@ -138,8 +140,7 @@ class SingleFileStorage(Storage, abc.ABC): Storage that can save and load all items from / to a single file """ - def __init__(self, rec_dir, plugin): - super().__init__(plugin) + def __init__(self, rec_dir): self._rec_dir = rec_dir def save_to_disk(self): diff --git a/pupil_src/shared_modules/surface_tracker/background_tasks.py b/pupil_src/shared_modules/surface_tracker/background_tasks.py index 120f1210f9..7885a241e6 100644 --- a/pupil_src/shared_modules/surface_tracker/background_tasks.py +++ b/pupil_src/shared_modules/surface_tracker/background_tasks.py @@ -272,9 +272,10 @@ def save_surface_statisics_to_file(self): logger.warning("Could not make metrics dir {}".format(self.metrics_dir)) return - self.gaze_on_surfaces, self.fixations_on_surfaces = ( - self._map_gaze_and_fixations() - ) + ( + self.gaze_on_surfaces, + self.fixations_on_surfaces, + ) = self._map_gaze_and_fixations() self._export_surface_visibility() self._export_surface_gaze_distribution() diff --git a/pupil_src/shared_modules/surface_tracker/surface_file_store.py b/pupil_src/shared_modules/surface_tracker/surface_file_store.py index a30b2313a9..5463ee253d 100644 --- a/pupil_src/shared_modules/surface_tracker/surface_file_store.py +++ b/pupil_src/shared_modules/surface_tracker/surface_file_store.py @@ -64,24 +64,34 @@ def read_surfaces_from_file(self, surface_class) -> typing.Iterator[Surface]: file_path=self.file_path, serializer=self.serializer, surface_class=surface_class, - should_skip_on_invalid=False + should_skip_on_invalid=False, ) def write_surfaces_to_file(self, surfaces: typing.Iterator[Surface]): dict_from_surface = self.serializer.dict_from_surface - serialized_surfaces = [ dict_from_surface(surface) for surface in surfaces if surface.defined] + serialized_surfaces = [ + dict_from_surface(surface) for surface in surfaces if surface.defined + ] surface_definitions = self._persistent_dict_class(self.file_path) surface_definitions["surfaces"] = serialized_surfaces surface_definitions.save() # Protected API - def _read_surfaces_from_file_path(self, file_path: str, serializer: _Surface_Serializer_Base, surface_class, should_skip_on_invalid: bool = False) -> typing.Iterator[Surface]: + def _read_surfaces_from_file_path( + self, + file_path: str, + serializer: _Surface_Serializer_Base, + surface_class, + should_skip_on_invalid: bool = False, + ) -> typing.Iterator[Surface]: # TODO: Assert surface_class is a class, and is a Surface subclass def surface_from_dict(surface_dict: dict) -> typing.Optional[Surface]: try: - return serializer.surface_from_dict(surface_definition=surface_dict, surface_class=surface_class) + return serializer.surface_from_dict( + surface_definition=surface_dict, surface_class=surface_class + ) except InvalidSurfaceDefinition: if should_skip_on_invalid: return None @@ -137,7 +147,7 @@ def read_surfaces_from_file(self, surface_class) -> typing.Iterator[Surface]: file_path=self.__legacy_file_path, serializer=self.serializer, surface_class=surface_class, - should_skip_on_invalid=True # Since this file might contain older surface definitions, those can be skipped + should_skip_on_invalid=True, # Since this file might contain older surface definitions, those can be skipped ) # Private @@ -165,7 +175,9 @@ def __init__(self, parent_dir, **kwargs): # Pre-computed properties self.__supported_versions = tuple(sorted(self.__versioned_file_stores.keys())) - self.__migration_step_sequence = tuple(zip(self.__supported_versions, self.__supported_versions[1:])) + self.__migration_step_sequence = tuple( + zip(self.__supported_versions, self.__supported_versions[1:]) + ) @property def file_name(self) -> str: @@ -183,10 +195,16 @@ def read_surfaces_from_file(self, surface_class) -> typing.Iterator[Surface]: # Perform all migrations for source_version, target_version in self.__migration_step_sequence: - migration_proc = self.__migration_procedure(surface_class=surface_class, source_version=source_version, target_version=target_version) + migration_proc = self.__migration_procedure( + surface_class=surface_class, + source_version=source_version, + target_version=target_version, + ) migration_proc() - return self.__file_store_latest.read_surfaces_from_file(surface_class=surface_class) + return self.__file_store_latest.read_surfaces_from_file( + surface_class=surface_class + ) def write_surfaces_to_file(self, surfaces: typing.Iterator[Surface]): self.__file_store_latest.write_surfaces_to_file(surfaces=surfaces) @@ -198,7 +216,9 @@ def __file_store_latest(self) -> _Surface_File_Store_Base: latest_version = self.__supported_versions[-1] return self.__versioned_file_stores[latest_version] - def __migration_procedure(self, surface_class, source_version: Version, target_version: Version) -> Migration_Procedure: + def __migration_procedure( + self, surface_class, source_version: Version, target_version: Version + ) -> Migration_Procedure: # Handle any special-case migrations here if (source_version, target_version) == (0, 1): return functools.partial( @@ -216,7 +236,9 @@ def __migration_procedure(self, surface_class, source_version: Version, target_v target_version=target_version, ) - def __simple_rewrite_migration(self, surface_class, source_version: Version, target_version: Version): + def __simple_rewrite_migration( + self, surface_class, source_version: Version, target_version: Version + ): source_file_store = self.__versioned_file_stores[source_version] target_file_store = self.__versioned_file_stores[target_version] @@ -225,10 +247,14 @@ def __simple_rewrite_migration(self, surface_class, source_version: Version, tar return # Otherwise, write the surfaces to the new location with the new format - surfaces = source_file_store.read_surfaces_from_file(surface_class=surface_class) + surfaces = source_file_store.read_surfaces_from_file( + surface_class=surface_class + ) target_file_store.write_surfaces_to_file(surfaces=surfaces) - def __migration_v00_v01(self, surface_class, source_version: Version, target_version: Version): + def __migration_v00_v01( + self, surface_class, source_version: Version, target_version: Version + ): assert source_version == 0 assert target_version == 1 diff --git a/pupil_src/shared_modules/surface_tracker/surface_marker.py b/pupil_src/shared_modules/surface_tracker/surface_marker.py index 001b1df9c4..7b6f1e7c38 100644 --- a/pupil_src/shared_modules/surface_tracker/surface_marker.py +++ b/pupil_src/shared_modules/surface_tracker/surface_marker.py @@ -74,7 +74,7 @@ def create_surface_marker_uid( def _parse_surface_marker_uid_components( - uid: Surface_Marker_UID + uid: Surface_Marker_UID, ) -> typing.Tuple[Surface_Marker_Type, typing.Optional[str], Surface_Marker_TagID]: components = str(uid).split(":") if len(components) == 2: @@ -284,7 +284,7 @@ def from_square_tag_detection(detection: dict) -> "Surface_Marker": @staticmethod def from_apriltag_v3_detection( - detection: Apriltag_V3_Detection + detection: Apriltag_V3_Detection, ) -> "Surface_Marker": cls = _Apriltag_V3_Marker_Detection raw_marker = cls( diff --git a/pupil_src/shared_modules/surface_tracker/surface_marker_aggregate.py b/pupil_src/shared_modules/surface_tracker/surface_marker_aggregate.py index f64ea54613..e49b7d6551 100644 --- a/pupil_src/shared_modules/surface_tracker/surface_marker_aggregate.py +++ b/pupil_src/shared_modules/surface_tracker/surface_marker_aggregate.py @@ -26,11 +26,14 @@ class Surface_Marker_Aggregate(object): """ @staticmethod - def property_equality(x: "Surface_Marker_Aggregate", y: "Surface_Marker_Aggregate") -> bool: + def property_equality( + x: "Surface_Marker_Aggregate", y: "Surface_Marker_Aggregate" + ) -> bool: def property_dict(x: Surface_Marker_Aggregate) -> dict: x_dict = x.__dict__.copy() x_dict["_verts_uv"] = x_dict["_verts_uv"].tolist() return x_dict + return property_dict(x) == property_dict(y) def __init__( diff --git a/pupil_src/shared_modules/surface_tracker/surface_serializer.py b/pupil_src/shared_modules/surface_tracker/surface_serializer.py index f374f27792..11e4091f89 100644 --- a/pupil_src/shared_modules/surface_tracker/surface_serializer.py +++ b/pupil_src/shared_modules/surface_tracker/surface_serializer.py @@ -15,11 +15,15 @@ import os import typing -from .surface import Surface -from .surface import Surface_Marker_Aggregate -from .surface_marker import Surface_Marker_UID, Surface_Marker_Type, Surface_Marker_TagID -from .surface_marker import create_surface_marker_uid, parse_surface_marker_tag_id, parse_surface_marker_type - +from .surface import Surface, Surface_Marker_Aggregate +from .surface_marker import ( + Surface_Marker_TagID, + Surface_Marker_Type, + Surface_Marker_UID, + create_surface_marker_uid, + parse_surface_marker_tag_id, + parse_surface_marker_type, +) logger = logging.getLogger(__name__) @@ -29,18 +33,21 @@ class InvalidSurfaceDefinition(Exception): class _Surface_Serializer_Base(abc.ABC): - @property @abc.abstractmethod def version(self) -> int: pass @abc.abstractmethod - def dict_from_surface_marker_aggregate(self, surface_marker_aggregate: Surface_Marker_Aggregate) -> dict: + def dict_from_surface_marker_aggregate( + self, surface_marker_aggregate: Surface_Marker_Aggregate + ) -> dict: pass @abc.abstractmethod - def surface_marker_aggregate_from_dict(self, surface_marker_aggregate_dict: dict) -> Surface_Marker_Aggregate: + def surface_marker_aggregate_from_dict( + self, surface_marker_aggregate_dict: dict + ) -> Surface_Marker_Aggregate: pass def dict_from_surface(self, surface: Surface) -> dict: @@ -48,11 +55,11 @@ def dict_from_surface(self, surface: Surface) -> dict: reg_markers = [ dict_from_marker_aggregate(marker_aggregate) - for marker_aggregate in surface._registered_markers_undist.values() #TODO: Provide a public property for this + for marker_aggregate in surface._registered_markers_undist.values() # TODO: Provide a public property for this ] registered_markers_dist = [ dict_from_marker_aggregate(marker_aggregate) - for marker_aggregate in surface._registered_markers_dist.values() #TODO: Provide a public property for this + for marker_aggregate in surface._registered_markers_dist.values() # TODO: Provide a public property for this ] return { "version": self.version, @@ -65,8 +72,12 @@ def dict_from_surface(self, surface: Surface) -> dict: } def surface_from_dict(self, surface_class, surface_definition: dict) -> Surface: - assert isinstance(surface_class, type(object)), f"surface_class must be a class: {surface_class}" - assert issubclass(surface_class, Surface), f"surface_class must be a subclass of Surface: {surface_class}" + assert isinstance( + surface_class, type(object) + ), f"surface_class must be a class: {surface_class}" + assert issubclass( + surface_class, Surface + ), f"surface_class must be a subclass of Surface: {surface_class}" expected_version = self.version actual_version = surface_definition["version"] @@ -76,15 +87,17 @@ def surface_from_dict(self, surface_class, surface_definition: dict) -> Surface: marker_aggregate_from_dict = self.surface_marker_aggregate_from_dict - marker_aggregates_undist = [marker_aggregate_from_dict(d) for d in surface_definition["reg_markers" ]] - marker_aggregates_dist = [marker_aggregate_from_dict(d) for d in surface_definition["registered_markers_dist"]] + marker_aggregates_undist = [ + marker_aggregate_from_dict(d) for d in surface_definition["reg_markers"] + ] + marker_aggregates_dist = [ + marker_aggregate_from_dict(d) + for d in surface_definition["registered_markers_dist"] + ] + + deprecated_definition = surface_definition.get("deprecated", True) - deprecated_definition = None - try: - deprecated_definition = surface_definition["deprecated"] - except KeyError: - pass - else: + if deprecated_definition: logger.warning( "You have loaded an old and deprecated surface definition! " "Please re-define this surface for increased mapping accuracy!" @@ -104,7 +117,9 @@ class _Surface_Serializer_V00(_Surface_Serializer_Base): version = 0 - def dict_from_surface_marker_aggregate(self, surface_marker_aggregate: Surface_Marker_Aggregate) -> dict: + def dict_from_surface_marker_aggregate( + self, surface_marker_aggregate: Surface_Marker_Aggregate + ) -> dict: id = parse_surface_marker_tag_id(uid=surface_marker_aggregate.uid) marker_type = parse_surface_marker_type(uid=surface_marker_aggregate.uid) if marker_type != Surface_Marker_Type.SQUARE: @@ -115,12 +130,14 @@ def dict_from_surface_marker_aggregate(self, surface_marker_aggregate: Surface_M verts_uv = [v.tolist() for v in verts_uv] return {"id": id, "verts_uv": verts_uv} - def surface_marker_aggregate_from_dict(self, surface_marker_aggregate_dict: dict) -> Surface_Marker_Aggregate: + def surface_marker_aggregate_from_dict( + self, surface_marker_aggregate_dict: dict + ) -> Surface_Marker_Aggregate: tag_id = surface_marker_aggregate_dict["id"] uid = create_surface_marker_uid( marker_type=Surface_Marker_Type.SQUARE, tag_family=None, - tag_id=Surface_Marker_TagID(tag_id) + tag_id=Surface_Marker_TagID(tag_id), ) verts_uv = surface_marker_aggregate_dict["verts_uv"] return Surface_Marker_Aggregate(uid=uid, verts_uv=verts_uv) @@ -135,8 +152,7 @@ def surface_from_dict(self, surface_class, surface_definition: dict) -> Surface: # The format of v00 doesn't store any value for "version" key surface_definition["version"] = surface_definition.get("version", self.version) return super().surface_from_dict( - surface_class=surface_class, - surface_definition=surface_definition, + surface_class=surface_class, surface_definition=surface_definition ) @@ -144,14 +160,18 @@ class _Surface_Serializer_V01(_Surface_Serializer_Base): version = 1 - def dict_from_surface_marker_aggregate(self, surface_marker_aggregate: Surface_Marker_Aggregate) -> dict: + def dict_from_surface_marker_aggregate( + self, surface_marker_aggregate: Surface_Marker_Aggregate + ) -> dict: uid = str(surface_marker_aggregate.uid) verts_uv = surface_marker_aggregate.verts_uv if verts_uv is not None: verts_uv = [v.tolist() for v in verts_uv] return {"uid": uid, "verts_uv": verts_uv} - def surface_marker_aggregate_from_dict(self, surface_marker_aggregate_dict: dict) -> Surface_Marker_Aggregate: + def surface_marker_aggregate_from_dict( + self, surface_marker_aggregate_dict: dict + ) -> Surface_Marker_Aggregate: uid = surface_marker_aggregate_dict["uid"] verts_uv = surface_marker_aggregate_dict["verts_uv"] return Surface_Marker_Aggregate(uid=uid, verts_uv=verts_uv) diff --git a/pupil_src/shared_modules/surface_tracker/surface_tracker.py b/pupil_src/shared_modules/surface_tracker/surface_tracker.py index ad2f1921fc..db3292ba10 100644 --- a/pupil_src/shared_modules/surface_tracker/surface_tracker.py +++ b/pupil_src/shared_modules/surface_tracker/surface_tracker.py @@ -74,6 +74,7 @@ def __init__( self._edit_surf_verts = [] self._last_mouse_pos = (0.0, 0.0) self.gui = gui.GUI(self) + self._ui_heatmap_mode_selector = None if not isinstance(marker_detector_mode, MarkerDetectorMode): # Here we ensure that we pass a proper MarkerDetectorMode @@ -161,20 +162,19 @@ def _update_ui(self): self._per_surface_ui(surface) def _update_ui_visualization_menu(self): + self._ui_heatmap_mode_selector = ui.Selector( + "heatmap_mode", + self.gui, + label="Heatmap Mode", + labels=[e.value for e in self.supported_heatmap_modes], + selection=[e for e in self.supported_heatmap_modes], + ) self.menu.append(ui.Info_Text(self.ui_info_text)) self.menu.append( ui.Switch("show_marker_ids", self.gui, label="Show Marker IDs") ) self.menu.append(ui.Switch("show_heatmap", self.gui, label="Show Heatmap")) - self.menu.append( - ui.Selector( - "heatmap_mode", - self.gui, - label="Heatmap Mode", - labels=[e.value for e in self.supported_heatmap_modes], - selection=[e for e in self.supported_heatmap_modes], - ) - ) + self.menu.append(self._ui_heatmap_mode_selector) def _update_ui_custom(self): pass @@ -560,4 +560,5 @@ def deinit_ui(self): self.remove_menu() def cleanup(self): + self._ui_heatmap_mode_selector = None self.save_surface_definitions_to_file() diff --git a/pupil_src/shared_modules/surface_tracker/surface_tracker_offline.py b/pupil_src/shared_modules/surface_tracker/surface_tracker_offline.py index ad89f0e050..4b4d90b9bf 100644 --- a/pupil_src/shared_modules/surface_tracker/surface_tracker_offline.py +++ b/pupil_src/shared_modules/surface_tracker/surface_tracker_offline.py @@ -340,7 +340,11 @@ def _update_surface_heatmaps(self): surf_idx = self.surfaces.index(surface) gaze_on_surf = self.gaze_on_surf_buffer[surf_idx] gaze_on_surf = itertools.chain.from_iterable(gaze_on_surf) - gaze_on_surf = (g for g in gaze_on_surf if g["confidence"] >= self.g_pool.min_data_confidence) + gaze_on_surf = ( + g + for g in gaze_on_surf + if g["confidence"] >= self.g_pool.min_data_confidence + ) gaze_on_surf = list(gaze_on_surf) surface.update_heatmap(gaze_on_surf) diff --git a/pupil_src/shared_modules/surface_tracker/surface_tracker_online.py b/pupil_src/shared_modules/surface_tracker/surface_tracker_online.py index ea94fb3b9c..85335c8c73 100644 --- a/pupil_src/shared_modules/surface_tracker/surface_tracker_online.py +++ b/pupil_src/shared_modules/surface_tracker/surface_tracker_online.py @@ -61,7 +61,7 @@ def ui_info_text(self): @property def supported_heatmap_modes(self): - return [Heatmap_Mode.WITHIN_SURFACE, Heatmap_Mode.ACROSS_SURFACES] + return [Heatmap_Mode.WITHIN_SURFACE] def _update_ui_custom(self): def set_freeze_scene(val): @@ -95,6 +95,8 @@ def set_gaze_hist_len(val): ) def recent_events(self, events): + if self._ui_heatmap_mode_selector is not None: + self._ui_heatmap_mode_selector.read_only = True if self.freeze_scene: # If frozen, we overwrite the frame event with the last frame we have saved current_frame = events.get("frame") @@ -132,7 +134,11 @@ def _update_surface_corners(self): def _update_surface_heatmaps(self): for surface in self.surfaces: gaze_on_surf = surface.gaze_history - gaze_on_surf = (g for g in gaze_on_surf if g["confidence"] >= self.g_pool.min_data_confidence) + gaze_on_surf = ( + g + for g in gaze_on_surf + if g["confidence"] >= self.g_pool.min_data_confidence + ) gaze_on_surf = list(gaze_on_surf) surface.update_heatmap(gaze_on_surf) diff --git a/pupil_src/shared_modules/system_graphs.py b/pupil_src/shared_modules/system_graphs.py index 8f0be5b9ec..955deb99c8 100644 --- a/pupil_src/shared_modules/system_graphs.py +++ b/pupil_src/shared_modules/system_graphs.py @@ -117,10 +117,10 @@ def recent_events(self, events): if "frame" not in events or self.idx != events["frame"].index: for p in events["pupil"]: if p["topic"] == "pupil.0.2d": - assert p["id"] == 0 #sanity check + assert p["id"] == 0 # sanity check self.conf0_graph.add(p["confidence"]) if p["topic"] == "pupil.1.2d": - assert p["id"] == 1 #sanity check + assert p["id"] == 1 # sanity check self.conf1_graph.add(p["confidence"]) # update wprld fps graph diff --git a/pupil_src/shared_modules/tasklib/manager.py b/pupil_src/shared_modules/tasklib/manager.py index b1cebf67aa..1252961f9d 100644 --- a/pupil_src/shared_modules/tasklib/manager.py +++ b/pupil_src/shared_modules/tasklib/manager.py @@ -10,6 +10,9 @@ """ import tasklib.background +import logging + +logger = logging.getLogger(__name__) class PluginTaskManager: @@ -28,8 +31,7 @@ class PluginTaskManager: """ def __init__(self, plugin): - self._recently_added_tasks = [] - self._running_tasks = [] + self._tasks = [] plugin.add_observer("recent_events", self.on_recent_events) plugin.add_observer("cleanup", self.on_cleanup) @@ -84,7 +86,7 @@ def create_background_task( kwargs, patches, ) - self._recently_added_tasks.append(task) + self._tasks.append(task) return task def add_task(self, task): @@ -95,36 +97,46 @@ def add_task(self, task): Similarly to create_task(), you don't need to start the task before adding it, but you can if you want. """ - self._recently_added_tasks.append(task) + self._tasks.append(task) def on_recent_events(self, _): - self._start_recently_added_tasks() - self._update_running_tasks() - - def on_cleanup(self): - self._kill_all_running_tasks() - self._recently_added_tasks = [] - self._running_tasks = [] - - def _start_recently_added_tasks(self): - for task in self._recently_added_tasks: - # we test because the user might have already started it manually + for task in self._tasks.copy(): if not task.started: task.start() - self._recently_added_tasks.remove(task) - self._running_tasks.append(task) - - def _update_running_tasks(self): - for task in self._running_tasks: - if task.ended: - self._running_tasks.remove(task) - else: + if task.running: task.update() + if task.ended: + self._tasks.remove(task) + + def on_cleanup(self): + self._kill_all_running_tasks() + self._tasks = [] def _kill_all_running_tasks(self, grace_period_per_task=None): - for task in self._running_tasks: - # the test is for tasks that terminate between the last update and this - # method call, i.e. for tasks that had no chance to get removed from - # self._running_tasks + for task in self._tasks: if task.running: task.kill(grace_period=grace_period_per_task) + + +class UniqueTaskManager(PluginTaskManager): + """TaskManager ensuring tasks are unique by identifier""" + + def add_task(self, task_new, identifier: str): + UniqueTaskManager._patch_task(task_new, identifier) + task_duplicated = self._get_duplicated_task(identifier) + if task_duplicated is not None: + state = "running" if task_duplicated.running else "queued" + logger.debug(f"Replacing {state} task with ID '{identifier}'") + if task_duplicated.running: + task_duplicated.kill(grace_period=None) + self._tasks.remove(task_duplicated) + super().add_task(task_new) + + def _get_duplicated_task(self, identifier): + for task_prev in self._tasks: + if task_prev._unique_task_identifier == identifier: + return task_prev + + @staticmethod + def _patch_task(task, identifier: str): + task._unique_task_identifier = identifier diff --git a/pupil_src/shared_modules/video_overlay/controllers/overlay_manager.py b/pupil_src/shared_modules/video_overlay/controllers/overlay_manager.py index d3342b05c9..1ef8cee597 100644 --- a/pupil_src/shared_modules/video_overlay/controllers/overlay_manager.py +++ b/pupil_src/shared_modules/video_overlay/controllers/overlay_manager.py @@ -16,11 +16,10 @@ class OverlayManager(SingleFileStorage): - def __init__(self, rec_dir, plugin): - super().__init__(rec_dir, plugin) + def __init__(self, rec_dir): + super().__init__(rec_dir) self._overlays = [] self._load_from_disk() - self._patch_on_cleanup(plugin) @property def _storage_file_name(self): @@ -54,12 +53,3 @@ def most_recent(self): def remove_overlay(self, overlay): self._overlays.remove(overlay) self.save_to_disk() - - def _patch_on_cleanup(self, plugin): - """Patches cleanup observer to trigger on get_init_dict(). - - Save current settings to disk on get_init_dict() instead of cleanup(). - This ensures that the World Video Exporter loads the most recent settings. - """ - plugin.remove_observer("cleanup", self._on_cleanup) - plugin.add_observer("get_init_dict", self._on_cleanup) diff --git a/pupil_src/shared_modules/video_overlay/plugins/generic_overlay.py b/pupil_src/shared_modules/video_overlay/plugins/generic_overlay.py index d3e3c5a8e3..709a1234d8 100644 --- a/pupil_src/shared_modules/video_overlay/plugins/generic_overlay.py +++ b/pupil_src/shared_modules/video_overlay/plugins/generic_overlay.py @@ -27,7 +27,13 @@ class Video_Overlay(Observable, Plugin): def __init__(self, g_pool): super().__init__(g_pool) - self.manager = OverlayManager(g_pool.rec_dir, self) + self.manager = OverlayManager(g_pool.rec_dir) + + def get_init_dict(self): + # Save current settings to disk, ensures that the World Video Exporter + # loads the most recent settings. + self.manager.save_to_disk() + return super().get_init_dict() def recent_events(self, events): if "frame" in events: diff --git a/pupil_src/shared_modules/video_overlay/utils/image_manipulation.py b/pupil_src/shared_modules/video_overlay/utils/image_manipulation.py index 9696551015..9403611db2 100644 --- a/pupil_src/shared_modules/video_overlay/utils/image_manipulation.py +++ b/pupil_src/shared_modules/video_overlay/utils/image_manipulation.py @@ -10,10 +10,13 @@ """ import abc +import logging import cv2 import numpy as np +logger = logging.getLogger(__name__) + class ImageManipulator(metaclass=abc.ABCMeta): @abc.abstractmethod @@ -73,6 +76,12 @@ def render_pupil_3d(self, image, pupil_position): conf = int(pupil_position["confidence"] * 255) self.render_ellipse(image, el, color=(0, 0, 255, conf)) + if pupil_position["model_confidence"] <= 0.0: + # NOTE: if 'model_confidence' == 0, some values of the 'projected_sphere' + # might be 'nan', which will cause cv2.ellipse to crash. + # TODO: Fix in detectors. + return + eye_ball = pupil_position.get("projected_sphere", None) if eye_ball is not None: try: @@ -86,10 +95,15 @@ def render_pupil_3d(self, image, pupil_position): color=(26, 230, 0, 255 * pupil_position["model_confidence"]), thickness=2, ) - except ValueError: - # Happens when converting 'nan' to int - # TODO: Investigate why results are sometimes 'nan' - pass + except Exception as e: + # Known issues: + # - There are reports of negative eye_ball axes, raising cv2.error. + # TODO: Investigate cause in detectors. + logger.debug( + "Error rendering 3D eye-ball outline! Skipping...\n" + f"eye_ball: {eye_ball}\n" + f"{type(e)}: {e}" + ) def render_ellipse(self, image, ellipse, color): outline = self.get_ellipse_points( diff --git a/pupil_src/shared_modules/vis_watermark.py b/pupil_src/shared_modules/vis_watermark.py index d679bc69c8..952faff2e6 100644 --- a/pupil_src/shared_modules/vis_watermark.py +++ b/pupil_src/shared_modules/vis_watermark.py @@ -35,7 +35,7 @@ def __init__(self, g_pool, selected_watermark_path=None, pos=(20, 20)): self.menu = None available_files = glob( - os.path.join(self.g_pool.user_dir, "*png") + os.path.join(self.g_pool.user_dir, "[!.]*png") ) # we only look for png's self.available_files = [ f for f in available_files if cv2.imread(f, -1).shape[2] == 4