Switch to using polars
This commit is contained in:
7
.gitignore
vendored
7
.gitignore
vendored
@@ -158,3 +158,10 @@ cython_debug/
|
|||||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
|
|
||||||
|
# Exclude backup sql files
|
||||||
|
**/*.sql
|
||||||
|
|
||||||
|
# Exclude experimental juypter notebooks
|
||||||
|
**/*.ipynb
|
||||||
2
app.py
2
app.py
@@ -16,7 +16,7 @@ from routes.export import export_bp # Import the new export blueprint
|
|||||||
from routes.tags import tags_bp # Import the new tags blueprint
|
from routes.tags import tags_bp # Import the new tags blueprint
|
||||||
from routes.programs import programs_bp # Import the new programs blueprint
|
from routes.programs import programs_bp # Import the new programs blueprint
|
||||||
from extensions import db
|
from extensions import db
|
||||||
from utils import convert_str_to_date, generate_plot
|
from utils import convert_str_to_date
|
||||||
from flask_htmx import HTMX
|
from flask_htmx import HTMX
|
||||||
import minify_html
|
import minify_html
|
||||||
import os
|
import os
|
||||||
|
|||||||
9
db.py
9
db.py
@@ -5,7 +5,6 @@ from datetime import datetime
|
|||||||
from dateutil.relativedelta import relativedelta
|
from dateutil.relativedelta import relativedelta
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from flask import g
|
from flask import g
|
||||||
import pandas as pd
|
|
||||||
from features.exercises import Exercises
|
from features.exercises import Exercises
|
||||||
from features.people_graphs import PeopleGraphs
|
from features.people_graphs import PeopleGraphs
|
||||||
from features.person_overview import PersonOverview
|
from features.person_overview import PersonOverview
|
||||||
@@ -62,13 +61,7 @@ class DataBase():
|
|||||||
|
|
||||||
return (rv[0] if rv else None) if one else rv
|
return (rv[0] if rv else None) if one else rv
|
||||||
|
|
||||||
def read_sql_as_df(self, query, params=None):
|
|
||||||
conn = self.getDB()
|
|
||||||
try:
|
|
||||||
df = pd.read_sql(query, conn, params=params)
|
|
||||||
return df
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_exercise(self, exercise_id):
|
def get_exercise(self, exercise_id):
|
||||||
exercise = self.execute(
|
exercise = self.execute(
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import pandas as pd
|
import polars as pl
|
||||||
from utils import get_distinct_colors, calculate_estimated_1rm
|
from utils import get_distinct_colors
|
||||||
|
|
||||||
class PeopleGraphs:
|
class PeopleGraphs:
|
||||||
def __init__(self, db_connection_method):
|
def __init__(self, db_connection_method):
|
||||||
@@ -7,7 +7,7 @@ class PeopleGraphs:
|
|||||||
|
|
||||||
def get(self, selected_people_ids=None, min_date=None, max_date=None, selected_exercise_ids=None):
|
def get(self, selected_people_ids=None, min_date=None, max_date=None, selected_exercise_ids=None):
|
||||||
"""
|
"""
|
||||||
Fetch workout topsets, calculate Estimated1RM in Python,
|
Fetch workout topsets, calculate Estimated1RM in Polars,
|
||||||
then generate weekly workout & PR graphs.
|
then generate weekly workout & PR graphs.
|
||||||
"""
|
"""
|
||||||
# Build query (no in-SQL 1RM calculation).
|
# Build query (no in-SQL 1RM calculation).
|
||||||
@@ -51,15 +51,41 @@ class PeopleGraphs:
|
|||||||
self.get_graph_model("PRs per week", {})
|
self.get_graph_model("PRs per week", {})
|
||||||
]
|
]
|
||||||
|
|
||||||
df = pd.DataFrame(raw_data)
|
# Explicitly specify schema to ensure correct types
|
||||||
|
schema_overrides = {
|
||||||
|
"Weight": pl.Float64,
|
||||||
|
"Repetitions": pl.Int64,
|
||||||
|
"StartDate": pl.Date # Or pl.Datetime depending on DB driver, but usually Date for dates
|
||||||
|
}
|
||||||
|
|
||||||
|
# Depending on how 'self.execute' returns data (list of dicts usually),
|
||||||
|
# Polars can infer schema. For robustness with DB types:
|
||||||
|
try:
|
||||||
|
df = pl.DataFrame(raw_data, schema_overrides=schema_overrides, infer_schema_length=1000)
|
||||||
|
except:
|
||||||
|
# Fallback if specific schema injection fails due to mismatched input types
|
||||||
|
df = pl.DataFrame(raw_data)
|
||||||
|
|
||||||
# Calculate Estimated1RM in Python
|
|
||||||
df['Estimated1RM'] = df.apply(
|
# Calculate Estimated1RM in Polars
|
||||||
lambda row: calculate_estimated_1rm(row["Weight"], row["Repetitions"]), axis=1
|
# Formula: round((100 * int(weight)) / (101.3 - 2.67123 * repetitions), 0)
|
||||||
|
# Handle division by zero implicitly by filter or usage?
|
||||||
|
# The original code only avoided div by zero if Repetitions == 0.
|
||||||
|
# Polars handles nulls/NaNs usually, but let's replicate logic.
|
||||||
|
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.when(pl.col("Repetitions") == 0)
|
||||||
|
.then(0)
|
||||||
|
.otherwise(
|
||||||
|
(pl.lit(100) * pl.col("Weight")) / (pl.lit(101.3) - pl.lit(2.67123) * pl.col("Repetitions"))
|
||||||
|
)
|
||||||
|
.round(0)
|
||||||
|
.cast(pl.Int64)
|
||||||
|
.alias("Estimated1RM")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build the weekly data models
|
# Build the weekly data models
|
||||||
weekly_counts = self.get_workout_counts(df, period='week')
|
weekly_counts = self.get_workout_counts(df, period='week')
|
||||||
weekly_pr_counts = self.count_prs_over_time(df, period='week')
|
weekly_pr_counts = self.count_prs_over_time(df, period='week')
|
||||||
|
|
||||||
return [
|
return [
|
||||||
@@ -67,43 +93,48 @@ class PeopleGraphs:
|
|||||||
self.get_graph_model("PRs per week", weekly_pr_counts)
|
self.get_graph_model("PRs per week", weekly_pr_counts)
|
||||||
]
|
]
|
||||||
|
|
||||||
def _prepare_period_column(self, df, period='week'):
|
def _prepare_period_column(self, df: pl.DataFrame, period='week'):
|
||||||
"""
|
"""
|
||||||
Convert StartDate to datetime and add a Period column
|
Convert StartDate to proper date type and add a Period column
|
||||||
based on 'week' or 'month' as needed.
|
represented as the start date of that period.
|
||||||
"""
|
"""
|
||||||
df['StartDate'] = pd.to_datetime(df['StartDate'], errors='coerce')
|
# Ensure StartDate is Date/Datetime
|
||||||
freq = 'W' if period == 'week' else 'M'
|
if df["StartDate"].dtype == pl.String:
|
||||||
df['Period'] = df['StartDate'].dt.to_period(freq)
|
df = df.with_columns(pl.col("StartDate").str.strptime(pl.Date, "%Y-%m-%d")) # Adjust format if needed
|
||||||
|
elif df["StartDate"].dtype == pl.Object:
|
||||||
|
# If it's python datetime objects
|
||||||
|
df = df.with_columns(pl.col("StartDate").cast(pl.Date))
|
||||||
|
|
||||||
|
# Truncate to week or month
|
||||||
|
if period == 'week':
|
||||||
|
# Polars doesn't have a direct 'to_period' like Pandas.
|
||||||
|
# We can use dt.truncate("1w") which floors to start of week (Monday usually)
|
||||||
|
# Postgres/standard weeks usually start Monday.
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.col("StartDate").dt.truncate("1w").alias("Period")
|
||||||
|
)
|
||||||
|
else: # month
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.col("StartDate").dt.truncate("1mo").alias("Period")
|
||||||
|
)
|
||||||
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
def get_workout_counts(self, df, period='week'):
|
def get_workout_counts(self, df: pl.DataFrame, period='week'):
|
||||||
"""
|
"""
|
||||||
Returns a dictionary:
|
Returns workout counts per person per period.
|
||||||
{
|
|
||||||
person_id: {
|
|
||||||
'PersonName': 'Alice',
|
|
||||||
'PRCounts': {
|
|
||||||
Timestamp('2023-01-02'): 2,
|
|
||||||
...
|
|
||||||
}
|
|
||||||
},
|
|
||||||
...
|
|
||||||
}
|
|
||||||
representing how many workouts each person performed per time period.
|
|
||||||
"""
|
"""
|
||||||
# Make a copy and prepare Period column
|
df = self._prepare_period_column(df, period)
|
||||||
df = self._prepare_period_column(df.copy(), period)
|
|
||||||
|
# Ensure Period is string for consistent pivoting
|
||||||
|
df = df.with_columns(pl.col("Period").dt.strftime("%Y-%m-%d"))
|
||||||
|
|
||||||
# Count unique workouts per (PersonId, PersonName, Period)
|
# Count unique workouts per (PersonId, PersonName, Period)
|
||||||
grp = (
|
grp = (
|
||||||
df.groupby(['PersonId', 'PersonName', 'Period'], as_index=False)['WorkoutId']
|
df.group_by(['PersonId', 'PersonName', 'Period'])
|
||||||
.nunique()
|
.agg(pl.col('WorkoutId').n_unique().alias('Count'))
|
||||||
.rename(columns={'WorkoutId': 'Count'})
|
|
||||||
)
|
)
|
||||||
# Convert each Period to its start time
|
|
||||||
grp['Period'] = grp['Period'].apply(lambda p: p.start_time)
|
|
||||||
|
|
||||||
return self._pivot_to_graph_dict(
|
return self._pivot_to_graph_dict(
|
||||||
grp,
|
grp,
|
||||||
index_col='PersonId',
|
index_col='PersonId',
|
||||||
@@ -112,46 +143,47 @@ class PeopleGraphs:
|
|||||||
value_col='Count'
|
value_col='Count'
|
||||||
)
|
)
|
||||||
|
|
||||||
def count_prs_over_time(self, df, period='week'):
|
def count_prs_over_time(self, df: pl.DataFrame, period='week'):
|
||||||
"""
|
"""
|
||||||
Returns a dictionary:
|
Returns number of PRs hit per person per period.
|
||||||
{
|
|
||||||
person_id: {
|
|
||||||
'PersonName': 'Alice',
|
|
||||||
'PRCounts': {
|
|
||||||
Timestamp('2023-01-02'): 1,
|
|
||||||
...
|
|
||||||
}
|
|
||||||
},
|
|
||||||
...
|
|
||||||
}
|
|
||||||
representing how many PRs each person hit per time period.
|
|
||||||
"""
|
"""
|
||||||
# Make a copy and prepare Period column
|
df = self._prepare_period_column(df, period)
|
||||||
df = self._prepare_period_column(df.copy(), period)
|
|
||||||
|
|
||||||
# Max 1RM per (Person, Exercise, Period)
|
# Max 1RM per (Person, Exercise, Period) - 'PeriodMax'
|
||||||
grouped = (
|
grouped = (
|
||||||
df.groupby(['PersonId', 'PersonName', 'ExerciseId', 'Period'], as_index=False)['Estimated1RM']
|
df.group_by(['PersonId', 'PersonName', 'ExerciseId', 'Period'])
|
||||||
.max()
|
.agg(pl.col('Estimated1RM').max().alias('PeriodMax'))
|
||||||
.rename(columns={'Estimated1RM': 'PeriodMax'})
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sort so we can track "all-time max" up to that row
|
# Sort so we can track "all-time max"
|
||||||
grouped.sort_values(by=['PersonId', 'ExerciseId', 'Period'], inplace=True)
|
grouped = grouped.sort(by=['PersonId', 'ExerciseId', 'Period'])
|
||||||
|
|
||||||
# For each person & exercise, track the cumulative max (shifted by 1)
|
# Calculate AllTimeMax representing the max UP TO the previous row.
|
||||||
grouped['AllTimeMax'] = grouped.groupby(['PersonId', 'ExerciseId'])['PeriodMax'].cummax().shift(1)
|
grouped = grouped.with_columns(
|
||||||
grouped['IsPR'] = (grouped['PeriodMax'] > grouped['AllTimeMax']).astype(int)
|
pl.col("PeriodMax")
|
||||||
|
.cum_max()
|
||||||
|
.over(['PersonId', 'ExerciseId'])
|
||||||
|
.shift(1)
|
||||||
|
.alias("AllTimeMax")
|
||||||
|
)
|
||||||
|
|
||||||
|
grouped = grouped.with_columns(
|
||||||
|
pl.col("AllTimeMax").fill_null(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
grouped = grouped.with_columns(
|
||||||
|
(pl.col("PeriodMax") > pl.col("AllTimeMax")).cast(pl.Int64).alias("IsPR")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure Period is string for consistent pivoting
|
||||||
|
grouped = grouped.with_columns(pl.col("Period").dt.strftime("%Y-%m-%d"))
|
||||||
|
|
||||||
# Sum PRs across exercises for (Person, Period)
|
# Sum PRs across exercises for (Person, Period)
|
||||||
pr_counts = (
|
pr_counts = (
|
||||||
grouped.groupby(['PersonId', 'PersonName', 'Period'], as_index=False)['IsPR']
|
grouped.group_by(['PersonId', 'PersonName', 'Period'])
|
||||||
.sum()
|
.agg(pl.col('IsPR').sum().alias('Count'))
|
||||||
.rename(columns={'IsPR': 'Count'})
|
|
||||||
)
|
)
|
||||||
pr_counts['Period'] = pr_counts['Period'].apply(lambda p: p.start_time)
|
|
||||||
|
|
||||||
return self._pivot_to_graph_dict(
|
return self._pivot_to_graph_dict(
|
||||||
pr_counts,
|
pr_counts,
|
||||||
index_col='PersonId',
|
index_col='PersonId',
|
||||||
@@ -160,38 +192,47 @@ class PeopleGraphs:
|
|||||||
value_col='Count'
|
value_col='Count'
|
||||||
)
|
)
|
||||||
|
|
||||||
def _pivot_to_graph_dict(self, df, index_col, name_col, period_col, value_col):
|
def _pivot_to_graph_dict(self, df: pl.DataFrame, index_col, name_col, period_col, value_col):
|
||||||
"""
|
"""
|
||||||
Convert [index_col, name_col, period_col, value_col]
|
Convert Polars DataFrame to the nested dict structure expected by visualization.
|
||||||
into a nested dictionary for plotting:
|
|
||||||
{
|
|
||||||
person_id: {
|
|
||||||
'PersonName': <...>,
|
|
||||||
'PRCounts': {
|
|
||||||
<timestamp>: <value>,
|
|
||||||
...
|
|
||||||
}
|
|
||||||
},
|
|
||||||
...
|
|
||||||
}
|
|
||||||
"""
|
"""
|
||||||
if df.empty:
|
if df.is_empty():
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
# Pivot
|
||||||
pivoted = df.pivot(
|
pivoted = df.pivot(
|
||||||
|
values=value_col,
|
||||||
index=[index_col, name_col],
|
index=[index_col, name_col],
|
||||||
columns=period_col,
|
columns=period_col,
|
||||||
values=value_col
|
aggregate_function="sum"
|
||||||
).fillna(0)
|
).fill_null(0)
|
||||||
|
|
||||||
pivoted.reset_index(inplace=True)
|
|
||||||
|
|
||||||
|
rows = pivoted.to_dicts()
|
||||||
|
date_cols = [c for c in pivoted.columns if c not in [index_col, name_col]]
|
||||||
|
|
||||||
result = {}
|
result = {}
|
||||||
for _, row in pivoted.iterrows():
|
for row in rows:
|
||||||
pid = row[index_col]
|
pid = row[index_col]
|
||||||
pname = row[name_col]
|
pname = row[name_col]
|
||||||
# Remaining columns = date -> count
|
|
||||||
period_counts = row.drop([index_col, name_col]).to_dict()
|
period_counts = {}
|
||||||
|
for dc in date_cols:
|
||||||
|
val = row[dc]
|
||||||
|
# If val is 0, we can skip if sparse behavior is desired,
|
||||||
|
# but let's keep it to match original behavior exactly if possible.
|
||||||
|
# Only adding if val > 0 or if we want full zeros?
|
||||||
|
# Original Pandas pivot keeps all columns (dates) for all rows, filling NaNs (0).
|
||||||
|
# The iteration later in 'get_graph_model' determines what to plot.
|
||||||
|
|
||||||
|
# Parse date string back to date object
|
||||||
|
try:
|
||||||
|
d_obj = datetime.strptime(str(dc), "%Y-%m-%d").date()
|
||||||
|
period_counts[d_obj] = val
|
||||||
|
except ValueError:
|
||||||
|
# Should not happen if we controlled the format
|
||||||
|
print(f"Warning: Could not parse date column {dc}")
|
||||||
|
period_counts[dc] = val
|
||||||
|
|
||||||
result[pid] = {
|
result[pid] = {
|
||||||
'PersonName': pname,
|
'PersonName': pname,
|
||||||
'PRCounts': period_counts
|
'PRCounts': period_counts
|
||||||
@@ -201,18 +242,8 @@ class PeopleGraphs:
|
|||||||
|
|
||||||
def get_graph_model(self, title, data_dict):
|
def get_graph_model(self, title, data_dict):
|
||||||
"""
|
"""
|
||||||
Builds a line-graph model from a dictionary of the form:
|
Builds a line-graph model from the dictionary.
|
||||||
{
|
This part remains mostly standard Python as it manipulates dicts.
|
||||||
person_id: {
|
|
||||||
'PersonName': 'Alice',
|
|
||||||
'PRCounts': {
|
|
||||||
Timestamp('2023-01-02'): 2,
|
|
||||||
Timestamp('2023-01-09'): 1,
|
|
||||||
...
|
|
||||||
}
|
|
||||||
},
|
|
||||||
...
|
|
||||||
}
|
|
||||||
"""
|
"""
|
||||||
if not data_dict:
|
if not data_dict:
|
||||||
return {
|
return {
|
||||||
@@ -229,8 +260,21 @@ class PeopleGraphs:
|
|||||||
all_dates.extend(user_data['PRCounts'].keys())
|
all_dates.extend(user_data['PRCounts'].keys())
|
||||||
all_values.extend(user_data['PRCounts'].values())
|
all_values.extend(user_data['PRCounts'].values())
|
||||||
|
|
||||||
|
if not all_dates:
|
||||||
|
return {
|
||||||
|
'title': title,
|
||||||
|
'vb_width': 200,
|
||||||
|
'vb_height': 75,
|
||||||
|
'plots': []
|
||||||
|
}
|
||||||
|
|
||||||
min_date = min(all_dates)
|
min_date = min(all_dates)
|
||||||
max_date = max(all_dates)
|
max_date = max(all_dates)
|
||||||
|
|
||||||
|
# Ensure min_date/max_date are comparable types.
|
||||||
|
# If they are strings vs dates, that's an issue.
|
||||||
|
# We tried to enforce conversion in _pivot_to_graph_dict.
|
||||||
|
|
||||||
date_span = max((max_date - min_date).days, 1)
|
date_span = max((max_date - min_date).days, 1)
|
||||||
|
|
||||||
max_val = max(all_values)
|
max_val = max(all_values)
|
||||||
@@ -269,3 +313,5 @@ class PeopleGraphs:
|
|||||||
'vb_height': vb_height,
|
'vb_height': vb_height,
|
||||||
'plots': plots
|
'plots': plots
|
||||||
}
|
}
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ minify-html==0.10.3
|
|||||||
jinja2-fragments==0.3.0
|
jinja2-fragments==0.3.0
|
||||||
Werkzeug==2.2.2
|
Werkzeug==2.2.2
|
||||||
numpy==1.19.5
|
numpy==1.19.5
|
||||||
pandas==1.3.1
|
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
plotly==5.24.1
|
plotly==5.24.1
|
||||||
wtforms==3.2.1
|
wtforms==3.2.1
|
||||||
@@ -17,4 +16,6 @@ flask-wtf==1.2.2
|
|||||||
Flask-Login==0.6.3
|
Flask-Login==0.6.3
|
||||||
Flask-Bcrypt==1.0.1
|
Flask-Bcrypt==1.0.1
|
||||||
email-validator==2.2.0
|
email-validator==2.2.0
|
||||||
requests==2.26.0
|
requests==2.26.0
|
||||||
|
polars>=0.20.0
|
||||||
|
pyarrow>=14.0.0
|
||||||
27
utils.py
27
utils.py
@@ -1,7 +1,7 @@
|
|||||||
import colorsys
|
import colorsys
|
||||||
from datetime import datetime, date, timedelta
|
from datetime import datetime, date, timedelta
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
|
||||||
import plotly.express as px
|
import plotly.express as px
|
||||||
import plotly.io as pio # Keep for now, might remove later if generate_plot is fully replaced
|
import plotly.io as pio # Keep for now, might remove later if generate_plot is fully replaced
|
||||||
import math
|
import math
|
||||||
@@ -110,32 +110,7 @@ def get_distinct_colors(n):
|
|||||||
colors.append(hex_color)
|
colors.append(hex_color)
|
||||||
return colors
|
return colors
|
||||||
|
|
||||||
def generate_plot(df: pd.DataFrame, title: str) -> str:
|
|
||||||
"""
|
|
||||||
Analyzes the DataFrame and generates an appropriate Plotly visualization.
|
|
||||||
Returns the Plotly figure as a div string.
|
|
||||||
Optimized for speed.
|
|
||||||
"""
|
|
||||||
if df.empty:
|
|
||||||
return "<p>No data available to plot.</p>"
|
|
||||||
|
|
||||||
num_columns = len(df.columns)
|
|
||||||
|
|
||||||
# Dictionary-based lookup for faster decision-making
|
|
||||||
plot_funcs = {
|
|
||||||
1: lambda: px.histogram(df, x=df.columns[0], title=title)
|
|
||||||
if pd.api.types.is_numeric_dtype(df.iloc[:, 0]) else px.bar(df, x=df.columns[0], title=title),
|
|
||||||
|
|
||||||
2: lambda: px.scatter(df, x=df.columns[0], y=df.columns[1], title=title)
|
|
||||||
if pd.api.types.is_numeric_dtype(df.iloc[:, 0]) and pd.api.types.is_numeric_dtype(df.iloc[:, 1])
|
|
||||||
else px.bar(df, x=df.columns[0], y=df.columns[1], title=title)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Select plot function based on column count
|
|
||||||
fig = plot_funcs.get(num_columns, lambda: px.imshow(df.corr(numeric_only=True), text_auto=True, title=title))()
|
|
||||||
|
|
||||||
# Use static rendering for speed
|
|
||||||
return pio.to_html(fig, full_html=False, include_plotlyjs=False, config={'staticPlot': True})
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_estimated_1rm(weight, repetitions):
|
def calculate_estimated_1rm(weight, repetitions):
|
||||||
|
|||||||
Reference in New Issue
Block a user