Skip to content

Commit

Permalink
update add_field_to function for improved error handling (#696)
Browse files Browse the repository at this point in the history
  • Loading branch information
dtrai2 authored Nov 14, 2024
1 parent 5c45241 commit 205045e
Show file tree
Hide file tree
Showing 45 changed files with 563 additions and 539 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,20 @@

## next release
### Breaking

* `CriticalInputError` is raised when the input preprocessor values can't be set, this was so far only true
for the hmac preprocessor, but is now also applied for all other preprocessors.
* fix `delimiter` typo in `StringSplitterRule` configuration

### Features
### Improvements

* replace `BaseException` with `Exception` for custom errors
* refactor `generic_resolver` to validate rules on startup instead of application of each rule
* rewrite the helper method `add_field_to` such that it always raises an `FieldExistsWarning` instead of return a bool.
* add new helper method `add_fields_to` to directly add multiple fields to one event
* refactored some processors to make use of the new helper methods


### Bugfix

Expand Down
71 changes: 37 additions & 34 deletions logprep/abc/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from logprep.abc.connector import Connector
from logprep.abc.exceptions import LogprepException
from logprep.metrics.metrics import Metric
from logprep.util.helper import add_field_to, get_dotted_field_value
from logprep.processor.base.exceptions import FieldExistsWarning
from logprep.util.helper import add_fields_to, get_dotted_field_value
from logprep.util.time import UTC, TimeParser
from logprep.util.validators import dict_structure_validator

Expand Down Expand Up @@ -280,16 +281,19 @@ def get_next(self, timeout: float) -> dict | None:
self.metrics.number_of_processed_events += 1
if not isinstance(event, dict):
raise CriticalInputError(self, "not a dict", event)
if self._add_hmac:
event = self._add_hmac_to(event, raw_event)
if self._add_version_info:
self._add_version_information_to_event(event)
if self._add_log_arrival_time_information:
self._add_arrival_time_information_to_event(event)
if self._add_log_arrival_timedelta_information:
self._add_arrival_timedelta_information_to_event(event)
if self._add_env_enrichment:
self._add_env_enrichment_to_event(event)
try:
if self._add_hmac:
event = self._add_hmac_to(event, raw_event)
if self._add_version_info:
self._add_version_information_to_event(event)
if self._add_log_arrival_time_information:
self._add_arrival_time_information_to_event(event)
if self._add_log_arrival_timedelta_information:
self._add_arrival_timedelta_information_to_event(event)
if self._add_env_enrichment:
self._add_env_enrichment_to_event(event)
except FieldExistsWarning as error:
raise CriticalInputError(self, error.args[0], event) from error
return event

def batch_finished_callback(self):
Expand All @@ -300,13 +304,19 @@ def _add_env_enrichment_to_event(self, event: dict):
enrichments = self._config.preprocessing.get("enrich_by_env_variables")
if not enrichments:
return
for target_field, variable_name in enrichments.items():
add_field_to(event, target_field, os.environ.get(variable_name, ""))
fields = {
target: os.environ.get(variable_name, "")
for target, variable_name in enrichments.items()
}
add_fields_to(event, fields)

def _add_arrival_time_information_to_event(self, event: dict):
now = TimeParser.now()
target_field = self._config.preprocessing.get("log_arrival_time_target_field")
add_field_to(event, target_field, now.isoformat())
new_field = {
self._config.preprocessing.get(
"log_arrival_time_target_field"
): TimeParser.now().isoformat()
}
add_fields_to(event, new_field)

def _add_arrival_timedelta_information_to_event(self, event: dict):
log_arrival_timedelta_config = self._config.preprocessing.get("log_arrival_timedelta")
Expand All @@ -322,16 +332,16 @@ def _add_arrival_timedelta_information_to_event(self, event: dict):
TimeParser.from_string(log_arrival_time).astimezone(UTC)
- TimeParser.from_string(time_reference).astimezone(UTC)
).total_seconds()
add_field_to(event, target_field, delta_time_sec)
add_fields_to(event, fields={target_field: delta_time_sec})

