import polars as pl from utils import get_distinct_colors class PeopleGraphs: def __init__(self, db_connection_method): self.execute = db_connection_method def get(self, selected_people_ids=None, min_date=None, max_date=None, selected_exercise_ids=None): """ Fetch workout topsets, calculate Estimated1RM in Polars, then generate weekly workout & PR graphs. """ # Build query (no in-SQL 1RM calculation). query = """ SELECT P.person_id AS "PersonId", P.name AS "PersonName", W.workout_id AS "WorkoutId", W.start_date AS "StartDate", T.topset_id AS "TopSetId", E.exercise_id AS "ExerciseId", E.name AS "ExerciseName", T.repetitions AS "Repetitions", T.weight AS "Weight" FROM Person P LEFT JOIN Workout W ON P.person_id = W.person_id LEFT JOIN TopSet T ON W.workout_id = T.workout_id LEFT JOIN Exercise E ON T.exercise_id = E.exercise_id WHERE TRUE """ params = [] if selected_people_ids: query += f" AND P.person_id IN ({', '.join(['%s'] * len(selected_people_ids))})" params.extend(selected_people_ids) if min_date: query += " AND W.start_date >= %s" params.append(min_date) if max_date: query += " AND W.start_date <= %s" params.append(max_date) if selected_exercise_ids: query += f" AND E.exercise_id IN ({', '.join(['%s'] * len(selected_exercise_ids))})" params.extend(selected_exercise_ids) # Execute and convert to DataFrame raw_data = self.execute(query, params) if not raw_data: # Return empty graphs if no data at all return [ self.get_graph_model("Workouts per week", {}), self.get_graph_model("PRs per week", {}) ] # Explicitly specify schema to ensure correct types schema_overrides = { "Weight": pl.Float64, "Repetitions": pl.Int64, "StartDate": pl.Date # Or pl.Datetime depending on DB driver, but usually Date for dates } # Depending on how 'self.execute' returns data (list of dicts usually), # Polars can infer schema. For robustness with DB types: try: df = pl.DataFrame(raw_data, schema_overrides=schema_overrides, infer_schema_length=1000) except: # Fallback if specific schema injection fails due to mismatched input types df = pl.DataFrame(raw_data) # Calculate Estimated1RM in Polars # Formula: round((100 * int(weight)) / (101.3 - 2.67123 * repetitions), 0) # Handle division by zero implicitly by filter or usage? # The original code only avoided div by zero if Repetitions == 0. # Polars handles nulls/NaNs usually, but let's replicate logic. df = df.with_columns( pl.when(pl.col("Repetitions") == 0) .then(0) .otherwise( (pl.lit(100) * pl.col("Weight")) / (pl.lit(101.3) - pl.lit(2.67123) * pl.col("Repetitions")) ) .round(0) .cast(pl.Int64) .alias("Estimated1RM") ) # Build the weekly data models weekly_counts = self.get_workout_counts(df, period='week') weekly_pr_counts = self.count_prs_over_time(df, period='week') return [ self.get_graph_model("Workouts per week", weekly_counts), self.get_graph_model("PRs per week", weekly_pr_counts) ] def _prepare_period_column(self, df: pl.DataFrame, period='week'): """ Convert StartDate to proper date type and add a Period column represented as the start date of that period. """ # Ensure StartDate is Date/Datetime if df["StartDate"].dtype == pl.String: df = df.with_columns(pl.col("StartDate").str.strptime(pl.Date, "%Y-%m-%d")) # Adjust format if needed elif df["StartDate"].dtype == pl.Object: # If it's python datetime objects df = df.with_columns(pl.col("StartDate").cast(pl.Date)) # Truncate to week or month if period == 'week': # Polars doesn't have a direct 'to_period' like Pandas. # We can use dt.truncate("1w") which floors to start of week (Monday usually) # Postgres/standard weeks usually start Monday. df = df.with_columns( pl.col("StartDate").dt.truncate("1w").alias("Period") ) else: # month df = df.with_columns( pl.col("StartDate").dt.truncate("1mo").alias("Period") ) return df def get_workout_counts(self, df: pl.DataFrame, period='week'): """ Returns workout counts per person per period. """ df = self._prepare_period_column(df, period) # Ensure Period is string for consistent pivoting df = df.with_columns(pl.col("Period").dt.strftime("%Y-%m-%d")) # Count unique workouts per (PersonId, PersonName, Period) grp = ( df.group_by(['PersonId', 'PersonName', 'Period']) .agg(pl.col('WorkoutId').n_unique().alias('Count')) ) return self._pivot_to_graph_dict( grp, index_col='PersonId', name_col='PersonName', period_col='Period', value_col='Count' ) def count_prs_over_time(self, df: pl.DataFrame, period='week'): """ Returns number of PRs hit per person per period. """ df = self._prepare_period_column(df, period) # Max 1RM per (Person, Exercise, Period) - 'PeriodMax' grouped = ( df.group_by(['PersonId', 'PersonName', 'ExerciseId', 'Period']) .agg(pl.col('Estimated1RM').max().alias('PeriodMax')) ) # Sort so we can track "all-time max" grouped = grouped.sort(by=['PersonId', 'ExerciseId', 'Period']) # Calculate AllTimeMax representing the max UP TO the previous row. grouped = grouped.with_columns( pl.col("PeriodMax") .cum_max() .over(['PersonId', 'ExerciseId']) .shift(1) .alias("AllTimeMax") ) grouped = grouped.with_columns( pl.col("AllTimeMax").fill_null(0) ) grouped = grouped.with_columns( (pl.col("PeriodMax") > pl.col("AllTimeMax")).cast(pl.Int64).alias("IsPR") ) # Ensure Period is string for consistent pivoting grouped = grouped.with_columns(pl.col("Period").dt.strftime("%Y-%m-%d")) # Sum PRs across exercises for (Person, Period) pr_counts = ( grouped.group_by(['PersonId', 'PersonName', 'Period']) .agg(pl.col('IsPR').sum().alias('Count')) ) return self._pivot_to_graph_dict( pr_counts, index_col='PersonId', name_col='PersonName', period_col='Period', value_col='Count' ) def _pivot_to_graph_dict(self, df: pl.DataFrame, index_col, name_col, period_col, value_col): """ Convert Polars DataFrame to the nested dict structure expected by visualization. """ if df.is_empty(): return {} # Pivot pivoted = df.pivot( values=value_col, index=[index_col, name_col], columns=period_col, aggregate_function="sum" ).fill_null(0) rows = pivoted.to_dicts() date_cols = [c for c in pivoted.columns if c not in [index_col, name_col]] result = {} for row in rows: pid = row[index_col] pname = row[name_col] period_counts = {} for dc in date_cols: val = row[dc] # If val is 0, we can skip if sparse behavior is desired, # but let's keep it to match original behavior exactly if possible. # Only adding if val > 0 or if we want full zeros? # Original Pandas pivot keeps all columns (dates) for all rows, filling NaNs (0). # The iteration later in 'get_graph_model' determines what to plot. # Parse date string back to date object try: d_obj = datetime.strptime(str(dc), "%Y-%m-%d").date() period_counts[d_obj] = val except ValueError: # Should not happen if we controlled the format print(f"Warning: Could not parse date column {dc}") period_counts[dc] = val result[pid] = { 'PersonName': pname, 'PRCounts': period_counts } return result def get_graph_model(self, title, data_dict): """ Builds a line-graph model from the dictionary. This part remains mostly standard Python as it manipulates dicts. """ if not data_dict: return { 'title': title, 'vb_width': 200, 'vb_height': 75, 'plots': [] } # Gather all dates & values all_dates = [] all_values = [] for user_data in data_dict.values(): all_dates.extend(user_data['PRCounts'].keys()) all_values.extend(user_data['PRCounts'].values()) if not all_dates: return { 'title': title, 'vb_width': 200, 'vb_height': 75, 'plots': [] } min_date = min(all_dates) max_date = max(all_dates) # Ensure min_date/max_date are comparable types. # If they are strings vs dates, that's an issue. # We tried to enforce conversion in _pivot_to_graph_dict. date_span = max((max_date - min_date).days, 1) max_val = max(all_values) min_val = 0 val_range = max_val - min_val if max_val != min_val else 1 vb_width, vb_height = 200, 75 colors = get_distinct_colors(len(data_dict)) plots = [] for i, (pid, user_data) in enumerate(data_dict.items()): name = user_data['PersonName'] pr_counts = user_data['PRCounts'] # Sort by date so points are in chronological order sorted_pr = sorted(pr_counts.items(), key=lambda x: x[0]) points = [] labels = [] for d, val in sorted_pr: # Scale x,y to fit [0..1], then we multiply y by vb_height x = (d - min_date).days / date_span y = (val - min_val) / val_range * vb_height points.append((y, x)) labels.append((y, x, f'{val} for {name} at {d.strftime("%d %b %y")}')) plots.append({ 'label': name, 'color': colors[i], 'points': points, 'plot_labels': labels }) return { 'title': title, 'vb_width': vb_width, 'vb_height': vb_height, 'plots': plots } from datetime import datetime