Skip to content

Commit

Permalink
Authorization, Thread management
Browse files Browse the repository at this point in the history
  • Loading branch information
lmangani authored Jan 28, 2025
1 parent 2339d09 commit df5a506
Showing 1 changed file with 119 additions and 59 deletions.
178 changes: 119 additions & 59 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,31 @@
import pyarrow as pa
import pyarrow.flight as flight

import signal
import threading
import logging
import sys

# Configure logging
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(threadName)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('server.log')
]
)
logger = logging.getLogger('server')
# Global flag for server state
running = True

def signal_handler(signum, frame):
"""Handle shutdown signals"""
global running
logger.info("Shutdown signal received")
running = False


# Force Self-Service for UI
os.environ['VITE_SELFSERVICE'] = 'true'
# Default path for temp databases
Expand Down Expand Up @@ -257,96 +282,131 @@ def handle_404(e):

host = os.getenv('HOST', '0.0.0.0')
port = int(os.getenv('PORT', 8123))
flight_host = os.getenv('FLIGHT_HOST', '0.0.0.0')
flight_host = os.getenv('FLIGHT_HOST', 'localhost')
flight_port = int(os.getenv('FLIGHT_PORT', 8815))
path = os.getenv('DATA', '.duckdb_data')

if __name__ == '__main__':
# Set up signal handlers
signal.signal(signal.SIGINT, signal_handler)

def run_flask():
app.run(host=host, port=port)
"""Run Flask server"""
logger.info("Starting Flask server")
try:
app.run(host=host, port=port, use_reloader=False)
except Exception as e:
logger.exception("Flask server error")
finally:
logger.info("Flask server stopped")

def run_flight_server():
"""Run Flight server"""
class HeaderMiddleware(flight.ServerMiddleware):
def __init__(self):
self.authorization = None

def call_completed(self, exception=None):
pass

class HeaderMiddlewareFactory(flight.ServerMiddlewareFactory):
def start_call(self, info, headers):
logger.debug(f"Headers received: {headers}")
if "authorization" in headers:
# Get first value from list
auth = headers["authorization"][0]
logger.info(f"Authorization header found: {auth}")
middleware = HeaderMiddleware()
middleware.authorization = auth
return middleware
return HeaderMiddleware()

class DuckDBFlightServer(flight.FlightServerBase):
def __init__(self, location=f"grpc://{flight_host}:{flight_port}", db_path=":memory:"):
middleware = {"auth": HeaderMiddlewareFactory()}
super().__init__(location=location, middleware=middleware)
self._location = location
super().__init__(location)
self.conn = duckdb.connect(db_path) # Initialize connection
logger.info(f"Initializing Flight server at {location}")
self.conn = duckdb.connect(db_path)
self.conn.install_extension("chsql", repository="community")
self.conn.install_extension("chsql_native", repository="community")
self.conn.load_extension("chsql")
self.conn.load_extension("chsql_native")

def do_get(self, context, ticket):
"""Handle 'GET' requests from clients to retrieve data."""
"""Handle 'GET' requests"""
logger.debug("do_get called")

# Access middleware
try:
middleware = context.get_middleware("auth")
if middleware and middleware.authorization:
auth_header = middleware.authorization
logger.info(f"Using authorization from middleware: {auth_header}")
if isinstance(auth_header, str): # Make sure we have a string
username, password = auth_header.split(':', 1)
user_pass_hash = hashlib.sha256((username + password).encode()).hexdigest()
db_file = os.path.join(dbpath, f"{user_pass_hash}.db")
logger.info(f'Using database file: {db_file}')
self.conn = duckdb.connect(db_file)
self.conn.load_extension("chsql")
self.conn.load_extension("chsql_native")
except Exception as e:
logger.debug(f"Middleware access error: {e}")

