import os import requests # Import requests library import json # Import json library from flask import Blueprint, render_template, request, current_app, jsonify from jinja2_fragments import render_block from flask_htmx import HTMX from extensions import db from utils import prepare_svg_plot_data # Will be created for SVG data prep sql_explorer_bp = Blueprint('sql_explorer', __name__, url_prefix='/sql') htmx = HTMX() def _execute_sql(query): """Executes arbitrary SQL query, returning results, columns, and error.""" results, columns, error = None, [], None try: results = db.execute(query) if results: columns = list(results[0].keys()) if isinstance(results, list) and results else [] except Exception as e: error = str(e) db.getDB().rollback() return (results, columns, error) def _save_query(title, query): """Saves a query to the database.""" error = None if not title: return "Must provide title" try: db.execute("INSERT INTO saved_query (title, query) VALUES (%s, %s)", [title, query], commit=True) except Exception as e: error = str(e) db.getDB().rollback() return error def _list_saved_queries(): """Lists all saved queries.""" return db.execute("SELECT id, title, query FROM saved_query ORDER BY title") def _get_saved_query(query_id): """Fetches a specific saved query.""" result = db.execute("SELECT title, query FROM saved_query WHERE id=%s", [query_id], one=True) return (result['title'], result['query']) if result else (None, None) def _delete_saved_query(query_id): """Deletes a saved query.""" db.execute("DELETE FROM saved_query WHERE id=%s", [query_id], commit=True) def _generate_sql_from_natural_language(natural_query): """Generates SQL query from natural language using Gemini REST API.""" gemni_model = os.environ.get("GEMINI_MODEL","gemini-2.0-flash") api_key = os.environ.get("GEMINI_API_KEY") if not api_key: return None, "GEMINI_API_KEY environment variable not set." # Using gemini-pro model endpoint api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemni_model}:generateContent?key={api_key}" headers = {'Content-Type': 'application/json'} try: # Get and format schema schema_info = _get_schema_info() schema_string = _generate_create_script(schema_info) prompt = f"""Given the following database schema: ```sql {schema_string} ``` Generate a PostgreSQL query that answers the following question: "{natural_query}" Return ONLY the SQL query, without any explanation or surrounding text/markdown. """ # Construct the request payload payload = json.dumps({ "contents": [{ "parts": [{"text": prompt}] }] }) # Make the POST request response = requests.post(api_url, headers=headers, data=payload) response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) # Parse the response response_data = response.json() # Extract the generated text - structure might vary slightly based on API version/response # Safely navigate the response structure candidates = response_data.get('candidates', []) if not candidates: return None, "No candidates found in API response." content = candidates[0].get('content', {}) parts = content.get('parts', []) if not parts: return None, "No parts found in API response content." generated_sql = parts[0].get('text', '').strip() # Basic parsing: remove potential markdown code fences if generated_sql.startswith("```sql"): generated_sql = generated_sql[6:] if generated_sql.endswith("```"): generated_sql = generated_sql[:-3] # Remove leading SQL comment lines sql_lines = generated_sql.strip().splitlines() filtered_lines = [line for line in sql_lines if not line.strip().startswith('--')] final_sql = "\n".join(filtered_lines).strip() return final_sql, None except requests.exceptions.RequestException as e: current_app.logger.error(f"Gemini API request error: {e}") return None, f"Error communicating with API: {e}" except (KeyError, IndexError, Exception) as e: current_app.logger.error(f"Error processing Gemini API response: {e} - Response: {response_data if 'response_data' in locals() else 'N/A'}") return None, f"Error processing API response: {e}" # --- Routes --- @sql_explorer_bp.route("/explorer", methods=['GET']) def sql_explorer(): saved_queries = _list_saved_queries() if htmx: return render_block(current_app.jinja_env, 'sql_explorer.html', 'content', saved_queries=saved_queries) return render_template('sql_explorer.html', saved_queries=saved_queries) @sql_explorer_bp.route("/query", methods=['POST']) def sql_query(): query = request.form.get('query') title = request.form.get('title') error = _save_query(title, query) saved_queries = _list_saved_queries() return render_template('partials/sql_explorer/sql_query.html', title=title, query=query, error=error, saved_queries=saved_queries) @sql_explorer_bp.route("/query/execute", methods=['POST']) def execute_sql_query(): query = request.form.get('query') (results, columns, error) = _execute_sql(query) return render_template('partials/sql_explorer/results.html', results=results, columns=columns, error=error) @sql_explorer_bp.route('/load_query/', methods=['GET']) def load_sql_query(query_id): (title, query) = _get_saved_query(query_id) saved_queries = _list_saved_queries() return render_template('partials/sql_explorer/sql_query.html', title=title, query=query, saved_queries=saved_queries) @sql_explorer_bp.route('/delete_query/', methods=['DELETE']) def delete_sql_query(query_id): _delete_saved_query(query_id) saved_queries = _list_saved_queries() return render_template('partials/sql_explorer/sql_query.html', title="", query="", saved_queries=saved_queries) @sql_explorer_bp.route("/schema", methods=['GET']) def sql_schema(): schema_info = db.schema.get_schema_info() create_sql = db.schema.generate_create_script(schema_info) return render_template('partials/sql_explorer/schema.html', create_sql=create_sql) @sql_explorer_bp.route("/plot/", methods=['GET']) def plot_query(query_id): (title, query) = _get_saved_query(query_id) if not query: return "Query not found", 404 # Fetch raw results instead of DataFrame (results, columns, error) = _execute_sql(query) if error: # Return an HTML snippet indicating the error return f'<div class="p-4 text-red-700 bg-red-100 border border-red-400 rounded">Error executing query: {error}</div>', 400 if not results: # Return an HTML snippet indicating no data return '<div class="p-4 text-yellow-700 bg-yellow-100 border border-yellow-400 rounded">No data returned by query.</div>' try: # Prepare data for SVG plotting (function to be created in utils.py) plot_data = prepare_svg_plot_data(results, columns, title) # Render the new SVG template return render_template('partials/sql_explorer/svg_plot.html', **plot_data) except Exception as e: current_app.logger.error(f"Error preparing SVG plot data: {e}") # Return an HTML snippet indicating a processing error return f'<div class="p-4 text-red-700 bg-red-100 border border-red-400 rounded">Error preparing plot data: {e}</div>', 500 @sql_explorer_bp.route("/plot/show", methods=['POST']) def plot_unsaved_query(): query = request.form.get('query') title = request.form.get('title', 'SQL Query Plot') # Add default title # Fetch raw results instead of DataFrame (results, columns, error) = _execute_sql(query) if error: # Return an HTML snippet indicating the error return f'<div class="p-4 text-red-700 bg-red-100 border border-red-400 rounded">Error executing query: {error}</div>', 400 if not results: # Return an HTML snippet indicating no data return '<div class="p-4 text-yellow-700 bg-yellow-100 border border-yellow-400 rounded">No data returned by query.</div>' try: # Prepare data for SVG plotting (function to be created in utils.py) plot_data = prepare_svg_plot_data(results, columns, title) # Render the new SVG template return render_template('partials/sql_explorer/svg_plot.html', **plot_data) except Exception as e: current_app.logger.error(f"Error preparing SVG plot data: {e}") # Return an HTML snippet indicating a processing error return f'<div class="p-4 text-red-700 bg-red-100 border border-red-400 rounded">Error preparing plot data: {e}</div>', 500 @sql_explorer_bp.route("/generate_sql", methods=['POST']) def generate_sql(): """Generates SQL from natural language via Gemini REST API.""" natural_query = request.form.get('natural_query') if not natural_query: return "" generated_sql, error = _generate_sql_from_natural_language(natural_query) if error: # Return error message prepended, to be displayed in the textarea return f"-- Error generating SQL: {error}\n\n" return generated_sql if generated_sql else "-- No SQL generated."