import { DenseNetwork } from '../../apps/LunarLander/DenseNetwork'; export interface GAConfig { populationSize: number; mutationRate: number; mutationAmount: number; elitism: number; // Number of best agents to keep unchanged } export const DEFAULT_GA_CONFIG: GAConfig = { populationSize: 50, mutationRate: 0.05, // Reduced from 0.1 mutationAmount: 0.2, // Reduced from 0.5 elitism: 5 }; export class SimpleGA { private layerSizes: number[]; private config: GAConfig; constructor(layerSizes: number[], config: GAConfig = DEFAULT_GA_CONFIG) { this.layerSizes = layerSizes; this.config = config; } createPopulation(): Float32Array[] { const pop: Float32Array[] = []; // Helper to get weight count // We create a dummy network to calculate size easily, or duplicate logic. // Duplicating logic is safer to avoid instantiation overhead if large. // Logic from DenseNetwork: sum((full_in + 1) * out) // Let's just instantiate one to be sure. for (let i = 0; i < this.config.populationSize; i++) { const dn = new DenseNetwork(this.layerSizes); pop.push(dn.getWeights()); } return pop; } evolve(currentPop: Float32Array[], fitnesses: number[]): Float32Array[] { // 1. Sort by fitness (descending) const indices = currentPop.map((_, i) => i).sort((a, b) => fitnesses[b] - fitnesses[a]); const nextPop: Float32Array[] = []; const popSize = this.config.populationSize; // 2. Elitism for (let i = 0; i < this.config.elitism; i++) { if (i < indices.length) { // Keep exact copy nextPop.push(new Float32Array(currentPop[indices[i]])); } } // 3. Fill rest while (nextPop.length < popSize) { // Diversity Injection (Random Immigrants) // Increased from 5% to 15% to combat stagnation if (Math.random() < 0.15) { const dn = new DenseNetwork(this.layerSizes); nextPop.push(dn.getWeights()); continue; } // Tournament selection const p1 = currentPop[this.tournamentSelect(indices, fitnesses)]; const p2 = currentPop[this.tournamentSelect(indices, fitnesses)]; // Crossover const child = this.crossover(p1, p2); // Mutation this.mutate(child); nextPop.push(child); } return nextPop; } private tournamentSelect(indices: number[], fitnesses: number[]): number { const k = 3; let bestIndex = -1; let bestFitness = -Infinity; for (let i = 0; i < k; i++) { const r = Math.floor(Math.random() * indices.length); const realIdx = indices[r]; if (fitnesses[realIdx] > bestFitness) { bestFitness = fitnesses[realIdx]; bestIndex = realIdx; } } return bestIndex; } private crossover(w1: Float32Array, w2: Float32Array): Float32Array { const child = new Float32Array(w1.length); // Uniform crossover? Or Split? // Uniform is good for weights. for (let i = 0; i < w1.length; i++) { child[i] = Math.random() < 0.5 ? w1[i] : w2[i]; } return child; } private mutate(weights: Float32Array) { for (let i = 0; i < weights.length; i++) { if (Math.random() < this.config.mutationRate) { weights[i] += (Math.random() * 2 - 1) * this.config.mutationAmount; // Clamp? Optional. Tanh handles range usually. } } } }