Add fitness graph

This commit is contained in:
Peter Stockings
2026-01-10 11:49:28 +11:00
parent 246a4a14e3
commit de1563dae6
8 changed files with 292 additions and 42 deletions

View File

@@ -0,0 +1,140 @@
interface FitnessGraphProps {
history: Array<{ generation: number; best: number; average: number }>;
width?: number | string;
height?: number | string;
className?: string;
}
export default function FitnessGraph({ history, width = "100%", height = 150, className = "" }: FitnessGraphProps) {
if (history.length < 2) {
return (
<div style={{
width,
height,
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
color: '#666',
fontSize: '0.8rem',
background: 'rgba(0,0,0,0.2)',
borderRadius: '4px'
}}>
Waiting for data...
</div>
);
}
const PADDING = 20; // Internal padding
// Use internal coordinate system for viewBox
const VIEW_WIDTH = 500;
const VIEW_HEIGHT = 200;
const GRAPH_WIDTH = VIEW_WIDTH - PADDING * 2;
const GRAPH_HEIGHT = VIEW_HEIGHT - PADDING * 2;
// Find min/max for scaling
const maxFitness = Math.max(...history.map(h => h.best), 1);
const minGeneration = history[0].generation;
const maxGeneration = history[history.length - 1].generation;
const genRange = Math.max(maxGeneration - minGeneration, 1);
// Helper to scale points
const getX = (gen: number) => {
return PADDING + ((gen - minGeneration) / genRange) * GRAPH_WIDTH;
};
const getY = (fitness: number) => {
// Invert Y because SVG 0 is top
return PADDING + GRAPH_HEIGHT - (fitness / maxFitness) * GRAPH_HEIGHT;
};
// Generate path data
const bestPath = history.map((p, i) =>
`${i === 0 ? 'M' : 'L'} ${getX(p.generation)} ${getY(p.best)}`
).join(' ');
const averagePath = history.map((p, i) =>
`${i === 0 ? 'M' : 'L'} ${getX(p.generation)} ${getY(p.average)}`
).join(' ');
// Areas (closed paths for gradients)
const bestArea = bestPath + ` L ${getX(history[history.length - 1].generation)} ${GRAPH_HEIGHT + PADDING} L ${getX(minGeneration)} ${GRAPH_HEIGHT + PADDING} Z`;
const averageArea = averagePath + ` L ${getX(history[history.length - 1].generation)} ${GRAPH_HEIGHT + PADDING} L ${getX(minGeneration)} ${GRAPH_HEIGHT + PADDING} Z`;
return (
<div className={`fitness-graph-container ${className}`} style={{ width: '100%', height, position: 'relative' }}>
{/* Legend Overlay */}
<div style={{
position: 'absolute',
top: 0,
right: 0,
display: 'flex',
gap: '12px',
fontSize: '0.75rem',
fontWeight: 600,
background: 'rgba(0,0,0,0.4)',
padding: '4px 8px',
borderRadius: '0 0 0 8px',
pointerEvents: 'none',
backdropFilter: 'blur(2px)'
}}>
<div style={{ color: '#4ecdc4', display: 'flex', alignItems: 'center', gap: '6px' }}>
<div style={{ width: 8, height: 8, background: '#4ecdc4', borderRadius: '50%' }}></div>
Best: {Math.round(history[history.length - 1].best)}
</div>
<div style={{ color: '#4a9eff', display: 'flex', alignItems: 'center', gap: '6px' }}>
<div style={{ width: 8, height: 8, background: '#4a9eff', borderRadius: '50%' }}></div>
Avg: {Math.round(history[history.length - 1].average)}
</div>
</div>
<svg
width="100%"
height="100%"
viewBox={`0 0 ${VIEW_WIDTH} ${VIEW_HEIGHT}`}
preserveAspectRatio="none"
style={{ overflow: 'visible' }}
>
<defs>
<linearGradient id="gradBest" x1="0%" y1="0%" x2="0%" y2="100%">
<stop offset="0%" stopColor="#4ecdc4" stopOpacity={0.4} />
<stop offset="100%" stopColor="#4ecdc4" stopOpacity={0} />
</linearGradient>
<linearGradient id="gradAvg" x1="0%" y1="0%" x2="0%" y2="100%">
<stop offset="0%" stopColor="#4a9eff" stopOpacity={0.3} />
<stop offset="100%" stopColor="#4a9eff" stopOpacity={0} />
</linearGradient>
</defs>
{/* Grid Lines (Horizontal) */}
{[0, 0.25, 0.5, 0.75, 1].map(ratio => {
const y = PADDING + ratio * GRAPH_HEIGHT;
return (
<line
key={ratio}
x1={PADDING}
y1={y}
x2={VIEW_WIDTH - PADDING}
y2={y}
stroke="#333"
strokeWidth="1"
strokeDasharray="4 4"
opacity="0.5"
/>
);
})}
{/* Average Area */}
<path d={averageArea} fill="url(#gradAvg)" />
{/* Average Line */}
<path d={averagePath} fill="none" stroke="#4a9eff" strokeWidth="2" strokeOpacity="0.8" />
{/* Best Area */}
<path d={bestArea} fill="url(#gradBest)" />
{/* Best Line */}
<path d={bestPath} fill="none" stroke="#4ecdc4" strokeWidth="2.5" />
</svg>
</div>
);
}

