271 lines
12 KiB
Python
271 lines
12 KiB
Python
import os
|
|
import requests # Import requests library
|
|
import json # Import json library
|
|
from flask import Blueprint, render_template, request, current_app, jsonify
|
|
from flask_login import login_required, current_user
|
|
from jinja2_fragments import render_block
|
|
from flask_htmx import HTMX
|
|
from extensions import db
|
|
from utils import prepare_svg_plot_data, get_client_ip # Will be created for SVG data prep
|
|
|
|
sql_explorer_bp = Blueprint('sql_explorer', __name__, url_prefix='/sql')
|
|
htmx = HTMX()
|
|
|
|
|
|
|
|
|
|
def record_sql_audit(query, success, error_message=None):
|
|
"""Records a SQL execution in the audit table."""
|
|
try:
|
|
person_id = getattr(current_user, 'id', None)
|
|
ip_address = get_client_ip()
|
|
sql = """
|
|
INSERT INTO sql_audit (person_id, query, ip_address, success, error_message)
|
|
VALUES (%s, %s, %s, %s, %s)
|
|
"""
|
|
db.execute(sql, [person_id, query, ip_address, success, error_message], commit=True)
|
|
except Exception as e:
|
|
current_app.logger.error(f"Failed to record SQL audit: {e}")
|
|
|
|
def record_llm_audit(prompt, response, model, success, error_message=None):
|
|
"""Records an LLM interaction in the audit table."""
|
|
try:
|
|
person_id = getattr(current_user, 'id', None)
|
|
ip_address = get_client_ip()
|
|
sql = """
|
|
INSERT INTO llm_audit (person_id, prompt, response, model, ip_address, success, error_message)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
|
"""
|
|
db.execute(sql, [person_id, prompt, response, model, ip_address, success, error_message], commit=True)
|
|
except Exception as e:
|
|
current_app.logger.error(f"Failed to record LLM audit: {e}")
|
|
|
|
def _execute_sql(query):
|
|
"""Executes arbitrary SQL query, returning results, columns, and error."""
|
|
results, columns, error = None, [], None
|
|
try:
|
|
results = db.execute(query)
|
|
if results:
|
|
columns = list(results[0].keys()) if isinstance(results, list) and results else []
|
|
record_sql_audit(query, True)
|
|
except Exception as e:
|
|
error = str(e)
|
|
db.getDB().rollback()
|
|
record_sql_audit(query, False, error)
|
|
return (results, columns, error)
|
|
|
|
def _save_query(title, query):
|
|
"""Saves a query to the database."""
|
|
error = None
|
|
if not title: return "Must provide title"
|
|
try:
|
|
db.execute("INSERT INTO saved_query (title, query) VALUES (%s, %s)", [title, query], commit=True)
|
|
except Exception as e:
|
|
error = str(e)
|
|
db.getDB().rollback()
|
|
return error
|
|
|
|
def _list_saved_queries():
|
|
"""Lists all saved queries."""
|
|
return db.execute("SELECT id, title, query FROM saved_query ORDER BY title")
|
|
|
|
def _get_saved_query(query_id):
|
|
"""Fetches a specific saved query."""
|
|
result = db.execute("SELECT title, query FROM saved_query WHERE id=%s", [query_id], one=True)
|
|
return (result['title'], result['query']) if result else (None, None)
|
|
|
|
def _delete_saved_query(query_id):
|
|
"""Deletes a saved query."""
|
|
db.execute("DELETE FROM saved_query WHERE id=%s", [query_id], commit=True)
|
|
|
|
def _generate_sql_from_natural_language(natural_query):
|
|
"""Generates SQL query from natural language using Gemini REST API."""
|
|
gemni_model = os.environ.get("GEMINI_MODEL","gemini-2.0-flash")
|
|
api_key = os.environ.get("GEMINI_API_KEY")
|
|
if not api_key:
|
|
return None, "GEMINI_API_KEY environment variable not set."
|
|
|
|
# Using gemini-pro model endpoint
|
|
api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemni_model}:generateContent?key={api_key}"
|
|
headers = {'Content-Type': 'application/json'}
|
|
|
|
prompt = natural_query
|
|
try:
|
|
# Get and format schema
|
|
schema_info = db.schema.get_schema_info()
|
|
schema_string = db.schema.generate_create_script(schema_info)
|
|
|
|
prompt = f"""Given the following database schema:
|
|
```sql
|
|
{schema_string}
|
|
```
|
|
|
|
Generate a PostgreSQL query that answers the following question: "{natural_query}"
|
|
|
|
Return ONLY the SQL query, without any explanation or surrounding text/markdown.
|
|
"""
|
|
# Construct the request payload
|
|
payload = json.dumps({
|
|
"contents": [{
|
|
"parts": [{"text": prompt}]
|
|
}]
|
|
})
|
|
|
|
# Make the POST request
|
|
response = requests.post(api_url, headers=headers, data=payload)
|
|
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
|
|
|
|
# Parse the response
|
|
response_data = response.json()
|
|
|
|
# Extract the generated text - structure might vary slightly based on API version/response
|
|
# Safely navigate the response structure
|
|
candidates = response_data.get('candidates', [])
|
|
if not candidates:
|
|
return None, "No candidates found in API response."
|
|
|
|
content = candidates[0].get('content', {})
|
|
parts = content.get('parts', [])
|
|
if not parts:
|
|
return None, "No parts found in API response content."
|
|
|
|
generated_sql = parts[0].get('text', '').strip()
|
|
|
|
# Basic parsing: remove potential markdown code fences
|
|
if generated_sql.startswith("```sql"):
|
|
generated_sql = generated_sql[6:]
|
|
if generated_sql.endswith("```"):
|
|
generated_sql = generated_sql[:-3]
|
|
|
|
# Remove leading SQL comment lines
|
|
sql_lines = generated_sql.strip().splitlines()
|
|
filtered_lines = [line for line in sql_lines if not line.strip().startswith('--')]
|
|
final_sql = "\n".join(filtered_lines).strip()
|
|
|
|
generated_sql, error = final_sql, None
|
|
record_llm_audit(prompt, generated_sql, gemni_model, True)
|
|
return generated_sql, error
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
current_app.logger.error(f"Gemini API request error: {e}")
|
|
error_msg = f"Error communicating with API: {e}"
|
|
record_llm_audit(prompt, None, gemni_model, False, error_msg)
|
|
return None, error_msg
|
|
except (KeyError, IndexError, Exception) as e:
|
|
current_app.logger.error(f"Error processing Gemini API response: {e} - Response: {response_data if 'response_data' in locals() else 'N/A'}")
|
|
error_msg = f"Error processing API response: {e}"
|
|
record_llm_audit(prompt, None, gemni_model, False, error_msg)
|
|
return None, error_msg
|
|
|
|
|
|
# --- Routes ---
|
|
|
|
@sql_explorer_bp.route("/explorer", methods=['GET'])
|
|
def sql_explorer():
|
|
saved_queries = _list_saved_queries()
|
|
if htmx:
|
|
return render_block(current_app.jinja_env, 'sql_explorer.html', 'content', saved_queries=saved_queries)
|
|
return render_template('sql_explorer.html', saved_queries=saved_queries)
|
|
|
|
@sql_explorer_bp.route("/query", methods=['POST'])
|
|
@login_required
|
|
def sql_query():
|
|
query = request.form.get('query')
|
|
title = request.form.get('title')
|
|
error = _save_query(title, query)
|
|
saved_queries = _list_saved_queries()
|
|
return render_template('partials/sql_explorer/sql_query.html',
|
|
title=title, query=query, error=error, saved_queries=saved_queries)
|
|
|
|
@sql_explorer_bp.route("/query/execute", methods=['POST'])
|
|
@login_required
|
|
def execute_sql_query():
|
|
query = request.form.get('query')
|
|
(results, columns, error) = _execute_sql(query)
|
|
return render_template('partials/sql_explorer/results.html',
|
|
results=results, columns=columns, error=error)
|
|
|
|
@sql_explorer_bp.route('/load_query/<int:query_id>', methods=['GET'])
|
|
def load_sql_query(query_id):
|
|
(title, query) = _get_saved_query(query_id)
|
|
saved_queries = _list_saved_queries()
|
|
return render_template('partials/sql_explorer/sql_query.html',
|
|
title=title, query=query, saved_queries=saved_queries)
|
|
|
|
@sql_explorer_bp.route('/delete_query/<int:query_id>', methods=['DELETE'])
|
|
@login_required
|
|
def delete_sql_query(query_id):
|
|
_delete_saved_query(query_id)
|
|
saved_queries = _list_saved_queries()
|
|
return render_template('partials/sql_explorer/sql_query.html',
|
|
title="", query="", saved_queries=saved_queries)
|
|
|
|
@sql_explorer_bp.route("/schema", methods=['GET'])
|
|
def sql_schema():
|
|
schema_info = db.schema.get_schema_info()
|
|
create_sql = db.schema.generate_create_script(schema_info)
|
|
return render_template('partials/sql_explorer/schema.html', create_sql=create_sql)
|
|
|
|
@sql_explorer_bp.route("/plot/<int:query_id>", methods=['GET'])
|
|
@login_required
|
|
def plot_query(query_id):
|
|
(title, query) = _get_saved_query(query_id)
|
|
if not query: return "Query not found", 404
|
|
# Fetch raw results instead of DataFrame
|
|
(results, columns, error) = _execute_sql(query)
|
|
if error:
|
|
# Return an HTML snippet indicating the error
|
|
return f'<div class="p-4 text-red-700 bg-red-100 border border-red-400 rounded">Error executing query: {error}</div>', 400
|
|
if not results:
|
|
# Return an HTML snippet indicating no data
|
|
return '<div class="p-4 text-yellow-700 bg-yellow-100 border border-yellow-400 rounded">No data returned by query.</div>'
|
|
|
|
try:
|
|
# Prepare data for SVG plotting (function to be created in utils.py)
|
|
plot_data = prepare_svg_plot_data(results, columns, title)
|
|
# Render the new SVG template
|
|
return render_template('partials/sql_explorer/svg_plot.html', **plot_data)
|
|
except Exception as e:
|
|
current_app.logger.error(f"Error preparing SVG plot data: {e}")
|
|
# Return an HTML snippet indicating a processing error
|
|
return f'<div class="p-4 text-red-700 bg-red-100 border border-red-400 rounded">Error preparing plot data: {e}</div>', 500
|
|
|
|
@sql_explorer_bp.route("/plot/show", methods=['POST'])
|
|
@login_required
|
|
def plot_unsaved_query():
|
|
query = request.form.get('query')
|
|
title = request.form.get('title', 'SQL Query Plot') # Add default title
|
|
# Fetch raw results instead of DataFrame
|
|
(results, columns, error) = _execute_sql(query)
|
|
if error:
|
|
# Return an HTML snippet indicating the error
|
|
return f'<div class="p-4 text-red-700 bg-red-100 border border-red-400 rounded">Error executing query: {error}</div>', 400
|
|
if not results:
|
|
# Return an HTML snippet indicating no data
|
|
return '<div class="p-4 text-yellow-700 bg-yellow-100 border border-yellow-400 rounded">No data returned by query.</div>'
|
|
|
|
try:
|
|
# Prepare data for SVG plotting (function to be created in utils.py)
|
|
plot_data = prepare_svg_plot_data(results, columns, title)
|
|
# Render the new SVG template
|
|
return render_template('partials/sql_explorer/svg_plot.html', **plot_data)
|
|
except Exception as e:
|
|
current_app.logger.error(f"Error preparing SVG plot data: {e}")
|
|
# Return an HTML snippet indicating a processing error
|
|
return f'<div class="p-4 text-red-700 bg-red-100 border border-red-400 rounded">Error preparing plot data: {e}</div>', 500
|
|
|
|
@sql_explorer_bp.route("/generate_sql", methods=['POST'])
|
|
@login_required
|
|
def generate_sql():
|
|
"""Generates SQL from natural language via Gemini REST API."""
|
|
natural_query = request.form.get('natural_query')
|
|
if not natural_query:
|
|
return ""
|
|
|
|
generated_sql, error = _generate_sql_from_natural_language(natural_query)
|
|
|
|
if error:
|
|
# Return error message prepended, to be displayed in the textarea
|
|
return f"-- Error generating SQL: {error}\n\n"
|
|
|
|
return generated_sql if generated_sql else "-- No SQL generated." |