Skip to content

Commit

Permalink
Migrations (#1181)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomchop authored Nov 29, 2024
1 parent f1dcd42 commit de59609
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 1 deletion.
16 changes: 16 additions & 0 deletions core/database_arango.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

from .interfaces import AbstractYetiConnector

CODE_DB_VERSION = 2

LINK_TYPE_TO_GRAPH = {
"tagged": "tags",
"stix": "stix",
Expand Down Expand Up @@ -58,6 +60,7 @@ def connect(
username: str = None,
password: str = None,
database: str = None,
check_db_sync: bool = False,
):
host = host or yeti_config.get("arangodb", "host")
port = port or yeti_config.get("arangodb", "port")
Expand Down Expand Up @@ -88,6 +91,8 @@ def connect(
sys_db.create_database(database)

self.db = client.db(database, username=username, password=password)
if check_db_sync:
self.check_database_version()

self.create_edge_definition(
self.graph("tags"),
Expand Down Expand Up @@ -120,6 +125,17 @@ def connect(
self.create_analyzers()
self.create_views()

def check_database_version(self, skip_if_testing: bool = True):
if TESTING and skip_if_testing:
return
system = list(self.db.collection("system").all())
if not system:
raise RuntimeError("Database version not found, please run migrations.")
if system[0]["db_version"] != CODE_DB_VERSION:
raise RuntimeError(
f"Database version mismatch. Expected {CODE_DB_VERSION}, got {system[0]['db_version']}"
)

def create_analyzers(self):
self.db.create_analyzer(
name="norm",
Expand Down
Empty file added core/migrations/__init__.py
Empty file.
63 changes: 63 additions & 0 deletions core/migrations/arangodb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import time

from core.database_arango import ASYNC_JOB_WAIT_TIME, ArangoDatabase
from core.migrations import migration


class ArangoMigrationManager(migration.MigrationManager):
DB_TYPE = "arangodb"

def connect_to_db(self):
self.db = ArangoDatabase()
self.db.connect(check_db_sync=False)

system_coll = self.db.collection("system")
job = system_coll.all()
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
migrations = list(job.result())
if not migrations:
job = system_coll.insert(
{"db_version": 0, "db_type": self.DB_TYPE},
)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)

job = system_coll.all()
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
migrations = list(job.result())

db_version = migrations[0]["db_version"]
db_type = migrations[0]["db_type"]

self.db_version = db_version
self.db_type = db_type

def update_db_version(self, version: int):
job = self.db.collection("system").update_match(
{"db_version": self.db_version, "db_type": self.DB_TYPE},
{"db_version": version},
)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
self.db_version = version


def migration_0():
pass


def migration_1():
from core.schemas import observable

for obs in observable.Observable.list():
obs.save()


ArangoMigrationManager.register_migration(migration_0)
ArangoMigrationManager.register_migration(migration_1)

if __name__ == "__main__":
migration_manager = ArangoMigrationManager()
migration_manager.migrate_to_latest()
30 changes: 30 additions & 0 deletions core/migrations/migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Callable


class MigrationManager:
MIGRATIONS: list[Callable] = []

def __init__(self):
self.connect_to_db()

def connect_to_db(self):
raise NotImplementedError

def update_db_version(self, version: int):
raise NotImplementedError

def migrate_to_latest(self, stop_at: int | None = None):
for idx, migration in enumerate(self.MIGRATIONS):
if stop_at is not None and idx >= stop_at:
print(f"Stopping at migration {idx}")
elif idx >= self.db_version and (stop_at is None or idx < stop_at):
print(f"Running migration {idx} -> {idx + 1}")
migration()
self.update_db_version(idx + 1)
else:
print(f"Skipping migration {idx}, current version is {self.db_version}")
continue

@classmethod
def register_migration(cls, migration):
cls.MIGRATIONS.append(migration)
2 changes: 2 additions & 0 deletions extras/docker/docker-entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ elif [[ "$1" = 'toggle-user' ]]; then
poetry run python yetictl/cli.py toggle-user "${@:2}"
elif [[ "$1" = 'toggle-admin' ]]; then
poetry run python yetictl/cli.py toggle-admin "${@:2}"
elif [[ "$1" = 'migrate-arangodb' ]]; then
poetry run python yetictl/cli.py migrate-arangodb "${@:2}"
elif [[ "$1" = 'envshell' ]]; then
poetry shell
else
Expand Down
48 changes: 48 additions & 0 deletions tests/migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import time
import unittest

from core.migrations import arangodb


class ArangoMigrationTest(unittest.TestCase):
def setUp(self):
self.migration_manager = arangodb.ArangoMigrationManager()
self.migration_manager.update_db_version(0)

def test_migration_init(self):
self.assertEqual(self.migration_manager.db_version, 0)

def test_migration_0(self):
self.migration_manager.migrate_to_latest(stop_at=1)
self.assertEqual(self.migration_manager.db_version, 1)

def test_migration_1(self):
observable_col = self.migration_manager.db.collection("observables")
observable_col.truncate()
observable_col.insert(
{
"value": "test.com",
"type": "hostname",
"root_type": "observable",
"created": "2024-11-14T11:58:49.757379Z",
}
)
observable_col.insert(
{
"value": "test.com123",
"type": "hostname",
"root_type": "observable",
"created": "2024-11-14T11:58:49.757379Z",
}
)
self.migration_manager.migrate_to_latest(stop_at=2)
self.assertEqual(self.migration_manager.db_version, 2)
job = observable_col.all()
while job.status() != "done":
time.sleep(0.1)
obs = list(job.result())
self.assertEqual(len(obs), 2)
self.assertEqual(obs[0]["value"], "test.com")
self.assertEqual(obs[0]["is_valid"], True)
self.assertEqual(obs[1]["value"], "test.com123")
self.assertEqual(obs[1]["is_valid"], False)
12 changes: 11 additions & 1 deletion yetictl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def list_tasks(task_type="") -> None:
@cli.command()
@click.argument("task_name")
@click.argument("task_params", required=False)
def run_task(task_name: str, task_params: dict = None) -> None:
def run_task(task_name: str, task_params: dict | None = None) -> None:
"""Runs a task."""
# Load all tasks. Take into account new tasks that have not been registered
logging.getLogger().setLevel(logging.INFO)
Expand All @@ -149,5 +149,15 @@ def run_task(task_name: str, task_params: dict = None) -> None:
click.echo(traceback.format_exc())


@cli.command()
@click.argument("stop_at", required=False)
def migrate_arangodb(stop_at: int | None = None) -> None:
"""Runs the database migrations."""
from core.migrations.arangodb import ArangoMigrationManager

migration_manager = ArangoMigrationManager()
migration_manager.migrate_to_latest(stop_at=stop_at)


if __name__ == "__main__":
cli()

0 comments on commit de59609

Please sign in to comment.