Switch to using polars

This commit is contained in:
Peter Stockings
2026-01-29 00:05:25 +11:00
parent dd82f461be
commit 3a0d4531b6
6 changed files with 156 additions and 134 deletions

7
.gitignore vendored
View File

@@ -158,3 +158,10 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
# Exclude backup sql files
**/*.sql
# Exclude experimental juypter notebooks
**/*.ipynb

2
app.py
View File

@@ -16,7 +16,7 @@ from routes.export import export_bp # Import the new export blueprint
from routes.tags import tags_bp # Import the new tags blueprint from routes.tags import tags_bp # Import the new tags blueprint
from routes.programs import programs_bp # Import the new programs blueprint from routes.programs import programs_bp # Import the new programs blueprint
from extensions import db from extensions import db
from utils import convert_str_to_date, generate_plot from utils import convert_str_to_date
from flask_htmx import HTMX from flask_htmx import HTMX
import minify_html import minify_html
import os import os

9
db.py
View File

@@ -5,7 +5,6 @@ from datetime import datetime
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
from urllib.parse import urlparse from urllib.parse import urlparse
from flask import g from flask import g
import pandas as pd
from features.exercises import Exercises from features.exercises import Exercises
from features.people_graphs import PeopleGraphs from features.people_graphs import PeopleGraphs
from features.person_overview import PersonOverview from features.person_overview import PersonOverview
@@ -62,13 +61,7 @@ class DataBase():
return (rv[0] if rv else None) if one else rv return (rv[0] if rv else None) if one else rv
def read_sql_as_df(self, query, params=None):
conn = self.getDB()
try:
df = pd.read_sql(query, conn, params=params)
return df
except Exception as e:
raise e
def get_exercise(self, exercise_id): def get_exercise(self, exercise_id):
exercise = self.execute( exercise = self.execute(

View File

@@ -1,5 +1,5 @@
import pandas as pd import polars as pl
from utils import get_distinct_colors, calculate_estimated_1rm from utils import get_distinct_colors
class PeopleGraphs: class PeopleGraphs:
def __init__(self, db_connection_method): def __init__(self, db_connection_method):
@@ -7,7 +7,7 @@ class PeopleGraphs:
def get(self, selected_people_ids=None, min_date=None, max_date=None, selected_exercise_ids=None): def get(self, selected_people_ids=None, min_date=None, max_date=None, selected_exercise_ids=None):
""" """
Fetch workout topsets, calculate Estimated1RM in Python, Fetch workout topsets, calculate Estimated1RM in Polars,
then generate weekly workout & PR graphs. then generate weekly workout & PR graphs.
""" """
# Build query (no in-SQL 1RM calculation). # Build query (no in-SQL 1RM calculation).
@@ -51,15 +51,41 @@ class PeopleGraphs:
self.get_graph_model("PRs per week", {}) self.get_graph_model("PRs per week", {})
] ]
df = pd.DataFrame(raw_data) # Explicitly specify schema to ensure correct types
schema_overrides = {
"Weight": pl.Float64,
"Repetitions": pl.Int64,
"StartDate": pl.Date # Or pl.Datetime depending on DB driver, but usually Date for dates
}
# Depending on how 'self.execute' returns data (list of dicts usually),
# Polars can infer schema. For robustness with DB types:
try:
df = pl.DataFrame(raw_data, schema_overrides=schema_overrides, infer_schema_length=1000)
except:
# Fallback if specific schema injection fails due to mismatched input types
df = pl.DataFrame(raw_data)
# Calculate Estimated1RM in Python
df['Estimated1RM'] = df.apply( # Calculate Estimated1RM in Polars
lambda row: calculate_estimated_1rm(row["Weight"], row["Repetitions"]), axis=1 # Formula: round((100 * int(weight)) / (101.3 - 2.67123 * repetitions), 0)
# Handle division by zero implicitly by filter or usage?
# The original code only avoided div by zero if Repetitions == 0.
# Polars handles nulls/NaNs usually, but let's replicate logic.
df = df.with_columns(
pl.when(pl.col("Repetitions") == 0)
.then(0)
.otherwise(
(pl.lit(100) * pl.col("Weight")) / (pl.lit(101.3) - pl.lit(2.67123) * pl.col("Repetitions"))
)
.round(0)
.cast(pl.Int64)
.alias("Estimated1RM")
) )
# Build the weekly data models # Build the weekly data models
weekly_counts = self.get_workout_counts(df, period='week') weekly_counts = self.get_workout_counts(df, period='week')
weekly_pr_counts = self.count_prs_over_time(df, period='week') weekly_pr_counts = self.count_prs_over_time(df, period='week')
return [ return [
@@ -67,43 +93,48 @@ class PeopleGraphs:
self.get_graph_model("PRs per week", weekly_pr_counts) self.get_graph_model("PRs per week", weekly_pr_counts)
] ]
def _prepare_period_column(self, df, period='week'): def _prepare_period_column(self, df: pl.DataFrame, period='week'):
""" """
Convert StartDate to datetime and add a Period column Convert StartDate to proper date type and add a Period column
based on 'week' or 'month' as needed. represented as the start date of that period.
""" """
df['StartDate'] = pd.to_datetime(df['StartDate'], errors='coerce') # Ensure StartDate is Date/Datetime
freq = 'W' if period == 'week' else 'M' if df["StartDate"].dtype == pl.String:
df['Period'] = df['StartDate'].dt.to_period(freq) df = df.with_columns(pl.col("StartDate").str.strptime(pl.Date, "%Y-%m-%d")) # Adjust format if needed
elif df["StartDate"].dtype == pl.Object:
# If it's python datetime objects
df = df.with_columns(pl.col("StartDate").cast(pl.Date))
# Truncate to week or month
if period == 'week':
# Polars doesn't have a direct 'to_period' like Pandas.
# We can use dt.truncate("1w") which floors to start of week (Monday usually)
# Postgres/standard weeks usually start Monday.
df = df.with_columns(
pl.col("StartDate").dt.truncate("1w").alias("Period")
)
else: # month
df = df.with_columns(
pl.col("StartDate").dt.truncate("1mo").alias("Period")
)
return df return df
def get_workout_counts(self, df, period='week'): def get_workout_counts(self, df: pl.DataFrame, period='week'):
""" """
Returns a dictionary: Returns workout counts per person per period.
{
person_id: {
'PersonName': 'Alice',
'PRCounts': {
Timestamp('2023-01-02'): 2,
...
}
},
...
}
representing how many workouts each person performed per time period.
""" """
# Make a copy and prepare Period column df = self._prepare_period_column(df, period)
df = self._prepare_period_column(df.copy(), period)
# Ensure Period is string for consistent pivoting
df = df.with_columns(pl.col("Period").dt.strftime("%Y-%m-%d"))
# Count unique workouts per (PersonId, PersonName, Period) # Count unique workouts per (PersonId, PersonName, Period)
grp = ( grp = (
df.groupby(['PersonId', 'PersonName', 'Period'], as_index=False)['WorkoutId'] df.group_by(['PersonId', 'PersonName', 'Period'])
.nunique() .agg(pl.col('WorkoutId').n_unique().alias('Count'))
.rename(columns={'WorkoutId': 'Count'})
) )
# Convert each Period to its start time
grp['Period'] = grp['Period'].apply(lambda p: p.start_time)
return self._pivot_to_graph_dict( return self._pivot_to_graph_dict(
grp, grp,
index_col='PersonId', index_col='PersonId',
@@ -112,46 +143,47 @@ class PeopleGraphs:
value_col='Count' value_col='Count'
) )
def count_prs_over_time(self, df, period='week'): def count_prs_over_time(self, df: pl.DataFrame, period='week'):
""" """
Returns a dictionary: Returns number of PRs hit per person per period.
{
person_id: {
'PersonName': 'Alice',
'PRCounts': {
Timestamp('2023-01-02'): 1,
...
}
},
...
}
representing how many PRs each person hit per time period.
""" """
# Make a copy and prepare Period column df = self._prepare_period_column(df, period)
df = self._prepare_period_column(df.copy(), period)
# Max 1RM per (Person, Exercise, Period) # Max 1RM per (Person, Exercise, Period) - 'PeriodMax'
grouped = ( grouped = (
df.groupby(['PersonId', 'PersonName', 'ExerciseId', 'Period'], as_index=False)['Estimated1RM'] df.group_by(['PersonId', 'PersonName', 'ExerciseId', 'Period'])
.max() .agg(pl.col('Estimated1RM').max().alias('PeriodMax'))
.rename(columns={'Estimated1RM': 'PeriodMax'})
) )
# Sort so we can track "all-time max" up to that row # Sort so we can track "all-time max"
grouped.sort_values(by=['PersonId', 'ExerciseId', 'Period'], inplace=True) grouped = grouped.sort(by=['PersonId', 'ExerciseId', 'Period'])
# For each person & exercise, track the cumulative max (shifted by 1) # Calculate AllTimeMax representing the max UP TO the previous row.
grouped['AllTimeMax'] = grouped.groupby(['PersonId', 'ExerciseId'])['PeriodMax'].cummax().shift(1) grouped = grouped.with_columns(
grouped['IsPR'] = (grouped['PeriodMax'] > grouped['AllTimeMax']).astype(int) pl.col("PeriodMax")
.cum_max()
.over(['PersonId', 'ExerciseId'])
.shift(1)
.alias("AllTimeMax")
)
grouped = grouped.with_columns(
pl.col("AllTimeMax").fill_null(0)
)
grouped = grouped.with_columns(
(pl.col("PeriodMax") > pl.col("AllTimeMax")).cast(pl.Int64).alias("IsPR")
)
# Ensure Period is string for consistent pivoting
grouped = grouped.with_columns(pl.col("Period").dt.strftime("%Y-%m-%d"))
# Sum PRs across exercises for (Person, Period) # Sum PRs across exercises for (Person, Period)
pr_counts = ( pr_counts = (
grouped.groupby(['PersonId', 'PersonName', 'Period'], as_index=False)['IsPR'] grouped.group_by(['PersonId', 'PersonName', 'Period'])
.sum() .agg(pl.col('IsPR').sum().alias('Count'))
.rename(columns={'IsPR': 'Count'})
) )
pr_counts['Period'] = pr_counts['Period'].apply(lambda p: p.start_time)
return self._pivot_to_graph_dict( return self._pivot_to_graph_dict(
pr_counts, pr_counts,
index_col='PersonId', index_col='PersonId',
@@ -160,38 +192,47 @@ class PeopleGraphs:
value_col='Count' value_col='Count'
) )
def _pivot_to_graph_dict(self, df, index_col, name_col, period_col, value_col): def _pivot_to_graph_dict(self, df: pl.DataFrame, index_col, name_col, period_col, value_col):
""" """
Convert [index_col, name_col, period_col, value_col] Convert Polars DataFrame to the nested dict structure expected by visualization.
into a nested dictionary for plotting:
{
person_id: {
'PersonName': <...>,
'PRCounts': {
<timestamp>: <value>,
...
}
},
...
}
""" """
if df.empty: if df.is_empty():
return {} return {}
# Pivot
pivoted = df.pivot( pivoted = df.pivot(
values=value_col,
index=[index_col, name_col], index=[index_col, name_col],
columns=period_col, columns=period_col,
values=value_col aggregate_function="sum"
).fillna(0) ).fill_null(0)
pivoted.reset_index(inplace=True)
rows = pivoted.to_dicts()
date_cols = [c for c in pivoted.columns if c not in [index_col, name_col]]
result = {} result = {}
for _, row in pivoted.iterrows(): for row in rows:
pid = row[index_col] pid = row[index_col]
pname = row[name_col] pname = row[name_col]
# Remaining columns = date -> count
period_counts = row.drop([index_col, name_col]).to_dict() period_counts = {}
for dc in date_cols:
val = row[dc]
# If val is 0, we can skip if sparse behavior is desired,
# but let's keep it to match original behavior exactly if possible.
# Only adding if val > 0 or if we want full zeros?
# Original Pandas pivot keeps all columns (dates) for all rows, filling NaNs (0).
# The iteration later in 'get_graph_model' determines what to plot.
# Parse date string back to date object
try:
d_obj = datetime.strptime(str(dc), "%Y-%m-%d").date()
period_counts[d_obj] = val
except ValueError:
# Should not happen if we controlled the format
print(f"Warning: Could not parse date column {dc}")
period_counts[dc] = val
result[pid] = { result[pid] = {
'PersonName': pname, 'PersonName': pname,
'PRCounts': period_counts 'PRCounts': period_counts
@@ -201,18 +242,8 @@ class PeopleGraphs:
def get_graph_model(self, title, data_dict): def get_graph_model(self, title, data_dict):
""" """
Builds a line-graph model from a dictionary of the form: Builds a line-graph model from the dictionary.
{ This part remains mostly standard Python as it manipulates dicts.
person_id: {
'PersonName': 'Alice',
'PRCounts': {
Timestamp('2023-01-02'): 2,
Timestamp('2023-01-09'): 1,
...
}
},
...
}
""" """
if not data_dict: if not data_dict:
return { return {
@@ -229,8 +260,21 @@ class PeopleGraphs:
all_dates.extend(user_data['PRCounts'].keys()) all_dates.extend(user_data['PRCounts'].keys())
all_values.extend(user_data['PRCounts'].values()) all_values.extend(user_data['PRCounts'].values())
if not all_dates:
return {
'title': title,
'vb_width': 200,
'vb_height': 75,
'plots': []
}
min_date = min(all_dates) min_date = min(all_dates)
max_date = max(all_dates) max_date = max(all_dates)
# Ensure min_date/max_date are comparable types.
# If they are strings vs dates, that's an issue.
# We tried to enforce conversion in _pivot_to_graph_dict.
date_span = max((max_date - min_date).days, 1) date_span = max((max_date - min_date).days, 1)
max_val = max(all_values) max_val = max(all_values)
@@ -269,3 +313,5 @@ class PeopleGraphs:
'vb_height': vb_height, 'vb_height': vb_height,
'plots': plots 'plots': plots
} }
from datetime import datetime

View File

@@ -9,7 +9,6 @@ minify-html==0.10.3
jinja2-fragments==0.3.0 jinja2-fragments==0.3.0
Werkzeug==2.2.2 Werkzeug==2.2.2
numpy==1.19.5 numpy==1.19.5
pandas==1.3.1
python-dotenv==1.0.1 python-dotenv==1.0.1
plotly==5.24.1 plotly==5.24.1
wtforms==3.2.1 wtforms==3.2.1
@@ -17,4 +16,6 @@ flask-wtf==1.2.2
Flask-Login==0.6.3 Flask-Login==0.6.3
Flask-Bcrypt==1.0.1 Flask-Bcrypt==1.0.1
email-validator==2.2.0 email-validator==2.2.0
requests==2.26.0 requests==2.26.0
polars>=0.20.0
pyarrow>=14.0.0

View File

@@ -1,7 +1,7 @@
import colorsys import colorsys
from datetime import datetime, date, timedelta from datetime import datetime, date, timedelta
import numpy as np import numpy as np
import pandas as pd
import plotly.express as px import plotly.express as px
import plotly.io as pio # Keep for now, might remove later if generate_plot is fully replaced import plotly.io as pio # Keep for now, might remove later if generate_plot is fully replaced
import math import math
@@ -110,32 +110,7 @@ def get_distinct_colors(n):
colors.append(hex_color) colors.append(hex_color)
return colors return colors
def generate_plot(df: pd.DataFrame, title: str) -> str:
"""
Analyzes the DataFrame and generates an appropriate Plotly visualization.
Returns the Plotly figure as a div string.
Optimized for speed.
"""
if df.empty:
return "<p>No data available to plot.</p>"
num_columns = len(df.columns)
# Dictionary-based lookup for faster decision-making
plot_funcs = {
1: lambda: px.histogram(df, x=df.columns[0], title=title)
if pd.api.types.is_numeric_dtype(df.iloc[:, 0]) else px.bar(df, x=df.columns[0], title=title),
2: lambda: px.scatter(df, x=df.columns[0], y=df.columns[1], title=title)
if pd.api.types.is_numeric_dtype(df.iloc[:, 0]) and pd.api.types.is_numeric_dtype(df.iloc[:, 1])
else px.bar(df, x=df.columns[0], y=df.columns[1], title=title)
}
# Select plot function based on column count
fig = plot_funcs.get(num_columns, lambda: px.imshow(df.corr(numeric_only=True), text_auto=True, title=title))()
# Use static rendering for speed
return pio.to_html(fig, full_html=False, include_plotlyjs=False, config={'staticPlot': True})
def calculate_estimated_1rm(weight, repetitions): def calculate_estimated_1rm(weight, repetitions):