def _add_version_information_to_event(self, event: dict):
"""Add the version information to the event"""
target_field = self._config.preprocessing.get("version_info_target_field")
# pylint: disable=protected-access
add_field_to(event, target_field, self._config._version_information)
add_fields_to(event, fields={target_field: self._config._version_information})
# pylint: enable=protected-access

def _add_hmac_to(self, event_dict, raw_event) -> Tuple[dict, str]:
def _add_hmac_to(self, event_dict, raw_event) -> dict:
"""
Calculates an HMAC (Hash-based message authentication code) based on a given target field
and adds it to the given event. If the target field has the value '<RAW_MSG>' the full raw
Expand All @@ -357,7 +367,7 @@ def _add_hmac_to(self, event_dict, raw_event) -> Tuple[dict, str]:
------
CriticalInputError
If the hmac could not be added to the event because the desired output field already
exists or cant't be found.
exists or can't be found.
"""
hmac_options = self._config.preprocessing.get("hmac", {})
hmac_target_field_name = hmac_options.get("target")
Expand All @@ -381,18 +391,11 @@ def _add_hmac_to(self, event_dict, raw_event) -> Tuple[dict, str]:
digestmod=hashlib.sha256,
).hexdigest()
compressed = zlib.compress(received_orig_message, level=-1)
hmac_output = {"hmac": hmac, "compressed_base64": base64.b64encode(compressed).decode()}
add_was_successful = add_field_to(
event_dict,
hmac_options.get("output_field"),
hmac_output,
)
if not add_was_successful:
raise CriticalInputError(
self,
f"Couldn't add the hmac to the input event as the desired "
f"output field '{hmac_options.get('output_field')}' already "
f"exist.",
event_dict,
)
new_field = {
hmac_options.get("output_field"): {
"hmac": hmac,
"compressed_base64": base64.b64encode(compressed).decode(),
}
}
add_fields_to(event_dict, new_field)
return event_dict
20 changes: 9 additions & 11 deletions logprep/abc/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
from logprep.framework.rule_tree.rule_tree import RuleTree, RuleTreeType
from logprep.metrics.metrics import Metric
from logprep.processor.base.exceptions import (
FieldExistsWarning,
ProcessingCriticalError,
ProcessingError,
ProcessingWarning,
)
from logprep.util import getter
from logprep.util.helper import (
add_and_overwrite,
add_field_to,
add_fields_to,
get_dotted_field_value,
pop_dotted_field_value,
)
Expand Down Expand Up @@ -357,13 +356,15 @@ def _handle_warning_error(self, event, rule, error, failure_tags=None):
if failure_tags is None:
failure_tags = rule.failure_tags
if tags is None:
add_and_overwrite(event, "tags", sorted(list({*failure_tags})))
new_field = {"tags": sorted(list({*failure_tags}))}
else:
add_and_overwrite(event, "tags", sorted(list({*tags, *failure_tags})))
new_field = {"tags": sorted(list({*tags, *failure_tags}))}
add_and_overwrite(event, new_field, rule)
if isinstance(error, ProcessingWarning):
if error.tags:
tags = tags if tags else []
add_and_overwrite(event, "tags", sorted(list({*error.tags, *tags, *failure_tags})))
new_field = {"tags": sorted(list({*error.tags, *tags, *failure_tags}))}
add_and_overwrite(event, new_field, rule)
self.result.warnings.append(error)
else:
self.result.warnings.append(ProcessingWarning(str(error), rule, event))
Expand All @@ -381,15 +382,12 @@ def _has_missing_values(self, event, rule, source_field_dict):
return False

