Skip to content

Commit

Permalink
fix(fix-file-input-and-sql-cell): Updated the file input to work prop…
Browse files Browse the repository at this point in the history
…erly when cleared, cancelled and updated the sql cell to work properly with variables in query
  • Loading branch information
priyakanabar-crest committed Nov 14, 2024
1 parent ed20ff2 commit 669295d
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 100 deletions.
191 changes: 148 additions & 43 deletions zt_backend/runner/code_cell_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import traceback
from zt_backend.config import settings
from datetime import datetime, date


logger = logging.getLogger("__name__")
Expand Down Expand Up @@ -80,52 +81,153 @@ def get_loaded_names(module, defined_names) -> List[str]:
] + aug_names


def generate_sql_code(cell, uuid_value, db_file="zt_db.db"):
"""Generate SQL code for the given cell."""

# Common import statements
base_code = "import duckdb\nimport zero_true as zt"

conn_uuid = str(uuid.uuid4())[:8]

# Initialize the DuckDB database connection to persist tables
db_init = f"conn_{conn_uuid} = duckdb.connect('{db_file}')"

# Extract all placeholders from the SQL string
placeholders = re.findall(r"\{(.*?)\}", cell.code)

# Replace placeholders with "?" for SQL parameterization
parametrized_sql = re.sub(r"\{(.*?)\}", "?", cell.code)

# SQL code execution
sql_execution = f'conn_{conn_uuid}.execute("""{parametrized_sql}""", [{", ".join(placeholders)}]).df()'
def format_value_for_sql(value: Any) -> str:
"""
Formats a Python value into its SQL-compatible string representation.
"""
if value is None:
return 'NULL'
elif isinstance(value, bool):
return str(value).lower()
elif isinstance(value, (int, float)):
return str(value)
elif isinstance(value, (list, tuple, set)):
if not value:
return '(NULL)'
elements = [format_value_for_sql(v) for v in value]
return f"({', '.join(elements)})"
elif isinstance(value, (date, datetime)):
return f"'{value.isoformat()}'"
elif isinstance(value, str):
return f"'{value}'"
else:
return f"'{str(value)}'"

close_db = f"conn_{conn_uuid}.close()"
def extract_value_from_node(node: astroid.NodeNG) -> Any:
"""
Extracts value from an AST node.
"""
if isinstance(node, astroid.Const):
return node.value
elif isinstance(node, (astroid.List, astroid.Set)):
return [extract_value_from_node(elt) for elt in node.elts]
elif isinstance(node, astroid.Tuple):
return tuple(extract_value_from_node(elt) for elt in node.elts)
elif isinstance(node, astroid.Dict):
if node.items:
return {
extract_value_from_node(key): extract_value_from_node(value)
for key, value in node.items
}
return {}
elif isinstance(node, astroid.Call):
if isinstance(node.func, astroid.Name):
if node.func.name == 'date' and len(node.args) == 3:
try:
year, month, day = [extract_value_from_node(arg) for arg in node.args]
return date(year, month, day)
except:
logger.warning(f"Failed to create date from args: {node.args}")
elif node.func.name == 'datetime' and len(node.args) >= 3:
try:
args = [extract_value_from_node(arg) for arg in node.args]
return datetime(*args)
except:
logger.warning(f"Failed to create datetime from args: {node.args}")
return None

def extract_variables(code: str) -> Dict[str, Any]:
"""
Extracts variable assignments from Python code.
"""
try:
module = astroid.parse(code)
variables = {}

for node in module.body:
if isinstance(node, astroid.Assign):
value = extract_value_from_node(node.value)
if value is not None:
for target in node.targets:
if isinstance(target, astroid.AssignName):
variables[target.name] = value

return variables
except Exception as e:
logger.error(f"Error extracting variables: {str(e)}")
return {}

def resolve_sql_variables(sql_code: str, cells: List['Cell'], current_cell_id: str) -> Tuple[str, List[str]]:
"""
Resolves variables in SQL and returns modified SQL with resolved values.
"""
variables = {}
for cell in cells:
if cell.id == current_cell_id:
break
if cell.cellType == "code":
cell_vars = extract_variables(cell.code)
variables.update(cell_vars)

