Skip to content

Commit

Permalink
feat: add serialization of data
Browse files Browse the repository at this point in the history
  • Loading branch information
huaruoji committed Nov 30, 2024
1 parent e56e382 commit e118564
Show file tree
Hide file tree
Showing 4 changed files with 403 additions and 124 deletions.
82 changes: 80 additions & 2 deletions src/database.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,28 @@
#include "statement.hpp"
#include "table.hpp"
#include "utils.hpp"
#include <filesystem>
#include <fstream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_map>

class Database {
public:
explicit Database(const std::string &name) : name(name) {}
explicit Database(const std::string &name,
const std::string &data_dir = "data")
: name(name), data_dir(data_dir) {}

~Database() {
namespace fs = std::filesystem;
if (!fs::exists(data_dir)) {
fs::create_directory(data_dir);
}
fs::path db_path = fs::path(data_dir) / (name + ".db");
serialize(db_path.string());
}

void executeStatement(SQLStatement *stmt) {
switch (stmt->type) {
Expand All @@ -32,7 +46,6 @@ class Database {
}
case SQLStatementType::INSERT: {
auto insert_stmt = static_cast<InsertStatement *>(stmt);
// debug(insert_stmt->table_name, insert_stmt->values.size());
if (tables.find(insert_stmt->table_name) == tables.end()) {
throw DatabaseError("Table does not exist", insert_stmt->line_number);
}
Expand Down Expand Up @@ -79,7 +92,72 @@ class Database {
}
}

std::string getName() const { return name; }

// Serialize database to a file
void serialize(const std::string &filepath) const {
std::ofstream out(filepath);
if (!out) {
throw DatabaseError("Failed to open file for serialization", 0);
}

// Write database name
out << "DATABASE " << name << "\n";

// Write number of tables
out << "TABLES " << tables.size() << "\n\n";

// Write each table
for (const auto &[table_name, table] : tables) {
table->serialize(out);
out << "\n"; // Add a blank line between tables
}
}

// Deserialize database from a file
static std::unique_ptr<Database> deserialize(const std::string &filepath) {
std::ifstream in(filepath);
if (!in) {
throw DatabaseError("Failed to open file for deserialization", 0);
}

std::string line, word;

// Read database name
std::getline(in, line);
std::istringstream iss(line);
iss >> word; // Skip "DATABASE"
std::string db_name;
iss >> db_name;

// Create database instance
auto db = std::make_unique<Database>(db_name);

// Read number of tables
std::getline(in, line);
iss.clear();
iss.str(line);
iss >> word; // Skip "TABLES"
size_t num_tables;
iss >> num_tables;

// Skip the blank line
std::getline(in, line);

// Read each table
for (size_t i = 0; i < num_tables; ++i) {
auto table = Table::deserialize(in);
db->tables[table->getName()] = std::move(table);

// Skip the blank line between tables
std::getline(in, line);
}

return db;
}

private:
std::string name;
std::string data_dir;
std::unordered_map<std::string, std::unique_ptr<Table>> tables;
};
28 changes: 26 additions & 2 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,44 @@
#include "parser.hpp"
#include "statement.hpp"
#include "utils.hpp"
#include <filesystem>
#include <fstream>
#include <iostream>

FileWriter file_writer;
OutputWriter file_writer;
namespace fs = std::filesystem;

void loadDatabases(
std::unordered_map<std::string, std::unique_ptr<Database>> &databases,
const std::string &data_dir) {
if (!fs::exists(data_dir)) {
fs::create_directory(data_dir);
return;
}

for (const auto &entry : fs::directory_iterator(data_dir)) {
if (entry.path().extension() == ".db") {
auto db = Database::deserialize(entry.path().string());
databases[db->getName()] = std::move(db);
}
}
}

int main(int argc, char *argv[]) {
std::unordered_map<std::string, std::unique_ptr<Database>> databases;
Database *current_database = nullptr;
const std::string data_dir = "data";

try {
if (argc != 3) {
throw ArgumentError("Argument number error");
} else if (std::string(argv[1]).find(".sql") == std::string::npos) {
throw ArgumentError("Input file must be a SQL file");
}

// Load existing databases
loadDatabases(databases, data_dir);

// Read input SQL file and split into statements
std::ifstream input_file(argv[1]);
if (!input_file.is_open()) {
Expand All @@ -38,7 +61,8 @@ int main(int argc, char *argv[]) {
parsed_statement->line_number);
}
databases[parsed_statement->getDatabaseName()] =
std::make_unique<Database>(parsed_statement->getDatabaseName());
std::make_unique<Database>(parsed_statement->getDatabaseName(),
data_dir);
} else if (parsed_statement->type == SQLStatementType::USE_DATABASE) {
if (databases.find(parsed_statement->getDatabaseName()) ==
databases.end()) {
Expand Down
161 changes: 160 additions & 1 deletion src/table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,29 @@
#include "statement.hpp"
#include "utils.hpp"
#include <algorithm>
#include <fstream>
#include <list>
#include <memory>
#include <string>
#include <vector>
#include <sstream>

class Table {
public:
Table(const std::string &name, const std::vector<ColumnDefinition> &columns)
: name(name), columns(columns) {
// Build column index
rebuildColumnIndex();
}

void rebuildColumnIndex() {
column_index.clear();
for (size_t i = 0; i < columns.size(); i++) {
column_index[columns[i].name] = i;
}
}

std::string getName() const { return name; }

void insert(const std::vector<Value> &row) {
// Validate number of values matches number of columns
if (row.size() != columns.size()) {
Expand Down Expand Up @@ -299,6 +308,156 @@ class Table {
file_writer.write(results);
}

// Helper function to quote string
static std::string quoteString(const std::string& str) {
return "\"" + str + "\"";
}

// Helper function to read a quoted string
static std::string readQuotedString(std::istream& in) {
char c;
std::string result;

// Skip leading whitespace
while (in.get(c) && std::isspace(c)) {}
in.unget();

if (in.peek() != '"') {
throw TableError("Expected quoted string");
}

in.get(); // Skip opening quote
while (in.get(c) && c != '"') {
result += c;
}

if (c != '"') {
throw TableError("Unterminated quoted string");
}

return result;
}

void serialize(std::ofstream &out) const {
// Write table name
out << "TABLE " << quoteString(name) << "\n";

// Write columns
out << "COLUMNS " << columns.size() << "\n";
for (const auto &col : columns) {
out << quoteString(col.name) << " " << TOKEN_STR.find(col.type)->second << "\n";
}

// Write rows
out << "ROWS " << rows.size() << "\n";
for (const auto &row : rows) {
for (size_t i = 0; i < row.size(); ++i) {
if (i > 0) out << " ";
const auto &value = row[i];
if (std::holds_alternative<int>(value)) {
out << "INT " << std::get<int>(value);
} else if (std::holds_alternative<double>(value)) {
out << "FLOAT " << std::get<double>(value);
} else if (std::holds_alternative<std::string>(value)) {
out << "TEXT " << quoteString(std::get<std::string>(value));
}
}
out << "\n";
}
}

static std::unique_ptr<Table> deserialize(std::ifstream &in) {
std::string line, word;

// Read table name
std::getline(in, line);
std::istringstream iss(line);
iss >> word; // Skip "TABLE"
std::string table_name = readQuotedString(iss);

// Read columns
std::getline(in, line);
iss.clear();
iss.str(line);
iss >> word; // Skip "COLUMNS"
size_t num_columns;
iss >> num_columns;

std::vector<ColumnDefinition> columns;
columns.reserve(num_columns);

for (size_t i = 0; i < num_columns; ++i) {
std::getline(in, line);
iss.clear();
iss.str(line);

ColumnDefinition col;
col.name = readQuotedString(iss);
std::string type_str;
iss >> type_str;

// Convert type string to TokenType
bool found = false;
for (const auto &[token_type, str] : TOKEN_STR) {
if (str == type_str) {
col.type = token_type;
found = true;
break;
}
}

if (!found) {
throw TableError("Invalid column type: " + type_str);
}

columns.push_back(col);
}

// Create table
auto table = std::make_unique<Table>(table_name, columns);

// Read rows
std::getline(in, line);
iss.clear();
iss.str(line);
iss >> word; // Skip "ROWS"
size_t num_rows;
iss >> num_rows;

for (size_t i = 0; i < num_rows; ++i) {
std::getline(in, line);
iss.clear();
iss.str(line);

std::vector<Value> row;
row.reserve(num_columns);

for (size_t j = 0; j < num_columns; ++j) {
std::string type;
iss >> type;

if (type == "INT") {
int value;
iss >> value;
row.push_back(value);
} else if (type == "FLOAT") {
double value;
iss >> value;
row.push_back(value);
} else if (type == "TEXT") {
std::string value = readQuotedString(iss);
row.push_back(value);
} else {
throw TableError("Invalid value type: " + type);
}
}

table->rows.push_back(std::move(row));
}

return table;
}

private:
std::string name;
std::vector<ColumnDefinition> columns;
Expand Down
Loading

0 comments on commit e118564

Please sign in to comment.