Files
workout/features/schema.py
2026-01-30 18:47:26 +11:00

137 lines
6.5 KiB
Python

class Schema:
def __init__(self, db_connection_method):
self.execute = db_connection_method
def get_schema_info(self, schema='public'):
"""Fetches schema information directly."""
tables_result = self.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s AND table_type = 'BASE TABLE';
""", [schema])
tables = [row['table_name'] for row in tables_result]
schema_info = {}
for table in tables:
columns_result = self.execute("""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position;
""", [schema, table])
columns = [(row['column_name'], row['data_type']) for row in columns_result]
primary_keys_result = self.execute("""
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY' AND tc.table_schema = %s AND tc.table_name = %s;
""", [schema, table])
primary_keys = [row['column_name'] for row in primary_keys_result]
foreign_keys_result = self.execute("""
SELECT kcu.column_name AS fk_column, ccu.table_name AS referenced_table, ccu.column_name AS referenced_column
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_schema = %s AND tc.table_name = %s;
""", [schema, table])
foreign_keys = [(row['fk_column'], row['referenced_table'], row['referenced_column']) for row in foreign_keys_result]
schema_info[table] = {
'columns': columns,
'primary_keys': primary_keys,
'foreign_keys': foreign_keys
}
return schema_info
def _map_data_type_for_sql(self, postgres_type):
"""Maps PostgreSQL types to standard SQL types (simplified)."""
return {
'character varying': 'VARCHAR', 'varchar': 'VARCHAR', 'text': 'TEXT',
'integer': 'INTEGER', 'bigint': 'BIGINT', 'boolean': 'BOOLEAN',
'timestamp without time zone': 'TIMESTAMP', 'timestamp with time zone': 'TIMESTAMPTZ',
'numeric': 'NUMERIC', 'real': 'REAL', 'date': 'DATE'
}.get(postgres_type, postgres_type.upper())
def _map_data_type(self, postgres_type):
"""Maps PostgreSQL types to Mermaid ER diagram types."""
type_mapping = {
'integer': 'int', 'bigint': 'int', 'smallint': 'int',
'character varying': 'string', 'varchar': 'string', 'text': 'string',
'date': 'date', 'timestamp without time zone': 'datetime',
'timestamp with time zone': 'datetime', 'boolean': 'bool',
'numeric': 'float', 'real': 'float'
}
return type_mapping.get(postgres_type, 'string')
def generate_mermaid_er(self, schema_info):
"""Generates Mermaid ER diagram code from schema info."""
mermaid_lines = [
"%%{init: {'theme': 'default', 'themeCSS': '.er.entityBox { fill: transparent !important; } .er.attributeBoxEven { fill: transparent !important; } .er.attributeBoxOdd { fill: transparent !important; }'}}%%",
"erDiagram"
]
# Sort tables for stable output
sorted_tables = sorted(schema_info.keys())
for table in sorted_tables:
info = schema_info[table]
mermaid_lines.append(f" {table} {{")
pks = set(info.get('primary_keys', []))
fks = {fk[0] for fk in info.get('foreign_keys', [])}
for column_name, data_type in info['columns']:
mermaid_data_type = self._map_data_type(data_type)
markers = []
if column_name in pks:
markers.append("PK")
if column_name in fks:
markers.append("FK")
marker_str = f" {','.join(markers)}" if markers else ""
mermaid_lines.append(f" {mermaid_data_type} {column_name}{marker_str}")
mermaid_lines.append(" }")
for table in sorted_tables:
info = schema_info[table]
# Sort foreign keys for stable output
sorted_fks = sorted(info.get('foreign_keys', []), key=lambda x: x[0])
for fk_column, referenced_table, referenced_column in sorted_fks:
relation = f" {referenced_table} ||--o{{ {table} : \"{fk_column}\""
mermaid_lines.append(relation)
return "\n".join(mermaid_lines)
def generate_create_script(self, schema_info):
"""Generates SQL CREATE TABLE scripts from schema info."""
lines = []
for table, info in schema_info.items():
columns = info['columns']
pks = info.get('primary_keys', [])
fks = info['foreign_keys']
column_defs = []
for column_name, data_type in columns:
sql_type = self._map_data_type_for_sql(data_type)
column_defs.append(f' "{column_name}" {sql_type}')
if pks:
pk_columns = ", ".join(f'"{pk}"' for pk in pks)
column_defs.append(f' PRIMARY KEY ({pk_columns})')
columns_sql = ",\n".join(column_defs)
create_stmt = f'CREATE TABLE "{table}" (\n{columns_sql}\n);'
lines.append(create_stmt)
for fk_column, ref_table, ref_col in fks:
alter_stmt = (
f'ALTER TABLE "{table}" ADD CONSTRAINT "fk_{table}_{fk_column}" '
f'FOREIGN KEY ("{fk_column}") REFERENCES "{ref_table}" ("{ref_col}");'
)
lines.append(alter_stmt)
lines.append("")
return "\n".join(lines)