138 lines
5.2 KiB
Python
138 lines
5.2 KiB
Python
class SQLExplorer:
|
|
def __init__(self, db_connection_method):
|
|
self.execute = db_connection_method
|
|
|
|
def get_schema_info(self, schema='public'):
|
|
# Get all table names in the specified schema
|
|
tables_result = self.execute("""
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = %s AND table_type = 'BASE TABLE';
|
|
""", [schema])
|
|
tables = [row['table_name'] for row in tables_result]
|
|
|
|
schema_info = {}
|
|
|
|
for table in tables:
|
|
# Get columns and data types
|
|
columns_result = self.execute("""
|
|
SELECT column_name, data_type
|
|
FROM information_schema.columns
|
|
WHERE table_schema = %s AND table_name = %s;
|
|
""", [schema, table])
|
|
columns = [(row['column_name'], row['data_type']) for row in columns_result]
|
|
|
|
# Get foreign keys
|
|
foreign_keys_result = self.execute("""
|
|
SELECT
|
|
kcu.column_name AS fk_column,
|
|
ccu.table_name AS referenced_table,
|
|
ccu.column_name AS referenced_column
|
|
FROM
|
|
information_schema.table_constraints AS tc
|
|
JOIN information_schema.key_column_usage AS kcu
|
|
ON tc.constraint_name = kcu.constraint_name
|
|
AND tc.table_schema = kcu.table_schema
|
|
JOIN information_schema.constraint_column_usage AS ccu
|
|
ON ccu.constraint_name = tc.constraint_name
|
|
AND ccu.table_schema = tc.table_schema
|
|
WHERE
|
|
tc.constraint_type = 'FOREIGN KEY' AND
|
|
tc.table_schema = %s AND
|
|
tc.table_name = %s;
|
|
""", [schema, table])
|
|
foreign_keys = [
|
|
(row['fk_column'], row['referenced_table'], row['referenced_column'])
|
|
for row in foreign_keys_result
|
|
]
|
|
|
|
schema_info[table] = {
|
|
'columns': columns,
|
|
'foreign_keys': foreign_keys
|
|
}
|
|
|
|
return schema_info
|
|
|
|
def map_data_type(self, postgres_type):
|
|
type_mapping = {
|
|
'integer': 'int',
|
|
'bigint': 'int',
|
|
'smallint': 'int',
|
|
'character varying': 'string',
|
|
'varchar': 'string',
|
|
'text': 'string',
|
|
'date': 'date',
|
|
'timestamp without time zone': 'datetime',
|
|
'timestamp with time zone': 'datetime',
|
|
'boolean': 'bool',
|
|
'numeric': 'float',
|
|
'real': 'float'
|
|
# Add more mappings as needed
|
|
}
|
|
return type_mapping.get(postgres_type, 'string') # Default to 'string' if type not mapped
|
|
|
|
def generate_mermaid_er(self, schema_info):
|
|
mermaid_lines = ["erDiagram"]
|
|
|
|
for table, info in schema_info.items():
|
|
# Define the table and its columns
|
|
mermaid_lines.append(f" {table} {{")
|
|
for column_name, data_type in info['columns']:
|
|
# Convert PostgreSQL data types to Mermaid-compatible types
|
|
mermaid_data_type = self.map_data_type(data_type)
|
|
mermaid_lines.append(f" {mermaid_data_type} {column_name}")
|
|
mermaid_lines.append(" }")
|
|
|
|
# Define relationships
|
|
for table, info in schema_info.items():
|
|
for fk_column, referenced_table, referenced_column in info['foreign_keys']:
|
|
# Mermaid relationship syntax: [Table1] }|--|| [Table2] : "FK_name"
|
|
relation = f" {table} }}|--|| {referenced_table} : \"{fk_column} to {referenced_column}\""
|
|
mermaid_lines.append(relation)
|
|
|
|
return "\n".join(mermaid_lines)
|
|
|
|
def execute_sql(self, query):
|
|
results = None
|
|
columns = []
|
|
error = None
|
|
|
|
try:
|
|
# Use your custom execute method
|
|
results = self.execute(query)
|
|
if results:
|
|
# Extract column names from the keys of the first result
|
|
columns = list(results[0].keys())
|
|
except Exception as e:
|
|
error = str(e)
|
|
|
|
return (results, columns, error)
|
|
|
|
def save_query(self, title, query):
|
|
error = None
|
|
|
|
if not title:
|
|
return "Must provide title"
|
|
|
|
try:
|
|
self.execute("""
|
|
INSERT INTO saved_query (title, query)
|
|
VALUES (%s, %s)""",[title, query], commit=True)
|
|
except Exception as e:
|
|
error = str(e)
|
|
|
|
return error
|
|
|
|
def list_saved_queries(self):
|
|
queries = self.execute("SELECT id, title, query FROM saved_query")
|
|
return queries
|
|
|
|
def get_saved_query(self, query_id):
|
|
result = self.execute("SELECT title, query FROM saved_query where id=%s", [query_id], one=True)
|
|
return (result['title'], result['query'])
|
|
|
|
def delete_saved_query(self, query_id):
|
|
self.execute("DELETE FROM saved_query where id=%s", [query_id], commit=True)
|
|
|
|
|