import colorsys from datetime import datetime, date, timedelta import numpy as np import pandas as pd import plotly.express as px import plotly.io as pio # Keep for now, might remove later if generate_plot is fully replaced import math from decimal import Decimal def convert_str_to_date(date_str, format='%Y-%m-%d'): try: return datetime.strptime(date_str, format).date() except ValueError: return None except TypeError: return None def get_exercise_graph_model(title, estimated_1rm, repetitions, weight, start_dates, messages, epoch, person_id, exercise_id, min_date=None, max_date=None, degree=1): # Precompute ranges min_date, max_date = min(start_dates), max(start_dates) total_span = (max_date - min_date).days or 1 min_e1rm, max_e1rm = min(estimated_1rm), max(estimated_1rm) min_reps, max_reps = min(repetitions), max(repetitions) min_weight, max_weight = min(weight), max(weight) e1rm_range = max_e1rm - min_e1rm or 1 reps_range = max_reps - min_reps or 1 weight_range = max_weight - min_weight or 1 # Calculate viewBox dimensions vb_width, vb_height = total_span, e1rm_range vb_width *= 200 / vb_width # Scale to 200px width vb_height *= 75 / vb_height # Scale to 75px height # Use NumPy arrays for efficient scaling relative_positions = np.array([(date - min_date).days / total_span for date in start_dates]) estimated_1rm_scaled = ((np.array(estimated_1rm) - min_e1rm) / e1rm_range) * vb_height repetitions_scaled = ((np.array(repetitions) - min_reps) / reps_range) * vb_height weight_scaled = ((np.array(weight) - min_weight) / weight_range) * vb_height # Calculate slope and line of best fit slope_kg_per_day = e1rm_range / total_span best_fit_formula = { 'kg_per_week': round(slope_kg_per_day * 7, 1), 'kg_per_month': round(slope_kg_per_day * 30, 1) } best_fit_points = [] try: if len(relative_positions) > 1: # Ensure there are enough points for polyfit # Fit a polynomial of the given degree coeffs = np.polyfit(relative_positions, estimated_1rm_scaled, degree) poly_fit = np.poly1d(coeffs) y_best_fit = poly_fit(relative_positions) best_fit_points = list(zip(y_best_fit.tolist(), relative_positions.tolist())) else: raise ValueError("Not enough data points for polyfit") except (np.linalg.LinAlgError, ValueError) as e: # Handle cases where polyfit fails best_fit_points = [] m, b = 0, 0 # Prepare data for plots repetitions_data = { 'label': 'Reps', 'color': '#388fed', 'points': list(zip(repetitions_scaled.tolist(), relative_positions.tolist())) } weight_data = { 'label': 'Weight', 'color': '#bd3178', 'points': list(zip(weight_scaled.tolist(), relative_positions.tolist())) } estimated_1rm_data = { 'label': 'E1RM', 'color': '#2ca02c', 'points': list(zip(estimated_1rm_scaled.tolist(), relative_positions.tolist())) } # Prepare plot labels plot_labels = list(zip(relative_positions.tolist(), messages)) # Return exercise data with SVG dimensions and data points return { 'title': title, 'vb_width': vb_width, 'vb_height': vb_height, 'plots': [repetitions_data, weight_data, estimated_1rm_data], 'best_fit_points': best_fit_points, 'best_fit_formula': best_fit_formula, 'plot_labels': plot_labels, 'epochs': ['Custom', '1M', '3M', '6M', 'All'], 'selected_epoch': epoch, 'person_id': person_id, 'exercise_id': exercise_id, 'min_date': min_date, 'max_date': max_date, 'degree': degree } def get_distinct_colors(n): colors = [] for i in range(n): # Divide the color wheel into n parts hue = i / n # Convert HSL (Hue, Saturation, Lightness) to RGB and then to a Hex string rgb = colorsys.hls_to_rgb(hue, 0.6, 0.4) # Fixed lightness and saturation hex_color = '#{:02x}{:02x}{:02x}'.format(int(rgb[0]*255), int(rgb[1]*255), int(rgb[2]*255)) colors.append(hex_color) return colors def generate_plot(df: pd.DataFrame, title: str) -> str: """ Analyzes the DataFrame and generates an appropriate Plotly visualization. Returns the Plotly figure as a div string. Optimized for speed. """ if df.empty: return "

No data available to plot.

