From c76b4cd6fc58e4ba60d3a005fb13abfd8954c1f0 Mon Sep 17 00:00:00 2001 From: Peter Stockings Date: Tue, 24 Feb 2026 21:41:55 +1100 Subject: [PATCH] perf: connection pooling, query consolidation, inline chart data, batch milestones --- app/auth.py | 9 ++-- app/db.py | 40 +++++++++----- app/routes/api.py | 27 +++++++--- app/routes/dashboard.py | 98 +++++++++++++++++++++++++--------- app/routes/leaderboard.py | 29 +++++++--- app/templates/dashboard.html | 25 ++++----- app/utils.py | 75 ++++++++++++++++++-------- migrations/003_add_indexes.sql | 4 ++ 8 files changed, 219 insertions(+), 88 deletions(-) create mode 100644 migrations/003_add_indexes.sql diff --git a/app/auth.py b/app/auth.py index b3c733e..214ca79 100644 --- a/app/auth.py +++ b/app/auth.py @@ -1,5 +1,5 @@ from functools import wraps -from flask import session, redirect, url_for, request, jsonify +from flask import g, session, redirect, url_for, request, jsonify from app.db import query_one @@ -14,11 +14,14 @@ def login_required(f): def get_current_user(): - """Get the current logged-in user from the database.""" + """Get the current logged-in user (cached per-request on g).""" + if "current_user" in g: + return g.current_user user_id = session.get("user_id") if user_id is None: return None - return query_one("SELECT * FROM users WHERE id = %s", (user_id,)) + g.current_user = query_one("SELECT * FROM users WHERE id = %s", (user_id,)) + return g.current_user def privacy_guard(f): diff --git a/app/db.py b/app/db.py index 848bc15..566e3fe 100644 --- a/app/db.py +++ b/app/db.py @@ -1,33 +1,41 @@ import psycopg2 import psycopg2.extras +import psycopg2.pool from flask import g, current_app +# Module-level connection pool (initialised once per process) +_pool = None + def init_db(app): - """Test the database connection on startup.""" + """Initialise the connection pool on startup.""" + global _pool try: - conn = psycopg2.connect(app.config["DATABASE_URL"]) - conn.close() - print(" * Database connection OK") + _pool = psycopg2.pool.SimpleConnectionPool( + minconn=2, + maxconn=10, + dsn=app.config["DATABASE_URL"], + ) + print(" * Database connection pool OK (2–10 connections)") except Exception as e: - print(f" * Database connection FAILED: {e}") + print(f" * Database connection pool FAILED: {e}") def get_db(): - """Get a database connection for the current request.""" + """Get a pooled database connection for the current request.""" if "db" not in g: - g.db = psycopg2.connect( - current_app.config["DATABASE_URL"], - cursor_factory=psycopg2.extras.RealDictCursor, - ) + g.db = _pool.getconn() + g.db.cursor_factory = psycopg2.extras.RealDictCursor return g.db def close_db(exception=None): - """Close database connection at end of request.""" + """Return database connection to the pool at end of request.""" db = g.pop("db", None) if db is not None: - db.close() + if exception: + db.rollback() + _pool.putconn(db) def query(sql, params=None): @@ -62,3 +70,11 @@ def execute_returning(sql, params=None): row = cur.fetchone() db.commit() return row + + +def execute_many(sql, params_list): + """Execute a batch INSERT/UPDATE/DELETE and commit.""" + db = get_db() + with db.cursor() as cur: + cur.executemany(sql, params_list) + db.commit() diff --git a/app/routes/api.py b/app/routes/api.py index 35f173a..fce3823 100644 --- a/app/routes/api.py +++ b/app/routes/api.py @@ -44,18 +44,24 @@ def progress_over_time(): where_sql = " AND ".join(where_clauses) + # Use CTE for first_weight instead of correlated subquery rows = query(f""" + WITH first_weights AS ( + SELECT DISTINCT ON (user_id) user_id, weight_kg AS first_weight + FROM checkins + ORDER BY user_id, checked_in_at ASC + ) SELECT u.id AS user_id, u.display_name, u.username, u.starting_weight_kg, - (SELECT weight_kg FROM checkins - WHERE user_id = u.id ORDER BY checked_in_at ASC LIMIT 1) AS first_weight, + fw.first_weight, c.weight_kg, c.checked_in_at FROM checkins c JOIN users u ON u.id = c.user_id + LEFT JOIN first_weights fw ON fw.user_id = u.id WHERE {where_sql} ORDER BY u.id, c.checked_in_at ASC """, params) @@ -144,17 +150,26 @@ def chart_data(user_id): @login_required def comparison(): """Return all-user comparison data for bar charts.""" + # Use CTE with window functions instead of correlated subqueries users = query(""" + WITH user_weights AS ( + SELECT + user_id, + FIRST_VALUE(weight_kg) OVER (PARTITION BY user_id ORDER BY checked_in_at ASC) AS first_weight, + FIRST_VALUE(weight_kg) OVER (PARTITION BY user_id ORDER BY checked_in_at DESC) AS current_weight, + ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY checked_in_at DESC) AS rn + FROM checkins + ) SELECT u.id, u.display_name, u.username, u.starting_weight_kg, - (SELECT weight_kg FROM checkins WHERE user_id = u.id ORDER BY checked_in_at ASC LIMIT 1) as first_weight, - (SELECT weight_kg FROM checkins WHERE user_id = u.id ORDER BY checked_in_at DESC LIMIT 1) as current_weight + uw.first_weight, + uw.current_weight FROM users u - WHERE (SELECT COUNT(*) FROM checkins WHERE user_id = u.id) > 0 - AND u.is_private = FALSE + JOIN user_weights uw ON uw.user_id = u.id AND uw.rn = 1 + WHERE u.is_private = FALSE ORDER BY u.display_name """) diff --git a/app/routes/dashboard.py b/app/routes/dashboard.py index 0de76f3..c3af30f 100644 --- a/app/routes/dashboard.py +++ b/app/routes/dashboard.py @@ -1,6 +1,8 @@ +from datetime import timezone from flask import Blueprint, render_template from app.auth import login_required, get_current_user from app.db import query, query_one +from app.config import SYDNEY_TZ from app.utils import calculate_streak, calculate_weight_change bp = Blueprint("dashboard", __name__) @@ -10,38 +12,53 @@ bp = Blueprint("dashboard", __name__) @login_required def index(): user = get_current_user() + uid = user["id"] - # Get latest check-in - latest = query_one( - "SELECT * FROM checkins WHERE user_id = %s ORDER BY checked_in_at DESC LIMIT 1", - (user["id"],), - ) - - # Get check-in count - stats = query_one( - "SELECT COUNT(*) as total_checkins FROM checkins WHERE user_id = %s", - (user["id"],), - ) - - # Calculate weight change - first_checkin = query_one( - "SELECT weight_kg FROM checkins WHERE user_id = %s ORDER BY checked_in_at ASC LIMIT 1", - (user["id"],), - ) + # --- Single query: latest, first, count via window functions ---------- + summary = query_one(""" + SELECT + total, + first_weight, + latest_weight, + latest_bmi, + latest_at + FROM ( + SELECT + COUNT(*) OVER () AS total, + FIRST_VALUE(weight_kg) OVER (ORDER BY checked_in_at ASC) AS first_weight, + FIRST_VALUE(weight_kg) OVER (ORDER BY checked_in_at DESC) AS latest_weight, + FIRST_VALUE(bmi) OVER (ORDER BY checked_in_at DESC) AS latest_bmi, + FIRST_VALUE(checked_in_at) OVER (ORDER BY checked_in_at DESC) AS latest_at, + ROW_NUMBER() OVER (ORDER BY checked_in_at DESC) AS rn + FROM checkins + WHERE user_id = %s + ) sub + WHERE rn = 1 + """, (uid,)) + # Build lightweight "latest" dict for the template + latest = None weight_change = None weight_change_pct = None - if latest and first_checkin: + total_checkins = 0 + + if summary: + total_checkins = summary["total"] + latest = { + "weight_kg": summary["latest_weight"], + "bmi": summary["latest_bmi"], + "checked_in_at": summary["latest_at"], + } kg_lost, pct_lost = calculate_weight_change( - first_checkin["weight_kg"], latest["weight_kg"] + summary["first_weight"], summary["latest_weight"] ) - weight_change = round(-kg_lost, 1) # negative = gained, positive = lost + weight_change = round(-kg_lost, 1) weight_change_pct = round(-pct_lost, 1) # Recent check-ins (last 5) recent_checkins = query( "SELECT * FROM checkins WHERE user_id = %s ORDER BY checked_in_at DESC LIMIT 5", - (user["id"],), + (uid,), ) # Activity feed (recent check-ins from all users) @@ -52,26 +69,57 @@ def index(): WHERE u.is_private = FALSE OR u.id = %s ORDER BY c.checked_in_at DESC LIMIT 10 - """, (user["id"],)) + """, (uid,)) # Milestones milestones = query( "SELECT * FROM milestones WHERE user_id = %s ORDER BY achieved_at DESC", - (user["id"],), + (uid,), ) # Streak - streak = calculate_streak(user["id"]) + streak = calculate_streak(uid) + + # --- Pre-compute chart data (eliminates 2 client-side fetches) -------- + chart_checkins = query( + """SELECT weight_kg, bmi, checked_in_at + FROM checkins WHERE user_id = %s + ORDER BY checked_in_at ASC""", + (uid,), + ) + + chart_labels = [] + chart_weights = [] + weekly_labels = [] + weekly_changes = [] + + for i, c in enumerate(chart_checkins): + dt = c["checked_in_at"] + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + label = dt.astimezone(SYDNEY_TZ).strftime("%d %b") + chart_labels.append(label) + chart_weights.append(float(c["weight_kg"])) + + if i > 0: + prev_w = float(chart_checkins[i - 1]["weight_kg"]) + curr_w = float(c["weight_kg"]) + weekly_labels.append(label) + weekly_changes.append(round(curr_w - prev_w, 1)) return render_template( "dashboard.html", user=user, latest=latest, - stats=stats, + stats={"total_checkins": total_checkins}, weight_change=weight_change, weight_change_pct=weight_change_pct, recent_checkins=recent_checkins, activity=activity, milestones=milestones, streak=streak, + chart_labels=chart_labels, + chart_weights=chart_weights, + weekly_labels=weekly_labels, + weekly_changes=weekly_changes, ) diff --git a/app/routes/leaderboard.py b/app/routes/leaderboard.py index 35f52c2..1e098a3 100644 --- a/app/routes/leaderboard.py +++ b/app/routes/leaderboard.py @@ -2,7 +2,7 @@ from flask import Blueprint, render_template from app.auth import login_required from app.db import query, query_one from app.config import SYDNEY_TZ -from app.utils import calculate_streak, calculate_weight_change +from app.utils import calculate_streaks_bulk, calculate_weight_change from datetime import timezone bp = Blueprint("leaderboard", __name__) @@ -11,23 +11,38 @@ bp = Blueprint("leaderboard", __name__) @bp.route("/leaderboard") @login_required def index(): - # Get all users with their weight stats + # Get all users with weight stats using window functions (no correlated subqueries) users = query(""" + WITH user_weights AS ( + SELECT + user_id, + FIRST_VALUE(weight_kg) OVER (PARTITION BY user_id ORDER BY checked_in_at ASC) AS first_weight, + FIRST_VALUE(weight_kg) OVER (PARTITION BY user_id ORDER BY checked_in_at DESC) AS current_weight, + COUNT(*) OVER (PARTITION BY user_id) AS total_checkins, + MAX(checked_in_at) OVER (PARTITION BY user_id) AS last_checkin, + ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY checked_in_at DESC) AS rn + FROM checkins + ) SELECT u.id, u.display_name, u.username, u.starting_weight_kg, u.goal_weight_kg, - (SELECT weight_kg FROM checkins WHERE user_id = u.id ORDER BY checked_in_at ASC LIMIT 1) as first_weight, - (SELECT weight_kg FROM checkins WHERE user_id = u.id ORDER BY checked_in_at DESC LIMIT 1) as current_weight, - (SELECT COUNT(*) FROM checkins WHERE user_id = u.id) as total_checkins, - (SELECT checked_in_at FROM checkins WHERE user_id = u.id ORDER BY checked_in_at DESC LIMIT 1) as last_checkin + uw.first_weight, + uw.current_weight, + uw.total_checkins, + uw.last_checkin FROM users u + JOIN user_weights uw ON uw.user_id = u.id AND uw.rn = 1 WHERE u.is_private = FALSE ORDER BY u.created_at """) + # Batch-compute streaks for all users in one query + user_ids = [u["id"] for u in users] + all_streaks = calculate_streaks_bulk(user_ids) + # Calculate rankings ranked = [] for u in users: @@ -41,7 +56,7 @@ def index(): total_to_lose = start_w - goal goal_progress = min(100, round((weight_lost / total_to_lose) * 100, 1)) if total_to_lose > 0 else 0 - streak = calculate_streak(u["id"]) + streak = all_streaks.get(u["id"], {"current": 0, "best": 0}) ranked.append({ **u, "weight_lost": weight_lost, diff --git a/app/templates/dashboard.html b/app/templates/dashboard.html index 7a729b3..0cca95f 100644 --- a/app/templates/dashboard.html +++ b/app/templates/dashboard.html @@ -143,21 +143,18 @@ {% endblock %} \ No newline at end of file diff --git a/app/utils.py b/app/utils.py index 585c0c1..3b55e09 100644 --- a/app/utils.py +++ b/app/utils.py @@ -4,7 +4,7 @@ Shared business-logic helpers. Keep route handlers thin — calculation logic lives here. """ -from app.db import query, execute +from app.db import query, execute_many from app.config import SYDNEY_TZ from datetime import datetime, timedelta @@ -66,19 +66,11 @@ def calculate_weight_change(start_w, current_w): # Streaks # --------------------------------------------------------------------------- -def calculate_streak(user_id): - """Calculate current and best consecutive-day check-in streaks.""" - rows = query( - """SELECT DISTINCT (checked_in_at AT TIME ZONE 'UTC' AT TIME ZONE 'Australia/Sydney')::date AS d - FROM checkins WHERE user_id = %s ORDER BY d DESC""", - (user_id,), - ) - if not rows: +def _compute_streak_from_dates(days, today): + """Compute current and best streak from a sorted-desc list of dates.""" + if not days: return {"current": 0, "best": 0} - days = [r["d"] for r in rows] - today = datetime.now(SYDNEY_TZ).date() - # Current streak: must include today or yesterday to count current = 0 expected = today @@ -105,6 +97,50 @@ def calculate_streak(user_id): return {"current": current, "best": best} +def calculate_streak(user_id): + """Calculate current and best consecutive-day check-in streaks.""" + rows = query( + """SELECT DISTINCT (checked_in_at AT TIME ZONE 'UTC' AT TIME ZONE 'Australia/Sydney')::date AS d + FROM checkins WHERE user_id = %s ORDER BY d DESC""", + (user_id,), + ) + days = [r["d"] for r in rows] + today = datetime.now(SYDNEY_TZ).date() + return _compute_streak_from_dates(days, today) + + +def calculate_streaks_bulk(user_ids): + """Calculate streaks for multiple users in a single query. + + Returns a dict: {user_id: {"current": int, "best": int}}. + """ + if not user_ids: + return {} + + placeholders = ",".join(["%s"] * len(user_ids)) + rows = query( + f"""SELECT user_id, + (checked_in_at AT TIME ZONE 'UTC' AT TIME ZONE 'Australia/Sydney')::date AS d + FROM checkins + WHERE user_id IN ({placeholders}) + GROUP BY user_id, d + ORDER BY user_id, d DESC""", + tuple(user_ids), + ) + + # Group by user + from collections import defaultdict + user_days = defaultdict(list) + for r in rows: + user_days[r["user_id"]].append(r["d"]) + + today = datetime.now(SYDNEY_TZ).date() + result = {} + for uid in user_ids: + result[uid] = _compute_streak_from_dates(user_days.get(uid, []), today) + return result + + # --------------------------------------------------------------------------- # Milestone checker # --------------------------------------------------------------------------- @@ -136,15 +172,12 @@ def check_milestones(user_id, user): ("lost_20kg", total_lost >= 20), ] - for key, achieved in milestone_checks: - if achieved: - try: - execute( - "INSERT INTO milestones (user_id, milestone_key) VALUES (%s, %s) ON CONFLICT DO NOTHING", - (user_id, key), - ) - except Exception: - pass + achieved = [(user_id, key) for key, ok in milestone_checks if ok] + if achieved: + execute_many( + "INSERT INTO milestones (user_id, milestone_key) VALUES (%s, %s) ON CONFLICT DO NOTHING", + achieved, + ) # --------------------------------------------------------------------------- diff --git a/migrations/003_add_indexes.sql b/migrations/003_add_indexes.sql new file mode 100644 index 0000000..e9b3d0c --- /dev/null +++ b/migrations/003_add_indexes.sql @@ -0,0 +1,4 @@ +-- Migration 003: Add indexes for performance +-- Partial index to speed up queries that filter on is_private = FALSE + +CREATE INDEX IF NOT EXISTS idx_users_public ON users(id) WHERE is_private = FALSE;