From c1c4c4a960f1950bc2525f00b252a50b6ed488b1 Mon Sep 17 00:00:00 2001 From: Peter Stockings Date: Sat, 1 Feb 2025 21:06:21 +1100 Subject: [PATCH] Speed up sql plot generation --- utils.py | 39 +++++++++++++++++---------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/utils.py b/utils.py index 5f435a5..5c95e97 100644 --- a/utils.py +++ b/utils.py @@ -106,38 +106,33 @@ def get_distinct_colors(n): colors.append(hex_color) return colors -def generate_plot(df, title): +def generate_plot(df: pd.DataFrame, title: str) -> str: """ Analyzes the DataFrame and generates an appropriate Plotly visualization. Returns the Plotly figure as a div string. + Optimized for speed. """ if df.empty: return "

No data available to plot.

" num_columns = len(df.columns) - # Simple logic to decide plot type based on DataFrame structure - if num_columns == 1: - # Single column: perhaps a histogram or bar chart - column = df.columns[0] - if pd.api.types.is_numeric_dtype(df[column]): - fig = px.histogram(df, x=column, title=title) - else: - fig = px.bar(df, x=column, title=title) - elif num_columns == 2: - # Two columns: scatter plot or line chart - col1, col2 = df.columns - if pd.api.types.is_numeric_dtype(df[col1]) and pd.api.types.is_numeric_dtype(df[col2]): - fig = px.scatter(df, x=col1, y=col2, title=title) - else: - fig = px.bar(df, x=col1, y=col2, title=title) - else: - # More than two columns: heatmap or other complex plots - fig = px.imshow(df.corr(), text_auto=True, title=title) + # Dictionary-based lookup for faster decision-making + plot_funcs = { + 1: lambda: px.histogram(df, x=df.columns[0], title=title) + if pd.api.types.is_numeric_dtype(df.iloc[:, 0]) else px.bar(df, x=df.columns[0], title=title), + + 2: lambda: px.scatter(df, x=df.columns[0], y=df.columns[1], title=title) + if pd.api.types.is_numeric_dtype(df.iloc[:, 0]) and pd.api.types.is_numeric_dtype(df.iloc[:, 1]) + else px.bar(df, x=df.columns[0], y=df.columns[1], title=title) + } + + # Select plot function based on column count + fig = plot_funcs.get(num_columns, lambda: px.imshow(df.corr(numeric_only=True), text_auto=True, title=title))() + + # Use static rendering for speed + return pio.to_html(fig, full_html=False, include_plotlyjs='cdn', config={'staticPlot': True}) - # Convert Plotly figure to HTML div - plot_div = pio.to_html(fig, full_html=False) - return plot_div def calculate_estimated_1rm(weight, repetitions): # Ensure the inputs are numeric