diff --git a/src/wxflow/__init__.py b/src/wxflow/__init__.py index 6da713d..70f6560 100644 --- a/src/wxflow/__init__.py +++ b/src/wxflow/__init__.py @@ -10,6 +10,7 @@ from .fsutils import chdir, cp, mkdir, mkdir_p, rm_p, rmdir from .jinja import Jinja from .logger import Logger, logit +from .sqlitedb import SQLiteDB from .task import Task from .template import Template, TemplateConstants from .timetools import * diff --git a/src/wxflow/sqlitedb.py b/src/wxflow/sqlitedb.py new file mode 100644 index 0000000..7099cd1 --- /dev/null +++ b/src/wxflow/sqlitedb.py @@ -0,0 +1,188 @@ +import sqlite3 +from typing import Any, List, Optional, Tuple + +__all__ = ["SQLiteDB"] + + +class SQLiteDB: + """ + A class for interacting with an SQLite3 database. + + Parameters: + db_name (str): The name of the SQLite database file. + + Attributes: + db_name (str): The name of the SQLite database file. + connection (sqlite3.Connection): The connection object for the database. + + """ + + def __init__(self, db_name: str) -> None: + self.db_name = db_name + self.connection: Optional[sqlite3.Connection] = None + + def connect(self) -> None: + """ + Connects to the SQLite database. + + """ + + try: + self.connection = sqlite3.connect(self.db_name) + except sqlite3.OperationalError as exc: + raise sqlite3.OperationalError(exc) + + def disconnect(self) -> None: + """ + Disconnects from the SQLite database. + + """ + + if self.connection: + self.connection.close() + + def execute_query(self, query: str, params: Optional[Tuple[Any, ...]] = None) -> sqlite3.Cursor: + """ + Executes an SQL query. + + Parameters: + query (str): The SQL query to execute. + params (tuple, optional): The parameters to be passed to the query. + + Returns: + cursor (sqlite3.Cursor): The cursor object. + + """ + + cursor = self.connection.cursor() + if params: + cursor.execute(query, params) + else: + cursor.execute(query) + self.connection.commit() + return cursor + + def create_table(self, table_name: str, columns: List[str]) -> None: + """ + Creates a table in the database. + + Parameters: + table_name (str): The name of the table to create. + columns (list): The list of column definitions. + + """ + + query = f"CREATE TABLE IF NOT EXISTS {table_name} ({', '.join(columns)})" + self.execute_query(query) + + def add_column(self, table_name: str, column_name: str, column_type: str) -> None: + """ + Adds a column to an existing table. + + Parameters: + table_name (str): The name of the table. + column_name (str): The name of the column to add. + column_type (str): The data type of the column. + + """ + + query = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}" + self.execute_query(query) + + def remove_column(self, table_name: str, column_name: str) -> None: + """ + Removes a column from an existing table. + + Parameters: + table_name (str): The name of the table. + column_name (str): The name of the column to remove. + + """ + + try: + query = f"ALTER TABLE {table_name} DROP COLUMN {column_name}" + self.execute_query(query) + except sqlite3.OperationalError as exc: + query = f"PRAGMA table_info({table_name})" + cursor = self.execute_query(query) + columns = [column[1] for column in cursor.fetchall()] + if column_name not in columns: + raise ValueError(f"Column '{column_name}' does not exist in table '{table_name}'") + raise sqlite3.OperationalError(exc) + + def update_data( + self, + table_name: str, + column_name: str, + new_value: Any, + condition_column: str, + condition_value: Any + ) -> None: + """ + Updates data in a table. + + Parameters: + table_name (str): The name of the table. + column_name (str): The name of the column to update. + new_value (any): The new value for the column. + condition_column (str): The column to use for the condition. + condition_value (any): The value to use in the condition. + + """ + + query = f"UPDATE {table_name} SET {column_name} = ? WHERE {condition_column} = ?" + self.execute_query(query, (new_value, condition_value)) + + def insert_data(self, table_name: str, values: List[Any]) -> None: + """ + Inserts data into a table. + + Parameters: + table_name (str): The name of the table. + values (list): The values to insert. + + """ + + placeholders = ", ".join(["?"] * len(values)) + query = f"INSERT INTO {table_name} VALUES ({placeholders})" + self.execute_query(query, values) + + def fetch_data( + self, + table_name: str, + columns: Optional[List[str]] = None, + condition: Optional[str] = None + ) -> List[Tuple]: + """ + Fetches data from a table. + + Parameters: + table_name (str): The name of the table. + columns (list, optional): The list of columns to fetch. + condition (str, optional): The condition to use in the query. + + Returns: + result (list): The fetched data. + + """ + + column_names = "*" if not columns else ", ".join(columns) + query = f"SELECT {column_names} FROM {table_name}" + if condition: + query += f" WHERE {condition}" + cursor = self.execute_query(query) + return cursor.fetchall() + + def remove_data(self, table_name: str, condition_column: str, condition_value: Any) -> None: + """ + Removes data from a table. + + Parameters: + table_name (str): The name of the table. + condition_column (str): The column to use for the condition. + condition_value (any): The value to use in the condition. + + """ + + query = f"DELETE FROM {table_name} WHERE {condition_column} = ?" + self.execute_query(query, (condition_value,)) diff --git a/tests/test_sqlitedb.py b/tests/test_sqlitedb.py new file mode 100644 index 0000000..2909987 --- /dev/null +++ b/tests/test_sqlitedb.py @@ -0,0 +1,122 @@ +import pytest + +from wxflow import SQLiteDB + + +@pytest.fixture(scope="module") +def db(): + # Create an in-memory SQLite database for testing + db = SQLiteDB(":memory:") + db.connect() + + # Create a test table + table_name = "test_table" + columns = ["id INTEGER PRIMARY KEY", "name TEXT", "age INTEGER"] + db.create_table(table_name, columns) + + yield db + + # Disconnect from the database + db.disconnect() + + +def test_create_table(db): + # Verify that the test table exists + assert table_exists(db, "test_table") + + +def test_add_column(db): + # Add a new column to the test table + column_name = "address" + column_type = "TEXT" + db.add_column("test_table", column_name, column_type) + + # Verify that the column exists in the test table + assert column_exists(db, "test_table", column_name) + + +def test_update_data(db): + # Insert test data into the table + values = [1, "Alice", 25, 'Apt 101'] + db.insert_data("test_table", values) + + # Update the age of the record + new_age = 30 + db.update_data("test_table", "age", new_age, "name", "Alice") + + # Fetch the updated data + result = db.fetch_data("test_table", condition="name='Alice'") + + # Verify that the age is updated correctly + assert result[0][2] == new_age + + +def test_remove_column(db): + # Removes a column from the test table + column_name = "address" + db.remove_column("test_table", column_name) + + # Verify that the column exists in the test table + assert not column_exists(db, "test_table", column_name) + + +def test_remove_column_raises_error_when_column_not_exists(db): + table_name = "test_table" + column_name = "vacation address" + + with pytest.raises(ValueError, match=f"Column '{column_name}' does not exist in table '{table_name}'"): + db.remove_column("test_table", column_name) + + +def test_insert_data(db): + # Insert test data into the table + values = [2, "Bob", 35] + db.insert_data("test_table", values) + + # Fetch all data from the table + result = db.fetch_data("test_table") + + # Verify that the inserted data is present in the table + assert len(result) == 2 + + +def test_fetch_data(db): + # Insert test data into the table + values = [3, "Charlie", 40] + db.insert_data("test_table", values) + + # Fetch data from the table + result = db.fetch_data("test_table", condition="age > 30") + + # Verify that the fetched data meets the condition + assert len(result) == 2 + + +def test_remove_data(db): + # Insert test data into the table + values = [4, "David", 45] + db.insert_data("test_table", values) + + # Remove a record from the table + db.remove_data("test_table", "name", "David") + + # Fetch all data from the table + result = db.fetch_data("test_table") + + # Verify that the removed data is not present in the table + assert len(result) == 3 + + +# Helper functions + +def table_exists(db, table_name): + query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'" + cursor = db.execute_query(query) + return cursor.fetchone() is not None + + +def column_exists(db, table_name, column_name): + query = f"PRAGMA table_info({table_name})" + cursor = db.execute_query(query) + columns = [column[1] for column in cursor.fetchall()] + return column_name in columns