diff --git a/routes/sql_explorer.py b/routes/sql_explorer.py index a10c160..af6e377 100644 --- a/routes/sql_explorer.py +++ b/routes/sql_explorer.py @@ -1,4 +1,7 @@ -from flask import Blueprint, render_template, request, current_app +import os +import requests # Import requests library +import json # Import json library +from flask import Blueprint, render_template, request, current_app, jsonify from jinja2_fragments import render_block from flask_htmx import HTMX from extensions import db @@ -7,8 +10,6 @@ from utils import generate_plot sql_explorer_bp = Blueprint('sql_explorer', __name__, url_prefix='/sql') htmx = HTMX() -# --- Database Helper Functions (Moved from features/sql_explorer.py) --- - def _get_schema_info(schema='public'): """Fetches schema information directly.""" tables_result = db.execute(""" @@ -62,7 +63,6 @@ def _map_data_type_for_sql(postgres_type): 'integer': 'INTEGER', 'bigint': 'BIGINT', 'boolean': 'BOOLEAN', 'timestamp without time zone': 'TIMESTAMP', 'timestamp with time zone': 'TIMESTAMPTZ', 'numeric': 'NUMERIC', 'real': 'REAL', 'date': 'DATE' - # Add more as needed }.get(postgres_type, postgres_type.upper()) def _map_data_type(postgres_type): @@ -108,7 +108,6 @@ def _generate_create_script(schema_info): pk_columns = ", ".join(f'"{pk}"' for pk in pks) column_defs.append(f' PRIMARY KEY ({pk_columns})') - # Format column definitions with newlines before using in f-string columns_sql = ",\n".join(column_defs) create_stmt = f'CREATE TABLE "{table}" (\n{columns_sql}\n);' lines.append(create_stmt) @@ -122,16 +121,17 @@ def _generate_create_script(schema_info): lines.append("") return "\n".join(lines) + def _execute_sql(query): """Executes arbitrary SQL query, returning results, columns, and error.""" results, columns, error = None, [], None try: - results = db.execute(query) # Use the imported db object directly + results = db.execute(query) if results: columns = list(results[0].keys()) if isinstance(results, list) and results else [] except Exception as e: error = str(e) - db.getDB().rollback() # Rollback on error + db.getDB().rollback() return (results, columns, error) def _save_query(title, query): @@ -142,7 +142,7 @@ def _save_query(title, query): db.execute("INSERT INTO saved_query (title, query) VALUES (%s, %s)", [title, query], commit=True) except Exception as e: error = str(e) - db.getDB().rollback() # Rollback on error + db.getDB().rollback() return error def _list_saved_queries(): @@ -158,12 +158,79 @@ def _delete_saved_query(query_id): """Deletes a saved query.""" db.execute("DELETE FROM saved_query WHERE id=%s", [query_id], commit=True) +def _generate_sql_from_natural_language(natural_query): + """Generates SQL query from natural language using Gemini REST API.""" + gemni_model = os.environ.get("GEMINI_MODEL","gemini-2.0-flash") + api_key = os.environ.get("GEMINI_API_KEY") + if not api_key: + return None, "GEMINI_API_KEY environment variable not set." + + # Using gemini-pro model endpoint + api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemni_model}:generateContent?key={api_key}" + headers = {'Content-Type': 'application/json'} + + try: + # Get and format schema + schema_info = _get_schema_info() + schema_string = _generate_create_script(schema_info) + + prompt = f"""Given the following database schema: +```sql +{schema_string} +``` + +Generate a PostgreSQL query that answers the following question: "{natural_query}" + +Return ONLY the SQL query, without any explanation or surrounding text/markdown. +""" + # Construct the request payload + payload = json.dumps({ + "contents": [{ + "parts": [{"text": prompt}] + }] + }) + + # Make the POST request + response = requests.post(api_url, headers=headers, data=payload) + response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) + + # Parse the response + response_data = response.json() + + # Extract the generated text - structure might vary slightly based on API version/response + # Safely navigate the response structure + candidates = response_data.get('candidates', []) + if not candidates: + return None, "No candidates found in API response." + + content = candidates[0].get('content', {}) + parts = content.get('parts', []) + if not parts: + return None, "No parts found in API response content." + + generated_sql = parts[0].get('text', '').strip() + + # Basic parsing: remove potential markdown code fences + if generated_sql.startswith("```sql"): + generated_sql = generated_sql[6:] + if generated_sql.endswith("```"): + generated_sql = generated_sql[:-3] + + return generated_sql.strip(), None + + except requests.exceptions.RequestException as e: + current_app.logger.error(f"Gemini API request error: {e}") + return None, f"Error communicating with API: {e}" + except (KeyError, IndexError, Exception) as e: + current_app.logger.error(f"Error processing Gemini API response: {e} - Response: {response_data if 'response_data' in locals() else 'N/A'}") + return None, f"Error processing API response: {e}" + # --- Routes --- @sql_explorer_bp.route("/explorer", methods=['GET']) def sql_explorer(): - saved_queries = _list_saved_queries() # Use local helper + saved_queries = _list_saved_queries() if htmx: return render_block(current_app.jinja_env, 'sql_explorer.html', 'content', saved_queries=saved_queries) return render_template('sql_explorer.html', saved_queries=saved_queries) @@ -172,44 +239,44 @@ def sql_explorer(): def sql_query(): query = request.form.get('query') title = request.form.get('title') - error = _save_query(title, query) # Use local helper - saved_queries = _list_saved_queries() # Use local helper + error = _save_query(title, query) + saved_queries = _list_saved_queries() return render_template('partials/sql_explorer/sql_query.html', title=title, query=query, error=error, saved_queries=saved_queries) @sql_explorer_bp.route("/query/execute", methods=['POST']) def execute_sql_query(): query = request.form.get('query') - (results, columns, error) = _execute_sql(query) # Use local helper + (results, columns, error) = _execute_sql(query) return render_template('partials/sql_explorer/results.html', results=results, columns=columns, error=error) @sql_explorer_bp.route('/load_query/', methods=['GET']) def load_sql_query(query_id): - (title, query) = _get_saved_query(query_id) # Use local helper - saved_queries = _list_saved_queries() # Use local helper + (title, query) = _get_saved_query(query_id) + saved_queries = _list_saved_queries() return render_template('partials/sql_explorer/sql_query.html', title=title, query=query, saved_queries=saved_queries) @sql_explorer_bp.route('/delete_query/', methods=['DELETE']) def delete_sql_query(query_id): - _delete_saved_query(query_id) # Use local helper - saved_queries = _list_saved_queries() # Use local helper + _delete_saved_query(query_id) + saved_queries = _list_saved_queries() return render_template('partials/sql_explorer/sql_query.html', title="", query="", saved_queries=saved_queries) @sql_explorer_bp.route("/schema", methods=['GET']) def sql_schema(): - schema_info = _get_schema_info() # Use local helper - mermaid_code = _generate_mermaid_er(schema_info) # Use local helper - create_sql = _generate_create_script(schema_info) # Use local helper + schema_info = _get_schema_info() + mermaid_code = _generate_mermaid_er(schema_info) + create_sql = _generate_create_script(schema_info) return render_template('partials/sql_explorer/schema.html', mermaid_code=mermaid_code, create_sql=create_sql) @sql_explorer_bp.route("/plot/", methods=['GET']) def plot_query(query_id): - (title, query) = _get_saved_query(query_id) # Use local helper + (title, query) = _get_saved_query(query_id) if not query: return "Query not found", 404 - results_df = db.read_sql_as_df(query) # Keep using db.py for pandas interaction + results_df = db.read_sql_as_df(query) plot_div = generate_plot(results_df, title) return plot_div @@ -217,6 +284,21 @@ def plot_query(query_id): def plot_unsaved_query(): query = request.form.get('query') title = request.form.get('title') - results_df = db.read_sql_as_df(query) # Keep using db.py for pandas interaction + results_df = db.read_sql_as_df(query) plot_div = generate_plot(results_df, title) - return plot_div \ No newline at end of file + return plot_div + +@sql_explorer_bp.route("/generate_sql", methods=['POST']) +def generate_sql(): + """Generates SQL from natural language via Gemini REST API.""" + natural_query = request.form.get('natural_query') + if not natural_query: + return "" + + generated_sql, error = _generate_sql_from_natural_language(natural_query) + + if error: + # Return error message prepended, to be displayed in the textarea + return f"-- Error generating SQL: {error}\n\n" + + return generated_sql if generated_sql else "-- No SQL generated." \ No newline at end of file diff --git a/templates/changelog/changelog.html b/templates/changelog/changelog.html index 58f2744..972ca63 100644 --- a/templates/changelog/changelog.html +++ b/templates/changelog/changelog.html @@ -10,6 +10,15 @@

Updates and changes to the site will be documented here, with the most recent changes listed first.

+ +
+

April 5, 2025

+
    +
  • Added experimental feature to SQL Explorer to generate SQL queries from natural language using the + Gemini REST API. Requires `GEMINI_API_KEY` environment variable.
  • +
+ +

March 31, 2025

diff --git a/templates/partials/sql_explorer/sql_query.html b/templates/partials/sql_explorer/sql_query.html index 144a7ed..1106a15 100644 --- a/templates/partials/sql_explorer/sql_query.html +++ b/templates/partials/sql_explorer/sql_query.html @@ -23,7 +23,34 @@ on input set my.style.height to 0 then set my.style.height to my.scrollHeight + 'px'">{{ query }}
- + +
+ +
+ + +
+
+ +