Skip to content

Commit

Permalink
Opensearch async test (#371)
Browse files Browse the repository at this point in the history
* Implement asynchronous Elasticsearch client creation and update query execution to support async operations

* Refactor run_active_sigma_queries_endpoint to execute queries concurrently using asyncio
  • Loading branch information
taylorwalton authored Dec 13, 2024
1 parent 6394dad commit 9db0af5
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 12 deletions.
69 changes: 61 additions & 8 deletions backend/app/connectors/wazuh_indexer/routes/sigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi import HTTPException
from fastapi import Query
from fastapi import UploadFile
import asyncio
from loguru import logger
from sqlalchemy.ext.asyncio import AsyncSession

Expand Down Expand Up @@ -272,6 +273,54 @@ async def deactivate_all_sigma_queries_endpoint(
)


# @wazuh_indexer_sigma_router.post("/run-active-queries", response_model=SigmaQueryOutResponse)
# async def run_active_sigma_queries_endpoint(
# index_name: str = Query(default="wazuh*"),
# db: AsyncSession = Depends(get_db),
# ):
# """
# Runs the active Sigma queries.

# Args:
# db (AsyncSession): The database session.

# Returns:
# SigmaQueryOutResponse: The Sigma queries response.
# """
# active_sigma_queries = await list_active_sigma_queries(db)
# for query in active_sigma_queries:
# time_interval_delta = parse_time_interval(query.time_interval)
# logger.info(f"Time interval delta: {time_interval_delta}")
# current_time = datetime.now()
# logger.info(f"Current time: {current_time}")
# logger.info(f"Last execution time: {query.last_execution_time}")

# # Check if the current time is less than the last execution time
# if current_time < query.last_execution_time or current_time - query.last_execution_time >= time_interval_delta:
# logger.info(f"Running Sigma query: {query.rule_name}")
# await execute_query(
# RunActiveSigmaQueries(
# query=query.rule_query,
# time_interval=query.time_interval,
# last_execution_time=query.last_execution_time,
# rule_name=query.rule_name,
# index=index_name,
# ),
# session=db,
# )
# # Update the last execution time to the current time and commit the changes
# # ! Remove commented out code after testing ! #
# query.last_execution_time = current_time
# await db.commit()
# else:
# time_comparison = current_time - query.last_execution_time
# logger.info(f"Time comparison: {time_comparison}")
# logger.info(f"Skipping Sigma query because the time interval has not passed: {query.rule_name}")
# return SigmaQueryOutResponse(
# success=True,
# message="Successfully ran the active Sigma queries.",
# )

@wazuh_indexer_sigma_router.post("/run-active-queries", response_model=SigmaQueryOutResponse)
async def run_active_sigma_queries_endpoint(
index_name: str = Query(default="wazuh*"),
Expand All @@ -287,6 +336,8 @@ async def run_active_sigma_queries_endpoint(
SigmaQueryOutResponse: The Sigma queries response.
"""
active_sigma_queries = await list_active_sigma_queries(db)
tasks = []

for query in active_sigma_queries:
time_interval_delta = parse_time_interval(query.time_interval)
logger.info(f"Time interval delta: {time_interval_delta}")
Expand All @@ -297,7 +348,7 @@ async def run_active_sigma_queries_endpoint(
# Check if the current time is less than the last execution time
if current_time < query.last_execution_time or current_time - query.last_execution_time >= time_interval_delta:
logger.info(f"Running Sigma query: {query.rule_name}")
await execute_query(
task = execute_query(
RunActiveSigmaQueries(
query=query.rule_query,
time_interval=query.time_interval,
Expand All @@ -307,14 +358,16 @@ async def run_active_sigma_queries_endpoint(
),
session=db,
)
# Update the last execution time to the current time and commit the changes
# ! Remove commented out code after testing ! #
tasks.append(task)
# Update the last execution time to the current time
query.last_execution_time = current_time
await db.commit()
else:
time_comparison = current_time - query.last_execution_time
logger.info(f"Time comparison: {time_comparison}")
logger.info(f"Skipping Sigma query because the time interval has not passed: {query.rule_name}")

# Run all tasks concurrently
results = await asyncio.gather(*tasks)

# Commit the changes to the database
await db.commit()

return SigmaQueryOutResponse(
success=True,
message="Successfully ran the active Sigma queries.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from fastapi import HTTPException
from loguru import logger
from sqlalchemy.ext.asyncio import AsyncSession
import asyncio

from app.connectors.wazuh_indexer.schema.sigma import RunActiveSigmaQueries
from app.connectors.wazuh_indexer.utils.universal import create_wazuh_indexer_client
from app.connectors.wazuh_indexer.utils.universal import create_wazuh_indexer_client, create_wazuh_indexer_client_async
from app.incidents.schema.incident_alert import CreatedAlertPayload
from app.incidents.services.incident_alert import add_asset_to_copilot_alert
from app.incidents.services.incident_alert import build_alert_context_payload
Expand Down Expand Up @@ -105,7 +106,7 @@ async def send_query_to_opensearch(
session: AsyncSession = None,
) -> List[dict]:
try:
response = es_client.search(index=index, body=query)
response = await es_client.search(index=index, body=query)
logger.info(f"Response: {response}")
hits = response["hits"]["hits"]
return await process_hits(hits, rule_name, session)
Expand Down Expand Up @@ -148,7 +149,7 @@ async def process_hits(hits, rule_name, session: AsyncSession):


async def execute_query(payload: RunActiveSigmaQueries, session: AsyncSession = None):
client = await create_wazuh_indexer_client()
client = await create_wazuh_indexer_client_async()
formatted_query = await format_opensearch_query(payload.query, payload.time_interval, payload.last_execution_time)
logger.info(f"Executing query: {formatted_query}")
results = await send_query_to_opensearch(client, formatted_query, payload.rule_name, index=payload.index, session=session)
Expand Down
40 changes: 39 additions & 1 deletion backend/app/connectors/wazuh_indexer/utils/universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Iterable
from typing import Tuple

from elasticsearch7 import Elasticsearch
from elasticsearch7 import Elasticsearch, AsyncElasticsearch
from fastapi import HTTPException
from loguru import logger

Expand Down Expand Up @@ -112,6 +112,44 @@ async def create_wazuh_indexer_client(connector_name: str = "Wazuh-Indexer") ->
detail=f"Failed to create Elasticsearch client: {e}",
)

async def create_wazuh_indexer_client_async(connector_name: str = "Wazuh-Indexer") -> AsyncElasticsearch:
"""
Returns an Elasticsearch client for the Wazuh Indexer service.
Returns:
Elasticsearch: Elasticsearch client for the Wazuh Indexer service.
"""
# attributes = get_connector_info_from_db(connector_name)
async with get_db_session() as session: # This will correctly enter the context manager
attributes = await get_connector_info_from_db(connector_name, session)
if attributes is None:
raise HTTPException(
status_code=500,
detail=f"No {connector_name} connector found in the database",
)
if attributes["connector_url"] == "https://127.1.1.1:9200":
raise HTTPException(
status_code=500,
detail=f"Please update the {connector_name} connector URL",
)
try:
return AsyncElasticsearch(
[attributes["connector_url"]],
http_auth=(
attributes["connector_username"],
attributes["connector_password"],
),
verify_certs=False,
timeout=15,
max_retries=10,
retry_on_timeout=False,
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to create Elasticsearch client: {e}",
)


async def format_node_allocation(node_allocation):
"""
Expand Down

0 comments on commit 9db0af5

Please sign in to comment.