Add ability to plot saved queries using plotly, need to check performance in production, also need to improve generate_plot function

This commit is contained in:
Peter Stockings
2024-11-09 16:49:08 +11:00
parent 23def088bb
commit c7013e0eac
5 changed files with 85 additions and 13 deletions

10
app.py
View File

@@ -7,7 +7,7 @@ import jinja_partials
from jinja2_fragments import render_block from jinja2_fragments import render_block
from decorators import validate_person, validate_topset, validate_workout from decorators import validate_person, validate_topset, validate_workout
from db import DataBase from db import DataBase
from utils import count_prs_over_time, get_people_and_exercise_rep_maxes, convert_str_to_date, get_earliest_and_latest_workout_date, filter_workout_topsets, first_and_last_visible_days_in_month, get_weekly_pr_graph_model, get_workout_counts from utils import count_prs_over_time, get_people_and_exercise_rep_maxes, convert_str_to_date, get_earliest_and_latest_workout_date, filter_workout_topsets, first_and_last_visible_days_in_month, get_weekly_pr_graph_model, get_workout_counts, generate_plot
from flask_htmx import HTMX from flask_htmx import HTMX
import minify_html import minify_html
from urllib.parse import quote from urllib.parse import quote
@@ -526,6 +526,14 @@ def sql_schema():
mermaid_code = db.sql_explorer.generate_mermaid_er(schema_info) mermaid_code = db.sql_explorer.generate_mermaid_er(schema_info)
return render_template('partials/sql_explorer/schema.html', mermaid_code=mermaid_code) return render_template('partials/sql_explorer/schema.html', mermaid_code=mermaid_code)
@app.route("/plot/<int:query_id>", methods=['GET'])
def plot_query(query_id):
(title, query) = db.sql_explorer.get_saved_query(query_id)
#(results, columns, error) = db.sql_explorer.execute_sql(query)
results_df = db.read_sql_as_df(query)
plot_div = generate_plot(results_df, title)
return plot_div
@app.teardown_appcontext @app.teardown_appcontext
def closeConnection(exception): def closeConnection(exception):
db.close_connection() db.close_connection()

9
db.py
View File

@@ -6,6 +6,7 @@ from datetime import datetime
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
from urllib.parse import urlparse from urllib.parse import urlparse
from flask import g from flask import g
import pandas as pd
from features.calendar import Calendar from features.calendar import Calendar
from features.exercises import Exercises from features.exercises import Exercises
from features.stats import Stats from features.stats import Stats
@@ -61,6 +62,14 @@ class DataBase():
return (rv[0] if rv else None) if one else rv return (rv[0] if rv else None) if one else rv
def read_sql_as_df(self, query, params=None):
conn = self.getDB()
try:
df = pd.read_sql(query, conn, params=params)
return df
except Exception as e:
raise e
def get_exercise(self, exercise_id): def get_exercise(self, exercise_id):
exercise = self.execute( exercise = self.execute(
'SELECT exercise_id, name FROM exercise WHERE exercise_id=%s LIMIT 1', [exercise_id], one=True) 'SELECT exercise_id, name FROM exercise WHERE exercise_id=%s LIMIT 1', [exercise_id], one=True)

View File

@@ -11,3 +11,4 @@ Werkzeug==2.2.2
numpy==1.19.5 numpy==1.19.5
pandas==1.3.1 pandas==1.3.1
python-dotenv==1.0.1 python-dotenv==1.0.1
plotly==5.24.1

View File