View File

@@ -328,9 +328,12 @@ input[type='range']::-webkit-slider-thumb:hover {
}
.progress-indicator {
background: #080808;
padding: 0.75rem;
border: 1px solid #222;
background: linear-gradient(135deg, #2a2a3e 0%, #1a1a2e 100%);
padding: 1.5rem;
border-radius: 12px;
border: 1px solid #3a3a4e;
display: flex;
flex-direction: column;
}
.progress-label {

View File

@@ -7,10 +7,6 @@ import Tips from './Tips';
import BestSnakeDisplay from './BestSnakeDisplay';
import {
createPopulation,
evaluatePopulation,
evolveGeneration,
getBestIndividual,
getAverageFitness,
type Population,
} from '../../lib/snakeAI/evolution';
import type { EvolutionConfig } from '../../lib/snakeAI/types';
@@ -24,6 +20,8 @@ const DEFAULT_CONFIG: EvolutionConfig = {
maxGameSteps: 20000,
};
import EvolutionWorker from '../../lib/snakeAI/evolution.worker?worker';
export default function SnakeAI() {
const [population, setPopulation] = useState<Population>(() =>
createPopulation(DEFAULT_CONFIG)
@@ -32,30 +30,85 @@ export default function SnakeAI() {
const [isRunning, setIsRunning] = useState(false);
const [speed, setSpeed] = useState(5);
const [gamesPlayed, setGamesPlayed] = useState(0);
const [fitnessHistory, setFitnessHistory] = useState<Array<{ generation: number, best: number, average: number }>>([]);
// Compute derived values from population
const bestIndividual = getBestIndividual(population);
const averageFitness = getAverageFitness(population);
// Keep a ref to population for the worker
const populationRef = useRef(population);
useEffect(() => {
populationRef.current = population;
}, [population]);
const animationFrameRef = useRef<number>();
const lastUpdateRef = useRef<number>(0);
const runGeneration = useCallback(() => {
setPopulation((prev) => {
try {
// Evaluate current generation
const evaluated = evaluatePopulation(prev, config);
// Compute derived values for display
// If we have stats from the last generation, use them. Otherwise default to 0.
const currentBestFitness = population.lastGenerationStats?.bestFitness || 0;
const currentAverageFitness = population.lastGenerationStats?.averageFitness || 0;
// Evolve to next generation
const nextGen = evolveGeneration(evaluated, config);
const workerRef = useRef<Worker | null>(null);
const isProcessingRef = useRef(false);
return nextGen;
} catch (error) {
console.error("SnakeAI: Generation update failed", error);
return prev;
useEffect(() => {
workerRef.current = new EvolutionWorker();
workerRef.current.onmessage = (e) => {
const { type, payload } = e.data; // payload is the NEW population
if (type === 'SUCCESS') {
// Critical: Update ref immediately to prevent race condition with next animation frame
populationRef.current = payload;
setPopulation(payload);
// Update history if we have stats
if (payload.lastGenerationStats) {
setFitnessHistory(prev => {
const newEntry = {
generation: payload.generation - 1, // The stats are for the gen that just finished
best: payload.lastGenerationStats!.bestFitness,
average: payload.lastGenerationStats!.averageFitness
};
// Keep last 100 generations to avoid memory issues if running for eternity
const newHistory = [...prev, newEntry];
if (newHistory.length > 100) return newHistory.slice(newHistory.length - 100);
return newHistory;
});
}
isProcessingRef.current = false;
} else {
console.error("Worker error:", payload);
isProcessingRef.current = false;
}
};
return () => {
workerRef.current?.terminate();
};
}, []);
const runGeneration = useCallback((generations: number = 1) => {
if (isProcessingRef.current || !workerRef.current) return;
isProcessingRef.current = true;
// We need to send the *current* population.
// Since this is inside a callback, we need to be careful about closure staleness.
// However, we can't easily access the "latest" state inside a callback without refs or dependency.
// But 'population' is in the dependency array of the effect calling this? No.
// The animate loop calls this.
// Let's use a functional update approach? No, we need to SEND data.
// We will use a ref to track current population for the worker to ensure we always send latest
// OR rely on the fact that 'population' is in dependency of runGeneration (it wasn't before).
// Wait, 'runGeneration' lines 43-58 previously used setPopulation(prev => ...).
// It didn't need 'population' in dependency.
// Now we need it.
workerRef.current.postMessage({
population: populationRef.current, // Use a ref for latest population
config,
generations
});
}, [config]);
}, [config]); // populationRef will be handled separately
// Update stats when generation changes
useEffect(() => {
@@ -93,7 +146,7 @@ export default function SnakeAI() {
}
if (elapsed >= updateInterval) {
runGeneration();
runGeneration(1);
lastUpdateRef.current = timestamp;
}
} else {
@@ -102,9 +155,9 @@ export default function SnakeAI() {
// Speed 100 -> 10 gens per frame (~600 eps)
const gensPerFrame = Math.floor((speed - 10) / 10);
for (let i = 0; i < gensPerFrame; i++) {
runGeneration();
}
// For turbo mode, we just fire once per frame (or whenever the worker is ready)
// asking for multiple generations
runGeneration(gensPerFrame);
lastUpdateRef.current = timestamp;
}
@@ -122,7 +175,9 @@ export default function SnakeAI() {
const handleReset = () => {
setIsRunning(false);
setPopulation(createPopulation(config));
const newPop = createPopulation(config);
populationRef.current = newPop;
setPopulation(newPop);
setGamesPlayed(0);
};
@@ -162,10 +217,11 @@ export default function SnakeAI() {
<Stats
generation={population.generation}
bestFitness={bestIndividual.fitness}
bestFitness={currentBestFitness}
bestFitnessEver={population.bestFitnessEver}
averageFitness={averageFitness}
averageFitness={currentAverageFitness}
gamesPlayed={gamesPlayed}
history={fitnessHistory}
/>
<Tips />

View File

@@ -24,6 +24,11 @@ export default function SnakeCanvas({ network, gridSize, showGrid = true, size =
const [currentGame, setCurrentGame] = useState<GameState | null>(null);
const animationFrameRef = useRef<number>();
const lastUpdateRef = useRef<number>(0);
const networkRef = useRef(network);
useEffect(() => {
networkRef.current = network;
}, [network]);
const CELL_SIZE = CELL_SIZES[size];
@@ -32,7 +37,7 @@ export default function SnakeCanvas({ network, gridSize, showGrid = true, size =
if (network) {
setCurrentGame(createGame(gridSize));
}
}, [network, gridSize]);
}, [network?.id, gridSize]);
// Animation loop to step through game
useEffect(() => {
@@ -54,8 +59,11 @@ export default function SnakeCanvas({ network, gridSize, showGrid = true, size =
}
// Get neural network decision
const currentNetwork = networkRef.current;
if (!currentNetwork) return prevGame;
const inputs = getInputs(prevGame);
const action = getAction(network, inputs);
const action = getAction(currentNetwork, inputs);
// Step the game forward
return step(prevGame, action);
@@ -74,7 +82,7 @@ export default function SnakeCanvas({ network, gridSize, showGrid = true, size =
cancelAnimationFrame(animationFrameRef.current);
}
};
}, [network, currentGame, gridSize]);
}, [network?.id, !!currentGame, gridSize]); // Use ID and boolean existence check to prevent loop restart on every frame
// Set canvas size once when props change (not on every render)
useEffect(() => {

View File

@@ -1,9 +1,12 @@
import FitnessGraph from './FitnessGraph';
interface StatsProps {
generation: number;
bestFitness: number;
bestFitnessEver: number;
averageFitness: number;
gamesPlayed: number;
history: Array<{ generation: number; best: number; average: number }>;
}
export default function Stats({
@@ -12,6 +15,7 @@ export default function Stats({
bestFitnessEver,
averageFitness,
gamesPlayed,
history,
}: StatsProps) {
return (
<div className="stats-panel">
@@ -45,17 +49,10 @@ export default function Stats({
</div>
<div className="progress-indicator">
<div className="progress-label">
Improvement: {bestFitnessEver > 0 ? ((bestFitness / bestFitnessEver) * 100).toFixed(1) : 0}%
</div>
<div className="progress-bar">
<div
className="progress-fill"
style={{
width: `${bestFitnessEver > 0 ? Math.min(100, (bestFitness / bestFitnessEver) * 100) : 0}%`,
}}
/>
<div className="progress-label" style={{ marginBottom: '0.5rem' }}>
Fitness History
</div>
<FitnessGraph history={history} height={120} />
</div>
</div>
);

View File

@@ -13,6 +13,10 @@ export interface Population {
generation: number;
bestFitnessEver: number;
bestNetworkEver: Network | null;
lastGenerationStats?: {
bestFitness: number;
averageFitness: number;
};
}
export function createPopulation(config: EvolutionConfig): Population {
@@ -81,6 +85,10 @@ export function evolveGeneration(
// Sort by fitness (descending)
const sorted = [...population.individuals].sort((a, b) => b.fitness - a.fitness);
// Calculate stats for this generation BEFORE creating the new one
const currentBestFitness = sorted[0].fitness;
const currentAverageFitness = sorted.reduce((sum, ind) => sum + ind.fitness, 0) / sorted.length;
const newIndividuals: Individual[] = [];
// Elite preservation (top performers survive unchanged)
@@ -122,6 +130,10 @@ export function evolveGeneration(
generation: population.generation + 1,
bestFitnessEver: population.bestFitnessEver,
bestNetworkEver: population.bestNetworkEver,
lastGenerationStats: {
bestFitness: currentBestFitness,
averageFitness: currentAverageFitness
}
};
}
@@ -142,6 +154,7 @@ function selectParent(sorted: Individual[]): Individual {
function crossover(parent1: Network, parent2: Network): Network {
const child = cloneNetwork(parent1);
child.id = Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
// Single-point crossover on weights and biases
const crossoverRate = 0.5;
@@ -182,6 +195,7 @@ function crossover(parent1: Network, parent2: Network): Network {
function mutate(network: Network, mutationRate: number): Network {
const mutated = cloneNetwork(network);
mutated.id = Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
// Mutate input-hidden weights
for (let i = 0; i < mutated.weightsIH.length; i++) {

View File

@@ -0,0 +1,25 @@
import { evaluatePopulation, evolveGeneration, type Population } from './evolution';
import type { EvolutionConfig } from './types';
self.onmessage = (e: MessageEvent) => {
const { population, config, generations = 1 } = e.data as {
population: Population;
config: EvolutionConfig;
generations?: number;
};
try {
let currentPop = population;
for (let i = 0; i < generations; i++) {
// Run the heavy computation
const evaluated = evaluatePopulation(currentPop, config);
currentPop = evolveGeneration(evaluated, config);
}
// Send back the result
self.postMessage({ type: 'SUCCESS', payload: currentPop });
} catch (error) {
self.postMessage({ type: 'ERROR', payload: error });
}
};

View File

@@ -1,6 +1,7 @@
import { Action } from './types';
export interface Network {
id: string;
inputSize: number;
hiddenSize: number;
outputSize: number;
@@ -16,6 +17,7 @@ export function createNetwork(
outputSize: number = 3
): Network {
return {
id: generateId(),
inputSize,
hiddenSize,
outputSize,
@@ -26,6 +28,10 @@ export function createNetwork(
};
}
function generateId(): string {
return Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
}
function createRandomMatrix(rows: number, cols: number): number[][] {
const matrix: number[][] = [];
for (let i = 0; i < rows; i++) {
@@ -101,6 +107,7 @@ export function getAction(network: Network, inputs: number[]): Action {
export function cloneNetwork(network: Network): Network {
return {
id: network.id,
inputSize: network.inputSize,
hiddenSize: network.hiddenSize,
outputSize: network.outputSize,