diff --git a/zt_backend/runner/code_cell_parser.py b/zt_backend/runner/code_cell_parser.py index 4754d0db..fdd56494 100644 --- a/zt_backend/runner/code_cell_parser.py +++ b/zt_backend/runner/code_cell_parser.py @@ -7,6 +7,7 @@ import logging import traceback from zt_backend.config import settings +from datetime import datetime, date logger = logging.getLogger("__name__") @@ -80,52 +81,153 @@ def get_loaded_names(module, defined_names) -> List[str]: ] + aug_names -def generate_sql_code(cell, uuid_value, db_file="zt_db.db"): - """Generate SQL code for the given cell.""" - - # Common import statements - base_code = "import duckdb\nimport zero_true as zt" - - conn_uuid = str(uuid.uuid4())[:8] - - # Initialize the DuckDB database connection to persist tables - db_init = f"conn_{conn_uuid} = duckdb.connect('{db_file}')" - - # Extract all placeholders from the SQL string - placeholders = re.findall(r"\{(.*?)\}", cell.code) - - # Replace placeholders with "?" for SQL parameterization - parametrized_sql = re.sub(r"\{(.*?)\}", "?", cell.code) - - # SQL code execution - sql_execution = f'conn_{conn_uuid}.execute("""{parametrized_sql}""", [{", ".join(placeholders)}]).df()' +def format_value_for_sql(value: Any) -> str: + """ + Formats a Python value into its SQL-compatible string representation. + """ + if value is None: + return 'NULL' + elif isinstance(value, bool): + return str(value).lower() + elif isinstance(value, (int, float)): + return str(value) + elif isinstance(value, (list, tuple, set)): + if not value: + return '(NULL)' + elements = [format_value_for_sql(v) for v in value] + return f"({', '.join(elements)})" + elif isinstance(value, (date, datetime)): + return f"'{value.isoformat()}'" + elif isinstance(value, str): + return f"'{value}'" + else: + return f"'{str(value)}'" - close_db = f"conn_{conn_uuid}.close()" +def extract_value_from_node(node: astroid.NodeNG) -> Any: + """ + Extracts value from an AST node. + """ + if isinstance(node, astroid.Const): + return node.value + elif isinstance(node, (astroid.List, astroid.Set)): + return [extract_value_from_node(elt) for elt in node.elts] + elif isinstance(node, astroid.Tuple): + return tuple(extract_value_from_node(elt) for elt in node.elts) + elif isinstance(node, astroid.Dict): + if node.items: + return { + extract_value_from_node(key): extract_value_from_node(value) + for key, value in node.items + } + return {} + elif isinstance(node, astroid.Call): + if isinstance(node.func, astroid.Name): + if node.func.name == 'date' and len(node.args) == 3: + try: + year, month, day = [extract_value_from_node(arg) for arg in node.args] + return date(year, month, day) + except: + logger.warning(f"Failed to create date from args: {node.args}") + elif node.func.name == 'datetime' and len(node.args) >= 3: + try: + args = [extract_value_from_node(arg) for arg in node.args] + return datetime(*args) + except: + logger.warning(f"Failed to create datetime from args: {node.args}") + return None + +def extract_variables(code: str) -> Dict[str, Any]: + """ + Extracts variable assignments from Python code. + """ + try: + module = astroid.parse(code) + variables = {} + + for node in module.body: + if isinstance(node, astroid.Assign): + value = extract_value_from_node(node.value) + if value is not None: + for target in node.targets: + if isinstance(target, astroid.AssignName): + variables[target.name] = value + + return variables + except Exception as e: + logger.error(f"Error extracting variables: {str(e)}") + return {} +def resolve_sql_variables(sql_code: str, cells: List['Cell'], current_cell_id: str) -> Tuple[str, List[str]]: + """ + Resolves variables in SQL and returns modified SQL with resolved values. + """ + variables = {} + for cell in cells: + if cell.id == current_cell_id: + break + if cell.cellType == "code": + cell_vars = extract_variables(cell.code) + variables.update(cell_vars) + + logger.debug(f"Found variables: {variables}") + resolved_sql = sql_code + + # Handle variable substitutions + for var_name, value in variables.items(): + placeholder = f"{{{var_name}}}" + if placeholder in resolved_sql: + resolved_sql = resolved_sql.replace(placeholder, format_value_for_sql(value)) + + logger.debug(f"Resolved SQL: {resolved_sql}") + + # Extract table names + table_pattern = r"(?i)(?:FROM|JOIN|INTO|UPDATE|TABLE)\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)?)" + table_names = re.findall(table_pattern, resolved_sql) + unique_tables = list(dict.fromkeys(table_names)) + + return resolved_sql, unique_tables + +def generate_sql_code(cell, uuid_value: str, db_file: str = "zt_db.db") -> str: + """ + Generates executable SQL code with proper connection handling and DataFrame conversion. + SQL code is assumed to have all variables already resolved. + """ + conn_id = str(uuid.uuid4())[:8] + code_parts = [ + "import duckdb\nimport zero_true as zt", + f"conn_{conn_id} = duckdb.connect('{db_file}')", + "query_result = None", # Initialize query_result + ] + + # Log the SQL being executed for debugging + code_parts.extend([ + "try:", + f' query_result = conn_{conn_id}.execute("""{cell.code}""").df()', + "except duckdb.CatalogException as e:", + " if 'does not exist' in str(e) and ('CREATE' in cell.code.upper() or 'INSERT' in cell.code.upper()):", + " # For CREATE/INSERT queries, we don't need to fetch results", + f' conn_{conn_id}.execute("""{cell.code}""")', + " else:", + " raise", + "except Exception as e:", + " print(f'Error executing SQL: {str(e)}')", # Add error logging + " raise" + ]) + + # Only handle the result if we got one + code_parts.append("if query_result is not None:") if cell.variable_name: - # If variable_name is provided, use it for assignment - variable_assignment = f"{cell.variable_name} = {sql_execution}" - - # Convert the result to a custom DataFrame - data_frame_conversion = ( - f"zt.DataFrame.from_dataframe(id='{uuid_value}', df={cell.variable_name})" - ) - - full_code = f"{base_code}\n{db_init}\n{variable_assignment}\n{data_frame_conversion}\n{close_db}" - + code_parts.extend([ + f" {cell.variable_name} = query_result", + f" zt.DataFrame.from_dataframe(id='{uuid_value}', df={cell.variable_name})" + ]) else: - # If variable_name is not provided, directly use the SQL execution - if settings.run_mode == "app" and not cell.showTable: - data_frame_conversion = sql_execution - else: - data_frame_conversion = ( - f"zt.DataFrame.from_dataframe(id='{uuid_value}', df={sql_execution})" - ) - - full_code = f"{base_code}\n{db_init}\n{data_frame_conversion}\n{close_db}" - - return full_code - + conversion = (" query_result" if settings.run_mode == "app" and not cell.showTable + else f" zt.DataFrame.from_dataframe(id='{uuid_value}', df=query_result)") + code_parts.append(conversion) + + code_parts.append(f"conn_{conn_id}.close()") + return "\n".join(code_parts) def parse_cells(request: Request) -> CodeDict: cell_dict = {} @@ -135,7 +237,10 @@ def parse_cells(request: Request) -> CodeDict: table_names = [] if cell.cellType == "sql" and cell.code: try: - table_names = duckdb.get_table_names(re.sub(r"\{.*?\}", "1", cell.code)) + # Get SQL with resolved table names and list of tables + resolved_sql, table_names = resolve_sql_variables(cell.code, request.cells, cell.id) + # Update cell code with resolved table names + cell.code = resolved_sql except Exception as e: logger.error("Error getting table names: %s", traceback.format_exc()) uuid_value = str(uuid.uuid4()) diff --git a/zt_frontend/src/components/ComponentWrapper.vue b/zt_frontend/src/components/ComponentWrapper.vue index 45036997..9e01b752 100644 --- a/zt_frontend/src/components/ComponentWrapper.vue +++ b/zt_frontend/src/components/ComponentWrapper.vue @@ -17,24 +17,8 @@ v-bind="componentBind(component)" :error="errors[component.id]?.hasError || false" :error-messages="errors[component.id]?.message || ''" - @update:model-value=" - async (newValue: any) => { - if (!newValue) return; - const files = Array.isArray(newValue) ? newValue : [newValue]; - const totalSize = files.reduce((acc, file) => acc + file.size, 0); - const maxSize = 50 * 1024 * 1024; // 50 MB in bytes - - if (totalSize > maxSize) { - setError(component.id, 'Total file size must not exceed 50 MB'); - return; - } - // Clear any existing error - clearError(component.id); - - component.value = await createFormData(files); - runCode(true, component.id, component.value); - } - " + @update:model-value="(newValue: any)=> handleFileInput(component, newValue)" + @click:clear="() => handleFileClear(component)" /> , + }; }, methods: { - componentBind(component: any) { + + async handleFileInput(component: ZTComponent, newValue: File | File[] | null): Promise { + // If no files, exit early + if (!newValue) { + return; + } + + const files = Array.isArray(newValue) ? newValue : [newValue]; + + // Validate file size (50MB limit) + const totalSize = files.reduce((sum, file) => sum + file.size, 0); + if (totalSize > 50 * 1024 * 1024) { + this.setError(component.id, 'Total file size must not exceed 50 MB'); + return; + } + + this.clearError(component.id); + + try { + component.value = await this.processFiles(files); + this.$emit('runCode', true, component.id, component.value); + } catch (error) { + console.error('Error processing files:', error); + this.setError(component.id, 'Error processing files'); + } + }, + + handleFileClear(component: ZTComponent): void { + component.value = {}; + this.$emit('runCode', true, component.id, component.value); + }, + + async processFiles(files: File[]): Promise> { + const fileList: Record = {}; + + for (const file of files) { + if (file) { + const base64Content = await this.fileToBase64(file); + fileList[file.name] = base64Content; + } + } + + return fileList; + }, + + async fileToBase64(file: File): Promise { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + + reader.onload = () => { + const result = reader.result as string; + const base64 = result.split(',')[1]; + // Add padding if needed + resolve(base64.padEnd( + base64.length + ((4 - (base64.length % 4)) % 4), + '=' + )); + }; + + reader.onerror = () => reject(new Error('Failed to read file')); + reader.readAsDataURL(file); + }); + }, + + setError(componentId: string, message: string): void { + this.errors[componentId] = { + hasError: true, + message + }; + }, + + clearError(componentId: string): void { + if (this.errors[componentId]) { + this.errors[componentId] = { + hasError: false, + message: '' + }; + } + }, + + componentBind(component: ZTComponent) { if ( component.component && (component.component === "v-autocomplete" || @@ -128,12 +193,11 @@ export default { return this.convertUnderscoresToHyphens(component); }, - convertUnderscoresToHyphens(obj: any) { - return Object.entries(obj).reduce((newObj: any, [key, value]) => { - const modifiedKey = key.replace(/_/g, "-"); - newObj[modifiedKey] = value; - return newObj; - }, {}); + convertUnderscoresToHyphens(obj: Record): Record { + return Object.entries(obj).reduce((acc, [key, value]) => ({ + ...acc, + [key.replace(/_/g, '-')]: value + }), {}); }, getEventBindings(component: any) { @@ -180,38 +244,6 @@ export default { } this.$emit("runCode", fromComponent, componentId, componentValue); }, - - setError(componentId: string, message: string) { - this.errors[componentId] = { - hasError: true, - message: message, - }; - }, - - clearError(componentId: string) { - if (this.errors[componentId]) { - this.errors[componentId] = { - hasError: false, - message: "", - }; - } - }, - - async fileToBase64(file: File) { - const reader = new FileReader(); - reader.readAsDataURL(file); - return new Promise((resolve) => { - reader.onload = () => { - let base64String = (reader.result as string).split(",")[1]; - base64String = base64String.padEnd( - base64String.length + ((4 - (base64String.length % 4)) % 4), - "=" - ); - resolve(base64String); - }; - }); - }, - async createFormData(files: Array) { const fileList: { [key: string]: any } = {}; for (const file of files) {