204 lines
7.6 KiB
Python
204 lines
7.6 KiB
Python
import polars as pl
|
|
from utils import get_distinct_colors
|
|
from datetime import datetime
|
|
|
|
class PeopleGraphs:
|
|
def __init__(self, db_connection_method):
|
|
self.execute = db_connection_method
|
|
|
|
def get(self, selected_people_ids=None, min_date=None, max_date=None, selected_exercise_ids=None):
|
|
"""
|
|
Fetch workout topsets, calculate Estimated1RM and graph data in Polars,
|
|
then generate weekly workout & PR graphs.
|
|
"""
|
|
query = """
|
|
SELECT
|
|
P.person_id AS "PersonId",
|
|
P.name AS "PersonName",
|
|
W.workout_id AS "WorkoutId",
|
|
W.start_date AS "StartDate",
|
|
T.topset_id AS "TopSetId",
|
|
E.exercise_id AS "ExerciseId",
|
|
E.name AS "ExerciseName",
|
|
T.repetitions AS "Repetitions",
|
|
T.weight AS "Weight"
|
|
FROM Person P
|
|
LEFT JOIN Workout W ON P.person_id = W.person_id
|
|
LEFT JOIN TopSet T ON W.workout_id = T.workout_id
|
|
LEFT JOIN Exercise E ON T.exercise_id = E.exercise_id
|
|
WHERE TRUE
|
|
"""
|
|
params = []
|
|
if selected_people_ids:
|
|
query += f" AND P.person_id IN ({', '.join(['%s'] * len(selected_people_ids))})"
|
|
params.extend(selected_people_ids)
|
|
if min_date:
|
|
query += " AND W.start_date >= %s"
|
|
params.append(min_date)
|
|
if max_date:
|
|
query += " AND W.start_date <= %s"
|
|
params.append(max_date)
|
|
if selected_exercise_ids:
|
|
query += f" AND E.exercise_id IN ({', '.join(['%s'] * len(selected_exercise_ids))})"
|
|
params.extend(selected_exercise_ids)
|
|
|
|
raw_data = self.execute(query, params)
|
|
if not raw_data:
|
|
return [
|
|
self._empty_graph("Workouts per week"),
|
|
self._empty_graph("PRs per week")
|
|
]
|
|
|
|
# Explicitly specify schema to ensure correct types
|
|
schema_overrides = {
|
|
"Weight": pl.Float64,
|
|
"Repetitions": pl.Int64,
|
|
"StartDate": pl.Date
|
|
}
|
|
|
|
try:
|
|
df = pl.DataFrame(raw_data, schema_overrides=schema_overrides, infer_schema_length=10000)
|
|
except:
|
|
df = pl.DataFrame(raw_data)
|
|
|
|
# Force StartDate to Date type and filter nulls
|
|
df = df.with_columns(pl.col("StartDate").cast(pl.Date)).filter(pl.col("StartDate").is_not_null())
|
|
|
|
if df.is_empty():
|
|
return [
|
|
self._empty_graph("Workouts per week"),
|
|
self._empty_graph("PRs per week")
|
|
]
|
|
|
|
# Calculate Estimated1RM
|
|
# SQL cast Weight::integer rounds to nearest. Matching that here.
|
|
df = df.with_columns(
|
|
pl.when(pl.col("Repetitions") == 0)
|
|
.then(0)
|
|
.otherwise(
|
|
(pl.lit(100) * pl.col("Weight").round(0).cast(pl.Int64)) / (pl.lit(101.3) - pl.lit(2.67123) * pl.col("Repetitions"))
|
|
)
|
|
.round(0)
|
|
.cast(pl.Int64)
|
|
.alias("Estimated1RM")
|
|
)
|
|
|
|
# Prepare period-truncated column
|
|
df = df.with_columns(
|
|
pl.col("StartDate").dt.truncate("1w").alias("Period")
|
|
)
|
|
|
|
# 1. Workouts per week
|
|
workout_counts = (
|
|
df.group_by(['PersonId', 'PersonName', 'Period'])
|
|
.agg(pl.col('WorkoutId').n_unique().alias('Count'))
|
|
)
|
|
|
|
# 2. PRs per week
|
|
grouped_prs = (
|
|
df.group_by(['PersonId', 'PersonName', 'ExerciseId', 'Period'])
|
|
.agg(pl.col('Estimated1RM').max().alias('PeriodMax'))
|
|
.sort(['PersonId', 'ExerciseId', 'Period'])
|
|
)
|
|
grouped_prs = grouped_prs.with_columns(
|
|
pl.col("PeriodMax")
|
|
.cum_max()
|
|
.over(['PersonId', 'ExerciseId'])
|
|
.shift(1)
|
|
.fill_null(0)
|
|
.alias("AllTimeMax")
|
|
)
|
|
grouped_prs = grouped_prs.with_columns(
|
|
(pl.col("PeriodMax") > pl.col("AllTimeMax")).cast(pl.Int64).alias("IsPR")
|
|
)
|
|
pr_counts = (
|
|
grouped_prs.group_by(['PersonId', 'PersonName', 'Period'])
|
|
.agg(pl.col('IsPR').sum().alias('Count'))
|
|
)
|
|
|
|
return [
|
|
self._build_graph_model("Workouts per week", workout_counts),
|
|
self._build_graph_model("PRs per week", pr_counts)
|
|
]
|
|
|
|
def _empty_graph(self, title):
|
|
return {
|
|
'title': title,
|
|
'vb_width': 200,
|
|
'vb_height': 75,
|
|
'plots': []
|
|
}
|
|
|
|
def _build_graph_model(self, title, df: pl.DataFrame):
|
|
if df.is_empty():
|
|
return self._empty_graph(title)
|
|
|
|
# 1. Scaling stats from the sparse data (to find global span and max value)
|
|
stats = df.select([
|
|
pl.col("Period").min().alias("min_date"),
|
|
pl.col("Period").max().alias("max_date"),
|
|
pl.col("Count").max().alias("max_val")
|
|
])
|
|
min_date = stats.get_column("min_date")[0]
|
|
max_date = stats.get_column("max_date")[0]
|
|
max_val = stats.get_column("max_val")[0]
|
|
|
|
date_span = max((max_date - min_date).days, 1)
|
|
val_range = max(max_val, 1)
|
|
vb_width, vb_height = 200, 75
|
|
|
|
# 2. Make data "dense" so lines connect to 0 for missing weeks
|
|
# This replicates the behavior of the original pivot().fill_null(0)
|
|
all_periods = df.select("Period").unique().sort("Period")
|
|
all_people = df.select(["PersonId", "PersonName"]).unique(subset=["PersonId"])
|
|
dense_df = all_people.join(all_periods, how="cross")
|
|
|
|
df = dense_df.join(df.select(["PersonId", "Period", "Count"]), on=["PersonId", "Period"], how="left").with_columns(
|
|
pl.col("Count").fill_null(0)
|
|
)
|
|
|
|
# 3. Vectorized coordinate calculation and label formatting
|
|
df = df.with_columns([
|
|
(((pl.col("Period") - min_date).dt.total_days() / date_span)).alias("x_norm"),
|
|
((pl.col("Count") / val_range) * vb_height).alias("y_scaled"),
|
|
(
|
|
pl.col("Count").cast(pl.String) +
|
|
" for " + pl.col("PersonName") +
|
|
" at " + pl.col("Period").dt.strftime("%d %b %y")
|
|
).alias("msg")
|
|
]).sort(["PersonId", "Period"])
|
|
|
|
# 4. Group by person to build the 'plots' structure
|
|
distinct_people = df.select("PersonId").unique(maintain_order=True).get_column("PersonId").to_list()
|
|
colors = get_distinct_colors(len(distinct_people))
|
|
color_map = {pid: colors[i] for i, pid in enumerate(distinct_people)}
|
|
|
|
plots = []
|
|
for pid in distinct_people:
|
|
person_df = df.filter(pl.col("PersonId") == pid)
|
|
if person_df.is_empty():
|
|
continue
|
|
|
|
name = person_df.get_column("PersonName")[0]
|
|
|
|
y_vals = person_df.get_column("y_scaled").to_list()
|
|
x_norms = person_df.get_column("x_norm").to_list()
|
|
msgs = person_df.get_column("msg").to_list()
|
|
|
|
points = list(zip(y_vals, x_norms))
|
|
labels = list(zip(y_vals, x_norms, msgs))
|
|
|
|
plots.append({
|
|
'label': name,
|
|
'color': color_map[pid],
|
|
'points': points,
|
|
'plot_labels': labels
|
|
})
|
|
|
|
return {
|
|
'title': title,
|
|
'vb_width': vb_width,
|
|
'vb_height': vb_height,
|
|
'plots': plots
|
|
}
|