Switch to using polars

This commit is contained in:
Peter Stockings
2026-01-29 00:05:25 +11:00
parent dd82f461be
commit 3a0d4531b6
6 changed files with 156 additions and 134 deletions

View File

@@ -1,5 +1,5 @@
import pandas as pd
from utils import get_distinct_colors, calculate_estimated_1rm
import polars as pl
from utils import get_distinct_colors
class PeopleGraphs:
def __init__(self, db_connection_method):
@@ -7,7 +7,7 @@ class PeopleGraphs:
def get(self, selected_people_ids=None, min_date=None, max_date=None, selected_exercise_ids=None):
"""
Fetch workout topsets, calculate Estimated1RM in Python,
Fetch workout topsets, calculate Estimated1RM in Polars,
then generate weekly workout & PR graphs.
"""
# Build query (no in-SQL 1RM calculation).
@@ -51,15 +51,41 @@ class PeopleGraphs:
self.get_graph_model("PRs per week", {})
]
df = pd.DataFrame(raw_data)
# Explicitly specify schema to ensure correct types
schema_overrides = {
"Weight": pl.Float64,
"Repetitions": pl.Int64,
"StartDate": pl.Date # Or pl.Datetime depending on DB driver, but usually Date for dates
}
# Depending on how 'self.execute' returns data (list of dicts usually),
# Polars can infer schema. For robustness with DB types:
try:
df = pl.DataFrame(raw_data, schema_overrides=schema_overrides, infer_schema_length=1000)
except:
# Fallback if specific schema injection fails due to mismatched input types
df = pl.DataFrame(raw_data)
# Calculate Estimated1RM in Python
df['Estimated1RM'] = df.apply(
lambda row: calculate_estimated_1rm(row["Weight"], row["Repetitions"]), axis=1
# Calculate Estimated1RM in Polars
# Formula: round((100 * int(weight)) / (101.3 - 2.67123 * repetitions), 0)
# Handle division by zero implicitly by filter or usage?
# The original code only avoided div by zero if Repetitions == 0.
# Polars handles nulls/NaNs usually, but let's replicate logic.
df = df.with_columns(
pl.when(pl.col("Repetitions") == 0)
.then(0)
.otherwise(
(pl.lit(100) * pl.col("Weight")) / (pl.lit(101.3) - pl.lit(2.67123) * pl.col("Repetitions"))
)
.round(0)
.cast(pl.Int64)
.alias("Estimated1RM")
)
# Build the weekly data models
weekly_counts = self.get_workout_counts(df, period='week')
weekly_counts = self.get_workout_counts(df, period='week')
weekly_pr_counts = self.count_prs_over_time(df, period='week')
return [
@@ -67,43 +93,48 @@ class PeopleGraphs:
self.get_graph_model("PRs per week", weekly_pr_counts)
]
def _prepare_period_column(self, df, period='week'):
def _prepare_period_column(self, df: pl.DataFrame, period='week'):
"""
Convert StartDate to datetime and add a Period column
based on 'week' or 'month' as needed.
Convert StartDate to proper date type and add a Period column
represented as the start date of that period.
"""
df['StartDate'] = pd.to_datetime(df['StartDate'], errors='coerce')
freq = 'W' if period == 'week' else 'M'
df['Period'] = df['StartDate'].dt.to_period(freq)
# Ensure StartDate is Date/Datetime
if df["StartDate"].dtype == pl.String:
df = df.with_columns(pl.col("StartDate").str.strptime(pl.Date, "%Y-%m-%d")) # Adjust format if needed
elif df["StartDate"].dtype == pl.Object:
# If it's python datetime objects
df = df.with_columns(pl.col("StartDate").cast(pl.Date))
# Truncate to week or month
if period == 'week':
# Polars doesn't have a direct 'to_period' like Pandas.
# We can use dt.truncate("1w") which floors to start of week (Monday usually)
# Postgres/standard weeks usually start Monday.
df = df.with_columns(
pl.col("StartDate").dt.truncate("1w").alias("Period")
)
else: # month
df = df.with_columns(
pl.col("StartDate").dt.truncate("1mo").alias("Period")
)
return df
def get_workout_counts(self, df, period='week'):
def get_workout_counts(self, df: pl.DataFrame, period='week'):
"""
Returns a dictionary:
{
person_id: {
'PersonName': 'Alice',
'PRCounts': {
Timestamp('2023-01-02'): 2,
...
}
},
...
}
representing how many workouts each person performed per time period.
Returns workout counts per person per period.
"""
# Make a copy and prepare Period column
df = self._prepare_period_column(df.copy(), period)
df = self._prepare_period_column(df, period)
# Ensure Period is string for consistent pivoting
df = df.with_columns(pl.col("Period").dt.strftime("%Y-%m-%d"))
# Count unique workouts per (PersonId, PersonName, Period)
grp = (
df.groupby(['PersonId', 'PersonName', 'Period'], as_index=False)['WorkoutId']
.nunique()
.rename(columns={'WorkoutId': 'Count'})
df.group_by(['PersonId', 'PersonName', 'Period'])
.agg(pl.col('WorkoutId').n_unique().alias('Count'))
)
# Convert each Period to its start time
grp['Period'] = grp['Period'].apply(lambda p: p.start_time)
return self._pivot_to_graph_dict(
grp,
index_col='PersonId',
@@ -112,46 +143,47 @@ class PeopleGraphs:
value_col='Count'
)
def count_prs_over_time(self, df, period='week'):
def count_prs_over_time(self, df: pl.DataFrame, period='week'):
"""
Returns a dictionary:
{
person_id: {
'PersonName': 'Alice',
'PRCounts': {
Timestamp('2023-01-02'): 1,
...
}
},
...
}
representing how many PRs each person hit per time period.
Returns number of PRs hit per person per period.
"""
# Make a copy and prepare Period column
df = self._prepare_period_column(df.copy(), period)
df = self._prepare_period_column(df, period)
# Max 1RM per (Person, Exercise, Period)
# Max 1RM per (Person, Exercise, Period) - 'PeriodMax'
grouped = (
df.groupby(['PersonId', 'PersonName', 'ExerciseId', 'Period'], as_index=False)['Estimated1RM']
.max()
.rename(columns={'Estimated1RM': 'PeriodMax'})
df.group_by(['PersonId', 'PersonName', 'ExerciseId', 'Period'])
.agg(pl.col('Estimated1RM').max().alias('PeriodMax'))
)
# Sort so we can track "all-time max" up to that row
grouped.sort_values(by=['PersonId', 'ExerciseId', 'Period'], inplace=True)
# Sort so we can track "all-time max"
grouped = grouped.sort(by=['PersonId', 'ExerciseId', 'Period'])
# For each person & exercise, track the cumulative max (shifted by 1)
grouped['AllTimeMax'] = grouped.groupby(['PersonId', 'ExerciseId'])['PeriodMax'].cummax().shift(1)
grouped['IsPR'] = (grouped['PeriodMax'] > grouped['AllTimeMax']).astype(int)
# Calculate AllTimeMax representing the max UP TO the previous row.
grouped = grouped.with_columns(
pl.col("PeriodMax")
.cum_max()
.over(['PersonId', 'ExerciseId'])
.shift(1)
.alias("AllTimeMax")
)
grouped = grouped.with_columns(
pl.col("AllTimeMax").fill_null(0)
)
grouped = grouped.with_columns(
(pl.col("PeriodMax") > pl.col("AllTimeMax")).cast(pl.Int64).alias("IsPR")
)
# Ensure Period is string for consistent pivoting
grouped = grouped.with_columns(pl.col("Period").dt.strftime("%Y-%m-%d"))
# Sum PRs across exercises for (Person, Period)
pr_counts = (
grouped.groupby(['PersonId', 'PersonName', 'Period'], as_index=False)['IsPR']
.sum()
.rename(columns={'IsPR': 'Count'})
grouped.group_by(['PersonId', 'PersonName', 'Period'])
.agg(pl.col('IsPR').sum().alias('Count'))
)
pr_counts['Period'] = pr_counts['Period'].apply(lambda p: p.start_time)
return self._pivot_to_graph_dict(
pr_counts,
index_col='PersonId',
@@ -160,38 +192,47 @@ class PeopleGraphs:
value_col='Count'
)
def _pivot_to_graph_dict(self, df, index_col, name_col, period_col, value_col):
def _pivot_to_graph_dict(self, df: pl.DataFrame, index_col, name_col, period_col, value_col):
"""
Convert [index_col, name_col, period_col, value_col]
into a nested dictionary for plotting:
{
person_id: {
'PersonName': <...>,
'PRCounts': {
<timestamp>: <value>,
...
}
},
...
}
Convert Polars DataFrame to the nested dict structure expected by visualization.
"""
if df.empty:
if df.is_empty():
return {}
# Pivot
pivoted = df.pivot(
values=value_col,
index=[index_col, name_col],
columns=period_col,
values=value_col
).fillna(0)
pivoted.reset_index(inplace=True)
aggregate_function="sum"
).fill_null(0)
rows = pivoted.to_dicts()
date_cols = [c for c in pivoted.columns if c not in [index_col, name_col]]
result = {}
for _, row in pivoted.iterrows():
for row in rows:
pid = row[index_col]
pname = row[name_col]
# Remaining columns = date -> count
period_counts = row.drop([index_col, name_col]).to_dict()
period_counts = {}
for dc in date_cols:
val = row[dc]
# If val is 0, we can skip if sparse behavior is desired,
# but let's keep it to match original behavior exactly if possible.
# Only adding if val > 0 or if we want full zeros?
# Original Pandas pivot keeps all columns (dates) for all rows, filling NaNs (0).
# The iteration later in 'get_graph_model' determines what to plot.
# Parse date string back to date object
try:
d_obj = datetime.strptime(str(dc), "%Y-%m-%d").date()
period_counts[d_obj] = val
except ValueError:
# Should not happen if we controlled the format
print(f"Warning: Could not parse date column {dc}")
period_counts[dc] = val
result[pid] = {
'PersonName': pname,
'PRCounts': period_counts
@@ -201,18 +242,8 @@ class PeopleGraphs:
def get_graph_model(self, title, data_dict):
"""
Builds a line-graph model from a dictionary of the form:
{
person_id: {
'PersonName': 'Alice',
'PRCounts': {
Timestamp('2023-01-02'): 2,
Timestamp('2023-01-09'): 1,
...
}
},
...
}
Builds a line-graph model from the dictionary.
This part remains mostly standard Python as it manipulates dicts.
"""
if not data_dict:
return {
@@ -229,8 +260,21 @@ class PeopleGraphs:
all_dates.extend(user_data['PRCounts'].keys())
all_values.extend(user_data['PRCounts'].values())
if not all_dates:
return {
'title': title,
'vb_width': 200,
'vb_height': 75,
'plots': []
}
min_date = min(all_dates)
max_date = max(all_dates)
# Ensure min_date/max_date are comparable types.
# If they are strings vs dates, that's an issue.
# We tried to enforce conversion in _pivot_to_graph_dict.
date_span = max((max_date - min_date).days, 1)
max_val = max(all_values)
@@ -269,3 +313,5 @@ class PeopleGraphs:
'vb_height': vb_height,
'plots': plots
}
from datetime import datetime