From 3a07b9d97f26e99c958fca69021cd2d5453efd6e Mon Sep 17 00:00:00 2001 From: Peter Stockings Date: Wed, 6 Nov 2024 22:48:51 +1100 Subject: [PATCH] WIP: Render database schema using Mermaid.js Still need to: * Move mermaid.js to static files * Make template for mermaid wrapper * Create new page for SQL viewer then add explorer --- app.py | 16 +++++++ db.py | 2 + features/sql_viewer.py | 94 +++++++++++++++++++++++++++++++++++++++++ templates/base.html | 13 ++++++ templates/settings.html | 2 + 5 files changed, 127 insertions(+) create mode 100644 features/sql_viewer.py diff --git a/app.py b/app.py index 81b2d1b..030a86c 100644 --- a/app.py +++ b/app.py @@ -469,6 +469,22 @@ def delete_exercise(exercise_id): db.exercises.delete_exercise(exercise_id) return "" +@ app.route("/sql_schema", methods=['GET']) +def get_sql_schema(): + schema_info = db.sql_viewer.get_schema_info() + mermaid_code = db.sql_viewer.generate_mermaid_er(schema_info) + html_content = f''' +
+
+{mermaid_code} +
+
+ ''' + return html_content + @app.teardown_appcontext def closeConnection(exception): db.close_connection() diff --git a/db.py b/db.py index e69e630..cfc0789 100644 --- a/db.py +++ b/db.py @@ -10,6 +10,7 @@ from features.calendar import Calendar from features.exercises import Exercises from features.stats import Stats from features.workout import Workout +from features.sql_viewer import SQLViewer from utils import count_prs_over_time, get_all_exercises_from_topsets, get_exercise_graph_model, get_stats_from_topsets, get_topsets_for_person, get_weekly_pr_graph_model, get_workout_counts, get_workouts @@ -19,6 +20,7 @@ class DataBase(): self.stats = Stats(self.execute) self.workout = Workout(self.execute) self.exercises = Exercises(self.execute) + self.sql_viewer = SQLViewer(self.execute) db_url = urlparse(os.environ['DATABASE_URL']) # if db_url is null then throw error if not db_url: diff --git a/features/sql_viewer.py b/features/sql_viewer.py new file mode 100644 index 0000000..1477802 --- /dev/null +++ b/features/sql_viewer.py @@ -0,0 +1,94 @@ +class SQLViewer: + def __init__(self, db_connection_method): + self.execute = db_connection_method + + def get_schema_info(self, schema='public'): + # Get all table names in the specified schema + tables_result = self.execute(""" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = %s AND table_type = 'BASE TABLE'; + """, [schema]) + tables = [row['table_name'] for row in tables_result] + + schema_info = {} + + for table in tables: + # Get columns and data types + columns_result = self.execute(""" + SELECT column_name, data_type + FROM information_schema.columns + WHERE table_schema = %s AND table_name = %s; + """, [schema, table]) + columns = [(row['column_name'], row['data_type']) for row in columns_result] + + # Get foreign keys + foreign_keys_result = self.execute(""" + SELECT + kcu.column_name AS fk_column, + ccu.table_name AS referenced_table, + ccu.column_name AS referenced_column + FROM + information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + WHERE + tc.constraint_type = 'FOREIGN KEY' AND + tc.table_schema = %s AND + tc.table_name = %s; + """, [schema, table]) + foreign_keys = [ + (row['fk_column'], row['referenced_table'], row['referenced_column']) + for row in foreign_keys_result + ] + + schema_info[table] = { + 'columns': columns, + 'foreign_keys': foreign_keys + } + + return schema_info + + def map_data_type(self, postgres_type): + type_mapping = { + 'integer': 'int', + 'bigint': 'int', + 'smallint': 'int', + 'character varying': 'string', + 'varchar': 'string', + 'text': 'string', + 'date': 'date', + 'timestamp without time zone': 'datetime', + 'timestamp with time zone': 'datetime', + 'boolean': 'bool', + 'numeric': 'float', + 'real': 'float' + # Add more mappings as needed + } + return type_mapping.get(postgres_type, 'string') # Default to 'string' if type not mapped + + def generate_mermaid_er(self, schema_info): + mermaid_lines = ["erDiagram"] + + for table, info in schema_info.items(): + # Define the table and its columns + mermaid_lines.append(f" {table} {{") + for column_name, data_type in info['columns']: + # Convert PostgreSQL data types to Mermaid-compatible types + mermaid_data_type = self.map_data_type(data_type) + mermaid_lines.append(f" {mermaid_data_type} {column_name}") + mermaid_lines.append(" }") + + # Define relationships + for table, info in schema_info.items(): + for fk_column, referenced_table, referenced_column in info['foreign_keys']: + # Mermaid relationship syntax: [Table1] }|--|| [Table2] : "FK_name" + relation = f" {table} }}|--|| {referenced_table} : \"{fk_column} to {referenced_column}\"" + mermaid_lines.append(relation) + + return "\n".join(mermaid_lines) + diff --git a/templates/base.html b/templates/base.html index d4efbf7..4f79271 100644 --- a/templates/base.html +++ b/templates/base.html @@ -16,6 +16,19 @@ + + + diff --git a/templates/settings.html b/templates/settings.html index 63bb885..c4c45de 100644 --- a/templates/settings.html +++ b/templates/settings.html @@ -181,6 +181,8 @@ + +
{% endblock %} \ No newline at end of file