logger.debug(f"Found variables: {variables}")
resolved_sql = sql_code

# Handle variable substitutions
for var_name, value in variables.items():
placeholder = f"{{{var_name}}}"
if placeholder in resolved_sql:
resolved_sql = resolved_sql.replace(placeholder, format_value_for_sql(value))

logger.debug(f"Resolved SQL: {resolved_sql}")

# Extract table names
table_pattern = r"(?i)(?:FROM|JOIN|INTO|UPDATE|TABLE)\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)?)"
table_names = re.findall(table_pattern, resolved_sql)
unique_tables = list(dict.fromkeys(table_names))

return resolved_sql, unique_tables

def generate_sql_code(cell, uuid_value: str, db_file: str = "zt_db.db") -> str:
"""
Generates executable SQL code with proper connection handling and DataFrame conversion.
SQL code is assumed to have all variables already resolved.
"""
conn_id = str(uuid.uuid4())[:8]
code_parts = [
"import duckdb\nimport zero_true as zt",
f"conn_{conn_id} = duckdb.connect('{db_file}')",
"query_result = None", # Initialize query_result
]

# Log the SQL being executed for debugging
code_parts.extend([
"try:",
f' query_result = conn_{conn_id}.execute("""{cell.code}""").df()',
"except duckdb.CatalogException as e:",
" if 'does not exist' in str(e) and ('CREATE' in cell.code.upper() or 'INSERT' in cell.code.upper()):",
" # For CREATE/INSERT queries, we don't need to fetch results",
f' conn_{conn_id}.execute("""{cell.code}""")',
" else:",
" raise",
"except Exception as e:",
" print(f'Error executing SQL: {str(e)}')", # Add error logging
" raise"
])

# Only handle the result if we got one
code_parts.append("if query_result is not None:")
if cell.variable_name:
# If variable_name is provided, use it for assignment
variable_assignment = f"{cell.variable_name} = {sql_execution}"

# Convert the result to a custom DataFrame
data_frame_conversion = (
f"zt.DataFrame.from_dataframe(id='{uuid_value}', df={cell.variable_name})"
)

full_code = f"{base_code}\n{db_init}\n{variable_assignment}\n{data_frame_conversion}\n{close_db}"

code_parts.extend([
f" {cell.variable_name} = query_result",
f" zt.DataFrame.from_dataframe(id='{uuid_value}', df={cell.variable_name})"
])
else:
# If variable_name is not provided, directly use the SQL execution
if settings.run_mode == "app" and not cell.showTable:
data_frame_conversion = sql_execution
else:
data_frame_conversion = (
f"zt.DataFrame.from_dataframe(id='{uuid_value}', df={sql_execution})"
)

full_code = f"{base_code}\n{db_init}\n{data_frame_conversion}\n{close_db}"

return full_code

conversion = (" query_result" if settings.run_mode == "app" and not cell.showTable
else f" zt.DataFrame.from_dataframe(id='{uuid_value}', df=query_result)")
code_parts.append(conversion)

code_parts.append(f"conn_{conn_id}.close()")
return "\n".join(code_parts)

