diff --git a/app/auth.py b/app/auth.py
index b3c733e..214ca79 100644
--- a/app/auth.py
+++ b/app/auth.py
@@ -1,5 +1,5 @@
from functools import wraps
-from flask import session, redirect, url_for, request, jsonify
+from flask import g, session, redirect, url_for, request, jsonify
from app.db import query_one
@@ -14,11 +14,14 @@ def login_required(f):
def get_current_user():
- """Get the current logged-in user from the database."""
+ """Get the current logged-in user (cached per-request on g)."""
+ if "current_user" in g:
+ return g.current_user
user_id = session.get("user_id")
if user_id is None:
return None
- return query_one("SELECT * FROM users WHERE id = %s", (user_id,))
+ g.current_user = query_one("SELECT * FROM users WHERE id = %s", (user_id,))
+ return g.current_user
def privacy_guard(f):
diff --git a/app/db.py b/app/db.py
index 848bc15..566e3fe 100644
--- a/app/db.py
+++ b/app/db.py
@@ -1,33 +1,41 @@
import psycopg2
import psycopg2.extras
+import psycopg2.pool
from flask import g, current_app
+# Module-level connection pool (initialised once per process)
+_pool = None
+
def init_db(app):
- """Test the database connection on startup."""
+ """Initialise the connection pool on startup."""
+ global _pool
try:
- conn = psycopg2.connect(app.config["DATABASE_URL"])
- conn.close()
- print(" * Database connection OK")
+ _pool = psycopg2.pool.SimpleConnectionPool(
+ minconn=2,
+ maxconn=10,
+ dsn=app.config["DATABASE_URL"],
+ )
+ print(" * Database connection pool OK (2–10 connections)")
except Exception as e:
- print(f" * Database connection FAILED: {e}")
+ print(f" * Database connection pool FAILED: {e}")
def get_db():
- """Get a database connection for the current request."""
+ """Get a pooled database connection for the current request."""
if "db" not in g:
- g.db = psycopg2.connect(
- current_app.config["DATABASE_URL"],
- cursor_factory=psycopg2.extras.RealDictCursor,
- )
+ g.db = _pool.getconn()
+ g.db.cursor_factory = psycopg2.extras.RealDictCursor
return g.db
def close_db(exception=None):
- """Close database connection at end of request."""
+ """Return database connection to the pool at end of request."""
db = g.pop("db", None)
if db is not None:
- db.close()
+ if exception:
+ db.rollback()
+ _pool.putconn(db)
def query(sql, params=None):
@@ -62,3 +70,11 @@ def execute_returning(sql, params=None):
row = cur.fetchone()
db.commit()
return row
+
+
+def execute_many(sql, params_list):
+ """Execute a batch INSERT/UPDATE/DELETE and commit."""
+ db = get_db()
+ with db.cursor() as cur:
+ cur.executemany(sql, params_list)
+ db.commit()
diff --git a/app/routes/api.py b/app/routes/api.py
index 35f173a..fce3823 100644
--- a/app/routes/api.py
+++ b/app/routes/api.py
@@ -44,18 +44,24 @@ def progress_over_time():
where_sql = " AND ".join(where_clauses)
+ # Use CTE for first_weight instead of correlated subquery
rows = query(f"""
+ WITH first_weights AS (
+ SELECT DISTINCT ON (user_id) user_id, weight_kg AS first_weight
+ FROM checkins
+ ORDER BY user_id, checked_in_at ASC
+ )
SELECT
u.id AS user_id,
u.display_name,
u.username,
u.starting_weight_kg,
- (SELECT weight_kg FROM checkins
- WHERE user_id = u.id ORDER BY checked_in_at ASC LIMIT 1) AS first_weight,
+ fw.first_weight,
c.weight_kg,
c.checked_in_at
FROM checkins c
JOIN users u ON u.id = c.user_id
+ LEFT JOIN first_weights fw ON fw.user_id = u.id
WHERE {where_sql}
ORDER BY u.id, c.checked_in_at ASC
""", params)
@@ -144,17 +150,26 @@ def chart_data(user_id):
@login_required
def comparison():
"""Return all-user comparison data for bar charts."""
+ # Use CTE with window functions instead of correlated subqueries
users = query("""
+ WITH user_weights AS (
+ SELECT
+ user_id,
+ FIRST_VALUE(weight_kg) OVER (PARTITION BY user_id ORDER BY checked_in_at ASC) AS first_weight,
+ FIRST_VALUE(weight_kg) OVER (PARTITION BY user_id ORDER BY checked_in_at DESC) AS current_weight,
+ ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY checked_in_at DESC) AS rn
+ FROM checkins
+ )
SELECT
u.id,
u.display_name,
u.username,
u.starting_weight_kg,
- (SELECT weight_kg FROM checkins WHERE user_id = u.id ORDER BY checked_in_at ASC LIMIT 1) as first_weight,
- (SELECT weight_kg FROM checkins WHERE user_id = u.id ORDER BY checked_in_at DESC LIMIT 1) as current_weight
+ uw.first_weight,
+ uw.current_weight
FROM users u
- WHERE (SELECT COUNT(*) FROM checkins WHERE user_id = u.id) > 0
- AND u.is_private = FALSE
+ JOIN user_weights uw ON uw.user_id = u.id AND uw.rn = 1
+ WHERE u.is_private = FALSE
ORDER BY u.display_name
""")
diff --git a/app/routes/dashboard.py b/app/routes/dashboard.py
index 0de76f3..c3af30f 100644
--- a/app/routes/dashboard.py
+++ b/app/routes/dashboard.py
@@ -1,6 +1,8 @@
+from datetime import timezone
from flask import Blueprint, render_template
from app.auth import login_required, get_current_user
from app.db import query, query_one
+from app.config import SYDNEY_TZ
from app.utils import calculate_streak, calculate_weight_change
bp = Blueprint("dashboard", __name__)
@@ -10,38 +12,53 @@ bp = Blueprint("dashboard", __name__)
@login_required
def index():
user = get_current_user()
+ uid = user["id"]
- # Get latest check-in
- latest = query_one(
- "SELECT * FROM checkins WHERE user_id = %s ORDER BY checked_in_at DESC LIMIT 1",
- (user["id"],),
- )
-
- # Get check-in count
- stats = query_one(
- "SELECT COUNT(*) as total_checkins FROM checkins WHERE user_id = %s",
- (user["id"],),
- )
-
- # Calculate weight change
- first_checkin = query_one(
- "SELECT weight_kg FROM checkins WHERE user_id = %s ORDER BY checked_in_at ASC LIMIT 1",
- (user["id"],),
- )
+ # --- Single query: latest, first, count via window functions ----------
+ summary = query_one("""
+ SELECT
+ total,
+ first_weight,
+ latest_weight,
+ latest_bmi,
+ latest_at
+ FROM (
+ SELECT
+ COUNT(*) OVER () AS total,
+ FIRST_VALUE(weight_kg) OVER (ORDER BY checked_in_at ASC) AS first_weight,
+ FIRST_VALUE(weight_kg) OVER (ORDER BY checked_in_at DESC) AS latest_weight,
+ FIRST_VALUE(bmi) OVER (ORDER BY checked_in_at DESC) AS latest_bmi,
+ FIRST_VALUE(checked_in_at) OVER (ORDER BY checked_in_at DESC) AS latest_at,
+ ROW_NUMBER() OVER (ORDER BY checked_in_at DESC) AS rn
+ FROM checkins
+ WHERE user_id = %s
+ ) sub
+ WHERE rn = 1
+ """, (uid,))
+ # Build lightweight "latest" dict for the template
+ latest = None
weight_change = None
weight_change_pct = None
- if latest and first_checkin:
+ total_checkins = 0
+
+ if summary:
+ total_checkins = summary["total"]
+ latest = {
+ "weight_kg": summary["latest_weight"],
+ "bmi": summary["latest_bmi"],
+ "checked_in_at": summary["latest_at"],
+ }
kg_lost, pct_lost = calculate_weight_change(
- first_checkin["weight_kg"], latest["weight_kg"]
+ summary["first_weight"], summary["latest_weight"]
)
- weight_change = round(-kg_lost, 1) # negative = gained, positive = lost
+ weight_change = round(-kg_lost, 1)
weight_change_pct = round(-pct_lost, 1)
# Recent check-ins (last 5)
recent_checkins = query(
"SELECT * FROM checkins WHERE user_id = %s ORDER BY checked_in_at DESC LIMIT 5",
- (user["id"],),
+ (uid,),
)
# Activity feed (recent check-ins from all users)
@@ -52,26 +69,57 @@ def index():
WHERE u.is_private = FALSE OR u.id = %s
ORDER BY c.checked_in_at DESC
LIMIT 10
- """, (user["id"],))
+ """, (uid,))
# Milestones
milestones = query(
"SELECT * FROM milestones WHERE user_id = %s ORDER BY achieved_at DESC",
- (user["id"],),
+ (uid,),
)
# Streak
- streak = calculate_streak(user["id"])
+ streak = calculate_streak(uid)
+
+ # --- Pre-compute chart data (eliminates 2 client-side fetches) --------
+ chart_checkins = query(
+ """SELECT weight_kg, bmi, checked_in_at
+ FROM checkins WHERE user_id = %s
+ ORDER BY checked_in_at ASC""",
+ (uid,),
+ )
+
+ chart_labels = []
+ chart_weights = []
+ weekly_labels = []
+ weekly_changes = []
+
+ for i, c in enumerate(chart_checkins):
+ dt = c["checked_in_at"]
+ if dt.tzinfo is None:
+ dt = dt.replace(tzinfo=timezone.utc)
+ label = dt.astimezone(SYDNEY_TZ).strftime("%d %b")
+ chart_labels.append(label)
+ chart_weights.append(float(c["weight_kg"]))
+
+ if i > 0:
+ prev_w = float(chart_checkins[i - 1]["weight_kg"])
+ curr_w = float(c["weight_kg"])
+ weekly_labels.append(label)
+ weekly_changes.append(round(curr_w - prev_w, 1))
return render_template(
"dashboard.html",
user=user,
latest=latest,
- stats=stats,
+ stats={"total_checkins": total_checkins},
weight_change=weight_change,
weight_change_pct=weight_change_pct,
recent_checkins=recent_checkins,
activity=activity,
milestones=milestones,
streak=streak,
+ chart_labels=chart_labels,
+ chart_weights=chart_weights,
+ weekly_labels=weekly_labels,
+ weekly_changes=weekly_changes,
)
diff --git a/app/routes/leaderboard.py b/app/routes/leaderboard.py
index 35f52c2..1e098a3 100644
--- a/app/routes/leaderboard.py
+++ b/app/routes/leaderboard.py
@@ -2,7 +2,7 @@ from flask import Blueprint, render_template
from app.auth import login_required
from app.db import query, query_one
from app.config import SYDNEY_TZ
-from app.utils import calculate_streak, calculate_weight_change
+from app.utils import calculate_streaks_bulk, calculate_weight_change
from datetime import timezone
bp = Blueprint("leaderboard", __name__)
@@ -11,23 +11,38 @@ bp = Blueprint("leaderboard", __name__)
@bp.route("/leaderboard")
@login_required
def index():
- # Get all users with their weight stats
+ # Get all users with weight stats using window functions (no correlated subqueries)
users = query("""
+ WITH user_weights AS (
+ SELECT
+ user_id,
+ FIRST_VALUE(weight_kg) OVER (PARTITION BY user_id ORDER BY checked_in_at ASC) AS first_weight,
+ FIRST_VALUE(weight_kg) OVER (PARTITION BY user_id ORDER BY checked_in_at DESC) AS current_weight,
+ COUNT(*) OVER (PARTITION BY user_id) AS total_checkins,
+ MAX(checked_in_at) OVER (PARTITION BY user_id) AS last_checkin,
+ ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY checked_in_at DESC) AS rn
+ FROM checkins
+ )
SELECT
u.id,
u.display_name,
u.username,
u.starting_weight_kg,
u.goal_weight_kg,
- (SELECT weight_kg FROM checkins WHERE user_id = u.id ORDER BY checked_in_at ASC LIMIT 1) as first_weight,
- (SELECT weight_kg FROM checkins WHERE user_id = u.id ORDER BY checked_in_at DESC LIMIT 1) as current_weight,
- (SELECT COUNT(*) FROM checkins WHERE user_id = u.id) as total_checkins,
- (SELECT checked_in_at FROM checkins WHERE user_id = u.id ORDER BY checked_in_at DESC LIMIT 1) as last_checkin
+ uw.first_weight,
+ uw.current_weight,
+ uw.total_checkins,
+ uw.last_checkin
FROM users u
+ JOIN user_weights uw ON uw.user_id = u.id AND uw.rn = 1
WHERE u.is_private = FALSE
ORDER BY u.created_at
""")
+ # Batch-compute streaks for all users in one query
+ user_ids = [u["id"] for u in users]
+ all_streaks = calculate_streaks_bulk(user_ids)
+
# Calculate rankings
ranked = []
for u in users:
@@ -41,7 +56,7 @@ def index():
total_to_lose = start_w - goal
goal_progress = min(100, round((weight_lost / total_to_lose) * 100, 1)) if total_to_lose > 0 else 0
- streak = calculate_streak(u["id"])
+ streak = all_streaks.get(u["id"], {"current": 0, "best": 0})
ranked.append({
**u,
"weight_lost": weight_lost,
diff --git a/app/templates/dashboard.html b/app/templates/dashboard.html
index 7a729b3..0cca95f 100644
--- a/app/templates/dashboard.html
+++ b/app/templates/dashboard.html
@@ -143,21 +143,18 @@
{% endblock %}
\ No newline at end of file
diff --git a/app/utils.py b/app/utils.py
index 585c0c1..3b55e09 100644
--- a/app/utils.py
+++ b/app/utils.py
@@ -4,7 +4,7 @@ Shared business-logic helpers.
Keep route handlers thin — calculation logic lives here.
"""
-from app.db import query, execute
+from app.db import query, execute_many
from app.config import SYDNEY_TZ
from datetime import datetime, timedelta
@@ -66,19 +66,11 @@ def calculate_weight_change(start_w, current_w):
# Streaks
# ---------------------------------------------------------------------------
-def calculate_streak(user_id):
- """Calculate current and best consecutive-day check-in streaks."""
- rows = query(
- """SELECT DISTINCT (checked_in_at AT TIME ZONE 'UTC' AT TIME ZONE 'Australia/Sydney')::date AS d
- FROM checkins WHERE user_id = %s ORDER BY d DESC""",
- (user_id,),
- )
- if not rows:
+def _compute_streak_from_dates(days, today):
+ """Compute current and best streak from a sorted-desc list of dates."""
+ if not days:
return {"current": 0, "best": 0}
- days = [r["d"] for r in rows]
- today = datetime.now(SYDNEY_TZ).date()
-
# Current streak: must include today or yesterday to count
current = 0
expected = today
@@ -105,6 +97,50 @@ def calculate_streak(user_id):
return {"current": current, "best": best}
+def calculate_streak(user_id):
+ """Calculate current and best consecutive-day check-in streaks."""
+ rows = query(
+ """SELECT DISTINCT (checked_in_at AT TIME ZONE 'UTC' AT TIME ZONE 'Australia/Sydney')::date AS d
+ FROM checkins WHERE user_id = %s ORDER BY d DESC""",
+ (user_id,),
+ )
+ days = [r["d"] for r in rows]
+ today = datetime.now(SYDNEY_TZ).date()
+ return _compute_streak_from_dates(days, today)
+
+
+def calculate_streaks_bulk(user_ids):
+ """Calculate streaks for multiple users in a single query.
+
+ Returns a dict: {user_id: {"current": int, "best": int}}.
+ """
+ if not user_ids:
+ return {}
+
+ placeholders = ",".join(["%s"] * len(user_ids))
+ rows = query(
+ f"""SELECT user_id,
+ (checked_in_at AT TIME ZONE 'UTC' AT TIME ZONE 'Australia/Sydney')::date AS d
+ FROM checkins
+ WHERE user_id IN ({placeholders})
+ GROUP BY user_id, d
+ ORDER BY user_id, d DESC""",
+ tuple(user_ids),
+ )
+
+ # Group by user
+ from collections import defaultdict
+ user_days = defaultdict(list)
+ for r in rows:
+ user_days[r["user_id"]].append(r["d"])
+
+ today = datetime.now(SYDNEY_TZ).date()
+ result = {}
+ for uid in user_ids:
+ result[uid] = _compute_streak_from_dates(user_days.get(uid, []), today)
+ return result
+
+
# ---------------------------------------------------------------------------
# Milestone checker
# ---------------------------------------------------------------------------
@@ -136,15 +172,12 @@ def check_milestones(user_id, user):
("lost_20kg", total_lost >= 20),
]
- for key, achieved in milestone_checks:
- if achieved:
- try:
- execute(
- "INSERT INTO milestones (user_id, milestone_key) VALUES (%s, %s) ON CONFLICT DO NOTHING",
- (user_id, key),
- )
- except Exception:
- pass
+ achieved = [(user_id, key) for key, ok in milestone_checks if ok]
+ if achieved:
+ execute_many(
+ "INSERT INTO milestones (user_id, milestone_key) VALUES (%s, %s) ON CONFLICT DO NOTHING",
+ achieved,
+ )
# ---------------------------------------------------------------------------
diff --git a/migrations/003_add_indexes.sql b/migrations/003_add_indexes.sql
new file mode 100644
index 0000000..e9b3d0c
--- /dev/null
+++ b/migrations/003_add_indexes.sql
@@ -0,0 +1,4 @@
+-- Migration 003: Add indexes for performance
+-- Partial index to speed up queries that filter on is_private = FALSE
+
+CREATE INDEX IF NOT EXISTS idx_users_public ON users(id) WHERE is_private = FALSE;