From 14a75b0374306e238f282af8b0ae169851a048b4 Mon Sep 17 00:00:00 2001 From: dtrai2 Date: Thu, 7 Nov 2024 18:20:54 +0100 Subject: [PATCH] fix geoip_enricher tests --- logprep/processor/geoip_enricher/processor.py | 31 +++++++++---------- logprep/util/helper.py | 18 ++++++++--- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/logprep/processor/geoip_enricher/processor.py b/logprep/processor/geoip_enricher/processor.py index 3e83ca044..b83791218 100644 --- a/logprep/processor/geoip_enricher/processor.py +++ b/logprep/processor/geoip_enricher/processor.py @@ -38,11 +38,10 @@ from geoip2 import database from geoip2.errors import AddressNotFoundError -from logprep.processor.base.exceptions import FieldExistsWarning from logprep.processor.field_manager.processor import FieldManager from logprep.processor.geoip_enricher.rule import GEOIP_DATA_STUBS, GeoipEnricherRule from logprep.util.getter import GetterFactory -from logprep.util.helper import add_field_to, get_dotted_field_value +from logprep.util.helper import get_dotted_field_value, add_batch_to logger = logging.getLogger("GeoipEnricher") @@ -129,18 +128,16 @@ def _apply_rules(self, event, rule): geoip_data = self._try_getting_geoip_data(ip_string) if not geoip_data: return - for target_subfield, value in geoip_data.items(): - if value is None: - continue - full_output_field = f"{rule.target_field}.{target_subfield}" - if target_subfield in rule.customize_target_subfields: - full_output_field = rule.customize_target_subfields.get(target_subfield) - adding_was_successful = add_field_to( - event=event, - target_field=full_output_field, - content=value, - extends_lists=False, - overwrite_output_field=rule.overwrite_target, - ) - if not adding_was_successful: - raise FieldExistsWarning(rule, event, [full_output_field]) + filtered_geoip_data = {k: v for k, v in geoip_data.items() if v is not None} + targets, contents = zip(*filtered_geoip_data.items()) + targets = [ + rule.customize_target_subfields.get(target, f"{rule.target_field}.{target}") + for target in targets + ] + add_batch_to( + event, + targets, + contents, + extends_lists=False, + overwrite_output_field=rule.overwrite_target, + ) diff --git a/logprep/util/helper.py b/logprep/util/helper.py index 900abd6da..34ef10d05 100644 --- a/logprep/util/helper.py +++ b/logprep/util/helper.py @@ -41,6 +41,14 @@ def print_fcolor(fore: AnsiFore, message: str): color_print_line(None, fore, message) +def _add_and_overwrite_key(sub_dict, key): + current_value = sub_dict.get(key) + if isinstance(current_value, dict): + return current_value + sub_dict.update({key: {}}) + return sub_dict.get(key) + + def _add_and_not_overwrite_key(sub_dict, key): current_value = sub_dict.get(key) if isinstance(current_value, dict): @@ -106,13 +114,15 @@ def add_field_to( raise ValueError("An output field can't be overwritten and extended at the same time") field_path = [event, *get_dotted_field_list(target_field)] target_key = field_path.pop() - try: - target_parent = reduce(_add_and_not_overwrite_key, field_path) - except KeyError as error: - raise FieldExistsWarning(event, [target_field]) from error + if overwrite_output_field: + target_parent = reduce(_add_and_overwrite_key, field_path) target_parent[target_key] = content else: + try: + target_parent = reduce(_add_and_not_overwrite_key, field_path) + except KeyError as error: + raise FieldExistsWarning(event, [target_field]) from error existing_value = target_parent.get(target_key) if existing_value is None: target_parent[target_key] = content