Files
workout/routes/sql_explorer.py
Peter Stockings e947feb3e3 refactor(sql_explorer): Replace Plotly with SVG rendering for plots
Replaces the Plotly-based graph generation in the SQL Explorer with direct SVG rendering within an HTML template, similar to the exercise progress sparklines.

- Modifies `routes/sql_explorer.py` endpoints (`plot_query`, `plot_unsaved_query`) to fetch raw data instead of using pandas/Plotly.
- Adds `utils.prepare_svg_plot_data` to process raw SQL results, determine plot type (scatter, line, bar, table), normalize data, and prepare it for SVG.
- Creates `templates/partials/sql_explorer/svg_plot.html` to render the SVG plot with axes, ticks, labels, and basic tooltips.
- Removes the `generate_plot` function's usage for SQL Explorer and the direct dependency on Plotly for this feature.
2025-04-15 19:34:26 +10:00

339 lines
15 KiB
Python

import os
import requests # Import requests library
import json # Import json library
from flask import Blueprint, render_template, request, current_app, jsonify
from jinja2_fragments import render_block
from flask_htmx import HTMX
from extensions import db
from utils import prepare_svg_plot_data # Will be created for SVG data prep
sql_explorer_bp = Blueprint('sql_explorer', __name__, url_prefix='/sql')
htmx = HTMX()
def _get_schema_info(schema='public'):
"""Fetches schema information directly."""
tables_result = db.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s AND table_type = 'BASE TABLE';
""", [schema])
tables = [row['table_name'] for row in tables_result]
schema_info = {}
for table in tables:
columns_result = db.execute("""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position;
""", [schema, table])
columns = [(row['column_name'], row['data_type']) for row in columns_result]
primary_keys_result = db.execute("""
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY' AND tc.table_schema = %s AND tc.table_name = %s;
""", [schema, table])
primary_keys = [row['column_name'] for row in primary_keys_result]
foreign_keys_result = db.execute("""
SELECT kcu.column_name AS fk_column, ccu.table_name AS referenced_table, ccu.column_name AS referenced_column
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_schema = %s AND tc.table_name = %s;
""", [schema, table])
foreign_keys = [(row['fk_column'], row['referenced_table'], row['referenced_column']) for row in foreign_keys_result]
schema_info[table] = {
'columns': columns,
'primary_keys': primary_keys,
'foreign_keys': foreign_keys
}
return schema_info
def _map_data_type_for_sql(postgres_type):
"""Maps PostgreSQL types to standard SQL types (simplified)."""
return {
'character varying': 'VARCHAR', 'varchar': 'VARCHAR', 'text': 'TEXT',
'integer': 'INTEGER', 'bigint': 'BIGINT', 'boolean': 'BOOLEAN',
'timestamp without time zone': 'TIMESTAMP', 'timestamp with time zone': 'TIMESTAMPTZ',
'numeric': 'NUMERIC', 'real': 'REAL', 'date': 'DATE'
}.get(postgres_type, postgres_type.upper())
def _map_data_type(postgres_type):
"""Maps PostgreSQL types to Mermaid ER diagram types."""
type_mapping = {
'integer': 'int', 'bigint': 'int', 'smallint': 'int',
'character varying': 'string', 'varchar': 'string', 'text': 'string',
'date': 'date', 'timestamp without time zone': 'datetime',
'timestamp with time zone': 'datetime', 'boolean': 'bool',
'numeric': 'float', 'real': 'float'
}
return type_mapping.get(postgres_type, 'string')
def _generate_mermaid_er(schema_info):
"""Generates Mermaid ER diagram code from schema info."""
mermaid_lines = ["erDiagram"]
for table, info in schema_info.items():
mermaid_lines.append(f" {table} {{")
for column_name, data_type in info['columns']:
mermaid_data_type = _map_data_type(data_type)
pk_marker = " PK" if column_name in info.get('primary_keys', []) else ""
mermaid_lines.append(f" {mermaid_data_type} {column_name}{pk_marker}")
mermaid_lines.append(" }")
for table, info in schema_info.items():
for fk_column, referenced_table, referenced_column in info['foreign_keys']:
relation = f" {table} }}|--|| {referenced_table} : \"{fk_column} to {referenced_column}\""
mermaid_lines.append(relation)
return "\n".join(mermaid_lines)
def _generate_create_script(schema_info):
"""Generates SQL CREATE TABLE scripts from schema info."""
lines = []
for table, info in schema_info.items():
columns = info['columns']
pks = info.get('primary_keys', [])
fks = info['foreign_keys']
column_defs = []
for column_name, data_type in columns:
sql_type = _map_data_type_for_sql(data_type)
column_defs.append(f' "{column_name}" {sql_type}')
if pks:
pk_columns = ", ".join(f'"{pk}"' for pk in pks)
column_defs.append(f' PRIMARY KEY ({pk_columns})')
columns_sql = ",\n".join(column_defs)
create_stmt = f'CREATE TABLE "{table}" (\n{columns_sql}\n);'
lines.append(create_stmt)
for fk_column, ref_table, ref_col in fks:
alter_stmt = (
f'ALTER TABLE "{table}" ADD CONSTRAINT "fk_{table}_{fk_column}" '
f'FOREIGN KEY ("{fk_column}") REFERENCES "{ref_table}" ("{ref_col}");'
)
lines.append(alter_stmt)
lines.append("")
return "\n".join(lines)
def _execute_sql(query):
"""Executes arbitrary SQL query, returning results, columns, and error."""
results, columns, error = None, [], None
try:
results = db.execute(query)
if results:
columns = list(results[0].keys()) if isinstance(results, list) and results else []
except Exception as e:
error = str(e)
db.getDB().rollback()
return (results, columns, error)
def _save_query(title, query):
"""Saves a query to the database."""
error = None
if not title: return "Must provide title"
try:
db.execute("INSERT INTO saved_query (title, query) VALUES (%s, %s)", [title, query], commit=True)
except Exception as e:
error = str(e)
db.getDB().rollback()
return error
def _list_saved_queries():
"""Lists all saved queries."""
return db.execute("SELECT id, title, query FROM saved_query ORDER BY title")
def _get_saved_query(query_id):
"""Fetches a specific saved query."""
result = db.execute("SELECT title, query FROM saved_query WHERE id=%s", [query_id], one=True)
return (result['title'], result['query']) if result else (None, None)
def _delete_saved_query(query_id):
"""Deletes a saved query."""
db.execute("DELETE FROM saved_query WHERE id=%s", [query_id], commit=True)
def _generate_sql_from_natural_language(natural_query):
"""Generates SQL query from natural language using Gemini REST API."""
gemni_model = os.environ.get("GEMINI_MODEL","gemini-2.0-flash")
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
return None, "GEMINI_API_KEY environment variable not set."
# Using gemini-pro model endpoint
api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemni_model}:generateContent?key={api_key}"
headers = {'Content-Type': 'application/json'}
try:
# Get and format schema
schema_info = _get_schema_info()
schema_string = _generate_create_script(schema_info)
prompt = f"""Given the following database schema:
```sql
{schema_string}
```
Generate a PostgreSQL query that answers the following question: "{natural_query}"
Return ONLY the SQL query, without any explanation or surrounding text/markdown.
"""
# Construct the request payload
payload = json.dumps({
"contents": [{
"parts": [{"text": prompt}]
}]
})
# Make the POST request
response = requests.post(api_url, headers=headers, data=payload)
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
# Parse the response
response_data = response.json()
# Extract the generated text - structure might vary slightly based on API version/response
# Safely navigate the response structure
candidates = response_data.get('candidates', [])
if not candidates:
return None, "No candidates found in API response."
content = candidates[0].get('content', {})
parts = content.get('parts', [])
if not parts:
return None, "No parts found in API response content."
generated_sql = parts[0].get('text', '').strip()
# Basic parsing: remove potential markdown code fences
if generated_sql.startswith("```sql"):
generated_sql = generated_sql[6:]
if generated_sql.endswith("```"):
generated_sql = generated_sql[:-3]
# Remove leading SQL comment lines
sql_lines = generated_sql.strip().splitlines()
filtered_lines = [line for line in sql_lines if not line.strip().startswith('--')]
final_sql = "\n".join(filtered_lines).strip()
return final_sql, None
except requests.exceptions.RequestException as e:
current_app.logger.error(f"Gemini API request error: {e}")
return None, f"Error communicating with API: {e}"
except (KeyError, IndexError, Exception) as e:
current_app.logger.error(f"Error processing Gemini API response: {e} - Response: {response_data if 'response_data' in locals() else 'N/A'}")
return None, f"Error processing API response: {e}"
# --- Routes ---
@sql_explorer_bp.route("/explorer", methods=['GET'])
def sql_explorer():
saved_queries = _list_saved_queries()
if htmx:
return render_block(current_app.jinja_env, 'sql_explorer.html', 'content', saved_queries=saved_queries)
return render_template('sql_explorer.html', saved_queries=saved_queries)
@sql_explorer_bp.route("/query", methods=['POST'])
def sql_query():
query = request.form.get('query')
title = request.form.get('title')
error = _save_query(title, query)
saved_queries = _list_saved_queries()
return render_template('partials/sql_explorer/sql_query.html',
title=title, query=query, error=error, saved_queries=saved_queries)
@sql_explorer_bp.route("/query/execute", methods=['POST'])
def execute_sql_query():
query = request.form.get('query')
(results, columns, error) = _execute_sql(query)
return render_template('partials/sql_explorer/results.html',
results=results, columns=columns, error=error)
@sql_explorer_bp.route('/load_query/<int:query_id>', methods=['GET'])
def load_sql_query(query_id):
(title, query) = _get_saved_query(query_id)
saved_queries = _list_saved_queries()
return render_template('partials/sql_explorer/sql_query.html',
title=title, query=query, saved_queries=saved_queries)
@sql_explorer_bp.route('/delete_query/<int:query_id>', methods=['DELETE'])
def delete_sql_query(query_id):
_delete_saved_query(query_id)
saved_queries = _list_saved_queries()
return render_template('partials/sql_explorer/sql_query.html',
title="", query="", saved_queries=saved_queries)
@sql_explorer_bp.route("/schema", methods=['GET'])
def sql_schema():
schema_info = _get_schema_info()
mermaid_code = _generate_mermaid_er(schema_info)
create_sql = _generate_create_script(schema_info)
return render_template('partials/sql_explorer/schema.html', mermaid_code=mermaid_code, create_sql=create_sql)
@sql_explorer_bp.route("/plot/<int:query_id>", methods=['GET'])
def plot_query(query_id):
(title, query) = _get_saved_query(query_id)
if not query: return "Query not found", 404
# Fetch raw results instead of DataFrame
(results, columns, error) = _execute_sql(query)
if error:
# Return an HTML snippet indicating the error
return f'&lt;div class="p-4 text-red-700 bg-red-100 border border-red-400 rounded"&gt;Error executing query: {error}&lt;/div&gt;', 400
if not results:
# Return an HTML snippet indicating no data
return '&lt;div class="p-4 text-yellow-700 bg-yellow-100 border border-yellow-400 rounded"&gt;No data returned by query.&lt;/div&gt;'
try:
# Prepare data for SVG plotting (function to be created in utils.py)
plot_data = prepare_svg_plot_data(results, columns, title)
# Render the new SVG template
return render_template('partials/sql_explorer/svg_plot.html', **plot_data)
except Exception as e:
current_app.logger.error(f"Error preparing SVG plot data: {e}")
# Return an HTML snippet indicating a processing error
return f'&lt;div class="p-4 text-red-700 bg-red-100 border border-red-400 rounded"&gt;Error preparing plot data: {e}&lt;/div&gt;', 500
@sql_explorer_bp.route("/plot/show", methods=['POST'])
def plot_unsaved_query():
query = request.form.get('query')
title = request.form.get('title', 'SQL Query Plot') # Add default title
# Fetch raw results instead of DataFrame
(results, columns, error) = _execute_sql(query)
if error:
# Return an HTML snippet indicating the error
return f'&lt;div class="p-4 text-red-700 bg-red-100 border border-red-400 rounded"&gt;Error executing query: {error}&lt;/div&gt;', 400
if not results:
# Return an HTML snippet indicating no data
return '&lt;div class="p-4 text-yellow-700 bg-yellow-100 border border-yellow-400 rounded"&gt;No data returned by query.&lt;/div&gt;'
try:
# Prepare data for SVG plotting (function to be created in utils.py)
plot_data = prepare_svg_plot_data(results, columns, title)
# Render the new SVG template
return render_template('partials/sql_explorer/svg_plot.html', **plot_data)
except Exception as e:
current_app.logger.error(f"Error preparing SVG plot data: {e}")
# Return an HTML snippet indicating a processing error
return f'&lt;div class="p-4 text-red-700 bg-red-100 border border-red-400 rounded"&gt;Error preparing plot data: {e}&lt;/div&gt;', 500
@sql_explorer_bp.route("/generate_sql", methods=['POST'])
def generate_sql():
"""Generates SQL from natural language via Gemini REST API."""
natural_query = request.form.get('natural_query')
if not natural_query:
return ""
generated_sql, error = _generate_sql_from_natural_language(natural_query)
if error:
# Return error message prepended, to be displayed in the textarea
return f"-- Error generating SQL: {error}\n\n"
return generated_sql if generated_sql else "-- No SQL generated."