118 lines
3.8 KiB
TypeScript
118 lines
3.8 KiB
TypeScript
|
|
import { DenseNetwork } from '../../apps/LunarLander/DenseNetwork';
|
|
|
|
export interface GAConfig {
|
|
populationSize: number;
|
|
mutationRate: number;
|
|
mutationAmount: number;
|
|
elitism: number; // Number of best agents to keep unchanged
|
|
}
|
|
|
|
export const DEFAULT_GA_CONFIG: GAConfig = {
|
|
populationSize: 50,
|
|
mutationRate: 0.05, // Reduced from 0.1
|
|
mutationAmount: 0.2, // Reduced from 0.5
|
|
elitism: 5
|
|
};
|
|
|
|
export class SimpleGA {
|
|
private layerSizes: number[];
|
|
private config: GAConfig;
|
|
|
|
constructor(layerSizes: number[], config: GAConfig = DEFAULT_GA_CONFIG) {
|
|
this.layerSizes = layerSizes;
|
|
this.config = config;
|
|
}
|
|
|
|
createPopulation(): Float32Array[] {
|
|
const pop: Float32Array[] = [];
|
|
// Helper to get weight count
|
|
// We create a dummy network to calculate size easily, or duplicate logic.
|
|
// Duplicating logic is safer to avoid instantiation overhead if large.
|
|
// Logic from DenseNetwork: sum((full_in + 1) * out)
|
|
// Let's just instantiate one to be sure.
|
|
|
|
for (let i = 0; i < this.config.populationSize; i++) {
|
|
const dn = new DenseNetwork(this.layerSizes);
|
|
pop.push(dn.getWeights());
|
|
}
|
|
return pop;
|
|
}
|
|
|
|
evolve(currentPop: Float32Array[], fitnesses: number[]): Float32Array[] {
|
|
// 1. Sort by fitness (descending)
|
|
const indices = currentPop.map((_, i) => i).sort((a, b) => fitnesses[b] - fitnesses[a]);
|
|
|
|
const nextPop: Float32Array[] = [];
|
|
const popSize = this.config.populationSize;
|
|
|
|
// 2. Elitism
|
|
for (let i = 0; i < this.config.elitism; i++) {
|
|
if (i < indices.length) {
|
|
// Keep exact copy
|
|
nextPop.push(new Float32Array(currentPop[indices[i]]));
|
|
}
|
|
}
|
|
|
|
// 3. Fill rest
|
|
while (nextPop.length < popSize) {
|
|
// Diversity Injection (Random Immigrants)
|
|
// Increased from 5% to 15% to combat stagnation
|
|
if (Math.random() < 0.15) {
|
|
const dn = new DenseNetwork(this.layerSizes);
|
|
nextPop.push(dn.getWeights());
|
|
continue;
|
|
}
|
|
|
|
// Tournament selection
|
|
const p1 = currentPop[this.tournamentSelect(indices, fitnesses)];
|
|
const p2 = currentPop[this.tournamentSelect(indices, fitnesses)];
|
|
|
|
// Crossover
|
|
const child = this.crossover(p1, p2);
|
|
|
|
// Mutation
|
|
this.mutate(child);
|
|
|
|
nextPop.push(child);
|
|
}
|
|
|
|
return nextPop;
|
|
}
|
|
|
|
private tournamentSelect(indices: number[], fitnesses: number[]): number {
|
|
const k = 3;
|
|
let bestIndex = -1;
|
|
let bestFitness = -Infinity;
|
|
|
|
for (let i = 0; i < k; i++) {
|
|
const r = Math.floor(Math.random() * indices.length);
|
|
const realIdx = indices[r];
|
|
if (fitnesses[realIdx] > bestFitness) {
|
|
bestFitness = fitnesses[realIdx];
|
|
bestIndex = realIdx;
|
|
}
|
|
}
|
|
return bestIndex;
|
|
}
|
|
|
|
private crossover(w1: Float32Array, w2: Float32Array): Float32Array {
|
|
const child = new Float32Array(w1.length);
|
|
// Uniform crossover? Or Split?
|
|
// Uniform is good for weights.
|
|
for (let i = 0; i < w1.length; i++) {
|
|
child[i] = Math.random() < 0.5 ? w1[i] : w2[i];
|
|
}
|
|
return child;
|
|
}
|
|
|
|
private mutate(weights: Float32Array) {
|
|
for (let i = 0; i < weights.length; i++) {
|
|
if (Math.random() < this.config.mutationRate) {
|
|
weights[i] += (Math.random() * 2 - 1) * this.config.mutationAmount;
|
|
// Clamp? Optional. Tanh handles range usually.
|
|
}
|
|
}
|
|
}
|
|
}
|