diff --git a/src/database.hpp b/src/database.hpp index 3bffb51..2df9ebb 100644 --- a/src/database.hpp +++ b/src/database.hpp @@ -2,14 +2,28 @@ #include "statement.hpp" #include "table.hpp" #include "utils.hpp" +#include +#include #include +#include #include #include #include class Database { public: - explicit Database(const std::string &name) : name(name) {} + explicit Database(const std::string &name, + const std::string &data_dir = "data") + : name(name), data_dir(data_dir) {} + + ~Database() { + namespace fs = std::filesystem; + if (!fs::exists(data_dir)) { + fs::create_directory(data_dir); + } + fs::path db_path = fs::path(data_dir) / (name + ".db"); + serialize(db_path.string()); + } void executeStatement(SQLStatement *stmt) { switch (stmt->type) { @@ -32,7 +46,6 @@ class Database { } case SQLStatementType::INSERT: { auto insert_stmt = static_cast(stmt); - // debug(insert_stmt->table_name, insert_stmt->values.size()); if (tables.find(insert_stmt->table_name) == tables.end()) { throw DatabaseError("Table does not exist", insert_stmt->line_number); } @@ -79,7 +92,72 @@ class Database { } } + std::string getName() const { return name; } + + // Serialize database to a file + void serialize(const std::string &filepath) const { + std::ofstream out(filepath); + if (!out) { + throw DatabaseError("Failed to open file for serialization", 0); + } + + // Write database name + out << "DATABASE " << name << "\n"; + + // Write number of tables + out << "TABLES " << tables.size() << "\n\n"; + + // Write each table + for (const auto &[table_name, table] : tables) { + table->serialize(out); + out << "\n"; // Add a blank line between tables + } + } + + // Deserialize database from a file + static std::unique_ptr deserialize(const std::string &filepath) { + std::ifstream in(filepath); + if (!in) { + throw DatabaseError("Failed to open file for deserialization", 0); + } + + std::string line, word; + + // Read database name + std::getline(in, line); + std::istringstream iss(line); + iss >> word; // Skip "DATABASE" + std::string db_name; + iss >> db_name; + + // Create database instance + auto db = std::make_unique(db_name); + + // Read number of tables + std::getline(in, line); + iss.clear(); + iss.str(line); + iss >> word; // Skip "TABLES" + size_t num_tables; + iss >> num_tables; + + // Skip the blank line + std::getline(in, line); + + // Read each table + for (size_t i = 0; i < num_tables; ++i) { + auto table = Table::deserialize(in); + db->tables[table->getName()] = std::move(table); + + // Skip the blank line between tables + std::getline(in, line); + } + + return db; + } + private: std::string name; + std::string data_dir; std::unordered_map> tables; }; \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index 7633c78..925585b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -2,14 +2,34 @@ #include "parser.hpp" #include "statement.hpp" #include "utils.hpp" +#include #include #include -FileWriter file_writer; +OutputWriter file_writer; +namespace fs = std::filesystem; + +void loadDatabases( + std::unordered_map> &databases, + const std::string &data_dir) { + if (!fs::exists(data_dir)) { + fs::create_directory(data_dir); + return; + } + + for (const auto &entry : fs::directory_iterator(data_dir)) { + if (entry.path().extension() == ".db") { + auto db = Database::deserialize(entry.path().string()); + databases[db->getName()] = std::move(db); + } + } +} int main(int argc, char *argv[]) { std::unordered_map> databases; Database *current_database = nullptr; + const std::string data_dir = "data"; + try { if (argc != 3) { throw ArgumentError("Argument number error"); @@ -17,6 +37,9 @@ int main(int argc, char *argv[]) { throw ArgumentError("Input file must be a SQL file"); } + // Load existing databases + loadDatabases(databases, data_dir); + // Read input SQL file and split into statements std::ifstream input_file(argv[1]); if (!input_file.is_open()) { @@ -38,7 +61,8 @@ int main(int argc, char *argv[]) { parsed_statement->line_number); } databases[parsed_statement->getDatabaseName()] = - std::make_unique(parsed_statement->getDatabaseName()); + std::make_unique(parsed_statement->getDatabaseName(), + data_dir); } else if (parsed_statement->type == SQLStatementType::USE_DATABASE) { if (databases.find(parsed_statement->getDatabaseName()) == databases.end()) { diff --git a/src/table.hpp b/src/table.hpp index 91992c8..583e1ba 100644 --- a/src/table.hpp +++ b/src/table.hpp @@ -3,20 +3,29 @@ #include "statement.hpp" #include "utils.hpp" #include +#include #include +#include #include #include +#include class Table { public: Table(const std::string &name, const std::vector &columns) : name(name), columns(columns) { - // Build column index + rebuildColumnIndex(); + } + + void rebuildColumnIndex() { + column_index.clear(); for (size_t i = 0; i < columns.size(); i++) { column_index[columns[i].name] = i; } } + std::string getName() const { return name; } + void insert(const std::vector &row) { // Validate number of values matches number of columns if (row.size() != columns.size()) { @@ -299,6 +308,156 @@ class Table { file_writer.write(results); } + // Helper function to quote string + static std::string quoteString(const std::string& str) { + return "\"" + str + "\""; + } + + // Helper function to read a quoted string + static std::string readQuotedString(std::istream& in) { + char c; + std::string result; + + // Skip leading whitespace + while (in.get(c) && std::isspace(c)) {} + in.unget(); + + if (in.peek() != '"') { + throw TableError("Expected quoted string"); + } + + in.get(); // Skip opening quote + while (in.get(c) && c != '"') { + result += c; + } + + if (c != '"') { + throw TableError("Unterminated quoted string"); + } + + return result; + } + + void serialize(std::ofstream &out) const { + // Write table name + out << "TABLE " << quoteString(name) << "\n"; + + // Write columns + out << "COLUMNS " << columns.size() << "\n"; + for (const auto &col : columns) { + out << quoteString(col.name) << " " << TOKEN_STR.find(col.type)->second << "\n"; + } + + // Write rows + out << "ROWS " << rows.size() << "\n"; + for (const auto &row : rows) { + for (size_t i = 0; i < row.size(); ++i) { + if (i > 0) out << " "; + const auto &value = row[i]; + if (std::holds_alternative(value)) { + out << "INT " << std::get(value); + } else if (std::holds_alternative(value)) { + out << "FLOAT " << std::get(value); + } else if (std::holds_alternative(value)) { + out << "TEXT " << quoteString(std::get(value)); + } + } + out << "\n"; + } + } + + static std::unique_ptr deserialize(std::ifstream &in) { + std::string line, word; + + // Read table name + std::getline(in, line); + std::istringstream iss(line); + iss >> word; // Skip "TABLE" + std::string table_name = readQuotedString(iss); + + // Read columns + std::getline(in, line); + iss.clear(); + iss.str(line); + iss >> word; // Skip "COLUMNS" + size_t num_columns; + iss >> num_columns; + + std::vector columns; + columns.reserve(num_columns); + + for (size_t i = 0; i < num_columns; ++i) { + std::getline(in, line); + iss.clear(); + iss.str(line); + + ColumnDefinition col; + col.name = readQuotedString(iss); + std::string type_str; + iss >> type_str; + + // Convert type string to TokenType + bool found = false; + for (const auto &[token_type, str] : TOKEN_STR) { + if (str == type_str) { + col.type = token_type; + found = true; + break; + } + } + + if (!found) { + throw TableError("Invalid column type: " + type_str); + } + + columns.push_back(col); + } + + // Create table + auto table = std::make_unique
(table_name, columns); + + // Read rows + std::getline(in, line); + iss.clear(); + iss.str(line); + iss >> word; // Skip "ROWS" + size_t num_rows; + iss >> num_rows; + + for (size_t i = 0; i < num_rows; ++i) { + std::getline(in, line); + iss.clear(); + iss.str(line); + + std::vector row; + row.reserve(num_columns); + + for (size_t j = 0; j < num_columns; ++j) { + std::string type; + iss >> type; + + if (type == "INT") { + int value; + iss >> value; + row.push_back(value); + } else if (type == "FLOAT") { + double value; + iss >> value; + row.push_back(value); + } else if (type == "TEXT") { + std::string value = readQuotedString(iss); + row.push_back(value); + } else { + throw TableError("Invalid value type: " + type); + } + } + + table->rows.push_back(std::move(row)); + } + + return table; + } + private: std::string name; std::vector columns; diff --git a/src/utils.hpp b/src/utils.hpp index f08fe97..9933ff5 100644 --- a/src/utils.hpp +++ b/src/utils.hpp @@ -1,5 +1,6 @@ #pragma once +// Standard library includes #include #include #include @@ -9,13 +10,17 @@ #include #include -// Debugging macro +//----------------------------------------------------------------------------- +// Debug utilities +//----------------------------------------------------------------------------- void debug_out() { std::cerr << '\n'; } + template void debug_out(Head H, Tail... T) { std::cerr << ' ' << H; debug_out(T...); } + #ifdef DEBUG #define debug(...) \ std::cerr << '[' << #__VA_ARGS__ << "]:", debug_out(__VA_ARGS__) @@ -23,7 +28,9 @@ template void debug_out(Head H, Tail... T) { #define debug(...) #endif -// Error classes +//----------------------------------------------------------------------------- +// Custom exception classes +//----------------------------------------------------------------------------- class ArgumentError : public std::runtime_error { public: @@ -69,11 +76,16 @@ class TableError : public SQLError { line) {} }; -// Token types and value types +//----------------------------------------------------------------------------- +// SQL types and tokens +//----------------------------------------------------------------------------- +// Value type for storing SQL data (can be string, int, or double) using Value = std::variant; + +// All possible SQL token types enum class TokenType { - EOF_TOKEN, + // Keywords CREATE, DATABASE, USE, @@ -93,13 +105,19 @@ enum class TokenType { ON, AND, OR, + + // Data types INTEGER, FLOAT, TEXT, + + // Literals and identifiers IDENTIFIER, INTEGER_LITERAL, FLOAT_LITERAL, STRING_LITERAL, + + // Operators and punctuation COMMA, SEMICOLON, LEFT_PAREN, @@ -111,9 +129,13 @@ enum class TokenType { DOT, ASTERISK, PLUS, - MINUS + MINUS, + + // Special tokens + EOF_TOKEN }; +// Mapping from string to token type (for lexical analysis) const std::unordered_map TOKEN_MAP = { {"CREATE", TokenType::CREATE}, {"DATABASE", TokenType::DATABASE}, {"USE", TokenType::USE}, {"TABLE", TokenType::TABLE}, @@ -131,9 +153,9 @@ const std::unordered_map TOKEN_MAP = { {"=", TokenType::EQUALS}, {">", TokenType::GREATER_THAN}, {"<", TokenType::LESS_THAN}, {".", TokenType::DOT}, {"*", TokenType::ASTERISK}, {"+", TokenType::PLUS}, - {"-", TokenType::MINUS}, {"!=", TokenType::INEQUALS}, -}; + {"-", TokenType::MINUS}, {"!=", TokenType::INEQUALS}}; +// Mapping from token type to string (for error messages and serialization) const std::unordered_map TOKEN_STR = { {TokenType::CREATE, "CREATE"}, {TokenType::DATABASE, "DATABASE"}, @@ -164,18 +186,22 @@ const std::unordered_map TOKEN_STR = { {TokenType::COMMA, "COMMA"}, {TokenType::SEMICOLON, "SEMICOLON"}, {TokenType::LEFT_PAREN, "LEFT_PAREN"}, - {TokenType::RIGHT_PAREN, "LEFT_PAREN"}, + {TokenType::RIGHT_PAREN, "RIGHT_PAREN"}, {TokenType::EQUALS, "EQUALS"}, + {TokenType::INEQUALS, "INEQUALS"}, {TokenType::GREATER_THAN, "GREATER_THAN"}, {TokenType::LESS_THAN, "LESS_THAN"}, {TokenType::DOT, "DOT"}, {TokenType::ASTERISK, "ASTERISK"}, - {TokenType::EOF_TOKEN, "EOF_TOKEN"}, {TokenType::PLUS, "PLUS"}, {TokenType::MINUS, "MINUS"}, - {TokenType::INEQUALS, "INEQUALS"}, -}; + {TokenType::EOF_TOKEN, "EOF_TOKEN"}}; + +//----------------------------------------------------------------------------- +// SQL statement types and structures +//----------------------------------------------------------------------------- +// Token structure for lexical analysis struct Token { TokenType type; std::string value; @@ -185,47 +211,7 @@ struct Token { : type(t), value(v), line_number(line) {} }; -inline Token recognizeToken(std::string token) { - if (TOKEN_MAP.find(token) != TOKEN_MAP.end()) { - return Token(TOKEN_MAP.at(token), token); - } - - // Check if token is an integer literal - bool is_integer = true; - for (size_t i = 0; i < token.length(); i++) { - char c = token[i]; - if (i == 0 && c == '-') - continue; - if (!std::isdigit(c)) { - is_integer = false; - break; - } - } - if (is_integer) { - return Token(TokenType::INTEGER_LITERAL, token); - } - - // Check if token is a float literal - bool is_float = true; - int decimal_count = 0; - for (size_t i = 0; i < token.length(); i++) { - char c = token[i]; - if (i == 0 && c == '-') - continue; - if (c == '.') { - ++decimal_count; - } else if (!std::isdigit(c)) { - is_float = false; - break; - } - } - if (is_float && decimal_count == 1) { - return Token(TokenType::FLOAT_LITERAL, token); - } - - return Token(TokenType::IDENTIFIER, token); -} - +// SQL statement types enum class SQLStatementType { CREATE_DATABASE, USE_DATABASE, @@ -238,19 +224,7 @@ enum class SQLStatementType { INNER_JOIN }; -const std::unordered_map SQL_STATEMENT_TYPE_STR = - { - {SQLStatementType::CREATE_DATABASE, "CREATE_DATABASE"}, - {SQLStatementType::USE_DATABASE, "USE_DATABASE"}, - {SQLStatementType::CREATE_TABLE, "CREATE_TABLE"}, - {SQLStatementType::DROP_TABLE, "DROP_TABLE"}, - {SQLStatementType::INSERT, "INSERT"}, - {SQLStatementType::SELECT, "SELECT"}, - {SQLStatementType::UPDATE, "UPDATE"}, - {SQLStatementType::DELETE, "DELETE"}, - {SQLStatementType::INNER_JOIN, "INNER_JOIN"}, -}; - +// Structure for WHERE conditions in SQL statements struct WhereCondition { TokenType logic_operator = TokenType::EOF_TOKEN; std::string column_name_a; @@ -261,6 +235,7 @@ struct WhereCondition { Token value_b; }; +// Structure for SET conditions in UPDATE statements struct SetCondition { std::string column_name_a; std::string column_name_b; @@ -268,25 +243,31 @@ struct SetCondition { Token value; }; +// Structure for column definitions in CREATE TABLE statements struct ColumnDefinition { std::string name; TokenType type; }; +// Base class for SQL statements struct Statement { std::string content; int start_line; - Statement(const std::string &content, int line) : content(content), start_line(line) {} }; -class FileWriter { +//----------------------------------------------------------------------------- +// Output handling +//----------------------------------------------------------------------------- + +// Class for handling output to files +class OutputWriter { private: std::ofstream file; public: - explicit FileWriter() = default; + explicit OutputWriter() = default; void open(const std::string &filename) { file.open(filename); @@ -305,9 +286,7 @@ class FileWriter { } void write(const std::string &data) { file << "\"" << data << "\""; } - void write(int data) { file << data; } - void write(double data) { file << data; } void write(const Value &data) { @@ -344,81 +323,120 @@ class FileWriter { } }; -extern FileWriter file_writer; +extern OutputWriter file_writer; +//----------------------------------------------------------------------------- // Utility functions +//----------------------------------------------------------------------------- -inline Value convertTokenToValue(const Token &token) { - if (token.type == TokenType::INTEGER_LITERAL) { +// Convert a token to its corresponding value +Value convertTokenToValue(const Token &token) { + switch (token.type) { + case TokenType::INTEGER_LITERAL: return std::stoi(token.value); - } else if (token.type == TokenType::FLOAT_LITERAL) { - // debug(token.value); + case TokenType::FLOAT_LITERAL: return std::stod(token.value); + case TokenType::STRING_LITERAL: + return token.value; + default: + throw ParseError("Invalid token type for value conversion"); } - return token.value; // For STRING_LITERAL and other types } -inline std::vector splitStatements(std::ifstream &input_file) { +// Split input into individual SQL statements +std::vector splitStatements(std::ifstream &input_file) { std::vector statements; std::string current_statement; - int current_line = 1; - int statement_start_line = 1; - char ch; - - // Read file character by character and build statements - while (input_file.get(ch)) { - if (ch == '\n') { - current_line++; + std::string line; + int line_number = 0; + int start_line = 1; + + while (std::getline(input_file, line)) { + line_number++; + + // Skip empty lines and comments + if (line.empty() || line[0] == '#') { + continue; } - current_statement += ch; - if (ch == ';' && !current_statement.empty()) { - statements.emplace_back(current_statement, statement_start_line); + // Remove trailing whitespace + size_t end = line.find_last_not_of(" \t\r\n"); + if (end != std::string::npos) { + line = line.substr(0, end + 1); + } + + current_statement += line; + + // If line ends with semicolon, we have a complete statement + if (line.back() == ';') { + statements.emplace_back(current_statement, start_line); current_statement.clear(); - statement_start_line = current_line; + start_line = line_number + 1; + } else { + current_statement += ' '; } } - // check if there is a statement left - bool empty_statement = true; - for (char c : current_statement) { - if (!isspace(c)) - empty_statement = false; - } - if (!empty_statement) { - throw ParseError("Unterminated statement", current_line); + if (!current_statement.empty()) { + throw ParseError("Missing semicolon at end of statement", line_number); } return statements; } -// Add stream output operator for TokenType +// Token recognition +inline Token recognizeToken(std::string token) { + // Check for keywords first + if (TOKEN_MAP.find(token) != TOKEN_MAP.end()) { + return Token(TOKEN_MAP.at(token), token); + } + + // Check for integer literal + bool is_integer = true; + for (size_t i = 0; i < token.length(); i++) { + char c = token[i]; + if (i == 0 && c == '-') + continue; + if (!std::isdigit(c)) { + is_integer = false; + break; + } + } + if (is_integer) { + return Token(TokenType::INTEGER_LITERAL, token); + } + + // Check for float literal + bool is_float = true; + bool has_dot = false; + for (size_t i = 0; i < token.length(); i++) { + char c = token[i]; + if (i == 0 && c == '-') + continue; + if (c == '.' && !has_dot) { + has_dot = true; + continue; + } + if (!std::isdigit(c)) { + is_float = false; + break; + } + } + if (is_float && has_dot) { + return Token(TokenType::FLOAT_LITERAL, token); + } + + // If not a keyword or number, it's an identifier + return Token(TokenType::IDENTIFIER, token); +} + +// Stream output operator for TokenType inline std::ostream &operator<<(std::ostream &os, const TokenType &type) { auto it = TOKEN_STR.find(type); if (it != TOKEN_STR.end()) { os << it->second; } else { - os << "UNKNOWN_TOKEN_TYPE"; - } - return os; -} - -// Add stream output operator for Value -inline std::ostream &operator<<(std::ostream &os, const Value &value) { - std::visit([&os](const auto &v) { os << v; }, value); - return os; -} - -// Add stream output operator for vector of Value -inline std::ostream &operator<<(std::ostream &os, - const std::vector &values) { - os << "["; - for (size_t i = 0; i < values.size(); ++i) { - os << values[i]; - if (i < values.size() - 1) { - os << ", "; - } + os << "UNKNOWN"; } - os << "]"; return os; } \ No newline at end of file