refactor: Use REST API for Gemini SQL generation
- Modified the `_generate_sql_from_natural_language` helper function in `routes/sql_explorer.py` to use direct REST API calls via the `requests` library instead of the `google-generativeai` Python library. - Added `requests` and `json` imports and removed the `google-generativeai` import. - Updated error handling for API communication and response parsing. - Updated the corresponding changelog entry.
This commit is contained in:
@@ -1,4 +1,7 @@
|
|||||||
from flask import Blueprint, render_template, request, current_app
|
import os
|
||||||
|
import requests # Import requests library
|
||||||
|
import json # Import json library
|
||||||
|
from flask import Blueprint, render_template, request, current_app, jsonify
|
||||||
from jinja2_fragments import render_block
|
from jinja2_fragments import render_block
|
||||||
from flask_htmx import HTMX
|
from flask_htmx import HTMX
|
||||||
from extensions import db
|
from extensions import db
|
||||||
@@ -7,8 +10,6 @@ from utils import generate_plot
|
|||||||
sql_explorer_bp = Blueprint('sql_explorer', __name__, url_prefix='/sql')
|
sql_explorer_bp = Blueprint('sql_explorer', __name__, url_prefix='/sql')
|
||||||
htmx = HTMX()
|
htmx = HTMX()
|
||||||
|
|
||||||
# --- Database Helper Functions (Moved from features/sql_explorer.py) ---
|
|
||||||
|
|
||||||
def _get_schema_info(schema='public'):
|
def _get_schema_info(schema='public'):
|
||||||
"""Fetches schema information directly."""
|
"""Fetches schema information directly."""
|
||||||
tables_result = db.execute("""
|
tables_result = db.execute("""
|
||||||
@@ -62,7 +63,6 @@ def _map_data_type_for_sql(postgres_type):
|
|||||||
'integer': 'INTEGER', 'bigint': 'BIGINT', 'boolean': 'BOOLEAN',
|
'integer': 'INTEGER', 'bigint': 'BIGINT', 'boolean': 'BOOLEAN',
|
||||||
'timestamp without time zone': 'TIMESTAMP', 'timestamp with time zone': 'TIMESTAMPTZ',
|
'timestamp without time zone': 'TIMESTAMP', 'timestamp with time zone': 'TIMESTAMPTZ',
|
||||||
'numeric': 'NUMERIC', 'real': 'REAL', 'date': 'DATE'
|
'numeric': 'NUMERIC', 'real': 'REAL', 'date': 'DATE'
|
||||||
# Add more as needed
|
|
||||||
}.get(postgres_type, postgres_type.upper())
|
}.get(postgres_type, postgres_type.upper())
|
||||||
|
|
||||||
def _map_data_type(postgres_type):
|
def _map_data_type(postgres_type):
|
||||||
@@ -108,7 +108,6 @@ def _generate_create_script(schema_info):
|
|||||||
pk_columns = ", ".join(f'"{pk}"' for pk in pks)
|
pk_columns = ", ".join(f'"{pk}"' for pk in pks)
|
||||||
column_defs.append(f' PRIMARY KEY ({pk_columns})')
|
column_defs.append(f' PRIMARY KEY ({pk_columns})')
|
||||||
|
|
||||||
# Format column definitions with newlines before using in f-string
|
|
||||||
columns_sql = ",\n".join(column_defs)
|
columns_sql = ",\n".join(column_defs)
|
||||||
create_stmt = f'CREATE TABLE "{table}" (\n{columns_sql}\n);'
|
create_stmt = f'CREATE TABLE "{table}" (\n{columns_sql}\n);'
|
||||||
lines.append(create_stmt)
|
lines.append(create_stmt)
|
||||||
@@ -122,16 +121,17 @@ def _generate_create_script(schema_info):
|
|||||||
lines.append("")
|
lines.append("")
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def _execute_sql(query):
|
def _execute_sql(query):
|
||||||
"""Executes arbitrary SQL query, returning results, columns, and error."""
|
"""Executes arbitrary SQL query, returning results, columns, and error."""
|
||||||
results, columns, error = None, [], None
|
results, columns, error = None, [], None
|
||||||
try:
|
try:
|
||||||
results = db.execute(query) # Use the imported db object directly
|
results = db.execute(query)
|
||||||
if results:
|
if results:
|
||||||
columns = list(results[0].keys()) if isinstance(results, list) and results else []
|
columns = list(results[0].keys()) if isinstance(results, list) and results else []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error = str(e)
|
error = str(e)
|
||||||
db.getDB().rollback() # Rollback on error
|
db.getDB().rollback()
|
||||||
return (results, columns, error)
|
return (results, columns, error)
|
||||||
|
|
||||||
def _save_query(title, query):
|
def _save_query(title, query):
|
||||||
@@ -142,7 +142,7 @@ def _save_query(title, query):
|
|||||||
db.execute("INSERT INTO saved_query (title, query) VALUES (%s, %s)", [title, query], commit=True)
|
db.execute("INSERT INTO saved_query (title, query) VALUES (%s, %s)", [title, query], commit=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error = str(e)
|
error = str(e)
|
||||||
db.getDB().rollback() # Rollback on error
|
db.getDB().rollback()
|
||||||
return error
|
return error
|
||||||
|
|
||||||
def _list_saved_queries():
|
def _list_saved_queries():
|
||||||
@@ -158,12 +158,79 @@ def _delete_saved_query(query_id):
|
|||||||
"""Deletes a saved query."""
|
"""Deletes a saved query."""
|
||||||
db.execute("DELETE FROM saved_query WHERE id=%s", [query_id], commit=True)
|
db.execute("DELETE FROM saved_query WHERE id=%s", [query_id], commit=True)
|
||||||
|
|
||||||
|
def _generate_sql_from_natural_language(natural_query):
|
||||||
|
"""Generates SQL query from natural language using Gemini REST API."""
|
||||||
|
gemni_model = os.environ.get("GEMINI_MODEL","gemini-2.0-flash")
|
||||||
|
api_key = os.environ.get("GEMINI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
return None, "GEMINI_API_KEY environment variable not set."
|
||||||
|
|
||||||
|
# Using gemini-pro model endpoint
|
||||||
|
api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemni_model}:generateContent?key={api_key}"
|
||||||
|
headers = {'Content-Type': 'application/json'}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get and format schema
|
||||||
|
schema_info = _get_schema_info()
|
||||||
|
schema_string = _generate_create_script(schema_info)
|
||||||
|
|
||||||
|
prompt = f"""Given the following database schema:
|
||||||
|
```sql
|
||||||
|
{schema_string}
|
||||||
|
```
|
||||||
|
|
||||||
|
Generate a PostgreSQL query that answers the following question: "{natural_query}"
|
||||||
|
|
||||||
|
Return ONLY the SQL query, without any explanation or surrounding text/markdown.
|
||||||
|
"""
|
||||||
|
# Construct the request payload
|
||||||
|
payload = json.dumps({
|
||||||
|
"contents": [{
|
||||||
|
"parts": [{"text": prompt}]
|
||||||
|
}]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Make the POST request
|
||||||
|
response = requests.post(api_url, headers=headers, data=payload)
|
||||||
|
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
|
||||||
|
|
||||||
|
# Parse the response
|
||||||
|
response_data = response.json()
|
||||||
|
|
||||||
|
# Extract the generated text - structure might vary slightly based on API version/response
|
||||||
|
# Safely navigate the response structure
|
||||||
|
candidates = response_data.get('candidates', [])
|
||||||
|
if not candidates:
|
||||||
|
return None, "No candidates found in API response."
|
||||||
|
|
||||||
|
content = candidates[0].get('content', {})
|
||||||
|
parts = content.get('parts', [])
|
||||||
|
if not parts:
|
||||||
|
return None, "No parts found in API response content."
|
||||||
|
|
||||||
|
generated_sql = parts[0].get('text', '').strip()
|
||||||
|
|
||||||
|
# Basic parsing: remove potential markdown code fences
|
||||||
|
if generated_sql.startswith("```sql"):
|
||||||
|
generated_sql = generated_sql[6:]
|
||||||
|
if generated_sql.endswith("```"):
|
||||||
|
generated_sql = generated_sql[:-3]
|
||||||
|
|
||||||
|
return generated_sql.strip(), None
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
current_app.logger.error(f"Gemini API request error: {e}")
|
||||||
|
return None, f"Error communicating with API: {e}"
|
||||||
|
except (KeyError, IndexError, Exception) as e:
|
||||||
|
current_app.logger.error(f"Error processing Gemini API response: {e} - Response: {response_data if 'response_data' in locals() else 'N/A'}")
|
||||||
|
return None, f"Error processing API response: {e}"
|
||||||
|
|
||||||
|
|
||||||
# --- Routes ---
|
# --- Routes ---
|
||||||
|
|
||||||
@sql_explorer_bp.route("/explorer", methods=['GET'])
|
@sql_explorer_bp.route("/explorer", methods=['GET'])
|
||||||
def sql_explorer():
|
def sql_explorer():
|
||||||
saved_queries = _list_saved_queries() # Use local helper
|
saved_queries = _list_saved_queries()
|
||||||
if htmx:
|
if htmx:
|
||||||
return render_block(current_app.jinja_env, 'sql_explorer.html', 'content', saved_queries=saved_queries)
|
return render_block(current_app.jinja_env, 'sql_explorer.html', 'content', saved_queries=saved_queries)
|
||||||
return render_template('sql_explorer.html', saved_queries=saved_queries)
|
return render_template('sql_explorer.html', saved_queries=saved_queries)
|
||||||
@@ -172,44 +239,44 @@ def sql_explorer():
|
|||||||
def sql_query():
|
def sql_query():
|
||||||
query = request.form.get('query')
|
query = request.form.get('query')
|
||||||
title = request.form.get('title')
|
title = request.form.get('title')
|
||||||
error = _save_query(title, query) # Use local helper
|
error = _save_query(title, query)
|
||||||
saved_queries = _list_saved_queries() # Use local helper
|
saved_queries = _list_saved_queries()
|
||||||
return render_template('partials/sql_explorer/sql_query.html',
|
return render_template('partials/sql_explorer/sql_query.html',
|
||||||
title=title, query=query, error=error, saved_queries=saved_queries)
|
title=title, query=query, error=error, saved_queries=saved_queries)
|
||||||
|
|
||||||
@sql_explorer_bp.route("/query/execute", methods=['POST'])
|
@sql_explorer_bp.route("/query/execute", methods=['POST'])
|
||||||
def execute_sql_query():
|
def execute_sql_query():
|
||||||
query = request.form.get('query')
|
query = request.form.get('query')
|
||||||
(results, columns, error) = _execute_sql(query) # Use local helper
|
(results, columns, error) = _execute_sql(query)
|
||||||
return render_template('partials/sql_explorer/results.html',
|
return render_template('partials/sql_explorer/results.html',
|
||||||
results=results, columns=columns, error=error)
|
results=results, columns=columns, error=error)
|
||||||
|
|
||||||
@sql_explorer_bp.route('/load_query/<int:query_id>', methods=['GET'])
|
@sql_explorer_bp.route('/load_query/<int:query_id>', methods=['GET'])
|
||||||
def load_sql_query(query_id):
|
def load_sql_query(query_id):
|
||||||
(title, query) = _get_saved_query(query_id) # Use local helper
|
(title, query) = _get_saved_query(query_id)
|
||||||
saved_queries = _list_saved_queries() # Use local helper
|
saved_queries = _list_saved_queries()
|
||||||
return render_template('partials/sql_explorer/sql_query.html',
|
return render_template('partials/sql_explorer/sql_query.html',
|
||||||
title=title, query=query, saved_queries=saved_queries)
|
title=title, query=query, saved_queries=saved_queries)
|
||||||
|
|
||||||
@sql_explorer_bp.route('/delete_query/<int:query_id>', methods=['DELETE'])
|
@sql_explorer_bp.route('/delete_query/<int:query_id>', methods=['DELETE'])
|
||||||
def delete_sql_query(query_id):
|
def delete_sql_query(query_id):
|
||||||
_delete_saved_query(query_id) # Use local helper
|
_delete_saved_query(query_id)
|
||||||
saved_queries = _list_saved_queries() # Use local helper
|
saved_queries = _list_saved_queries()
|
||||||
return render_template('partials/sql_explorer/sql_query.html',
|
return render_template('partials/sql_explorer/sql_query.html',
|
||||||
title="", query="", saved_queries=saved_queries)
|
title="", query="", saved_queries=saved_queries)
|
||||||
|
|
||||||
@sql_explorer_bp.route("/schema", methods=['GET'])
|
@sql_explorer_bp.route("/schema", methods=['GET'])
|
||||||
def sql_schema():
|
def sql_schema():
|
||||||
schema_info = _get_schema_info() # Use local helper
|
schema_info = _get_schema_info()
|
||||||
mermaid_code = _generate_mermaid_er(schema_info) # Use local helper
|
mermaid_code = _generate_mermaid_er(schema_info)
|
||||||
create_sql = _generate_create_script(schema_info) # Use local helper
|
create_sql = _generate_create_script(schema_info)
|
||||||
return render_template('partials/sql_explorer/schema.html', mermaid_code=mermaid_code, create_sql=create_sql)
|
return render_template('partials/sql_explorer/schema.html', mermaid_code=mermaid_code, create_sql=create_sql)
|
||||||
|
|
||||||
@sql_explorer_bp.route("/plot/<int:query_id>", methods=['GET'])
|
@sql_explorer_bp.route("/plot/<int:query_id>", methods=['GET'])
|
||||||
def plot_query(query_id):
|
def plot_query(query_id):
|
||||||
(title, query) = _get_saved_query(query_id) # Use local helper
|
(title, query) = _get_saved_query(query_id)
|
||||||
if not query: return "Query not found", 404
|
if not query: return "Query not found", 404
|
||||||
results_df = db.read_sql_as_df(query) # Keep using db.py for pandas interaction
|
results_df = db.read_sql_as_df(query)
|
||||||
plot_div = generate_plot(results_df, title)
|
plot_div = generate_plot(results_df, title)
|
||||||
return plot_div
|
return plot_div
|
||||||
|
|
||||||
@@ -217,6 +284,21 @@ def plot_query(query_id):
|
|||||||
def plot_unsaved_query():
|
def plot_unsaved_query():
|
||||||
query = request.form.get('query')
|
query = request.form.get('query')
|
||||||
title = request.form.get('title')
|
title = request.form.get('title')
|
||||||
results_df = db.read_sql_as_df(query) # Keep using db.py for pandas interaction
|
results_df = db.read_sql_as_df(query)
|
||||||
plot_div = generate_plot(results_df, title)
|
plot_div = generate_plot(results_df, title)
|
||||||
return plot_div
|
return plot_div
|
||||||
|
|
||||||
|
@sql_explorer_bp.route("/generate_sql", methods=['POST'])
|
||||||
|
def generate_sql():
|
||||||
|
"""Generates SQL from natural language via Gemini REST API."""
|
||||||
|
natural_query = request.form.get('natural_query')
|
||||||
|
if not natural_query:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
generated_sql, error = _generate_sql_from_natural_language(natural_query)
|
||||||
|
|
||||||
|
if error:
|
||||||
|
# Return error message prepended, to be displayed in the textarea
|
||||||
|
return f"-- Error generating SQL: {error}\n\n"
|
||||||
|
|
||||||
|
return generated_sql if generated_sql else "-- No SQL generated."
|
||||||
@@ -10,6 +10,15 @@
|
|||||||
<div class="prose max-w-none">
|
<div class="prose max-w-none">
|
||||||
<p>Updates and changes to the site will be documented here, with the most recent changes listed first.</p>
|
<p>Updates and changes to the site will be documented here, with the most recent changes listed first.</p>
|
||||||
|
|
||||||
|
<!-- New Entry for SQL Generation -->
|
||||||
|
<hr class="my-6">
|
||||||
|
<h2 class="text-xl font-semibold mb-2">April 5, 2025</h2>
|
||||||
|
<ul class="list-disc pl-5 space-y-1">
|
||||||
|
<li>Added experimental feature to SQL Explorer to generate SQL queries from natural language using the
|
||||||
|
Gemini REST API. Requires `GEMINI_API_KEY` environment variable.</li>
|
||||||
|
</ul>
|
||||||
|
|
||||||
|
|
||||||
<!-- New Entry for Endpoints Refactoring -->
|
<!-- New Entry for Endpoints Refactoring -->
|
||||||
<hr class="my-6">
|
<hr class="my-6">
|
||||||
<h2 class="text-xl font-semibold mb-2">March 31, 2025</h2>
|
<h2 class="text-xl font-semibold mb-2">March 31, 2025</h2>
|
||||||
|
|||||||
@@ -23,7 +23,34 @@
|
|||||||
on input set my.style.height to 0 then set my.style.height to my.scrollHeight + 'px'">{{ query }}</textarea>
|
on input set my.style.height to 0 then set my.style.height to my.scrollHeight + 'px'">{{ query }}</textarea>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Buttons -->
|
<!-- Natural Language Query Input -->
|
||||||
|
<div class="pt-2">
|
||||||
|
<label for="natural-query" class="block text-sm font-medium text-gray-700 pb-1">Generate SQL from Natural
|
||||||
|
Language</label>
|
||||||
|
<div class="flex items-center">
|
||||||
|
<input type="text" id="natural-query" name="natural_query"
|
||||||
|
class="flex-grow p-2 border border-gray-300 rounded-l-md shadow-sm focus:outline-none focus:ring-2 focus:ring-purple-500 focus:border-transparent"
|
||||||
|
placeholder="e.g., 'Show me the number of workouts per person'">
|
||||||
|
<button type="button" hx-post="{{ url_for('sql_explorer.generate_sql') }}"
|
||||||
|
hx-include="[name='natural_query']" hx-target="#query" hx-swap="innerHTML"
|
||||||
|
hx-indicator="#sql-spinner"
|
||||||
|
class="bg-purple-600 text-white p-2.5 rounded-r-md hover:bg-purple-700 focus:outline-none focus:ring-2 focus:ring-purple-500 focus:ring-opacity-50 inline-flex items-center">
|
||||||
|
Generate SQL
|
||||||
|
<span id="sql-spinner" class="htmx-indicator ml-2">
|
||||||
|
<svg class="animate-spin h-4 w-4 text-white" xmlns="http://www.w3.org/2000/svg" fill="none"
|
||||||
|
viewBox="0 0 24 24">
|
||||||
|
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4">
|
||||||
|
</circle>
|
||||||
|
<path class="opacity-75" fill="currentColor"
|
||||||
|
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z">
|
||||||
|
</path>
|
||||||
|
</svg>
|
||||||
|
</span>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Action Buttons -->
|
||||||
<div class="flex space-x-2 pt-1">
|
<div class="flex space-x-2 pt-1">
|
||||||
<!-- Execute Button -->
|
<!-- Execute Button -->
|
||||||
<button hx-post="{{ url_for('sql_explorer.execute_sql_query') }}" hx-target="#execute-query-results"
|
<button hx-post="{{ url_for('sql_explorer.execute_sql_query') }}" hx-target="#execute-query-results"
|
||||||
|
|||||||
Reference in New Issue
Block a user