from flask import Blueprint, render_template, request, current_app from jinja2_fragments import render_block from flask_htmx import HTMX from extensions import db from utils import generate_plot sql_explorer_bp = Blueprint('sql_explorer', __name__, url_prefix='/sql') htmx = HTMX() # --- Database Helper Functions (Moved from features/sql_explorer.py) --- def _get_schema_info(schema='public'): """Fetches schema information directly.""" tables_result = db.execute(""" SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_type = 'BASE TABLE'; """, [schema]) tables = [row['table_name'] for row in tables_result] schema_info = {} for table in tables: columns_result = db.execute(""" SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = %s AND table_name = %s ORDER BY ordinal_position; """, [schema, table]) columns = [(row['column_name'], row['data_type']) for row in columns_result] primary_keys_result = db.execute(""" SELECT kcu.column_name FROM information_schema.table_constraints tc JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema WHERE tc.constraint_type = 'PRIMARY KEY' AND tc.table_schema = %s AND tc.table_name = %s; """, [schema, table]) primary_keys = [row['column_name'] for row in primary_keys_result] foreign_keys_result = db.execute(""" SELECT kcu.column_name AS fk_column, ccu.table_name AS referenced_table, ccu.column_name AS referenced_column FROM information_schema.table_constraints AS tc JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema JOIN information_schema.constraint_column_usage AS ccu ON ccu.constraint_name = tc.constraint_name AND ccu.table_schema = tc.table_schema WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_schema = %s AND tc.table_name = %s; """, [schema, table]) foreign_keys = [(row['fk_column'], row['referenced_table'], row['referenced_column']) for row in foreign_keys_result] schema_info[table] = { 'columns': columns, 'primary_keys': primary_keys, 'foreign_keys': foreign_keys } return schema_info def _map_data_type_for_sql(postgres_type): """Maps PostgreSQL types to standard SQL types (simplified).""" return { 'character varying': 'VARCHAR', 'varchar': 'VARCHAR', 'text': 'TEXT', 'integer': 'INTEGER', 'bigint': 'BIGINT', 'boolean': 'BOOLEAN', 'timestamp without time zone': 'TIMESTAMP', 'timestamp with time zone': 'TIMESTAMPTZ', 'numeric': 'NUMERIC', 'real': 'REAL', 'date': 'DATE' # Add more as needed }.get(postgres_type, postgres_type.upper()) def _map_data_type(postgres_type): """Maps PostgreSQL types to Mermaid ER diagram types.""" type_mapping = { 'integer': 'int', 'bigint': 'int', 'smallint': 'int', 'character varying': 'string', 'varchar': 'string', 'text': 'string', 'date': 'date', 'timestamp without time zone': 'datetime', 'timestamp with time zone': 'datetime', 'boolean': 'bool', 'numeric': 'float', 'real': 'float' } return type_mapping.get(postgres_type, 'string') def _generate_mermaid_er(schema_info): """Generates Mermaid ER diagram code from schema info.""" mermaid_lines = ["erDiagram"] for table, info in schema_info.items(): mermaid_lines.append(f" {table} {{") for column_name, data_type in info['columns']: mermaid_data_type = _map_data_type(data_type) pk_marker = " PK" if column_name in info.get('primary_keys', []) else "" mermaid_lines.append(f" {mermaid_data_type} {column_name}{pk_marker}") mermaid_lines.append(" }") for table, info in schema_info.items(): for fk_column, referenced_table, referenced_column in info['foreign_keys']: relation = f" {table} }}|--|| {referenced_table} : \"{fk_column} to {referenced_column}\"" mermaid_lines.append(relation) return "\n".join(mermaid_lines) def _generate_create_script(schema_info): """Generates SQL CREATE TABLE scripts from schema info.""" lines = [] for table, info in schema_info.items(): columns = info['columns'] pks = info.get('primary_keys', []) fks = info['foreign_keys'] column_defs = [] for column_name, data_type in columns: sql_type = _map_data_type_for_sql(data_type) column_defs.append(f' "{column_name}" {sql_type}') if pks: pk_columns = ", ".join(f'"{pk}"' for pk in pks) column_defs.append(f' PRIMARY KEY ({pk_columns})') # Format column definitions with newlines before using in f-string columns_sql = ",\n".join(column_defs) create_stmt = f'CREATE TABLE "{table}" (\n{columns_sql}\n);' lines.append(create_stmt) for fk_column, ref_table, ref_col in fks: alter_stmt = ( f'ALTER TABLE "{table}" ADD CONSTRAINT "fk_{table}_{fk_column}" ' f'FOREIGN KEY ("{fk_column}") REFERENCES "{ref_table}" ("{ref_col}");' ) lines.append(alter_stmt) lines.append("") return "\n".join(lines) def _execute_sql(query): """Executes arbitrary SQL query, returning results, columns, and error.""" results, columns, error = None, [], None try: results = db.execute(query) # Use the imported db object directly if results: columns = list(results[0].keys()) if isinstance(results, list) and results else [] except Exception as e: error = str(e) db.getDB().rollback() # Rollback on error return (results, columns, error) def _save_query(title, query): """Saves a query to the database.""" error = None if not title: return "Must provide title" try: db.execute("INSERT INTO saved_query (title, query) VALUES (%s, %s)", [title, query], commit=True) except Exception as e: error = str(e) db.getDB().rollback() # Rollback on error return error def _list_saved_queries(): """Lists all saved queries.""" return db.execute("SELECT id, title, query FROM saved_query ORDER BY title") def _get_saved_query(query_id): """Fetches a specific saved query.""" result = db.execute("SELECT title, query FROM saved_query WHERE id=%s", [query_id], one=True) return (result['title'], result['query']) if result else (None, None) def _delete_saved_query(query_id): """Deletes a saved query.""" db.execute("DELETE FROM saved_query WHERE id=%s", [query_id], commit=True) # --- Routes --- @sql_explorer_bp.route("/explorer", methods=['GET']) def sql_explorer(): saved_queries = _list_saved_queries() # Use local helper if htmx: return render_block(current_app.jinja_env, 'sql_explorer.html', 'content', saved_queries=saved_queries) return render_template('sql_explorer.html', saved_queries=saved_queries) @sql_explorer_bp.route("/query", methods=['POST']) def sql_query(): query = request.form.get('query') title = request.form.get('title') error = _save_query(title, query) # Use local helper saved_queries = _list_saved_queries() # Use local helper return render_template('partials/sql_explorer/sql_query.html', title=title, query=query, error=error, saved_queries=saved_queries) @sql_explorer_bp.route("/query/execute", methods=['POST']) def execute_sql_query(): query = request.form.get('query') (results, columns, error) = _execute_sql(query) # Use local helper return render_template('partials/sql_explorer/results.html', results=results, columns=columns, error=error) @sql_explorer_bp.route('/load_query/', methods=['GET']) def load_sql_query(query_id): (title, query) = _get_saved_query(query_id) # Use local helper saved_queries = _list_saved_queries() # Use local helper return render_template('partials/sql_explorer/sql_query.html', title=title, query=query, saved_queries=saved_queries) @sql_explorer_bp.route('/delete_query/', methods=['DELETE']) def delete_sql_query(query_id): _delete_saved_query(query_id) # Use local helper saved_queries = _list_saved_queries() # Use local helper return render_template('partials/sql_explorer/sql_query.html', title="", query="", saved_queries=saved_queries) @sql_explorer_bp.route("/schema", methods=['GET']) def sql_schema(): schema_info = _get_schema_info() # Use local helper mermaid_code = _generate_mermaid_er(schema_info) # Use local helper create_sql = _generate_create_script(schema_info) # Use local helper return render_template('partials/sql_explorer/schema.html', mermaid_code=mermaid_code, create_sql=create_sql) @sql_explorer_bp.route("/plot/", methods=['GET']) def plot_query(query_id): (title, query) = _get_saved_query(query_id) # Use local helper if not query: return "Query not found", 404 results_df = db.read_sql_as_df(query) # Keep using db.py for pandas interaction plot_div = generate_plot(results_df, title) return plot_div @sql_explorer_bp.route("/plot/show", methods=['POST']) def plot_unsaved_query(): query = request.form.get('query') title = request.form.get('title') results_df = db.read_sql_as_df(query) # Keep using db.py for pandas interaction plot_div = generate_plot(results_df, title) return plot_div