Speed up sql plot generation

This commit is contained in:
Peter Stockings
2025-02-01 21:06:21 +11:00
parent 5fe003bcbf
commit c1c4c4a960

View File

@@ -106,38 +106,33 @@ def get_distinct_colors(n):
colors.append(hex_color) colors.append(hex_color)
return colors return colors
def generate_plot(df, title): def generate_plot(df: pd.DataFrame, title: str) -> str:
""" """
Analyzes the DataFrame and generates an appropriate Plotly visualization. Analyzes the DataFrame and generates an appropriate Plotly visualization.
Returns the Plotly figure as a div string. Returns the Plotly figure as a div string.
Optimized for speed.
""" """
if df.empty: if df.empty:
return "<p>No data available to plot.</p>" return "<p>No data available to plot.</p>"
num_columns = len(df.columns) num_columns = len(df.columns)
# Simple logic to decide plot type based on DataFrame structure # Dictionary-based lookup for faster decision-making
if num_columns == 1: plot_funcs = {
# Single column: perhaps a histogram or bar chart 1: lambda: px.histogram(df, x=df.columns[0], title=title)
column = df.columns[0] if pd.api.types.is_numeric_dtype(df.iloc[:, 0]) else px.bar(df, x=df.columns[0], title=title),
if pd.api.types.is_numeric_dtype(df[column]):
fig = px.histogram(df, x=column, title=title) 2: lambda: px.scatter(df, x=df.columns[0], y=df.columns[1], title=title)
else: if pd.api.types.is_numeric_dtype(df.iloc[:, 0]) and pd.api.types.is_numeric_dtype(df.iloc[:, 1])
fig = px.bar(df, x=column, title=title) else px.bar(df, x=df.columns[0], y=df.columns[1], title=title)
elif num_columns == 2: }
# Two columns: scatter plot or line chart
col1, col2 = df.columns # Select plot function based on column count
if pd.api.types.is_numeric_dtype(df[col1]) and pd.api.types.is_numeric_dtype(df[col2]): fig = plot_funcs.get(num_columns, lambda: px.imshow(df.corr(numeric_only=True), text_auto=True, title=title))()
fig = px.scatter(df, x=col1, y=col2, title=title)
else: # Use static rendering for speed
fig = px.bar(df, x=col1, y=col2, title=title) return pio.to_html(fig, full_html=False, include_plotlyjs='cdn', config={'staticPlot': True})
else:
# More than two columns: heatmap or other complex plots
fig = px.imshow(df.corr(), text_auto=True, title=title)
# Convert Plotly figure to HTML div
plot_div = pio.to_html(fig, full_html=False)
return plot_div
def calculate_estimated_1rm(weight, repetitions): def calculate_estimated_1rm(weight, repetitions):
# Ensure the inputs are numeric # Ensure the inputs are numeric