def parse_cells(request: Request) -> CodeDict:
cell_dict = {}
Expand All @@ -135,7 +237,10 @@ def parse_cells(request: Request) -> CodeDict:
table_names = []
if cell.cellType == "sql" and cell.code:
try:
table_names = duckdb.get_table_names(re.sub(r"\{.*?\}", "1", cell.code))
# Get SQL with resolved table names and list of tables
resolved_sql, table_names = resolve_sql_variables(cell.code, request.cells, cell.id)
# Update cell code with resolved table names
cell.code = resolved_sql
except Exception as e:
logger.error("Error getting table names: %s", traceback.format_exc())
uuid_value = str(uuid.uuid4())
Expand Down
146 changes: 89 additions & 57 deletions zt_frontend/src/components/ComponentWrapper.vue
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,8 @@
v-bind="componentBind(component)"
:error="errors[component.id]?.hasError || false"
:error-messages="errors[component.id]?.message || ''"
@update:model-value="
async (newValue: any) => {
if (!newValue) return;
const files = Array.isArray(newValue) ? newValue : [newValue];
const totalSize = files.reduce((acc, file) => acc + file.size, 0);
const maxSize = 50 * 1024 * 1024; // 50 MB in bytes
if (totalSize > maxSize) {
setError(component.id, 'Total file size must not exceed 50 MB');
return;
}
// Clear any existing error
clearError(component.id);
component.value = await createFormData(files);
runCode(true, component.id, component.value);
}
"
@update:model-value="(newValue: any)=> handleFileInput(component, newValue)"
@click:clear="() => handleFileClear(component)"
/>
<component
v-else
Expand Down Expand Up @@ -113,10 +97,91 @@ export default {
data() {
return {
errors: {} as Record<string, { hasError: boolean; message: string }>,
};
},
methods: {
componentBind(component: any) {
async handleFileInput(component: ZTComponent, newValue: File | File[] | null): Promise<void> {
// If no files, exit early
if (!newValue) {
return;
}
const files = Array.isArray(newValue) ? newValue : [newValue];
// Validate file size (50MB limit)
const totalSize = files.reduce((sum, file) => sum + file.size, 0);
if (totalSize > 50 * 1024 * 1024) {
this.setError(component.id, 'Total file size must not exceed 50 MB');
return;
}
this.clearError(component.id);
try {
component.value = await this.processFiles(files);
this.$emit('runCode', true, component.id, component.value);
} catch (error) {
console.error('Error processing files:', error);
this.setError(component.id, 'Error processing files');
}
},
handleFileClear(component: ZTComponent): void {
component.value = {};
this.$emit('runCode', true, component.id, component.value);
},
async processFiles(files: File[]): Promise<Record<string, string>> {
const fileList: Record<string, string> = {};
for (const file of files) {
if (file) {
const base64Content = await this.fileToBase64(file);
fileList[file.name] = base64Content;
}
}
return fileList;
},
async fileToBase64(file: File): Promise<string> {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = () => {
const result = reader.result as string;
const base64 = result.split(',')[1];
// Add padding if needed
resolve(base64.padEnd(
base64.length + ((4 - (base64.length % 4)) % 4),
'='
));
};
reader.onerror = () => reject(new Error('Failed to read file'));
reader.readAsDataURL(file);
});
},
setError(componentId: string, message: string): void {
this.errors[componentId] = {
hasError: true,
message
};
},
clearError(componentId: string): void {
if (this.errors[componentId]) {
this.errors[componentId] = {
hasError: false,
message: ''
};
}
},
componentBind(component: ZTComponent) {
if (
component.component &&
(component.component === "v-autocomplete" ||
Expand All @@ -128,12 +193,11 @@ export default {
return this.convertUnderscoresToHyphens(component);
},
convertUnderscoresToHyphens(obj: any) {
return Object.entries(obj).reduce((newObj: any, [key, value]) => {
const modifiedKey = key.replace(/_/g, "-");
newObj[modifiedKey] = value;
return newObj;
}, {});
convertUnderscoresToHyphens(obj: Record<string, any>): Record<string, any> {
return Object.entries(obj).reduce((acc, [key, value]) => ({
...acc,
[key.replace(/_/g, '-')]: value
}), {});
},
getEventBindings(component: any) {
Expand Down Expand Up @@ -180,38 +244,6 @@ export default {
}
this.$emit("runCode", fromComponent, componentId, componentValue);
},
setError(componentId: string, message: string) {
this.errors[componentId] = {
hasError: true,
message: message,
};
},
clearError(componentId: string) {
if (this.errors[componentId]) {
this.errors[componentId] = {
hasError: false,
message: "",
};
}
},
async fileToBase64(file: File) {
const reader = new FileReader();
reader.readAsDataURL(file);
return new Promise((resolve) => {
reader.onload = () => {
let base64String = (reader.result as string).split(",")[1];
base64String = base64String.padEnd(
base64String.length + ((4 - (base64String.length % 4)) % 4),
"="
);
resolve(base64String);
};
});
},
async createFormData(files: Array<File>) {
const fileList: { [key: string]: any } = {};
for (const file of files) {
Expand Down

0 comments on commit 669295d

Please sign in to comment.