Add fitness graph
This commit is contained in:
140
src/apps/SnakeAI/FitnessGraph.tsx
Normal file
140
src/apps/SnakeAI/FitnessGraph.tsx
Normal file
@@ -0,0 +1,140 @@
|
||||
interface FitnessGraphProps {
|
||||
history: Array<{ generation: number; best: number; average: number }>;
|
||||
width?: number | string;
|
||||
height?: number | string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export default function FitnessGraph({ history, width = "100%", height = 150, className = "" }: FitnessGraphProps) {
|
||||
if (history.length < 2) {
|
||||
return (
|
||||
<div style={{
|
||||
width,
|
||||
height,
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
color: '#666',
|
||||
fontSize: '0.8rem',
|
||||
background: 'rgba(0,0,0,0.2)',
|
||||
borderRadius: '4px'
|
||||
}}>
|
||||
Waiting for data...
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const PADDING = 20; // Internal padding
|
||||
// Use internal coordinate system for viewBox
|
||||
const VIEW_WIDTH = 500;
|
||||
const VIEW_HEIGHT = 200;
|
||||
|
||||
const GRAPH_WIDTH = VIEW_WIDTH - PADDING * 2;
|
||||
const GRAPH_HEIGHT = VIEW_HEIGHT - PADDING * 2;
|
||||
|
||||
// Find min/max for scaling
|
||||
const maxFitness = Math.max(...history.map(h => h.best), 1);
|
||||
const minGeneration = history[0].generation;
|
||||
const maxGeneration = history[history.length - 1].generation;
|
||||
const genRange = Math.max(maxGeneration - minGeneration, 1);
|
||||
|
||||
// Helper to scale points
|
||||
const getX = (gen: number) => {
|
||||
return PADDING + ((gen - minGeneration) / genRange) * GRAPH_WIDTH;
|
||||
};
|
||||
|
||||
const getY = (fitness: number) => {
|
||||
// Invert Y because SVG 0 is top
|
||||
return PADDING + GRAPH_HEIGHT - (fitness / maxFitness) * GRAPH_HEIGHT;
|
||||
};
|
||||
|
||||
// Generate path data
|
||||
const bestPath = history.map((p, i) =>
|
||||
`${i === 0 ? 'M' : 'L'} ${getX(p.generation)} ${getY(p.best)}`
|
||||
).join(' ');
|
||||
|
||||
const averagePath = history.map((p, i) =>
|
||||
`${i === 0 ? 'M' : 'L'} ${getX(p.generation)} ${getY(p.average)}`
|
||||
).join(' ');
|
||||
|
||||
|
||||
// Areas (closed paths for gradients)
|
||||
const bestArea = bestPath + ` L ${getX(history[history.length - 1].generation)} ${GRAPH_HEIGHT + PADDING} L ${getX(minGeneration)} ${GRAPH_HEIGHT + PADDING} Z`;
|
||||
const averageArea = averagePath + ` L ${getX(history[history.length - 1].generation)} ${GRAPH_HEIGHT + PADDING} L ${getX(minGeneration)} ${GRAPH_HEIGHT + PADDING} Z`;
|
||||
|
||||
return (
|
||||
<div className={`fitness-graph-container ${className}`} style={{ width: '100%', height, position: 'relative' }}>
|
||||
{/* Legend Overlay */}
|
||||
<div style={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
right: 0,
|
||||
display: 'flex',
|
||||
gap: '12px',
|
||||
fontSize: '0.75rem',
|
||||
fontWeight: 600,
|
||||
background: 'rgba(0,0,0,0.4)',
|
||||
padding: '4px 8px',
|
||||
borderRadius: '0 0 0 8px',
|
||||
pointerEvents: 'none',
|
||||
backdropFilter: 'blur(2px)'
|
||||
}}>
|
||||
<div style={{ color: '#4ecdc4', display: 'flex', alignItems: 'center', gap: '6px' }}>
|
||||
<div style={{ width: 8, height: 8, background: '#4ecdc4', borderRadius: '50%' }}></div>
|
||||
Best: {Math.round(history[history.length - 1].best)}
|
||||
</div>
|
||||
<div style={{ color: '#4a9eff', display: 'flex', alignItems: 'center', gap: '6px' }}>
|
||||
<div style={{ width: 8, height: 8, background: '#4a9eff', borderRadius: '50%' }}></div>
|
||||
Avg: {Math.round(history[history.length - 1].average)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<svg
|
||||
width="100%"
|
||||
height="100%"
|
||||
viewBox={`0 0 ${VIEW_WIDTH} ${VIEW_HEIGHT}`}
|
||||
preserveAspectRatio="none"
|
||||
style={{ overflow: 'visible' }}
|
||||
>
|
||||
<defs>
|
||||
<linearGradient id="gradBest" x1="0%" y1="0%" x2="0%" y2="100%">
|
||||
<stop offset="0%" stopColor="#4ecdc4" stopOpacity={0.4} />
|
||||
<stop offset="100%" stopColor="#4ecdc4" stopOpacity={0} />
|
||||
</linearGradient>
|
||||
<linearGradient id="gradAvg" x1="0%" y1="0%" x2="0%" y2="100%">
|
||||
<stop offset="0%" stopColor="#4a9eff" stopOpacity={0.3} />
|
||||
<stop offset="100%" stopColor="#4a9eff" stopOpacity={0} />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
|
||||
{/* Grid Lines (Horizontal) */}
|
||||
{[0, 0.25, 0.5, 0.75, 1].map(ratio => {
|
||||
const y = PADDING + ratio * GRAPH_HEIGHT;
|
||||
return (
|
||||
<line
|
||||
key={ratio}
|
||||
x1={PADDING}
|
||||
y1={y}
|
||||
x2={VIEW_WIDTH - PADDING}
|
||||
y2={y}
|
||||
stroke="#333"
|
||||
strokeWidth="1"
|
||||
strokeDasharray="4 4"
|
||||
opacity="0.5"
|
||||
/>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Average Area */}
|
||||
<path d={averageArea} fill="url(#gradAvg)" />
|
||||
{/* Average Line */}
|
||||
<path d={averagePath} fill="none" stroke="#4a9eff" strokeWidth="2" strokeOpacity="0.8" />
|
||||
|
||||
{/* Best Area */}
|
||||
<path d={bestArea} fill="url(#gradBest)" />
|
||||
{/* Best Line */}
|
||||
<path d={bestPath} fill="none" stroke="#4ecdc4" strokeWidth="2.5" />
|
||||
</svg>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -328,9 +328,12 @@ input[type='range']::-webkit-slider-thumb:hover {
|
||||
}
|
||||
|
||||
.progress-indicator {
|
||||
background: #080808;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid #222;
|
||||
background: linear-gradient(135deg, #2a2a3e 0%, #1a1a2e 100%);
|
||||
padding: 1.5rem;
|
||||
border-radius: 12px;
|
||||
border: 1px solid #3a3a4e;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.progress-label {
|
||||
|
||||
@@ -7,10 +7,6 @@ import Tips from './Tips';
|
||||
import BestSnakeDisplay from './BestSnakeDisplay';
|
||||
import {
|
||||
createPopulation,
|
||||
evaluatePopulation,
|
||||
evolveGeneration,
|
||||
getBestIndividual,
|
||||
getAverageFitness,
|
||||
type Population,
|
||||
} from '../../lib/snakeAI/evolution';
|
||||
import type { EvolutionConfig } from '../../lib/snakeAI/types';
|
||||
@@ -24,6 +20,8 @@ const DEFAULT_CONFIG: EvolutionConfig = {
|
||||
maxGameSteps: 20000,
|
||||
};
|
||||
|
||||
import EvolutionWorker from '../../lib/snakeAI/evolution.worker?worker';
|
||||
|
||||
export default function SnakeAI() {
|
||||
const [population, setPopulation] = useState<Population>(() =>
|
||||
createPopulation(DEFAULT_CONFIG)
|
||||
@@ -32,30 +30,85 @@ export default function SnakeAI() {
|
||||
const [isRunning, setIsRunning] = useState(false);
|
||||
const [speed, setSpeed] = useState(5);
|
||||
const [gamesPlayed, setGamesPlayed] = useState(0);
|
||||
const [fitnessHistory, setFitnessHistory] = useState<Array<{ generation: number, best: number, average: number }>>([]);
|
||||
|
||||
// Compute derived values from population
|
||||
const bestIndividual = getBestIndividual(population);
|
||||
const averageFitness = getAverageFitness(population);
|
||||
// Keep a ref to population for the worker
|
||||
const populationRef = useRef(population);
|
||||
useEffect(() => {
|
||||
populationRef.current = population;
|
||||
}, [population]);
|
||||
|
||||
const animationFrameRef = useRef<number>();
|
||||
const lastUpdateRef = useRef<number>(0);
|
||||
|
||||
const runGeneration = useCallback(() => {
|
||||
setPopulation((prev) => {
|
||||
try {
|
||||
// Evaluate current generation
|
||||
const evaluated = evaluatePopulation(prev, config);
|
||||
// Compute derived values for display
|
||||
// If we have stats from the last generation, use them. Otherwise default to 0.
|
||||
const currentBestFitness = population.lastGenerationStats?.bestFitness || 0;
|
||||
const currentAverageFitness = population.lastGenerationStats?.averageFitness || 0;
|
||||
|
||||
// Evolve to next generation
|
||||
const nextGen = evolveGeneration(evaluated, config);
|
||||
const workerRef = useRef<Worker | null>(null);
|
||||
const isProcessingRef = useRef(false);
|
||||
|
||||
return nextGen;
|
||||
} catch (error) {
|
||||
console.error("SnakeAI: Generation update failed", error);
|
||||
return prev;
|
||||
useEffect(() => {
|
||||
workerRef.current = new EvolutionWorker();
|
||||
workerRef.current.onmessage = (e) => {
|
||||
const { type, payload } = e.data; // payload is the NEW population
|
||||
if (type === 'SUCCESS') {
|
||||
// Critical: Update ref immediately to prevent race condition with next animation frame
|
||||
populationRef.current = payload;
|
||||
setPopulation(payload);
|
||||
|
||||
// Update history if we have stats
|
||||
if (payload.lastGenerationStats) {
|
||||
setFitnessHistory(prev => {
|
||||
const newEntry = {
|
||||
generation: payload.generation - 1, // The stats are for the gen that just finished
|
||||
best: payload.lastGenerationStats!.bestFitness,
|
||||
average: payload.lastGenerationStats!.averageFitness
|
||||
};
|
||||
// Keep last 100 generations to avoid memory issues if running for eternity
|
||||
const newHistory = [...prev, newEntry];
|
||||
if (newHistory.length > 100) return newHistory.slice(newHistory.length - 100);
|
||||
return newHistory;
|
||||
});
|
||||
}
|
||||
|
||||
isProcessingRef.current = false;
|
||||
} else {
|
||||
console.error("Worker error:", payload);
|
||||
isProcessingRef.current = false;
|
||||
}
|
||||
};
|
||||
|
||||
return () => {
|
||||
workerRef.current?.terminate();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const runGeneration = useCallback((generations: number = 1) => {
|
||||
if (isProcessingRef.current || !workerRef.current) return;
|
||||
|
||||
isProcessingRef.current = true;
|
||||
// We need to send the *current* population.
|
||||
// Since this is inside a callback, we need to be careful about closure staleness.
|
||||
// However, we can't easily access the "latest" state inside a callback without refs or dependency.
|
||||
// But 'population' is in the dependency array of the effect calling this? No.
|
||||
// The animate loop calls this.
|
||||
|
||||
// Let's use a functional update approach? No, we need to SEND data.
|
||||
// We will use a ref to track current population for the worker to ensure we always send latest
|
||||
// OR rely on the fact that 'population' is in dependency of runGeneration (it wasn't before).
|
||||
|
||||
// Wait, 'runGeneration' lines 43-58 previously used setPopulation(prev => ...).
|
||||
// It didn't need 'population' in dependency.
|
||||
// Now we need it.
|
||||
|
||||
workerRef.current.postMessage({
|
||||
population: populationRef.current, // Use a ref for latest population
|
||||
config,
|
||||
generations
|
||||
});
|
||||
}, [config]);
|
||||
}, [config]); // populationRef will be handled separately
|
||||
|
||||
// Update stats when generation changes
|
||||
useEffect(() => {
|
||||
@@ -93,7 +146,7 @@ export default function SnakeAI() {
|
||||
}
|
||||
|
||||
if (elapsed >= updateInterval) {
|
||||
runGeneration();
|
||||
runGeneration(1);
|
||||
lastUpdateRef.current = timestamp;
|
||||
}
|
||||
} else {
|
||||
@@ -102,9 +155,9 @@ export default function SnakeAI() {
|
||||
// Speed 100 -> 10 gens per frame (~600 eps)
|
||||
const gensPerFrame = Math.floor((speed - 10) / 10);
|
||||
|
||||
for (let i = 0; i < gensPerFrame; i++) {
|
||||
runGeneration();
|
||||
}
|
||||
// For turbo mode, we just fire once per frame (or whenever the worker is ready)
|
||||
// asking for multiple generations
|
||||
runGeneration(gensPerFrame);
|
||||
lastUpdateRef.current = timestamp;
|
||||
}
|
||||
|
||||
@@ -122,7 +175,9 @@ export default function SnakeAI() {
|
||||
|
||||
const handleReset = () => {
|
||||
setIsRunning(false);
|
||||
setPopulation(createPopulation(config));
|
||||
const newPop = createPopulation(config);
|
||||
populationRef.current = newPop;
|
||||
setPopulation(newPop);
|
||||
setGamesPlayed(0);
|
||||
};
|
||||
|
||||
@@ -162,10 +217,11 @@ export default function SnakeAI() {
|
||||
|
||||
<Stats
|
||||
generation={population.generation}
|
||||
bestFitness={bestIndividual.fitness}
|
||||
bestFitness={currentBestFitness}
|
||||
bestFitnessEver={population.bestFitnessEver}
|
||||
averageFitness={averageFitness}
|
||||
averageFitness={currentAverageFitness}
|
||||
gamesPlayed={gamesPlayed}
|
||||
history={fitnessHistory}
|
||||
/>
|
||||
|
||||
<Tips />
|
||||
|
||||
@@ -24,6 +24,11 @@ export default function SnakeCanvas({ network, gridSize, showGrid = true, size =
|
||||
const [currentGame, setCurrentGame] = useState<GameState | null>(null);
|
||||
const animationFrameRef = useRef<number>();
|
||||
const lastUpdateRef = useRef<number>(0);
|
||||
const networkRef = useRef(network);
|
||||
|
||||
useEffect(() => {
|
||||
networkRef.current = network;
|
||||
}, [network]);
|
||||
|
||||
const CELL_SIZE = CELL_SIZES[size];
|
||||
|
||||
@@ -32,7 +37,7 @@ export default function SnakeCanvas({ network, gridSize, showGrid = true, size =
|
||||
if (network) {
|
||||
setCurrentGame(createGame(gridSize));
|
||||
}
|
||||
}, [network, gridSize]);
|
||||
}, [network?.id, gridSize]);
|
||||
|
||||
// Animation loop to step through game
|
||||
useEffect(() => {
|
||||
@@ -54,8 +59,11 @@ export default function SnakeCanvas({ network, gridSize, showGrid = true, size =
|
||||
}
|
||||
|
||||
// Get neural network decision
|
||||
const currentNetwork = networkRef.current;
|
||||
if (!currentNetwork) return prevGame;
|
||||
|
||||
const inputs = getInputs(prevGame);
|
||||
const action = getAction(network, inputs);
|
||||
const action = getAction(currentNetwork, inputs);
|
||||
|
||||
// Step the game forward
|
||||
return step(prevGame, action);
|
||||
@@ -74,7 +82,7 @@ export default function SnakeCanvas({ network, gridSize, showGrid = true, size =
|
||||
cancelAnimationFrame(animationFrameRef.current);
|
||||
}
|
||||
};
|
||||
}, [network, currentGame, gridSize]);
|
||||
}, [network?.id, !!currentGame, gridSize]); // Use ID and boolean existence check to prevent loop restart on every frame
|
||||
|
||||
// Set canvas size once when props change (not on every render)
|
||||
useEffect(() => {
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import FitnessGraph from './FitnessGraph';
|
||||
|
||||
interface StatsProps {
|
||||
generation: number;
|
||||
bestFitness: number;
|
||||
bestFitnessEver: number;
|
||||
averageFitness: number;
|
||||
gamesPlayed: number;
|
||||
history: Array<{ generation: number; best: number; average: number }>;
|
||||
}
|
||||
|
||||
export default function Stats({
|
||||
@@ -12,6 +15,7 @@ export default function Stats({
|
||||
bestFitnessEver,
|
||||
averageFitness,
|
||||
gamesPlayed,
|
||||
history,
|
||||
}: StatsProps) {
|
||||
return (
|
||||
<div className="stats-panel">
|
||||
@@ -45,17 +49,10 @@ export default function Stats({
|
||||
</div>
|
||||
|
||||
<div className="progress-indicator">
|
||||
<div className="progress-label">
|
||||
Improvement: {bestFitnessEver > 0 ? ((bestFitness / bestFitnessEver) * 100).toFixed(1) : 0}%
|
||||
</div>
|
||||
<div className="progress-bar">
|
||||
<div
|
||||
className="progress-fill"
|
||||
style={{
|
||||
width: `${bestFitnessEver > 0 ? Math.min(100, (bestFitness / bestFitnessEver) * 100) : 0}%`,
|
||||
}}
|
||||
/>
|
||||
<div className="progress-label" style={{ marginBottom: '0.5rem' }}>
|
||||
Fitness History
|
||||
</div>
|
||||
<FitnessGraph history={history} height={120} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -13,6 +13,10 @@ export interface Population {
|
||||
generation: number;
|
||||
bestFitnessEver: number;
|
||||
bestNetworkEver: Network | null;
|
||||
lastGenerationStats?: {
|
||||
bestFitness: number;
|
||||
averageFitness: number;
|
||||
};
|
||||
}
|
||||
|
||||
export function createPopulation(config: EvolutionConfig): Population {
|
||||
@@ -81,6 +85,10 @@ export function evolveGeneration(
|
||||
// Sort by fitness (descending)
|
||||
const sorted = [...population.individuals].sort((a, b) => b.fitness - a.fitness);
|
||||
|
||||
// Calculate stats for this generation BEFORE creating the new one
|
||||
const currentBestFitness = sorted[0].fitness;
|
||||
const currentAverageFitness = sorted.reduce((sum, ind) => sum + ind.fitness, 0) / sorted.length;
|
||||
|
||||
const newIndividuals: Individual[] = [];
|
||||
|
||||
// Elite preservation (top performers survive unchanged)
|
||||
@@ -122,6 +130,10 @@ export function evolveGeneration(
|
||||
generation: population.generation + 1,
|
||||
bestFitnessEver: population.bestFitnessEver,
|
||||
bestNetworkEver: population.bestNetworkEver,
|
||||
lastGenerationStats: {
|
||||
bestFitness: currentBestFitness,
|
||||
averageFitness: currentAverageFitness
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -142,6 +154,7 @@ function selectParent(sorted: Individual[]): Individual {
|
||||
|
||||
function crossover(parent1: Network, parent2: Network): Network {
|
||||
const child = cloneNetwork(parent1);
|
||||
child.id = Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
|
||||
|
||||
// Single-point crossover on weights and biases
|
||||
const crossoverRate = 0.5;
|
||||
@@ -182,6 +195,7 @@ function crossover(parent1: Network, parent2: Network): Network {
|
||||
|
||||
function mutate(network: Network, mutationRate: number): Network {
|
||||
const mutated = cloneNetwork(network);
|
||||
mutated.id = Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
|
||||
|
||||
// Mutate input-hidden weights
|
||||
for (let i = 0; i < mutated.weightsIH.length; i++) {
|
||||
|
||||
25
src/lib/snakeAI/evolution.worker.ts
Normal file
25
src/lib/snakeAI/evolution.worker.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { evaluatePopulation, evolveGeneration, type Population } from './evolution';
|
||||
import type { EvolutionConfig } from './types';
|
||||
|
||||
self.onmessage = (e: MessageEvent) => {
|
||||
const { population, config, generations = 1 } = e.data as {
|
||||
population: Population;
|
||||
config: EvolutionConfig;
|
||||
generations?: number;
|
||||
};
|
||||
|
||||
try {
|
||||
let currentPop = population;
|
||||
|
||||
for (let i = 0; i < generations; i++) {
|
||||
// Run the heavy computation
|
||||
const evaluated = evaluatePopulation(currentPop, config);
|
||||
currentPop = evolveGeneration(evaluated, config);
|
||||
}
|
||||
|
||||
// Send back the result
|
||||
self.postMessage({ type: 'SUCCESS', payload: currentPop });
|
||||
} catch (error) {
|
||||
self.postMessage({ type: 'ERROR', payload: error });
|
||||
}
|
||||
};
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Action } from './types';
|
||||
|
||||
export interface Network {
|
||||
id: string;
|
||||
inputSize: number;
|
||||
hiddenSize: number;
|
||||
outputSize: number;
|
||||
@@ -16,6 +17,7 @@ export function createNetwork(
|
||||
outputSize: number = 3
|
||||
): Network {
|
||||
return {
|
||||
id: generateId(),
|
||||
inputSize,
|
||||
hiddenSize,
|
||||
outputSize,
|
||||
@@ -26,6 +28,10 @@ export function createNetwork(
|
||||
};
|
||||
}
|
||||
|
||||
function generateId(): string {
|
||||
return Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
|
||||
}
|
||||
|
||||
function createRandomMatrix(rows: number, cols: number): number[][] {
|
||||
const matrix: number[][] = [];
|
||||
for (let i = 0; i < rows; i++) {
|
||||
@@ -101,6 +107,7 @@ export function getAction(network: Network, inputs: number[]): Action {
|
||||
|
||||
export function cloneNetwork(network: Network): Network {
|
||||
return {
|
||||
id: network.id,
|
||||
inputSize: network.inputSize,
|
||||
hiddenSize: network.hiddenSize,
|
||||
outputSize: network.outputSize,
|
||||
|
||||
Reference in New Issue
Block a user