" num_columns = len(df.columns) # Dictionary-based lookup for faster decision-making plot_funcs = { 1: lambda: px.histogram(df, x=df.columns[0], title=title) if pd.api.types.is_numeric_dtype(df.iloc[:, 0]) else px.bar(df, x=df.columns[0], title=title), 2: lambda: px.scatter(df, x=df.columns[0], y=df.columns[1], title=title) if pd.api.types.is_numeric_dtype(df.iloc[:, 0]) and pd.api.types.is_numeric_dtype(df.iloc[:, 1]) else px.bar(df, x=df.columns[0], y=df.columns[1], title=title) } # Select plot function based on column count fig = plot_funcs.get(num_columns, lambda: px.imshow(df.corr(numeric_only=True), text_auto=True, title=title))() # Use static rendering for speed return pio.to_html(fig, full_html=False, include_plotlyjs=False, config={'staticPlot': True}) def calculate_estimated_1rm(weight, repetitions): # Ensure the inputs are numeric if repetitions == 0: # Avoid division by zero return 0 estimated_1rm = round((100 * int(weight)) / (101.3 - 2.67123 * repetitions), 0) return int(estimated_1rm) def _is_numeric(val): """Check if a value is numeric (int, float, Decimal).""" return isinstance(val, (int, float, Decimal)) def _is_datetime(val): """Check if a value is a date or datetime object.""" return isinstance(val, (date, datetime)) def _get_column_type(results, column_name): """Determine the effective type of a column (numeric, datetime, categorical).""" numeric_count = 0 datetime_count = 0 total_count = 0 for row in results: val = row.get(column_name) if val is not None: total_count += 1 if _is_numeric(val): numeric_count += 1 elif _is_datetime(val): datetime_count += 1 if total_count == 0: return 'categorical' # Default if all null or empty if numeric_count / total_count > 0.8: return 'numeric' # Allow some non-numeric noise if datetime_count / total_count > 0.8: return 'datetime' return 'categorical' def _normalize_value(value, min_val, range_val, target_max): """Normalize a value to a target range (e.g., SVG coordinate).""" if range_val == 0: return target_max / 2 # Avoid division by zero, place in middle return ((value - min_val) / range_val) * target_max def prepare_svg_plot_data(results, columns, title): """ Prepares data from raw SQL results for SVG plotting. Determines plot type and scales data. """ if not results: raise ValueError("No data provided for plotting.") num_columns = len(columns) plot_type = 'table' # Default if no suitable plot found plot_data = {} x_col, y_col = None, None x_type, y_type = None, None # --- Determine Plot Type and Columns --- if num_columns == 1: x_col = columns[0] x_type = _get_column_type(results, x_col) if x_type == 'numeric': plot_type = 'histogram' else: plot_type = 'bar_count' # Bar chart of value counts elif num_columns >= 2: # Prioritize common patterns x_col, y_col = columns[0], columns[1] x_type = _get_column_type(results, x_col) y_type = _get_column_type(results, y_col) if x_type == 'numeric' and y_type == 'numeric': plot_type = 'scatter' elif x_type == 'datetime' and y_type == 'numeric': plot_type = 'line' # Treat datetime as numeric for position elif x_type == 'categorical' and y_type == 'numeric': plot_type = 'bar' elif x_type == 'numeric' and y_type == 'categorical': # Could do horizontal bar, but let's stick to vertical for now plot_type = 'bar' # Treat numeric as category label, categorical as value (count?) - less common # Or maybe swap? Let's assume categorical X, numeric Y is more likely intended x_col, y_col = columns[1], columns[0] # Try swapping x_type, y_type = y_type, x_type if not (x_type == 'categorical' and y_type == 'numeric'): plot_type = 'table' # Revert if swap didn't help else: # Other combinations (datetime/cat, cat/cat, etc.) default to table plot_type = 'table' # --- Basic SVG Setup --- vb_width = 500 vb_height = 300 margin = {'top': 20, 'right': 20, 'bottom': 50, 'left': 60} # Increased bottom/left for labels/axes draw_width = vb_width - margin['left'] - margin['right'] draw_height = vb_height - margin['top'] - margin['bottom'] plot_data = { 'title': title, 'plot_type': plot_type, 'vb_width': vb_width, 'vb_height': vb_height, 'margin': margin, 'draw_width': draw_width, 'draw_height': draw_height, 'x_axis_label': x_col or '', 'y_axis_label': y_col or '', 'plots': [], 'x_ticks': [], 'y_ticks': [], 'original_results': results, # Keep original for table fallback 'original_columns': columns } if plot_type == 'table': return plot_data # No further processing needed for table fallback # --- Data Extraction and Scaling (Specific to Plot Type) --- points = [] x_values_raw = [] y_values_raw = [] # Extract relevant data, handling potential type issues for row in results: x_val_raw = row.get(x_col) y_val_raw = row.get(y_col) # Convert datetimes to numeric representation (e.g., days since min date) if x_type == 'datetime': x_values_raw.append(x_val_raw) # Keep original dates for range calculation elif _is_numeric(x_val_raw): x_values_raw.append(float(x_val_raw)) # Convert Decimal to float # Add handling for categorical X if needed (e.g., bar chart) if y_type == 'numeric': if _is_numeric(y_val_raw): y_values_raw.append(float(y_val_raw)) else: y_values_raw.append(None) # Mark non-numeric Y as None # Add handling for categorical Y if needed if not x_values_raw or not y_values_raw: plot_data['plot_type'] = 'table' # Fallback if essential data is missing return plot_data # Calculate ranges (handle datetime separately) if x_type == 'datetime': valid_dates = [d for d in x_values_raw if d is not None] if not valid_dates: plot_data['plot_type'] = 'table'; return plot_data min_x_dt, max_x_dt = min(valid_dates), max(valid_dates) # Convert dates to days since min_date for numerical scaling total_days = (max_x_dt - min_x_dt).days x_values_numeric = [(d - min_x_dt).days if d is not None else None for d in x_values_raw] min_x, max_x = 0, total_days else: # Numeric or Categorical (treat categorical index as numeric for now) valid_x = [x for x in x_values_raw if x is not None] if not valid_x: plot_data['plot_type'] = 'table'; return plot_data min_x, max_x = min(valid_x), max(valid_x) x_values_numeric = x_values_raw # Already numeric (or will be treated as such) valid_y = [y for y in y_values_raw if y is not None] if not valid_y: plot_data['plot_type'] = 'table'; return plot_data min_y, max_y = min(valid_y), max(valid_y) range_x = max_x - min_x range_y = max_y - min_y # Scale points for i, row in enumerate(results): x_num = x_values_numeric[i] y_num = y_values_raw[i] # Use original list which might have None if x_num is None or y_num is None: continue # Skip points with missing essential data # Scale X to drawing width, Y to drawing height (inverted Y for SVG) scaled_x = margin['left'] + _normalize_value(x_num, min_x, range_x, draw_width) scaled_y = margin['top'] + draw_height - _normalize_value(y_num, min_y, range_y, draw_height) points.append({ 'x': scaled_x, 'y': scaled_y, 'original': row # Store original row data for tooltips }) # --- Generate Ticks --- num_ticks = 5 # Desired number of ticks # X Ticks x_ticks = [] if range_x >= 0: step_x = (max_x - min_x) / (num_ticks -1) if num_ticks > 1 and range_x > 0 else 0 for i in range(num_ticks): tick_val_raw = min_x + i * step_x tick_pos = margin['left'] + _normalize_value(tick_val_raw, min_x, range_x, draw_width) label = "" if x_type == 'datetime': tick_date = min_x_dt + timedelta(days=tick_val_raw) label = tick_date.strftime('%Y-%m-%d') # Format date label else: # Numeric label = f"{tick_val_raw:.1f}" if isinstance(tick_val_raw, float) else str(tick_val_raw) x_ticks.append({'value': tick_val_raw, 'label': label, 'position': tick_pos}) # Y Ticks y_ticks = [] if range_y >= 0: step_y = (max_y - min_y) / (num_ticks - 1) if num_ticks > 1 and range_y > 0 else 0 for i in range(num_ticks): tick_val = min_y + i * step_y tick_pos = margin['top'] + draw_height - _normalize_value(tick_val, min_y, range_y, draw_height) label = f"{tick_val:.1f}" if isinstance(tick_val, float) else str(tick_val) y_ticks.append({'value': tick_val, 'label': label, 'position': tick_pos}) # --- Finalize Plot Data --- # For now, put all points into one series plot_data['plots'].append({ 'label': f'{y_col} vs {x_col}', 'color': '#388fed', # Default color 'points': points }) plot_data['x_ticks'] = x_ticks plot_data['y_ticks'] = y_ticks # Add specific adjustments for plot types if needed (e.g., bar width) if plot_type == 'bar': # Calculate bar width based on number of bars/categories # This needs more refinement based on how categorical X is handled plot_data['bar_width'] = draw_width / len(points) * 0.8 if points else 10 return plot_data