184 lines
5.3 KiB
TypeScript
184 lines
5.3 KiB
TypeScript
import type { Genome, ActivationFunction } from './genome';
|
|
|
|
/**
|
|
* Feedforward neural network built from a NEAT genome.
|
|
*
|
|
* The network is built by topologically sorting the nodes and
|
|
* evaluating them in order to ensure feedforward behavior.
|
|
*/
|
|
|
|
interface NetworkNode {
|
|
id: number;
|
|
activation: ActivationFunction;
|
|
inputs: { weight: number; sourceId: number }[];
|
|
value: number;
|
|
}
|
|
|
|
export class NeuralNetwork {
|
|
private inputNodes: number[];
|
|
private outputNodes: number[];
|
|
private nodes: Map<number, NetworkNode>;
|
|
private evaluationOrder: number[];
|
|
|
|
constructor(genome: Genome) {
|
|
this.inputNodes = [];
|
|
this.outputNodes = [];
|
|
this.nodes = new Map();
|
|
this.evaluationOrder = [];
|
|
|
|
this.buildNetwork(genome);
|
|
}
|
|
|
|
/**
|
|
* Build the network from the genome
|
|
*/
|
|
private buildNetwork(genome: Genome): void {
|
|
// Create network nodes
|
|
for (const nodeGene of genome.nodes) {
|
|
this.nodes.set(nodeGene.id, {
|
|
id: nodeGene.id,
|
|
activation: nodeGene.activation,
|
|
inputs: [],
|
|
value: 0,
|
|
});
|
|
|
|
if (nodeGene.type === 'input') {
|
|
this.inputNodes.push(nodeGene.id);
|
|
} else if (nodeGene.type === 'output') {
|
|
this.outputNodes.push(nodeGene.id);
|
|
}
|
|
}
|
|
|
|
// Add connections
|
|
for (const conn of genome.connections) {
|
|
if (!conn.enabled) continue;
|
|
|
|
const targetNode = this.nodes.get(conn.to);
|
|
if (targetNode) {
|
|
targetNode.inputs.push({
|
|
weight: conn.weight,
|
|
sourceId: conn.from,
|
|
});
|
|
}
|
|
}
|
|
|
|
// Compute evaluation order (topological sort)
|
|
this.evaluationOrder = this.topologicalSort(genome);
|
|
}
|
|
|
|
/**
|
|
* Topological sort to determine evaluation order
|
|
*/
|
|
private topologicalSort(genome: Genome): number[] {
|
|
const inDegree = new Map<number, number>();
|
|
const adj = new Map<number, number[]>();
|
|
|
|
// Initialize
|
|
for (const node of genome.nodes) {
|
|
inDegree.set(node.id, 0);
|
|
adj.set(node.id, []);
|
|
}
|
|
|
|
// Build adjacency list and in-degrees
|
|
for (const conn of genome.connections) {
|
|
if (!conn.enabled) continue;
|
|
|
|
adj.get(conn.from)!.push(conn.to);
|
|
inDegree.set(conn.to, (inDegree.get(conn.to) || 0) + 1);
|
|
}
|
|
|
|
// Kahn's algorithm
|
|
const queue: number[] = [];
|
|
const order: number[] = [];
|
|
|
|
// Start with nodes that have no incoming edges
|
|
for (const [nodeId, degree] of inDegree.entries()) {
|
|
if (degree === 0) {
|
|
queue.push(nodeId);
|
|
}
|
|
}
|
|
|
|
while (queue.length > 0) {
|
|
const nodeId = queue.shift()!;
|
|
order.push(nodeId);
|
|
|
|
for (const neighbor of adj.get(nodeId) || []) {
|
|
inDegree.set(neighbor, inDegree.get(neighbor)! - 1);
|
|
if (inDegree.get(neighbor) === 0) {
|
|
queue.push(neighbor);
|
|
}
|
|
}
|
|
}
|
|
|
|
return order;
|
|
}
|
|
|
|
/**
|
|
* Activate the network with inputs and return outputs
|
|
*/
|
|
activate(inputs: number[]): number[] {
|
|
if (inputs.length !== this.inputNodes.length) {
|
|
throw new Error(`Expected ${this.inputNodes.length} inputs, got ${inputs.length}`);
|
|
}
|
|
|
|
// Reset all node values
|
|
for (const node of this.nodes.values()) {
|
|
node.value = 0;
|
|
}
|
|
|
|
// Set input values
|
|
for (let i = 0; i < this.inputNodes.length; i++) {
|
|
const node = this.nodes.get(this.inputNodes[i])!;
|
|
node.value = inputs[i];
|
|
}
|
|
|
|
// Evaluate nodes in topological order
|
|
for (const nodeId of this.evaluationOrder) {
|
|
const node = this.nodes.get(nodeId)!;
|
|
|
|
// Skip input nodes (already set)
|
|
if (this.inputNodes.includes(nodeId)) continue;
|
|
|
|
// Sum weighted inputs
|
|
let sum = 0;
|
|
for (const input of node.inputs) {
|
|
const sourceNode = this.nodes.get(input.sourceId);
|
|
if (sourceNode) {
|
|
sum += sourceNode.value * input.weight;
|
|
}
|
|
}
|
|
|
|
// Apply activation function
|
|
node.value = this.applyActivation(sum, node.activation);
|
|
}
|
|
|
|
// Collect output values
|
|
return this.outputNodes.map(id => this.nodes.get(id)!.value);
|
|
}
|
|
|
|
/**
|
|
* Apply activation function
|
|
*/
|
|
private applyActivation(x: number, activation: ActivationFunction): number {
|
|
switch (activation) {
|
|
case 'tanh':
|
|
return Math.tanh(x);
|
|
case 'sigmoid':
|
|
return 1 / (1 + Math.exp(-x));
|
|
case 'relu':
|
|
return Math.max(0, x);
|
|
case 'linear':
|
|
return x;
|
|
default:
|
|
return Math.tanh(x);
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Create a neural network from a genome
|
|
*/
|
|
export function createNetwork(genome: Genome): NeuralNetwork {
|
|
return new NeuralNetwork(genome);
|
|
}
|