From a4c29e0d8fdded1c905875b6e13e20eae4f0e8c4 Mon Sep 17 00:00:00 2001 From: Peter Stockings Date: Wed, 23 Jul 2025 21:58:43 +1000 Subject: [PATCH] Add db connection pooling --- app.py | 4 +++ db.py | 74 +++++++++++++++++++++++++++++---------------------- extensions.py | 1 + 3 files changed, 47 insertions(+), 32 deletions(-) diff --git a/app.py b/app.py index 206d0a0..800f1ea 100644 --- a/app.py +++ b/app.py @@ -204,3 +204,7 @@ if __name__ == '__main__': # Bind to PORT if defined, otherwise default to 5000. port = int(os.environ.get('PORT', 5000)) app.run(host='127.0.0.1', port=port) + +@app.teardown_appcontext +def teardown_db(exception): + db.close_conn() diff --git a/db.py b/db.py index 090d855..2b1638a 100644 --- a/db.py +++ b/db.py @@ -1,52 +1,62 @@ import json import os import psycopg2 +from psycopg2 import pool from psycopg2.extras import RealDictCursor from urllib.parse import urlparse from flask import g class DataBase(): - def __init__(self, app=None): + def __init__(self): + self.pool = None + + def init_app(self, app): db_url = urlparse(os.environ['DATABASE_URL']) - # if db_url is null then throw error if not db_url: raise Exception("No DATABASE_URL environment variable set") + + self.pool = psycopg2.pool.SimpleConnectionPool( + 1, 20, + database=db_url.path[1:], + user=db_url.username, + password=db_url.password, + host=db_url.hostname, + port=db_url.port + ) + + app.teardown_appcontext(self.close_conn) - def getDB(self): - db = getattr(g, 'database', None) - if db is None: - db_url = urlparse(os.environ['DATABASE_URL']) - g.database = psycopg2.connect( - database=db_url.path[1:], - user=db_url.username, - password=db_url.password, - host=db_url.hostname, - port=db_url.port - ) - db = g.database - return db + def get_conn(self): + if 'db_conn' not in g: + g.db_conn = self.pool.getconn() + return g.db_conn - def close_connection(exception): - db = getattr(g, 'database', None) - if db is not None: - db.close() + def close_conn(self, e=None): + db_conn = g.pop('db_conn', None) + if db_conn is not None: + self.pool.putconn(db_conn) + + def close_all_connections(self): + if self.pool: + self.pool.closeall() def execute(self, query, args=(), one=False, commit=False): - conn = self.getDB() + conn = self.get_conn() cur = conn.cursor(cursor_factory=RealDictCursor) - cur.execute(query, args) - rv = None - if cur.description is not None: - rv = cur.fetchall() - if commit: - try: + try: + cur.execute(query, args) + rv = None + if cur.description is not None: + rv = cur.fetchall() + if commit: conn.commit() - except: - conn.rollback() - cur.close() - - return (rv[0] if rv else None) if one else rv - + return (rv[0] if rv else None) if one else rv + except Exception as e: + conn.rollback() + raise e + finally: + cur.close() + def get_http_functions_for_user(self, user_id): http_functions = self.execute( 'SELECT id, user_id, NAME, script_content, invoked_count, environment_info, is_public, log_request, log_response, version_number FROM http_functions WHERE user_id=%s ORDER by id DESC', [user_id]) diff --git a/extensions.py b/extensions.py index a910c31..f92327a 100644 --- a/extensions.py +++ b/extensions.py @@ -16,6 +16,7 @@ environment = Environment( def init_app(app): htmx.init_app(app) + db.init_app(app) # Add all Flask's default Jinja2 globals and filters environment.globals.update(