Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sqlite3 interface #17

Merged
merged 4 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/wxflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .fsutils import chdir, cp, mkdir, mkdir_p, rm_p, rmdir
from .jinja import Jinja
from .logger import Logger, logit
from .sqlitedb import SQLiteDB
from .task import Task
from .template import Template, TemplateConstants
from .timetools import *
Expand Down
188 changes: 188 additions & 0 deletions src/wxflow/sqlitedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import sqlite3
from typing import Any, List, Optional, Tuple

__all__ = ["SQLiteDB"]


class SQLiteDB:
"""
A class for interacting with an SQLite3 database.

Parameters:
db_name (str): The name of the SQLite database file.

Attributes:
db_name (str): The name of the SQLite database file.
connection (sqlite3.Connection): The connection object for the database.

"""

def __init__(self, db_name: str) -> None:
self.db_name = db_name
self.connection: Optional[sqlite3.Connection] = None

def connect(self) -> None:
"""
Connects to the SQLite database.

"""

try:
self.connection = sqlite3.connect(self.db_name)
except sqlite3.OperationalError as exc:
raise sqlite3.OperationalError(exc)

Check warning on line 33 in src/wxflow/sqlitedb.py

View check run for this annotation

Codecov / codecov/patch

src/wxflow/sqlitedb.py#L32-L33

Added lines #L32 - L33 were not covered by tests

def disconnect(self) -> None:
"""
Disconnects from the SQLite database.

"""

if self.connection:
self.connection.close()

def execute_query(self, query: str, params: Optional[Tuple[Any, ...]] = None) -> sqlite3.Cursor:
"""
Executes an SQL query.

Parameters:
query (str): The SQL query to execute.
params (tuple, optional): The parameters to be passed to the query.

Returns:
cursor (sqlite3.Cursor): The cursor object.

"""

cursor = self.connection.cursor()
if params:
cursor.execute(query, params)
else:
cursor.execute(query)
self.connection.commit()
return cursor

def create_table(self, table_name: str, columns: List[str]) -> None:
"""
Creates a table in the database.

Parameters:
table_name (str): The name of the table to create.
columns (list): The list of column definitions.

"""

query = f"CREATE TABLE IF NOT EXISTS {table_name} ({', '.join(columns)})"
self.execute_query(query)

def add_column(self, table_name: str, column_name: str, column_type: str) -> None:
"""
Adds a column to an existing table.

Parameters:
table_name (str): The name of the table.
column_name (str): The name of the column to add.
column_type (str): The data type of the column.

"""

query = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}"
self.execute_query(query)

def remove_column(self, table_name: str, column_name: str) -> None:
"""
Removes a column from an existing table.

Parameters:
table_name (str): The name of the table.
column_name (str): The name of the column to remove.

"""

try:
query = f"ALTER TABLE {table_name} DROP COLUMN {column_name}"
self.execute_query(query)
except sqlite3.OperationalError as exc:
query = f"PRAGMA table_info({table_name})"
cursor = self.execute_query(query)
columns = [column[1] for column in cursor.fetchall()]
if column_name not in columns:
raise ValueError(f"Column '{column_name}' does not exist in table '{table_name}'")
raise sqlite3.OperationalError(exc)

def update_data(
self,
table_name: str,
column_name: str,
new_value: Any,
condition_column: str,
condition_value: Any
) -> None:
"""
Updates data in a table.

Parameters:
table_name (str): The name of the table.
column_name (str): The name of the column to update.
new_value (any): The new value for the column.
condition_column (str): The column to use for the condition.
condition_value (any): The value to use in the condition.

"""

query = f"UPDATE {table_name} SET {column_name} = ? WHERE {condition_column} = ?"
self.execute_query(query, (new_value, condition_value))

def insert_data(self, table_name: str, values: List[Any]) -> None:
"""
Inserts data into a table.

Parameters:
table_name (str): The name of the table.
values (list): The values to insert.

"""

