Speed up people_graphs
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import polars as pl
|
||||
from utils import get_distinct_colors
|
||||
from datetime import datetime
|
||||
|
||||
class PeopleGraphs:
|
||||
def __init__(self, db_connection_method):
|
||||
@@ -7,10 +8,9 @@ class PeopleGraphs:
|
||||
|
||||
def get(self, selected_people_ids=None, min_date=None, max_date=None, selected_exercise_ids=None):
|
||||
"""
|
||||
Fetch workout topsets, calculate Estimated1RM in Polars,
|
||||
Fetch workout topsets, calculate Estimated1RM and graph data in Polars,
|
||||
then generate weekly workout & PR graphs.
|
||||
"""
|
||||
# Build query (no in-SQL 1RM calculation).
|
||||
query = """
|
||||
SELECT
|
||||
P.person_id AS "PersonId",
|
||||
@@ -42,267 +42,155 @@ class PeopleGraphs:
|
||||
query += f" AND E.exercise_id IN ({', '.join(['%s'] * len(selected_exercise_ids))})"
|
||||
params.extend(selected_exercise_ids)
|
||||
|
||||
# Execute and convert to DataFrame
|
||||
raw_data = self.execute(query, params)
|
||||
if not raw_data:
|
||||
# Return empty graphs if no data at all
|
||||
return [
|
||||
self.get_graph_model("Workouts per week", {}),
|
||||
self.get_graph_model("PRs per week", {})
|
||||
self._empty_graph("Workouts per week"),
|
||||
self._empty_graph("PRs per week")
|
||||
]
|
||||
|
||||
# Explicitly specify schema to ensure correct types
|
||||
schema_overrides = {
|
||||
"Weight": pl.Float64,
|
||||
"Repetitions": pl.Int64,
|
||||
"StartDate": pl.Date # Or pl.Datetime depending on DB driver, but usually Date for dates
|
||||
"StartDate": pl.Date
|
||||
}
|
||||
|
||||
# Depending on how 'self.execute' returns data (list of dicts usually),
|
||||
# Polars can infer schema. For robustness with DB types:
|
||||
try:
|
||||
df = pl.DataFrame(raw_data, schema_overrides=schema_overrides, infer_schema_length=1000)
|
||||
df = pl.DataFrame(raw_data, schema_overrides=schema_overrides, infer_schema_length=10000)
|
||||
except:
|
||||
# Fallback if specific schema injection fails due to mismatched input types
|
||||
df = pl.DataFrame(raw_data)
|
||||
|
||||
# Force StartDate to Date type and filter nulls
|
||||
df = df.with_columns(pl.col("StartDate").cast(pl.Date)).filter(pl.col("StartDate").is_not_null())
|
||||
|
||||
# Calculate Estimated1RM in Polars
|
||||
# Formula: round((100 * int(weight)) / (101.3 - 2.67123 * repetitions), 0)
|
||||
# Handle division by zero implicitly by filter or usage?
|
||||
# The original code only avoided div by zero if Repetitions == 0.
|
||||
# Polars handles nulls/NaNs usually, but let's replicate logic.
|
||||
|
||||
if df.is_empty():
|
||||
return [
|
||||
self._empty_graph("Workouts per week"),
|
||||
self._empty_graph("PRs per week")
|
||||
]
|
||||
|
||||
# Calculate Estimated1RM
|
||||
# SQL cast Weight::integer rounds to nearest. Matching that here.
|
||||
df = df.with_columns(
|
||||
pl.when(pl.col("Repetitions") == 0)
|
||||
.then(0)
|
||||
.otherwise(
|
||||
(pl.lit(100) * pl.col("Weight")) / (pl.lit(101.3) - pl.lit(2.67123) * pl.col("Repetitions"))
|
||||
(pl.lit(100) * pl.col("Weight").round(0).cast(pl.Int64)) / (pl.lit(101.3) - pl.lit(2.67123) * pl.col("Repetitions"))
|
||||
)
|
||||
.round(0)
|
||||
.cast(pl.Int64)
|
||||
.alias("Estimated1RM")
|
||||
)
|
||||
|
||||
# Build the weekly data models
|
||||
weekly_counts = self.get_workout_counts(df, period='week')
|
||||
weekly_pr_counts = self.count_prs_over_time(df, period='week')
|
||||
# Prepare period-truncated column
|
||||
df = df.with_columns(
|
||||
pl.col("StartDate").dt.truncate("1w").alias("Period")
|
||||
)
|
||||
|
||||
return [
|
||||
self.get_graph_model("Workouts per week", weekly_counts),
|
||||
self.get_graph_model("PRs per week", weekly_pr_counts)
|
||||
]
|
||||
|
||||
def _prepare_period_column(self, df: pl.DataFrame, period='week'):
|
||||
"""
|
||||
Convert StartDate to proper date type and add a Period column
|
||||
represented as the start date of that period.
|
||||
"""
|
||||
# Ensure StartDate is Date/Datetime
|
||||
if df["StartDate"].dtype == pl.String:
|
||||
df = df.with_columns(pl.col("StartDate").str.strptime(pl.Date, "%Y-%m-%d")) # Adjust format if needed
|
||||
elif df["StartDate"].dtype == pl.Object:
|
||||
# If it's python datetime objects
|
||||
df = df.with_columns(pl.col("StartDate").cast(pl.Date))
|
||||
|
||||
# Truncate to week or month
|
||||
if period == 'week':
|
||||
# Polars doesn't have a direct 'to_period' like Pandas.
|
||||
# We can use dt.truncate("1w") which floors to start of week (Monday usually)
|
||||
# Postgres/standard weeks usually start Monday.
|
||||
df = df.with_columns(
|
||||
pl.col("StartDate").dt.truncate("1w").alias("Period")
|
||||
)
|
||||
else: # month
|
||||
df = df.with_columns(
|
||||
pl.col("StartDate").dt.truncate("1mo").alias("Period")
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
def get_workout_counts(self, df: pl.DataFrame, period='week'):
|
||||
"""
|
||||
Returns workout counts per person per period.
|
||||
"""
|
||||
df = self._prepare_period_column(df, period)
|
||||
|
||||
# Ensure Period is string for consistent pivoting
|
||||
df = df.with_columns(pl.col("Period").dt.strftime("%Y-%m-%d"))
|
||||
|
||||
# Count unique workouts per (PersonId, PersonName, Period)
|
||||
grp = (
|
||||
# 1. Workouts per week
|
||||
workout_counts = (
|
||||
df.group_by(['PersonId', 'PersonName', 'Period'])
|
||||
.agg(pl.col('WorkoutId').n_unique().alias('Count'))
|
||||
)
|
||||
|
||||
return self._pivot_to_graph_dict(
|
||||
grp,
|
||||
index_col='PersonId',
|
||||
name_col='PersonName',
|
||||
period_col='Period',
|
||||
value_col='Count'
|
||||
)
|
||||
|
||||
def count_prs_over_time(self, df: pl.DataFrame, period='week'):
|
||||
"""
|
||||
Returns number of PRs hit per person per period.
|
||||
"""
|
||||
df = self._prepare_period_column(df, period)
|
||||
|
||||
# Max 1RM per (Person, Exercise, Period) - 'PeriodMax'
|
||||
grouped = (
|
||||
# 2. PRs per week
|
||||
grouped_prs = (
|
||||
df.group_by(['PersonId', 'PersonName', 'ExerciseId', 'Period'])
|
||||
.agg(pl.col('Estimated1RM').max().alias('PeriodMax'))
|
||||
.sort(['PersonId', 'ExerciseId', 'Period'])
|
||||
)
|
||||
|
||||
# Sort so we can track "all-time max"
|
||||
grouped = grouped.sort(by=['PersonId', 'ExerciseId', 'Period'])
|
||||
|
||||
# Calculate AllTimeMax representing the max UP TO the previous row.
|
||||
grouped = grouped.with_columns(
|
||||
grouped_prs = grouped_prs.with_columns(
|
||||
pl.col("PeriodMax")
|
||||
.cum_max()
|
||||
.over(['PersonId', 'ExerciseId'])
|
||||
.shift(1)
|
||||
.fill_null(0)
|
||||
.alias("AllTimeMax")
|
||||
)
|
||||
|
||||
grouped = grouped.with_columns(
|
||||
pl.col("AllTimeMax").fill_null(0)
|
||||
)
|
||||
|
||||
grouped = grouped.with_columns(
|
||||
grouped_prs = grouped_prs.with_columns(
|
||||
(pl.col("PeriodMax") > pl.col("AllTimeMax")).cast(pl.Int64).alias("IsPR")
|
||||
)
|
||||
|
||||
# Ensure Period is string for consistent pivoting
|
||||
grouped = grouped.with_columns(pl.col("Period").dt.strftime("%Y-%m-%d"))
|
||||
|
||||
# Sum PRs across exercises for (Person, Period)
|
||||
pr_counts = (
|
||||
grouped.group_by(['PersonId', 'PersonName', 'Period'])
|
||||
grouped_prs.group_by(['PersonId', 'PersonName', 'Period'])
|
||||
.agg(pl.col('IsPR').sum().alias('Count'))
|
||||
)
|
||||
|
||||
return [
|
||||
self._build_graph_model("Workouts per week", workout_counts),
|
||||
self._build_graph_model("PRs per week", pr_counts)
|
||||
]
|
||||
|
||||
def _empty_graph(self, title):
|
||||
return {
|
||||
'title': title,
|
||||
'vb_width': 200,
|
||||
'vb_height': 75,
|
||||
'plots': []
|
||||
}
|
||||
|
||||
def _build_graph_model(self, title, df: pl.DataFrame):
|
||||
if df.is_empty():
|
||||
return self._empty_graph(title)
|
||||
|
||||
# 1. Scaling stats from the sparse data (to find global span and max value)
|
||||
stats = df.select([
|
||||
pl.col("Period").min().alias("min_date"),
|
||||
pl.col("Period").max().alias("max_date"),
|
||||
pl.col("Count").max().alias("max_val")
|
||||
])
|
||||
min_date = stats.get_column("min_date")[0]
|
||||
max_date = stats.get_column("max_date")[0]
|
||||
max_val = stats.get_column("max_val")[0]
|
||||
|
||||
return self._pivot_to_graph_dict(
|
||||
pr_counts,
|
||||
index_col='PersonId',
|
||||
name_col='PersonName',
|
||||
period_col='Period',
|
||||
value_col='Count'
|
||||
date_span = max((max_date - min_date).days, 1)
|
||||
val_range = max(max_val, 1)
|
||||
vb_width, vb_height = 200, 75
|
||||
|
||||
# 2. Make data "dense" so lines connect to 0 for missing weeks
|
||||
# This replicates the behavior of the original pivot().fill_null(0)
|
||||
all_periods = df.select("Period").unique().sort("Period")
|
||||
all_people = df.select(["PersonId", "PersonName"]).unique(subset=["PersonId"])
|
||||
dense_df = all_people.join(all_periods, how="cross")
|
||||
|
||||
df = dense_df.join(df.select(["PersonId", "Period", "Count"]), on=["PersonId", "Period"], how="left").with_columns(
|
||||
pl.col("Count").fill_null(0)
|
||||
)
|
||||
|
||||
def _pivot_to_graph_dict(self, df: pl.DataFrame, index_col, name_col, period_col, value_col):
|
||||
"""
|
||||
Convert Polars DataFrame to the nested dict structure expected by visualization.
|
||||
"""
|
||||
if df.is_empty():
|
||||
return {}
|
||||
# 3. Vectorized coordinate calculation and label formatting
|
||||
df = df.with_columns([
|
||||
(((pl.col("Period") - min_date).dt.total_days() / date_span)).alias("x_norm"),
|
||||
((pl.col("Count") / val_range) * vb_height).alias("y_scaled"),
|
||||
(
|
||||
pl.col("Count").cast(pl.String) +
|
||||
" for " + pl.col("PersonName") +
|
||||
" at " + pl.col("Period").dt.strftime("%d %b %y")
|
||||
).alias("msg")
|
||||
]).sort(["PersonId", "Period"])
|
||||
|
||||
# Pivot
|
||||
pivoted = df.pivot(
|
||||
values=value_col,
|
||||
index=[index_col, name_col],
|
||||
columns=period_col,
|
||||
aggregate_function="sum"
|
||||
).fill_null(0)
|
||||
# 4. Group by person to build the 'plots' structure
|
||||
distinct_people = df.select("PersonId").unique(maintain_order=True).get_column("PersonId").to_list()
|
||||
colors = get_distinct_colors(len(distinct_people))
|
||||
color_map = {pid: colors[i] for i, pid in enumerate(distinct_people)}
|
||||
|
||||
rows = pivoted.to_dicts()
|
||||
date_cols = [c for c in pivoted.columns if c not in [index_col, name_col]]
|
||||
|
||||
result = {}
|
||||
for row in rows:
|
||||
pid = row[index_col]
|
||||
pname = row[name_col]
|
||||
|
||||
period_counts = {}
|
||||
for dc in date_cols:
|
||||
val = row[dc]
|
||||
# If val is 0, we can skip if sparse behavior is desired,
|
||||
# but let's keep it to match original behavior exactly if possible.
|
||||
# Only adding if val > 0 or if we want full zeros?
|
||||
# Original Pandas pivot keeps all columns (dates) for all rows, filling NaNs (0).
|
||||
# The iteration later in 'get_graph_model' determines what to plot.
|
||||
|
||||
# Parse date string back to date object
|
||||
try:
|
||||
d_obj = datetime.strptime(str(dc), "%Y-%m-%d").date()
|
||||
period_counts[d_obj] = val
|
||||
except ValueError:
|
||||
# Should not happen if we controlled the format
|
||||
print(f"Warning: Could not parse date column {dc}")
|
||||
period_counts[dc] = val
|
||||
|
||||
result[pid] = {
|
||||
'PersonName': pname,
|
||||
'PRCounts': period_counts
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_graph_model(self, title, data_dict):
|
||||
"""
|
||||
Builds a line-graph model from the dictionary.
|
||||
This part remains mostly standard Python as it manipulates dicts.
|
||||
"""
|
||||
if not data_dict:
|
||||
return {
|
||||
'title': title,
|
||||
'vb_width': 200,
|
||||
'vb_height': 75,
|
||||
'plots': []
|
||||
}
|
||||
|
||||
# Gather all dates & values
|
||||
all_dates = []
|
||||
all_values = []
|
||||
for user_data in data_dict.values():
|
||||
all_dates.extend(user_data['PRCounts'].keys())
|
||||
all_values.extend(user_data['PRCounts'].values())
|
||||
|
||||
if not all_dates:
|
||||
return {
|
||||
'title': title,
|
||||
'vb_width': 200,
|
||||
'vb_height': 75,
|
||||
'plots': []
|
||||
}
|
||||
|
||||
min_date = min(all_dates)
|
||||
max_date = max(all_dates)
|
||||
|
||||
# Ensure min_date/max_date are comparable types.
|
||||
# If they are strings vs dates, that's an issue.
|
||||
# We tried to enforce conversion in _pivot_to_graph_dict.
|
||||
|
||||
date_span = max((max_date - min_date).days, 1)
|
||||
|
||||
max_val = max(all_values)
|
||||
min_val = 0
|
||||
val_range = max_val - min_val if max_val != min_val else 1
|
||||
|
||||
vb_width, vb_height = 200, 75
|
||||
colors = get_distinct_colors(len(data_dict))
|
||||
plots = []
|
||||
|
||||
for i, (pid, user_data) in enumerate(data_dict.items()):
|
||||
name = user_data['PersonName']
|
||||
pr_counts = user_data['PRCounts']
|
||||
# Sort by date so points are in chronological order
|
||||
sorted_pr = sorted(pr_counts.items(), key=lambda x: x[0])
|
||||
|
||||
points = []
|
||||
labels = []
|
||||
for d, val in sorted_pr:
|
||||
# Scale x,y to fit [0..1], then we multiply y by vb_height
|
||||
x = (d - min_date).days / date_span
|
||||
y = (val - min_val) / val_range * vb_height
|
||||
points.append((y, x))
|
||||
labels.append((y, x, f'{val} for {name} at {d.strftime("%d %b %y")}'))
|
||||
for pid in distinct_people:
|
||||
person_df = df.filter(pl.col("PersonId") == pid)
|
||||
if person_df.is_empty():
|
||||
continue
|
||||
|
||||
name = person_df.get_column("PersonName")[0]
|
||||
|
||||
y_vals = person_df.get_column("y_scaled").to_list()
|
||||
x_norms = person_df.get_column("x_norm").to_list()
|
||||
msgs = person_df.get_column("msg").to_list()
|
||||
|
||||
points = list(zip(y_vals, x_norms))
|
||||
labels = list(zip(y_vals, x_norms, msgs))
|
||||
|
||||
plots.append({
|
||||
'label': name,
|
||||
'color': colors[i],
|
||||
'color': color_map[pid],
|
||||
'points': points,
|
||||
'plot_labels': labels
|
||||
})
|
||||
@@ -313,5 +201,3 @@ class PeopleGraphs:
|
||||
'vb_height': vb_height,
|
||||
'plots': plots
|
||||
}
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
Reference in New Issue
Block a user