Improve neural network and inputs

This commit is contained in:
Peter Stockings
2026-01-10 10:57:22 +11:00
parent c3e942ee60
commit 246a4a14e3
2 changed files with 54 additions and 17 deletions

View File

@@ -135,29 +135,64 @@ export function getInputs(state: GameState): number[] {
const head = state.snake[0];
const food = state.food;
// 8 inputs for the neural network:
// 4 food direction indicators (normalized -1 to 1)
const foodDirX = (food.x - head.x) / state.gridSize;
const foodDirY = (food.y - head.y) / state.gridSize;
// Calculate relative direction vectors based on current direction
// If facing UP (0): Front=(0, -1), Left=(-1, 0), Right=(1, 0)
// If facing RIGHT (1): Front=(1, 0), Left=(0, -1), Right=(0, 1)
// ...and so on
const frontVec = getDirectionVector(state.direction);
const leftVec = getDirectionVector(((state.direction + 3) % 4) as Direction);
const rightVec = getDirectionVector(((state.direction + 1) % 4) as Direction);
// 4 danger sensors (1 if danger, 0 if safe)
const dangerUp = isDanger(state, head.x, head.y - 1);
const dangerDown = isDanger(state, head.x, head.y + 1);
const dangerLeft = isDanger(state, head.x - 1, head.y);
const dangerRight = isDanger(state, head.x + 1, head.y);
// 1. Danger Sensors (Relative)
// Is there danger immediately to my Left, Front, or Right?
const dangerLeft = isDanger(state, head.x + leftVec.x, head.y + leftVec.y);
const dangerFront = isDanger(state, head.x + frontVec.x, head.y + frontVec.y);
const dangerRight = isDanger(state, head.x + rightVec.x, head.y + rightVec.y);
// 2. Food Direction (Relative)
// We want to know if food is to our Left/Right or In Front/Behind relative to head
// We can use dot products or simple coordinate checks
const relFoodX = food.x - head.x;
const relFoodY = food.y - head.y;
// Dot product to project food vector onto our relative axes
const foodFront = relFoodX * frontVec.x + relFoodY * frontVec.y;
const foodSide = relFoodX * rightVec.x + relFoodY * rightVec.y;
// foodSide: Positive = Right, Negative = Left
return [
foodDirX,
foodDirY,
dangerUp ? 1 : 0,
dangerDown ? 1 : 0,
// Sensor 1: Danger Left
dangerLeft ? 1 : 0,
// Sensor 2: Danger Front
dangerFront ? 1 : 0,
// Sensor 3: Danger Right
dangerRight ? 1 : 0,
state.direction / 3, // Normalized direction
state.snake.length / (state.gridSize * state.gridSize), // Normalized length
// Sensor 4: Food is to the Left
foodSide < 0 ? 1 : 0,
// Sensor 5: Food is to the Right
foodSide > 0 ? 1 : 0,
// Sensor 6: Food is Ahead
foodFront > 0 ? 1 : 0,
// Sensor 7: Food is Behind
foodFront < 0 ? 1 : 0,
// Sensor 8: Normalized Length (Growth Sensor)
state.snake.length / (state.gridSize * state.gridSize)
];
}
function getDirectionVector(dir: Direction): Position {
switch (dir) {
case Direction.UP: return { x: 0, y: -1 };
case Direction.DOWN: return { x: 0, y: 1 };
case Direction.LEFT: return { x: -1, y: 0 };
case Direction.RIGHT: return { x: 1, y: 0 };
default: return { x: 0, y: 0 };
}
}
function isDanger(state: GameState, x: number, y: number): boolean {
// Check wall
if (x < 0 || x >= state.gridSize || y < 0 || y >= state.gridSize) {

View File

@@ -12,7 +12,7 @@ export interface Network {
export function createNetwork(
inputSize: number = 8,
hiddenSize: number = 12,
hiddenSize: number = 18,
outputSize: number = 3
): Network {
return {
@@ -53,7 +53,9 @@ export function forward(network: Network, inputs: number[]): number[] {
for (let i = 0; i < network.inputSize; i++) {
sum += inputs[i] * network.weightsIH[i][h];
}
hidden[h] = tanh(sum);
// ReLU activation for hidden layer: f(x) = max(0, x)
// Faster and solves vanishing gradient better than tanh
hidden[h] = Math.max(0, sum);
}
// Output layer activation