placeholders = ", ".join(["?"] * len(values))
query = f"INSERT INTO {table_name} VALUES ({placeholders})"
self.execute_query(query, values)

def fetch_data(
self,
table_name: str,
columns: Optional[List[str]] = None,
condition: Optional[str] = None
) -> List[Tuple]:
"""
Fetches data from a table.

Parameters:
table_name (str): The name of the table.
columns (list, optional): The list of columns to fetch.
condition (str, optional): The condition to use in the query.

Returns:
result (list): The fetched data.

"""

column_names = "*" if not columns else ", ".join(columns)
query = f"SELECT {column_names} FROM {table_name}"
if condition:
query += f" WHERE {condition}"
cursor = self.execute_query(query)
return cursor.fetchall()

def remove_data(self, table_name: str, condition_column: str, condition_value: Any) -> None:
"""
Removes data from a table.

Parameters:
table_name (str): The name of the table.
condition_column (str): The column to use for the condition.
condition_value (any): The value to use in the condition.

"""

query = f"DELETE FROM {table_name} WHERE {condition_column} = ?"
self.execute_query(query, (condition_value,))
122 changes: 122 additions & 0 deletions tests/test_sqlitedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import pytest

from wxflow import SQLiteDB


@pytest.fixture(scope="module")
def db():
# Create an in-memory SQLite database for testing
db = SQLiteDB(":memory:")
db.connect()

# Create a test table
table_name = "test_table"
columns = ["id INTEGER PRIMARY KEY", "name TEXT", "age INTEGER"]
db.create_table(table_name, columns)

yield db

# Disconnect from the database
db.disconnect()


def test_create_table(db):
# Verify that the test table exists
assert table_exists(db, "test_table")


def test_add_column(db):
# Add a new column to the test table
column_name = "address"
column_type = "TEXT"
db.add_column("test_table", column_name, column_type)

# Verify that the column exists in the test table
assert column_exists(db, "test_table", column_name)


def test_update_data(db):
# Insert test data into the table
values = [1, "Alice", 25, 'Apt 101']
db.insert_data("test_table", values)

# Update the age of the record
new_age = 30
db.update_data("test_table", "age", new_age, "name", "Alice")

# Fetch the updated data
result = db.fetch_data("test_table", condition="name='Alice'")

# Verify that the age is updated correctly
assert result[0][2] == new_age


def test_remove_column(db):
# Removes a column from the test table
column_name = "address"
db.remove_column("test_table", column_name)

# Verify that the column exists in the test table
assert not column_exists(db, "test_table", column_name)


def test_remove_column_raises_error_when_column_not_exists(db):
table_name = "test_table"
column_name = "vacation address"

with pytest.raises(ValueError, match=f"Column '{column_name}' does not exist in table '{table_name}'"):
db.remove_column("test_table", column_name)


def test_insert_data(db):
# Insert test data into the table
values = [2, "Bob", 35]
db.insert_data("test_table", values)

# Fetch all data from the table
result = db.fetch_data("test_table")

# Verify that the inserted data is present in the table
assert len(result) == 2


def test_fetch_data(db):
# Insert test data into the table
values = [3, "Charlie", 40]
db.insert_data("test_table", values)

# Fetch data from the table
result = db.fetch_data("test_table", condition="age > 30")

# Verify that the fetched data meets the condition
assert len(result) == 2


def test_remove_data(db):
# Insert test data into the table
values = [4, "David", 45]
db.insert_data("test_table", values)

# Remove a record from the table
db.remove_data("test_table", "name", "David")

# Fetch all data from the table
result = db.fetch_data("test_table")

# Verify that the removed data is not present in the table
assert len(result) == 3


# Helper functions

def table_exists(db, table_name):
query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'"
cursor = db.execute_query(query)
return cursor.fetchone() is not None


def column_exists(db, table_name, column_name):
query = f"PRAGMA table_info({table_name})"
cursor = db.execute_query(query)
columns = [column[1] for column in cursor.fetchall()]
return column_name in columns