From c7013e0eacc7e8bbfd7df008f9070ec997404181 Mon Sep 17 00:00:00 2001 From: Peter Stockings Date: Sat, 9 Nov 2024 16:49:08 +1100 Subject: [PATCH] Add ability to plot saved queries using plotly, need to check performance in production, also need to improve generate_plot function --- app.py | 10 ++++- db.py | 9 +++++ requirements.txt | 3 +- .../partials/sql_explorer/sql_query.html | 39 ++++++++++++++----- utils.py | 37 +++++++++++++++++- 5 files changed, 85 insertions(+), 13 deletions(-) diff --git a/app.py b/app.py index 353c761..59247dc 100644 --- a/app.py +++ b/app.py @@ -7,7 +7,7 @@ import jinja_partials from jinja2_fragments import render_block from decorators import validate_person, validate_topset, validate_workout from db import DataBase -from utils import count_prs_over_time, get_people_and_exercise_rep_maxes, convert_str_to_date, get_earliest_and_latest_workout_date, filter_workout_topsets, first_and_last_visible_days_in_month, get_weekly_pr_graph_model, get_workout_counts +from utils import count_prs_over_time, get_people_and_exercise_rep_maxes, convert_str_to_date, get_earliest_and_latest_workout_date, filter_workout_topsets, first_and_last_visible_days_in_month, get_weekly_pr_graph_model, get_workout_counts, generate_plot from flask_htmx import HTMX import minify_html from urllib.parse import quote @@ -526,6 +526,14 @@ def sql_schema(): mermaid_code = db.sql_explorer.generate_mermaid_er(schema_info) return render_template('partials/sql_explorer/schema.html', mermaid_code=mermaid_code) +@app.route("/plot/", methods=['GET']) +def plot_query(query_id): + (title, query) = db.sql_explorer.get_saved_query(query_id) + #(results, columns, error) = db.sql_explorer.execute_sql(query) + results_df = db.read_sql_as_df(query) + plot_div = generate_plot(results_df, title) + return plot_div + @app.teardown_appcontext def closeConnection(exception): db.close_connection() diff --git a/db.py b/db.py index 3b5d23d..898187e 100644 --- a/db.py +++ b/db.py @@ -6,6 +6,7 @@ from datetime import datetime from dateutil.relativedelta import relativedelta from urllib.parse import urlparse from flask import g +import pandas as pd from features.calendar import Calendar from features.exercises import Exercises from features.stats import Stats @@ -60,6 +61,14 @@ class DataBase(): cur.close() return (rv[0] if rv else None) if one else rv + + def read_sql_as_df(self, query, params=None): + conn = self.getDB() + try: + df = pd.read_sql(query, conn, params=params) + return df + except Exception as e: + raise e def get_exercise(self, exercise_id): exercise = self.execute( diff --git a/requirements.txt b/requirements.txt index d336f5c..bec58c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ jinja2-fragments==0.3.0 Werkzeug==2.2.2 numpy==1.19.5 pandas==1.3.1 -python-dotenv==1.0.1 \ No newline at end of file +python-dotenv==1.0.1 +plotly==5.24.1 \ No newline at end of file diff --git a/templates/partials/sql_explorer/sql_query.html b/templates/partials/sql_explorer/sql_query.html index cfa8687..289c942 100644 --- a/templates/partials/sql_explorer/sql_query.html +++ b/templates/partials/sql_explorer/sql_query.html @@ -94,22 +94,36 @@ {% for saved in saved_queries %} - {{ saved.title }} + + + + + + + + + {{ saved.title }} + + + +
+ +
+ \ No newline at end of file diff --git a/utils.py b/utils.py index e428356..c8ef209 100644 --- a/utils.py +++ b/utils.py @@ -2,6 +2,8 @@ import colorsys from datetime import datetime, date, timedelta import numpy as np import pandas as pd +import plotly.express as px +import plotly.io as pio def get_workouts(topsets): # Ensure all entries have 'WorkoutId' and 'TopSetId', then sort by 'WorkoutId' and 'TopSetId' @@ -439,4 +441,37 @@ def get_distinct_colors(n): rgb = colorsys.hls_to_rgb(hue, 0.6, 0.4) # Fixed lightness and saturation hex_color = '#{:02x}{:02x}{:02x}'.format(int(rgb[0]*255), int(rgb[1]*255), int(rgb[2]*255)) colors.append(hex_color) - return colors \ No newline at end of file + return colors + +def generate_plot(df, title): + """ + Analyzes the DataFrame and generates an appropriate Plotly visualization. + Returns the Plotly figure as a div string. + """ + if df.empty: + return "

No data available to plot.

" + + num_columns = len(df.columns) + + # Simple logic to decide plot type based on DataFrame structure + if num_columns == 1: + # Single column: perhaps a histogram or bar chart + column = df.columns[0] + if pd.api.types.is_numeric_dtype(df[column]): + fig = px.histogram(df, x=column, title=title) + else: + fig = px.bar(df, x=column, title=title) + elif num_columns == 2: + # Two columns: scatter plot or line chart + col1, col2 = df.columns + if pd.api.types.is_numeric_dtype(df[col1]) and pd.api.types.is_numeric_dtype(df[col2]): + fig = px.scatter(df, x=col1, y=col2, title=title) + else: + fig = px.bar(df, x=col1, y=col2, title=title) + else: + # More than two columns: heatmap or other complex plots + fig = px.imshow(df.corr(), text_auto=True, title=title) + + # Convert Plotly figure to HTML div + plot_div = pio.to_html(fig, full_html=False) + return plot_div