@@ -94,22 +94,36 @@
<tbody> <tbody>
{% for saved in saved_queries %} {% for saved in saved_queries %}
<tr class="hover:bg-gray-100 transition-colors duration-200"> <tr class="hover:bg-gray-100 transition-colors duration-200">
<td class="py-4 px-6 border-b">{{ saved.title }}</td> <!-- Query Title as Load Action -->
<td class="py-4 px-6 border-b">
<a href="#" hx-get="{{ url_for('load_sql_query', query_id=saved.id) }}"
hx-target="#sql-query"
class="flex items-center text-blue-500 hover:text-blue-700 cursor-pointer">
<!-- Load Icon (Heroicon: Eye) -->
<svg xmlns="http://www.w3.org/2000/svg" class="h-5 w-5 mr-2" fill="none"
viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2"
d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" />
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2"
d="M2.458 12C3.732 7.943 7.523 5 12 5c4.477 0 8.268 2.943 9.542 7-1.274 4.057-5.065 7-9.542 7-4.477 0-8.268-2.943-9.542-7z" />
</svg>
{{ saved.title }}
</a>
</td>
<td class="py-4 px-6 border-b"> <td class="py-4 px-6 border-b">
<div class="flex space-x-4"> <div class="flex space-x-4">
<!-- Load Action --> <!-- Plot Action -->
<a href="#" hx-get="{{ url_for('load_sql_query', query_id=saved.id) }}" <a href="#" hx-get="{{ url_for('plot_query', query_id=saved.id) }}"
hx-target="#sql-query" hx-target="#sql-plot-results"
class="flex items-center text-blue-500 hover:text-blue-700 cursor-pointer"> class="flex items-center text-green-500 hover:text-green-700 cursor-pointer"
<!-- Load Icon (Heroicon: Eye) --> hx-trigger="click">
<!-- Plot Icon (Heroicon: Chart Bar) -->
<svg xmlns="http://www.w3.org/2000/svg" class="h-5 w-5 mr-1" fill="none" <svg xmlns="http://www.w3.org/2000/svg" class="h-5 w-5 mr-1" fill="none"
viewBox="0 0 24 24" stroke="currentColor"> viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2"
d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" /> d="M3 3h2l.4 2M7 13h10l4-8H5.4M7 13L5.4 5M7 13l-2 9m5-9v9m4-9v9m5-9v9" />
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2"
d="M2.458 12C3.732 7.943 7.523 5 12 5c4.477 0 8.268 2.943 9.542 7-1.274 4.057-5.065 7-9.542 7-4.477 0-8.268-2.943-9.542-7z" />
</svg> </svg>
Load Plot
</a> </a>
<!-- Delete Action --> <!-- Delete Action -->
@@ -137,5 +151,10 @@
{% endif %} {% endif %}
</div> </div>
<!-- Plot Results Section -->
<div id="sql-plot-results" class="mt-8">
<!-- Plot will be loaded here via htmx -->
</div>
</div> </div>

View File

@@ -2,6 +2,8 @@ import colorsys
from datetime import datetime, date, timedelta from datetime import datetime, date, timedelta
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import plotly.express as px
import plotly.io as pio
def get_workouts(topsets): def get_workouts(topsets):
# Ensure all entries have 'WorkoutId' and 'TopSetId', then sort by 'WorkoutId' and 'TopSetId' # Ensure all entries have 'WorkoutId' and 'TopSetId', then sort by 'WorkoutId' and 'TopSetId'
@@ -440,3 +442,36 @@ def get_distinct_colors(n):
hex_color = '#{:02x}{:02x}{:02x}'.format(int(rgb[0]*255), int(rgb[1]*255), int(rgb[2]*255)) hex_color = '#{:02x}{:02x}{:02x}'.format(int(rgb[0]*255), int(rgb[1]*255), int(rgb[2]*255))
colors.append(hex_color) colors.append(hex_color)
return colors return colors
def generate_plot(df, title):
"""
Analyzes the DataFrame and generates an appropriate Plotly visualization.
Returns the Plotly figure as a div string.
"""
if df.empty:
return "<p>No data available to plot.</p>"
num_columns = len(df.columns)
# Simple logic to decide plot type based on DataFrame structure
if num_columns == 1:
# Single column: perhaps a histogram or bar chart
column = df.columns[0]
if pd.api.types.is_numeric_dtype(df[column]):
fig = px.histogram(df, x=column, title=title)
else:
fig = px.bar(df, x=column, title=title)
elif num_columns == 2:
# Two columns: scatter plot or line chart
col1, col2 = df.columns
if pd.api.types.is_numeric_dtype(df[col1]) and pd.api.types.is_numeric_dtype(df[col2]):
fig = px.scatter(df, x=col1, y=col2, title=title)
else:
fig = px.bar(df, x=col1, y=col2, title=title)
else:
# More than two columns: heatmap or other complex plots
fig = px.imshow(df.corr(), text_auto=True, title=title)
# Convert Plotly figure to HTML div
plot_div = pio.to_html(fig, full_html=False)
return plot_div