def _write_target_field(self, event: dict, rule: "Rule", result: any) -> None:
add_successful = add_field_to(
add_fields_to(
event,
output_field=rule.target_field,
content=result,
fields={rule.target_field: result},
extends_lists=rule.extend_target_list,
overwrite_output_field=rule.overwrite_target,
overwrite_target_field=rule.overwrite_target,
)
if not add_successful:
raise FieldExistsWarning(rule, event, [rule.target_field])

def setup(self):
super().setup()
Expand Down
10 changes: 6 additions & 4 deletions logprep/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@
from attrs import define, field, validators
from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram

from logprep.util.helper import add_field_to
from logprep.util.helper import add_fields_to


@define(kw_only=True, slots=False)
Expand Down Expand Up @@ -222,12 +222,14 @@ def inner(self, *args, **kwargs): # nosemgrep
if hasattr(self, "rule_type"):
event = args[0]
if event:
add_field_to(event, f"processing_times.{self.rule_type}", duration)
add_fields_to(
event, fields={f"processing_times.{self.rule_type}": duration}
)
if hasattr(self, "_logprep_config"): # attribute of the Pipeline class
event = args[0]
if event:
add_field_to(event, "processing_times.pipeline", duration)
add_field_to(event, "processing_times.hostname", gethostname())
add_fields_to(event, fields={"processing_times.pipeline": duration})
add_fields_to(event, fields={"processing_times.hostname": gethostname()})
return result

return inner
Expand Down
11 changes: 7 additions & 4 deletions logprep/processor/base/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,20 @@ def __init__(self, message: str, rule: "Rule"):
class ProcessingWarning(Warning):
"""A warning occurred - log the warning, but continue processing the event."""

def __init__(self, message: str, rule: "Rule", event: dict, tags: List[str] = None):
def __init__(self, message: str, rule: "Rule | None", event: dict, tags: List[str] = None):
self.tags = tags if tags else []
rule.metrics.number_of_warnings += 1
message = f"{message}, {rule.id=}, {rule.description=}, {event=}"
if rule:
rule.metrics.number_of_warnings += 1
message += f", {rule.id=}, {rule.description=}"
message += f", {event=}"
super().__init__(f"{self.__class__.__name__}: {message}")


class FieldExistsWarning(ProcessingWarning):
"""Raised if field already exists."""

def __init__(self, rule: "Rule", event: dict, skipped_fields: List[str]):
def __init__(self, rule: "Rule | None", event: dict, skipped_fields: List[str]):
self.skipped_fields = skipped_fields
message = (
"The following fields could not be written, because "
"one or more subfields existed and could not be extended: "
Expand Down
9 changes: 4 additions & 5 deletions logprep/processor/clusterer/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
SignaturePhaseStreaming,
)
from logprep.processor.field_manager.processor import FieldManager
from logprep.util.helper import add_field_to, get_dotted_field_value
from logprep.util.helper import add_fields_to, get_dotted_field_value


class Clusterer(FieldManager):
Expand Down Expand Up @@ -138,12 +138,11 @@ def _cluster(self, event: dict, rule: ClustererRule):
)
else:
cluster_signature = cluster_signature_based_on_message
add_field_to(
add_fields_to(
event,
self._config.output_field_name,
cluster_signature,
fields={self._config.output_field_name: cluster_signature},
extends_lists=rule.extend_target_list,
overwrite_output_field=rule.overwrite_target,
overwrite_target_field=rule.overwrite_target,
)
self._last_non_extracted_signature = sig_text

Expand Down
16 changes: 11 additions & 5 deletions logprep/processor/dissector/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@
.. automodule:: logprep.processor.dissector.rule
"""

from typing import Callable, List, Tuple
from typing import TYPE_CHECKING, Callable, List, Tuple

from logprep.processor.dissector.rule import DissectorRule
from logprep.processor.field_manager.processor import FieldManager
from logprep.util.helper import add_field_to, get_dotted_field_value
from logprep.util.helper import add_fields_to, get_dotted_field_value

if TYPE_CHECKING:
from logprep.processor.base.rule import Rule


class Dissector(FieldManager):
Expand All @@ -51,7 +54,7 @@ def _apply_mapping(self, event, rule):
for action, *args, _ in action_mappings_sorted_by_position:
action(*args)

def _get_mappings(self, event, rule) -> List[Tuple[Callable, dict, str, str, str, int]]:
def _get_mappings(self, event, rule) -> List[Tuple[Callable, dict, dict, str, "Rule", int]]:
current_field = None
target_field_mapping = {}
for rule_action in rule.actions:
Expand Down Expand Up @@ -84,12 +87,15 @@ def _get_mappings(self, event, rule) -> List[Tuple[Callable, dict, str, str, str
target_field = target_field_mapping.get(target_field.lstrip("&"))
if strip_char:
content = content.strip(strip_char)
yield rule_action, event, target_field, content, separator, position
field = {target_field: content}
yield rule_action, event, field, separator, rule, position

def _apply_convert_datatype(self, event, rule):
for target_field, converter in rule.convert_actions:
try:
target_value = converter(get_dotted_field_value(event, target_field))
add_field_to(event, target_field, target_value, overwrite_output_field=True)
add_fields_to(
event, {target_field: target_value}, rule, overwrite_target_field=True
)
except ValueError as error:
self._handle_warning_error(event, rule, error)
28 changes: 12 additions & 16 deletions logprep/processor/domain_label_extractor/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,10 @@
from filelock import FileLock
from tldextract import TLDExtract

from logprep.processor.base.exceptions import FieldExistsWarning
from logprep.processor.domain_label_extractor.rule import DomainLabelExtractorRule
from logprep.processor.field_manager.processor import FieldManager
from logprep.util.getter import GetterFactory
from logprep.util.helper import add_and_overwrite, add_field_to, get_dotted_field_value
from logprep.util.helper import add_and_overwrite, add_fields_to, get_dotted_field_value
from logprep.util.validators import list_of_urls_validator

logger = logging.getLogger("DomainLabelExtractor")
Expand Down Expand Up @@ -131,27 +130,24 @@ def _apply_rules(self, event, rule: DomainLabelExtractorRule):

if self._is_valid_ip(domain):
tagging_field.append(f"ip_in_{rule.source_fields[0].replace('.', '_')}")
add_and_overwrite(event, self._config.tagging_field_name, tagging_field)
add_and_overwrite(
event, fields={self._config.tagging_field_name: tagging_field}, rule=rule
)
return

labels = self._tld_extractor(domain)
if labels.suffix != "":
labels_dict = {
"registered_domain": labels.domain + "." + labels.suffix,
"top_level_domain": labels.suffix,
"subdomain": labels.subdomain,
fields = {
f"{rule.target_field}.registered_domain": f"{labels.domain}.{labels.suffix}",
f"{rule.target_field}.top_level_domain": labels.suffix,
f"{rule.target_field}.subdomain": labels.subdomain,
}
for label, value in labels_dict.items():
output_field = f"{rule.target_field}.{label}"
add_successful = add_field_to(
event, output_field, value, overwrite_output_field=rule.overwrite_target
)

if not add_successful:
raise FieldExistsWarning(rule, event, [output_field])
add_fields_to(event, fields, rule, overwrite_target_field=rule.overwrite_target)
else:
tagging_field.append(f"invalid_domain_in_{rule.source_fields[0].replace('.', '_')}")
add_and_overwrite(event, self._config.tagging_field_name, tagging_field)
add_and_overwrite(
event, fields={self._config.tagging_field_name: tagging_field}, rule=rule
)

@staticmethod
def _is_valid_ip(domain):
Expand Down
10 changes: 6 additions & 4 deletions logprep/processor/domain_resolver/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from logprep.util.cache import Cache
from logprep.util.getter import GetterFactory
from logprep.util.hasher import SHA256Hasher
from logprep.util.helper import add_field_to, get_dotted_field_value
from logprep.util.helper import add_fields_to, get_dotted_field_value
from logprep.util.validators import list_of_urls_validator

logger = logging.getLogger("DomainResolver")
Expand Down Expand Up @@ -222,7 +222,9 @@ def _resolve_ip(self, domain, hash_string=None):

def _store_debug_infos(self, event, requires_storing):
event_dbg = {
"obtained_from_cache": not requires_storing,
"cache_size": len(self._domain_ip_map.keys()),
"resolved_ip_debug": {
"obtained_from_cache": not requires_storing,
"cache_size": len(self._domain_ip_map.keys()),
}
}
add_field_to(event, "resolved_ip_debug", event_dbg, overwrite_output_field=True)
add_fields_to(event, event_dbg, overwrite_target_field=True)
Loading

0 comments on commit 205045e

Please sign in to comment.