query = ticket.ticket.decode("utf-8")
result_table = self.conn.execute(query).fetch_arrow_table()
# Convert to record batches with alignment
batches = result_table.to_batches(max_chunksize=1024) # Use power of 2 for alignment
if not batches:
schema = result_table.schema
return flight.RecordBatchStream(pa.Table.from_batches([], schema))
return flight.RecordBatchStream(pa.Table.from_batches(batches))
logger.info(f"Executing query: {query}")
try:
result_table = self.conn.execute(query).fetch_arrow_table()
batches = result_table.to_batches(max_chunksize=1024)
if not batches:
logger.debug("No data in result")
schema = result_table.schema
return flight.RecordBatchStream(pa.Table.from_batches([], schema))
logger.debug(f"Returning {len(batches)} batches")
return flight.RecordBatchStream(pa.Table.from_batches(batches))
except Exception as e:
logger.exception(f"Query execution error: {str(e)}")
raise

def do_put(self, context, descriptor, reader, writer):
"""Handle 'PUT' requests to upload data to the DuckDB instance."""
"""Handle 'PUT' requests"""
table = reader.read_all()
table_name = descriptor.path[0].decode('utf-8')
self.conn.register("temp_table", table)
self.conn.execute(f"INSERT INTO {table_name} SELECT * FROM temp_table")

def get_flight_info(self, context, descriptor):
"""Implement 'get_flight_info' to provide information about the flight."""
"""Implement 'get_flight_info'"""
if descriptor.command is not None:
query = descriptor.command.decode("utf-8")
result_table = self.conn.execute(query).fetch_arrow_table()
schema = result_table.schema
endpoints = [flight.FlightEndpoint(ticket=flight.Ticket(query.encode("utf-8")), locations=[self._location])]
endpoints = [flight.FlightEndpoint(
ticket=flight.Ticket(query.encode("utf-8")),
locations=[self._location]
)]
return flight.FlightInfo(schema, descriptor, endpoints, -1, -1)
else:
raise flight.FlightUnavailableError("No command provided in the descriptor.")

def do_action(self, context, action):
"""Handle custom actions like executing SQL queries."""
if action.type == "query":
query = action.body.to_pybytes().decode("utf-8")
self.conn.execute(query)
return []
else:
raise NotImplementedError(f"Unknown action type: {action.type}")

def apply_auth(self, context):
"""Apply authentication based on the authorization header."""
metadata = context.method().call_metadata()
if metadata:
for key, value in metadata:
if key.lower() == 'authorization':
auth_value = value.decode('utf-8')
if ':' in auth_value:
username, password = auth_value.split(':', 1)
else:
username, password = auth_value, ''
if not (username and password):
print('stateless flight session')
return True
else:
print('stateful flight session')
os.makedirs(path, exist_ok=True)
user_pass_hash = hashlib.sha256((username + password).encode()).hexdigest()
db_file = os.path.join(dbpath, f"{user_pass_hash}.db")
print(f'stateful session {db_file}')
self.conn = duckdb.connect(db_file)
self.conn.load_extension("chsql")
self.conn.load_extension("chsql_native")
return True
return False
raise flight.FlightUnavailableError("No command provided in the descriptor")

server = DuckDBFlightServer()
print(f"Starting DuckDB Flight server on {flight_host}:{flight_port}")
logger.info(f"Starting DuckDB Flight server on {flight_host}:{flight_port}")
server.serve()

# Run both Flask and Flight Server in parallel
flask_thread = threading.Thread(target=run_flask)
flight_thread = threading.Thread(target=run_flight_server)

# Start Flask server in a daemon thread
flask_thread = threading.Thread(target=run_flask, daemon=True)
flask_thread.start()
flight_thread.start()

flask_thread.join()
flight_thread.join()
# Run Flight server in main thread
flight_thread = threading.Thread(target=run_flight_server, daemon=True)
flight_thread.start()

# Keep main thread alive until signal
try:
while running:
time.sleep(1)
except KeyboardInterrupt:
logger.info("KeyboardInterrupt received")
finally:
logger.info("Shutting down...")

0 comments on commit df5a506

Please sign in to comment.