class SQLExplorer: def __init__(self, db_connection_method): self.execute = db_connection_method def get_schema_info(self, schema='public'): # Get tables tables_result = self.execute(""" SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_type = 'BASE TABLE'; """, [schema]) tables = [row['table_name'] for row in tables_result] schema_info = {} for table in tables: # Get columns and data types columns_result = self.execute(""" SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = %s AND table_name = %s ORDER BY ordinal_position; """, [schema, table]) columns = [(row['column_name'], row['data_type']) for row in columns_result] # Get primary keys # The constraint_type = 'PRIMARY KEY' check ensures we only get PK constraints # This returns all columns that are part of the PK for this table. primary_keys_result = self.execute(""" SELECT kcu.column_name FROM information_schema.table_constraints tc JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema WHERE tc.constraint_type = 'PRIMARY KEY' AND tc.table_schema = %s AND tc.table_name = %s; """, [schema, table]) primary_keys = [row['column_name'] for row in primary_keys_result] # Get foreign keys foreign_keys_result = self.execute(""" SELECT kcu.column_name AS fk_column, ccu.table_name AS referenced_table, ccu.column_name AS referenced_column FROM information_schema.table_constraints AS tc JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema JOIN information_schema.constraint_column_usage AS ccu ON ccu.constraint_name = tc.constraint_name AND ccu.table_schema = tc.table_schema WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_schema = %s AND tc.table_name = %s; """, [schema, table]) foreign_keys = [ (row['fk_column'], row['referenced_table'], row['referenced_column']) for row in foreign_keys_result ] schema_info[table] = { 'columns': columns, 'primary_keys': primary_keys, 'foreign_keys': foreign_keys } return schema_info def map_data_type_for_sql(self, postgres_type): # This is naive. For real usage, you may handle numeric precision, etc. # Or simply return the raw type since your DB is PostgreSQL anyway. return { 'character varying': 'VARCHAR', 'varchar': 'VARCHAR', 'text': 'TEXT', 'integer': 'INTEGER', 'bigint': 'BIGINT', 'boolean': 'BOOLEAN', 'timestamp without time zone': 'TIMESTAMP', 'timestamp with time zone': 'TIMESTAMPTZ', }.get(postgres_type, postgres_type.upper()) def map_data_type(self, postgres_type): type_mapping = { 'integer': 'int', 'bigint': 'int', 'smallint': 'int', 'character varying': 'string', 'varchar': 'string', 'text': 'string', 'date': 'date', 'timestamp without time zone': 'datetime', 'timestamp with time zone': 'datetime', 'boolean': 'bool', 'numeric': 'float', 'real': 'float' # Add more mappings as needed } return type_mapping.get(postgres_type, 'string') # Default to 'string' if type not mapped def generate_mermaid_er(self, schema_info): mermaid_lines = ["erDiagram"] for table, info in schema_info.items(): # Define the table and its columns mermaid_lines.append(f" {table} {{") for column_name, data_type in info['columns']: # Convert PostgreSQL data types to Mermaid-compatible types mermaid_data_type = self.map_data_type(data_type) mermaid_lines.append(f" {mermaid_data_type} {column_name}") mermaid_lines.append(" }") # Define relationships for table, info in schema_info.items(): for fk_column, referenced_table, referenced_column in info['foreign_keys']: # Mermaid relationship syntax: [Table1] }|--|| [Table2] : "FK_name" relation = f" {table} }}|--|| {referenced_table} : \"{fk_column} to {referenced_column}\"" mermaid_lines.append(relation) return "\n".join(mermaid_lines) def generate_create_script(self, schema_info): lines = [] for table, info in schema_info.items(): columns = info['columns'] pks = info.get('primary_keys', []) fks = info['foreign_keys'] column_defs = [] for column_name, data_type in columns: sql_type = self.map_data_type_for_sql(data_type) column_defs.append(f' "{column_name}" {sql_type}') if pks: pk_columns = ", ".join(f'"{pk}"' for pk in pks) column_defs.append(f' PRIMARY KEY ({pk_columns})') create_stmt = 'CREATE TABLE "{}" (\n'.format(table) create_stmt += ",\n".join(column_defs) create_stmt += '\n);' lines.append(create_stmt) # Foreign keys for fk_column, ref_table, ref_col in fks: alter_stmt = ( f'ALTER TABLE "{table}" ' f'ADD CONSTRAINT "fk_{table}_{fk_column}" ' f'FOREIGN KEY ("{fk_column}") ' f'REFERENCES "{ref_table}" ("{ref_col}");' ) lines.append(alter_stmt) lines.append("") # separate blocks return "\n".join(lines) def execute_sql(self, query): results = None columns = [] error = None try: # Use your custom execute method results = self.execute(query) if results: # Extract column names from the keys of the first result columns = list(results[0].keys()) except Exception as e: error = str(e) return (results, columns, error) def save_query(self, title, query): error = None if not title: return "Must provide title" try: self.execute(""" INSERT INTO saved_query (title, query) VALUES (%s, %s)""",[title, query], commit=True) except Exception as e: error = str(e) return error def list_saved_queries(self): queries = self.execute("SELECT id, title, query FROM saved_query") return queries def get_saved_query(self, query_id): result = self.execute("SELECT title, query FROM saved_query where id=%s", [query_id], one=True) return (result['title'], result['query']) def delete_saved_query(self, query_id): self.execute("DELETE FROM saved_query where id=%s", [query_id], commit=True)