refactor(sql_explorer): Replace Plotly with SVG rendering for plots

Replaces the Plotly-based graph generation in the SQL Explorer with direct SVG rendering within an HTML template, similar to the exercise progress sparklines.

- Modifies `routes/sql_explorer.py` endpoints (`plot_query`, `plot_unsaved_query`) to fetch raw data instead of using pandas/Plotly.
- Adds `utils.prepare_svg_plot_data` to process raw SQL results, determine plot type (scatter, line, bar, table), normalize data, and prepare it for SVG.
- Creates `templates/partials/sql_explorer/svg_plot.html` to render the SVG plot with axes, ticks, labels, and basic tooltips.
- Removes the `generate_plot` function's usage for SQL Explorer and the direct dependency on Plotly for this feature.
This commit is contained in:
Peter Stockings
2025-04-15 19:34:26 +10:00
parent 51ec18c461
commit e947feb3e3
6 changed files with 636 additions and 152 deletions

View File

@@ -1,39 +1,36 @@
import pandas as pd import pandas as pd
from utils import get_distinct_colors from utils import get_distinct_colors, calculate_estimated_1rm
class PeopleGraphs: class PeopleGraphs:
def __init__(self, db_connection_method): def __init__(self, db_connection_method):
self.execute = db_connection_method self.execute = db_connection_method
def get(self, selected_people_ids=None, min_date=None, max_date=None, selected_exercise_ids=None): def get(self, selected_people_ids=None, min_date=None, max_date=None, selected_exercise_ids=None):
# Base query """
Fetch workout topsets, calculate Estimated1RM in Python,
then generate weekly workout & PR graphs.
"""
# Build query (no in-SQL 1RM calculation).
query = """ query = """
SELECT SELECT
P.person_id AS "PersonId", P.person_id AS "PersonId",
P.name AS "PersonName", P.name AS "PersonName",
W.workout_id AS "WorkoutId", W.workout_id AS "WorkoutId",
W.start_date AS "StartDate", W.start_date AS "StartDate",
T.topset_id AS "TopSetId", T.topset_id AS "TopSetId",
E.exercise_id AS "ExerciseId", E.exercise_id AS "ExerciseId",
E.name AS "ExerciseName", E.name AS "ExerciseName",
T.repetitions AS "Repetitions", T.repetitions AS "Repetitions",
T.weight AS "Weight", T.weight AS "Weight"
round((100 * T.Weight::numeric::integer)/(101.3-2.67123 * T.Repetitions),0)::numeric::integer AS "Estimated1RM"
FROM Person P FROM Person P
LEFT JOIN Workout W ON P.person_id = W.person_id LEFT JOIN Workout W ON P.person_id = W.person_id
LEFT JOIN TopSet T ON W.workout_id = T.workout_id LEFT JOIN TopSet T ON W.workout_id = T.workout_id
LEFT JOIN Exercise E ON T.exercise_id = E.exercise_id LEFT JOIN Exercise E ON T.exercise_id = E.exercise_id
WHERE TRUE WHERE TRUE
""" """
# Parameters for the query
params = [] params = []
# Add optional filters
if selected_people_ids: if selected_people_ids:
placeholders = ", ".join(["%s"] * len(selected_people_ids)) query += f" AND P.person_id IN ({', '.join(['%s'] * len(selected_people_ids))})"
query += f" AND P.person_id IN ({placeholders})"
params.extend(selected_people_ids) params.extend(selected_people_ids)
if min_date: if min_date:
query += " AND W.start_date >= %s" query += " AND W.start_date >= %s"
@@ -42,143 +39,233 @@ class PeopleGraphs:
query += " AND W.start_date <= %s" query += " AND W.start_date <= %s"
params.append(max_date) params.append(max_date)
if selected_exercise_ids: if selected_exercise_ids:
placeholders = ", ".join(["%s"] * len(selected_exercise_ids)) query += f" AND E.exercise_id IN ({', '.join(['%s'] * len(selected_exercise_ids))})"
query += f" AND E.exercise_id IN ({placeholders})"
params.extend(selected_exercise_ids) params.extend(selected_exercise_ids)
# Execute the query # Execute and convert to DataFrame
topsets = self.execute(query, params) raw_data = self.execute(query, params)
if not raw_data:
# Return empty graphs if no data at all
return [
self.get_graph_model("Workouts per week", {}),
self.get_graph_model("PRs per week", {})
]
# Generate graphs df = pd.DataFrame(raw_data)
weekly_counts = self.get_workout_counts(topsets, 'week')
weekly_pr_counts = self.count_prs_over_time(topsets, 'week')
graphs = [self.get_weekly_pr_graph_model('Workouts per week', weekly_counts), self.get_weekly_pr_graph_model('PRs per week', weekly_pr_counts)] # Calculate Estimated1RM in Python
return graphs df['Estimated1RM'] = df.apply(
lambda row: calculate_estimated_1rm(row["Weight"], row["Repetitions"]), axis=1
)
def get_weekly_pr_graph_model(self, title, weekly_pr_data): # Build the weekly data models
# Assuming weekly_pr_data is in the format {1: {"PersonName": "Alice", "PRCounts": {Timestamp('2022-01-01', freq='W-MON'): 0, ...}}, 2: {...}, ...} weekly_counts = self.get_workout_counts(df, period='week')
weekly_pr_counts = self.count_prs_over_time(df, period='week')
# Find the overall date range for all users return [
all_dates = [date for user_data in weekly_pr_data.values() for date in user_data["PRCounts"].keys()] self.get_graph_model("Workouts per week", weekly_counts),
min_date, max_date = min(all_dates), max(all_dates) self.get_graph_model("PRs per week", weekly_pr_counts)
total_span = (max_date - min_date).days or 1 ]
relative_positions = [(date - min_date).days / total_span for date in all_dates]
# Calculate viewBox dimensions def _prepare_period_column(self, df, period='week'):
max_value = max(max(user_data["PRCounts"].values()) for user_data in weekly_pr_data.values()) or 1 """
min_value = 0 Convert StartDate to datetime and add a Period column
value_range = max_value - min_value based on 'week' or 'month' as needed.
vb_width = 200 """
vb_height= 75 df['StartDate'] = pd.to_datetime(df['StartDate'], errors='coerce')
freq = 'W' if period == 'week' else 'M'
df['Period'] = df['StartDate'].dt.to_period(freq)
return df
plots = [] def get_workout_counts(self, df, period='week'):
colors = get_distinct_colors(len(weekly_pr_data.items())) """
for count, (user_id, user_data) in enumerate(weekly_pr_data.items()): Returns a dictionary:
pr_counts = user_data["PRCounts"] {
person_name = user_data["PersonName"] person_id: {
'PersonName': 'Alice',
'PRCounts': {
Timestamp('2023-01-02'): 2,
...
}
},
...
}
representing how many workouts each person performed per time period.
"""
# Make a copy and prepare Period column
df = self._prepare_period_column(df.copy(), period)
values = pr_counts.values() # Count unique workouts per (PersonId, PersonName, Period)
grp = (
df.groupby(['PersonId', 'PersonName', 'Period'], as_index=False)['WorkoutId']
.nunique()
.rename(columns={'WorkoutId': 'Count'})
)
# Convert each Period to its start time
grp['Period'] = grp['Period'].apply(lambda p: p.start_time)
values_scaled = [((value - min_value) / value_range) * vb_height for value in values] return self._pivot_to_graph_dict(
plot_points = list(zip(values_scaled, relative_positions)) grp,
messages = [f'{value} for {person_name} at {date.strftime("%d %b %y")}' for value, date in zip(values, pr_counts.keys())] index_col='PersonId',
plot_labels = zip(values_scaled, relative_positions, messages) name_col='PersonName',
period_col='Period',
value_col='Count'
)
# Create a plot for each user def count_prs_over_time(self, df, period='week'):
plot = { """
'label': person_name, # Use PersonName instead of User ID Returns a dictionary:
'color': colors[count], {
'points': plot_points, person_id: {
'plot_labels': plot_labels 'PersonName': 'Alice',
'PRCounts': {
Timestamp('2023-01-02'): 1,
...
}
},
...
}
representing how many PRs each person hit per time period.
"""
# Make a copy and prepare Period column
df = self._prepare_period_column(df.copy(), period)
# Max 1RM per (Person, Exercise, Period)
grouped = (
df.groupby(['PersonId', 'PersonName', 'ExerciseId', 'Period'], as_index=False)['Estimated1RM']
.max()
.rename(columns={'Estimated1RM': 'PeriodMax'})
)
# Sort so we can track "all-time max" up to that row
grouped.sort_values(by=['PersonId', 'ExerciseId', 'Period'], inplace=True)
# For each person & exercise, track the cumulative max (shifted by 1)
grouped['AllTimeMax'] = grouped.groupby(['PersonId', 'ExerciseId'])['PeriodMax'].cummax().shift(1)
grouped['IsPR'] = (grouped['PeriodMax'] > grouped['AllTimeMax']).astype(int)
# Sum PRs across exercises for (Person, Period)
pr_counts = (
grouped.groupby(['PersonId', 'PersonName', 'Period'], as_index=False)['IsPR']
.sum()
.rename(columns={'IsPR': 'Count'})
)
pr_counts['Period'] = pr_counts['Period'].apply(lambda p: p.start_time)
return self._pivot_to_graph_dict(
pr_counts,
index_col='PersonId',
name_col='PersonName',
period_col='Period',
value_col='Count'
)
def _pivot_to_graph_dict(self, df, index_col, name_col, period_col, value_col):
"""
Convert [index_col, name_col, period_col, value_col]
into a nested dictionary for plotting:
{
person_id: {
'PersonName': <...>,
'PRCounts': {
<timestamp>: <value>,
...
}
},
...
}
"""
if df.empty:
return {}
pivoted = df.pivot(
index=[index_col, name_col],
columns=period_col,
values=value_col
).fillna(0)
pivoted.reset_index(inplace=True)
result = {}
for _, row in pivoted.iterrows():
pid = row[index_col]
pname = row[name_col]
# Remaining columns = date -> count
period_counts = row.drop([index_col, name_col]).to_dict()
result[pid] = {
'PersonName': pname,
'PRCounts': period_counts
} }
plots.append(plot)
# Return workout data with SVG dimensions and data points return result
def get_graph_model(self, title, data_dict):
"""
Builds a line-graph model from a dictionary of the form:
{
person_id: {
'PersonName': 'Alice',
'PRCounts': {
Timestamp('2023-01-02'): 2,
Timestamp('2023-01-09'): 1,
...
}
},
...
}
"""
if not data_dict:
return {
'title': title,
'vb_width': 200,
'vb_height': 75,
'plots': []
}
# Gather all dates & values
all_dates = []
all_values = []
for user_data in data_dict.values():
all_dates.extend(user_data['PRCounts'].keys())
all_values.extend(user_data['PRCounts'].values())
min_date = min(all_dates)
max_date = max(all_dates)
date_span = max((max_date - min_date).days, 1)
max_val = max(all_values)
min_val = 0
val_range = max_val - min_val if max_val != min_val else 1
vb_width, vb_height = 200, 75
colors = get_distinct_colors(len(data_dict))
plots = []
for i, (pid, user_data) in enumerate(data_dict.items()):
name = user_data['PersonName']
pr_counts = user_data['PRCounts']
# Sort by date so points are in chronological order
sorted_pr = sorted(pr_counts.items(), key=lambda x: x[0])
points = []
labels = []
for d, val in sorted_pr:
# Scale x,y to fit [0..1], then we multiply y by vb_height
x = (d - min_date).days / date_span
y = (val - min_val) / val_range * vb_height
points.append((y, x))
labels.append((y, x, f'{val} for {name} at {d.strftime("%d %b %y")}'))
plots.append({
'label': name,
'color': colors[i],
'points': points,
'plot_labels': labels
})
return { return {
'title': title, 'title': title,
'vb_width': vb_width, 'vb_width': vb_width,
'vb_height': vb_height, 'vb_height': vb_height,
'plots': plots 'plots': plots
} }
def get_workout_counts(self, workouts, period='week'):
df = pd.DataFrame(workouts)
# Convert 'StartDate' to datetime and set period
df['StartDate'] = pd.to_datetime(df['StartDate'])
df['Period'] = df['StartDate'].dt.to_period('W' if period == 'week' else 'M')
# Group by PersonId, Period and count unique workouts
workout_counts = df.groupby(['PersonId', 'Period'])['WorkoutId'].nunique().reset_index()
# Convert 'Period' to timestamp using the start date of the period
workout_counts['Period'] = workout_counts['Period'].apply(lambda x: x.start_time)
# Pivot the result to get periods as columns
workout_counts_pivot = workout_counts.pivot(index='PersonId', columns='Period', values='WorkoutId').fillna(0)
# Include person names
names = df[['PersonId', 'PersonName']].drop_duplicates().set_index('PersonId')
workout_counts_final = names.join(workout_counts_pivot, how='left').fillna(0)
# Convert DataFrame to dictionary
result = workout_counts_final.reset_index().to_dict('records')
# Reformat the dictionary to desired structure
formatted_result = {}
for record in result:
person_id = record.pop('PersonId')
person_name = record.pop('PersonName')
pr_counts = {k: v for k, v in record.items()}
formatted_result[person_id] = {'PersonName': person_name, 'PRCounts': pr_counts}
return formatted_result
def count_prs_over_time(self, workouts, period='week'):
df = pd.DataFrame(workouts)
# Convert 'StartDate' to datetime
df['StartDate'] = pd.to_datetime(df['StartDate'])
# Set period as week or month
df['Period'] = df['StartDate'].dt.to_period('W' if period == 'week' else 'M')
# Group by Person, Exercise, and Period to find max Estimated1RM in each period
period_max = df.groupby(['PersonId', 'ExerciseId', 'Period'])['Estimated1RM'].max().reset_index()
# Determine all-time max Estimated1RM up to the start of each period
period_max['AllTimeMax'] = period_max.groupby(['PersonId', 'ExerciseId'])['Estimated1RM'].cummax().shift(1)
# Identify PRs as entries where the period's max Estimated1RM exceeds the all-time max
period_max['IsPR'] = period_max['Estimated1RM'] > period_max['AllTimeMax']
# Count PRs in each period for each person
pr_counts = period_max.groupby(['PersonId', 'Period'])['IsPR'].sum().reset_index()
# Convert 'Period' to timestamp using the start date of the period
pr_counts['Period'] = pr_counts['Period'].apply(lambda x: x.start_time)
# Pivot table to get the desired output format
output = pr_counts.pivot(index='PersonId', columns='Period', values='IsPR').fillna(0)
# Convert only the PR count columns to integers
for col in output.columns:
output[col] = output[col].astype(int)
# Merge with names and convert to desired format
names = df[['PersonId', 'PersonName']].drop_duplicates().set_index('PersonId')
output = names.join(output, how='left').fillna(0)
# Reset the index to bring 'PersonId' back as a column
output.reset_index(inplace=True)
# Convert to the final dictionary format with PRCounts nested
result = {}
for index, row in output.iterrows():
person_id = row['PersonId']
person_name = row['PersonName']
pr_counts = row.drop(['PersonId', 'PersonName']).to_dict()
result[person_id] = {"PersonName": person_name, "PRCounts": pr_counts}
return result

View File

@@ -5,7 +5,7 @@ from flask import Blueprint, render_template, request, current_app, jsonify
from jinja2_fragments import render_block from jinja2_fragments import render_block
from flask_htmx import HTMX from flask_htmx import HTMX
from extensions import db from extensions import db
from utils import generate_plot from utils import prepare_svg_plot_data # Will be created for SVG data prep
sql_explorer_bp = Blueprint('sql_explorer', __name__, url_prefix='/sql') sql_explorer_bp = Blueprint('sql_explorer', __name__, url_prefix='/sql')
htmx = HTMX() htmx = HTMX()
@@ -281,17 +281,47 @@ def sql_schema():
def plot_query(query_id): def plot_query(query_id):
(title, query) = _get_saved_query(query_id) (title, query) = _get_saved_query(query_id)
if not query: return "Query not found", 404 if not query: return "Query not found", 404
results_df = db.read_sql_as_df(query) # Fetch raw results instead of DataFrame
plot_div = generate_plot(results_df, title) (results, columns, error) = _execute_sql(query)
return plot_div if error:
# Return an HTML snippet indicating the error
return f'&lt;div class="p-4 text-red-700 bg-red-100 border border-red-400 rounded"&gt;Error executing query: {error}&lt;/div&gt;', 400
if not results:
# Return an HTML snippet indicating no data
return '&lt;div class="p-4 text-yellow-700 bg-yellow-100 border border-yellow-400 rounded"&gt;No data returned by query.&lt;/div&gt;'
try:
# Prepare data for SVG plotting (function to be created in utils.py)
plot_data = prepare_svg_plot_data(results, columns, title)
# Render the new SVG template
return render_template('partials/sql_explorer/svg_plot.html', **plot_data)
except Exception as e:
current_app.logger.error(f"Error preparing SVG plot data: {e}")
# Return an HTML snippet indicating a processing error
return f'&lt;div class="p-4 text-red-700 bg-red-100 border border-red-400 rounded"&gt;Error preparing plot data: {e}&lt;/div&gt;', 500
@sql_explorer_bp.route("/plot/show", methods=['POST']) @sql_explorer_bp.route("/plot/show", methods=['POST'])
def plot_unsaved_query(): def plot_unsaved_query():
query = request.form.get('query') query = request.form.get('query')
title = request.form.get('title') title = request.form.get('title', 'SQL Query Plot') # Add default title
results_df = db.read_sql_as_df(query) # Fetch raw results instead of DataFrame
plot_div = generate_plot(results_df, title) (results, columns, error) = _execute_sql(query)
return plot_div if error:
# Return an HTML snippet indicating the error
return f'&lt;div class="p-4 text-red-700 bg-red-100 border border-red-400 rounded"&gt;Error executing query: {error}&lt;/div&gt;', 400
if not results:
# Return an HTML snippet indicating no data
return '&lt;div class="p-4 text-yellow-700 bg-yellow-100 border border-yellow-400 rounded"&gt;No data returned by query.&lt;/div&gt;'
try:
# Prepare data for SVG plotting (function to be created in utils.py)
plot_data = prepare_svg_plot_data(results, columns, title)
# Render the new SVG template
return render_template('partials/sql_explorer/svg_plot.html', **plot_data)
except Exception as e:
current_app.logger.error(f"Error preparing SVG plot data: {e}")
# Return an HTML snippet indicating a processing error
return f'&lt;div class="p-4 text-red-700 bg-red-100 border border-red-400 rounded"&gt;Error preparing plot data: {e}&lt;/div&gt;', 500
@sql_explorer_bp.route("/generate_sql", methods=['POST']) @sql_explorer_bp.route("/generate_sql", methods=['POST'])
def generate_sql(): def generate_sql():

View File

@@ -20,7 +20,6 @@
<script src="/static/js/sweetalert2@11.js" defer></script> <script src="/static/js/sweetalert2@11.js" defer></script>
<!-- Mermaid --> <!-- Mermaid -->
<script src="/static/js/mermaid.min.js"></script> <script src="/static/js/mermaid.min.js"></script>
<script src="/static/js/plotly-2.35.2.min.js" defer></script>
<script> <script>
// Initialize Mermaid with startOnLoad set to false // Initialize Mermaid with startOnLoad set to false
mermaid.initialize({ mermaid.initialize({

View File

@@ -10,6 +10,23 @@
<div class="prose max-w-none"> <div class="prose max-w-none">
<p>Updates and changes to the site will be documented here, with the most recent changes listed first.</p> <p>Updates and changes to the site will be documented here, with the most recent changes listed first.</p>
<!-- New Entry for SQL Explorer SVG Plots -->
<hr class="my-6">
<h2 class="text-xl font-semibold mb-2">April 15, 2025</h2>
<ul class="list-disc pl-5 space-y-1">
<li>Replaced Plotly graph generation in SQL Explorer with direct SVG rendering:</li>
<ul class="list-disc pl-5 space-y-1">
<li>Updated `plot_query` and `plot_unsaved_query` endpoints in `routes/sql_explorer.py` to fetch raw
data.</li>
<li>Added `prepare_svg_plot_data` function in `utils.py` to process data and determine plot type
(scatter, line, bar, or table fallback).</li>
<li>Created `templates/partials/sql_explorer/svg_plot.html` template to render SVG plots with axes
and basic tooltips.</li>
<li>Removes the need for Plotly library for SQL Explorer plots, reducing dependencies and
potentially improving load times.</li>
</ul>
</ul>
<!-- New Entry for Dismissible Exercise Graph --> <!-- New Entry for Dismissible Exercise Graph -->
<hr class="my-6"> <hr class="my-6">
<h2 class="text-xl font-semibold mb-2">April 13, 2025</h2> <h2 class="text-xl font-semibold mb-2">April 13, 2025</h2>

View File

@@ -0,0 +1,125 @@
{# Basic SVG Plot Template for SQL Explorer #}
{% set unique_id = range(1000, 9999) | random %} {# Simple unique ID for elements #}
<div class="sql-plot-container p-4 border rounded bg-white shadow" id="sql-plot-{{ unique_id }}">
<h4 class="text-lg font-semibold text-gray-700 text-center mb-2">{{ title }}</h4>
{% if plot_type == 'table' %}
{# Fallback to rendering a table if plot type is not supported or data is unsuitable #}
<div class="overflow-x-auto max-h-96"> {# Limit height and allow scroll #}
<table class="min-w-full divide-y divide-gray-200 text-sm">
<thead class="bg-gray-50">
<tr>
{% for col in original_columns %}
<th scope="col" class="px-4 py-2 text-left font-medium text-gray-500 uppercase tracking-wider">
{{ col }}
</th>
{% endfor %}
</tr>
</thead>
<tbody class="bg-white divide-y divide-gray-200">
{% for row in original_results %}
<tr>
{% for col in original_columns %}
<td class="px-4 py-2 whitespace-nowrap">
{{ row[col] }}
</td>
{% endfor %}
</tr>
{% else %}
<tr>
<td colspan="{{ original_columns|length }}" class="px-4 py-2 text-center text-gray-500">No data
available.</td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
{% else %}
{# SVG Plot Area #}
<div class="relative" _="
on mouseover from .plot-point-{{ unique_id }}
get event.target @data-tooltip
if it
put it into #tooltip-{{ unique_id }}
remove .hidden from #tooltip-{{ unique_id }}
end
on mouseout from .plot-point-{{ unique_id }}
add .hidden to #tooltip-{{ unique_id }}
">
{# Tooltip Element #}
<div id="tooltip-{{ unique_id }}"
class="absolute top-0 left-0 hidden bg-gray-800 text-white text-xs p-1 rounded shadow-lg z-10 pointer-events-none">
Tooltip
</div>
<svg viewBox="0 0 {{ vb_width }} {{ vb_height }}" preserveAspectRatio="xMidYMid meet" class="w-full h-auto">
{# Draw Axes #}
<g class="axes" stroke="#6b7280" stroke-width="1">
{# Y Axis #}
<line x1="{{ margin.left }}" y1="{{ margin.top }}" x2="{{ margin.left }}"
y2="{{ vb_height - margin.bottom }}"></line>
{# X Axis #}
<line x1="{{ margin.left }}" y1="{{ vb_height - margin.bottom }}" x2="{{ vb_width - margin.right }}"
y2="{{ vb_height - margin.bottom }}"></line>
</g>
{# Draw Ticks and Grid Lines #}
<g class="ticks" font-size="10" fill="#6b7280" text-anchor="middle">
{# Y Ticks #}
{% for tick in y_ticks %}
<line x1="{{ margin.left - 5 }}" y1="{{ tick.position }}" x2="{{ vb_width - margin.right }}"
y2="{{ tick.position }}" stroke="#e5e7eb" stroke-width="0.5"></line> {# Grid line #}
<text x="{{ margin.left - 8 }}" y="{{ tick.position + 3 }}" text-anchor="end">{{ tick.label }}</text>
{% endfor %}
{# X Ticks #}
{% for tick in x_ticks %}
<line x1="{{ tick.position }}" y1="{{ margin.top }}" x2="{{ tick.position }}"
y2="{{ vb_height - margin.bottom + 5 }}" stroke="#e5e7eb" stroke-width="0.5"></line> {# Grid line #}
<text x="{{ tick.position }}" y="{{ vb_height - margin.bottom + 15 }}">{{ tick.label }}</text>
{% endfor %}
</g>
{# Draw Axis Labels #}
<g class="axis-labels" font-size="12" fill="#374151" text-anchor="middle">
{# Y Axis Label #}
<text
transform="translate({{ margin.left / 2 - 5 }}, {{ (vb_height - margin.bottom + margin.top) / 2 }}) rotate(-90)">{{
y_axis_label }}</text>
{# X Axis Label #}
<text x="{{ (vb_width - margin.right + margin.left) / 2 }}"
y="{{ vb_height - margin.bottom / 2 + 10 }}">{{ x_axis_label }}</text>
</g>
{# Plot Data Points/Bars #}
{% for plot in plots %}
<g class="plot-series-{{ loop.index }}" fill="{{ plot.color }}" stroke="{{ plot.color }}">
{% if plot_type == 'scatter' %}
{% for p in plot.points %}
<circle cx="{{ p.x }}" cy="{{ p.y }}" r="3" class="plot-point-{{ unique_id }}"
data-tooltip="{{ p.original | tojson | escape }}" />
{% endfor %}
{% elif plot_type == 'line' %}
<path
d="{% for p in plot.points %}{% if loop.first %}M{% else %}L{% endif %}{{ p.x }} {{ p.y }}{% endfor %}"
fill="none" stroke-width="1.5" />
{% for p in plot.points %}
<circle cx="{{ p.x }}" cy="{{ p.y }}" r="2.5" class="plot-point-{{ unique_id }}"
data-tooltip="{{ p.original | tojson | escape }}" />
{% endfor %}
{% elif plot_type == 'bar' %}
{% set bar_w = bar_width | default(10) %}
{% for p in plot.points %}
<rect x="{{ p.x - bar_w / 2 }}" y="{{ p.y }}" width="{{ bar_w }}"
height="{{ (vb_height - margin.bottom) - p.y }}" stroke-width="0.5"
class="plot-point-{{ unique_id }}" data-tooltip="{{ p.original | tojson | escape }}" />
{% endfor %}
{% endif %}
</g>
{% endfor %}
</svg>
</div>
{% endif %}
</div>

228
utils.py
View File

@@ -3,7 +3,9 @@ from datetime import datetime, date, timedelta
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import plotly.express as px import plotly.express as px
import plotly.io as pio import plotly.io as pio # Keep for now, might remove later if generate_plot is fully replaced
import math
from decimal import Decimal
def convert_str_to_date(date_str, format='%Y-%m-%d'): def convert_str_to_date(date_str, format='%Y-%m-%d'):
try: try:
@@ -142,3 +144,227 @@ def calculate_estimated_1rm(weight, repetitions):
return 0 return 0
estimated_1rm = round((100 * int(weight)) / (101.3 - 2.67123 * repetitions), 0) estimated_1rm = round((100 * int(weight)) / (101.3 - 2.67123 * repetitions), 0)
return int(estimated_1rm) return int(estimated_1rm)
def _is_numeric(val):
"""Check if a value is numeric (int, float, Decimal)."""
return isinstance(val, (int, float, Decimal))
def _is_datetime(val):
"""Check if a value is a date or datetime object."""
return isinstance(val, (date, datetime))
def _get_column_type(results, column_name):
"""Determine the effective type of a column (numeric, datetime, categorical)."""
numeric_count = 0
datetime_count = 0
total_count = 0
for row in results:
val = row.get(column_name)
if val is not None:
total_count += 1
if _is_numeric(val):
numeric_count += 1
elif _is_datetime(val):
datetime_count += 1
if total_count == 0: return 'categorical' # Default if all null or empty
if numeric_count / total_count > 0.8: return 'numeric' # Allow some non-numeric noise
if datetime_count / total_count > 0.8: return 'datetime'
return 'categorical'
def _normalize_value(value, min_val, range_val, target_max):
"""Normalize a value to a target range (e.g., SVG coordinate)."""
if range_val == 0: return target_max / 2 # Avoid division by zero, place in middle
return ((value - min_val) / range_val) * target_max
def prepare_svg_plot_data(results, columns, title):
"""
Prepares data from raw SQL results for SVG plotting.
Determines plot type and scales data.
"""
if not results:
raise ValueError("No data provided for plotting.")
num_columns = len(columns)
plot_type = 'table' # Default if no suitable plot found
plot_data = {}
x_col, y_col = None, None
x_type, y_type = None, None
# --- Determine Plot Type and Columns ---
if num_columns == 1:
x_col = columns[0]
x_type = _get_column_type(results, x_col)
if x_type == 'numeric':
plot_type = 'histogram'
else:
plot_type = 'bar_count' # Bar chart of value counts
elif num_columns >= 2:
# Prioritize common patterns
x_col, y_col = columns[0], columns[1]
x_type = _get_column_type(results, x_col)
y_type = _get_column_type(results, y_col)
if x_type == 'numeric' and y_type == 'numeric':
plot_type = 'scatter'
elif x_type == 'datetime' and y_type == 'numeric':
plot_type = 'line' # Treat datetime as numeric for position
elif x_type == 'categorical' and y_type == 'numeric':
plot_type = 'bar'
elif x_type == 'numeric' and y_type == 'categorical':
# Could do horizontal bar, but let's stick to vertical for now
plot_type = 'bar' # Treat numeric as category label, categorical as value (count?) - less common
# Or maybe swap? Let's assume categorical X, numeric Y is more likely intended
x_col, y_col = columns[1], columns[0] # Try swapping
x_type, y_type = y_type, x_type
if not (x_type == 'categorical' and y_type == 'numeric'):
plot_type = 'table' # Revert if swap didn't help
else: # Other combinations (datetime/cat, cat/cat, etc.) default to table
plot_type = 'table'
# --- Basic SVG Setup ---
vb_width = 500
vb_height = 300
margin = {'top': 20, 'right': 20, 'bottom': 50, 'left': 60} # Increased bottom/left for labels/axes
draw_width = vb_width - margin['left'] - margin['right']
draw_height = vb_height - margin['top'] - margin['bottom']
plot_data = {
'title': title,
'plot_type': plot_type,
'vb_width': vb_width,
'vb_height': vb_height,
'margin': margin,
'draw_width': draw_width,
'draw_height': draw_height,
'x_axis_label': x_col or '',
'y_axis_label': y_col or '',
'plots': [],
'x_ticks': [],
'y_ticks': [],
'original_results': results, # Keep original for table fallback
'original_columns': columns
}
if plot_type == 'table':
return plot_data # No further processing needed for table fallback
# --- Data Extraction and Scaling (Specific to Plot Type) ---
points = []
x_values_raw = []
y_values_raw = []
# Extract relevant data, handling potential type issues
for row in results:
x_val_raw = row.get(x_col)
y_val_raw = row.get(y_col)
# Convert datetimes to numeric representation (e.g., days since min date)
if x_type == 'datetime':
x_values_raw.append(x_val_raw) # Keep original dates for range calculation
elif _is_numeric(x_val_raw):
x_values_raw.append(float(x_val_raw)) # Convert Decimal to float
# Add handling for categorical X if needed (e.g., bar chart)
if y_type == 'numeric':
if _is_numeric(y_val_raw):
y_values_raw.append(float(y_val_raw))
else:
y_values_raw.append(None) # Mark non-numeric Y as None
# Add handling for categorical Y if needed
if not x_values_raw or not y_values_raw:
plot_data['plot_type'] = 'table' # Fallback if essential data is missing
return plot_data
# Calculate ranges (handle datetime separately)
if x_type == 'datetime':
valid_dates = [d for d in x_values_raw if d is not None]
if not valid_dates:
plot_data['plot_type'] = 'table'; return plot_data
min_x_dt, max_x_dt = min(valid_dates), max(valid_dates)
# Convert dates to days since min_date for numerical scaling
total_days = (max_x_dt - min_x_dt).days
x_values_numeric = [(d - min_x_dt).days if d is not None else None for d in x_values_raw]
min_x, max_x = 0, total_days
else: # Numeric or Categorical (treat categorical index as numeric for now)
valid_x = [x for x in x_values_raw if x is not None]
if not valid_x:
plot_data['plot_type'] = 'table'; return plot_data
min_x, max_x = min(valid_x), max(valid_x)
x_values_numeric = x_values_raw # Already numeric (or will be treated as such)
valid_y = [y for y in y_values_raw if y is not None]
if not valid_y:
plot_data['plot_type'] = 'table'; return plot_data
min_y, max_y = min(valid_y), max(valid_y)
range_x = max_x - min_x
range_y = max_y - min_y
# Scale points
for i, row in enumerate(results):
x_num = x_values_numeric[i]
y_num = y_values_raw[i] # Use original list which might have None
if x_num is None or y_num is None: continue # Skip points with missing essential data
# Scale X to drawing width, Y to drawing height (inverted Y for SVG)
scaled_x = margin['left'] + _normalize_value(x_num, min_x, range_x, draw_width)
scaled_y = margin['top'] + draw_height - _normalize_value(y_num, min_y, range_y, draw_height)
points.append({
'x': scaled_x,
'y': scaled_y,
'original': row # Store original row data for tooltips
})
# --- Generate Ticks ---
num_ticks = 5 # Desired number of ticks
# X Ticks
x_ticks = []
if range_x >= 0:
step_x = (max_x - min_x) / (num_ticks -1) if num_ticks > 1 and range_x > 0 else 0
for i in range(num_ticks):
tick_val_raw = min_x + i * step_x
tick_pos = margin['left'] + _normalize_value(tick_val_raw, min_x, range_x, draw_width)
label = ""
if x_type == 'datetime':
tick_date = min_x_dt + timedelta(days=tick_val_raw)
label = tick_date.strftime('%Y-%m-%d') # Format date label
else: # Numeric
label = f"{tick_val_raw:.1f}" if isinstance(tick_val_raw, float) else str(tick_val_raw)
x_ticks.append({'value': tick_val_raw, 'label': label, 'position': tick_pos})
# Y Ticks
y_ticks = []
if range_y >= 0:
step_y = (max_y - min_y) / (num_ticks - 1) if num_ticks > 1 and range_y > 0 else 0
for i in range(num_ticks):
tick_val = min_y + i * step_y
tick_pos = margin['top'] + draw_height - _normalize_value(tick_val, min_y, range_y, draw_height)
label = f"{tick_val:.1f}" if isinstance(tick_val, float) else str(tick_val)
y_ticks.append({'value': tick_val, 'label': label, 'position': tick_pos})
# --- Finalize Plot Data ---
# For now, put all points into one series
plot_data['plots'].append({
'label': f'{y_col} vs {x_col}',
'color': '#388fed', # Default color
'points': points
})
plot_data['x_ticks'] = x_ticks
plot_data['y_ticks'] = y_ticks
# Add specific adjustments for plot types if needed (e.g., bar width)
if plot_type == 'bar':
# Calculate bar width based on number of bars/categories
# This needs more refinement based on how categorical X is handled
plot_data['bar_width'] = draw_width / len(points) * 0.8 if points else 10
return plot_data