Improve neural network and inputs
This commit is contained in:
@@ -135,29 +135,64 @@ export function getInputs(state: GameState): number[] {
|
||||
const head = state.snake[0];
|
||||
const food = state.food;
|
||||
|
||||
// 8 inputs for the neural network:
|
||||
// 4 food direction indicators (normalized -1 to 1)
|
||||
const foodDirX = (food.x - head.x) / state.gridSize;
|
||||
const foodDirY = (food.y - head.y) / state.gridSize;
|
||||
// Calculate relative direction vectors based on current direction
|
||||
// If facing UP (0): Front=(0, -1), Left=(-1, 0), Right=(1, 0)
|
||||
// If facing RIGHT (1): Front=(1, 0), Left=(0, -1), Right=(0, 1)
|
||||
// ...and so on
|
||||
|
||||
const frontVec = getDirectionVector(state.direction);
|
||||
const leftVec = getDirectionVector(((state.direction + 3) % 4) as Direction);
|
||||
const rightVec = getDirectionVector(((state.direction + 1) % 4) as Direction);
|
||||
|
||||
// 4 danger sensors (1 if danger, 0 if safe)
|
||||
const dangerUp = isDanger(state, head.x, head.y - 1);
|
||||
const dangerDown = isDanger(state, head.x, head.y + 1);
|
||||
const dangerLeft = isDanger(state, head.x - 1, head.y);
|
||||
const dangerRight = isDanger(state, head.x + 1, head.y);
|
||||
// 1. Danger Sensors (Relative)
|
||||
// Is there danger immediately to my Left, Front, or Right?
|
||||
const dangerLeft = isDanger(state, head.x + leftVec.x, head.y + leftVec.y);
|
||||
const dangerFront = isDanger(state, head.x + frontVec.x, head.y + frontVec.y);
|
||||
const dangerRight = isDanger(state, head.x + rightVec.x, head.y + rightVec.y);
|
||||
|
||||
// 2. Food Direction (Relative)
|
||||
// We want to know if food is to our Left/Right or In Front/Behind relative to head
|
||||
// We can use dot products or simple coordinate checks
|
||||
const relFoodX = food.x - head.x;
|
||||
const relFoodY = food.y - head.y;
|
||||
|
||||
// Dot product to project food vector onto our relative axes
|
||||
const foodFront = relFoodX * frontVec.x + relFoodY * frontVec.y;
|
||||
const foodSide = relFoodX * rightVec.x + relFoodY * rightVec.y;
|
||||
// foodSide: Positive = Right, Negative = Left
|
||||
|
||||
return [
|
||||
foodDirX,
|
||||
foodDirY,
|
||||
dangerUp ? 1 : 0,
|
||||
dangerDown ? 1 : 0,
|
||||
// Sensor 1: Danger Left
|
||||
dangerLeft ? 1 : 0,
|
||||
// Sensor 2: Danger Front
|
||||
dangerFront ? 1 : 0,
|
||||
// Sensor 3: Danger Right
|
||||
dangerRight ? 1 : 0,
|
||||
state.direction / 3, // Normalized direction
|
||||
state.snake.length / (state.gridSize * state.gridSize), // Normalized length
|
||||
|
||||
// Sensor 4: Food is to the Left
|
||||
foodSide < 0 ? 1 : 0,
|
||||
// Sensor 5: Food is to the Right
|
||||
foodSide > 0 ? 1 : 0,
|
||||
// Sensor 6: Food is Ahead
|
||||
foodFront > 0 ? 1 : 0,
|
||||
// Sensor 7: Food is Behind
|
||||
foodFront < 0 ? 1 : 0,
|
||||
|
||||
// Sensor 8: Normalized Length (Growth Sensor)
|
||||
state.snake.length / (state.gridSize * state.gridSize)
|
||||
];
|
||||
}
|
||||
|
||||
function getDirectionVector(dir: Direction): Position {
|
||||
switch (dir) {
|
||||
case Direction.UP: return { x: 0, y: -1 };
|
||||
case Direction.DOWN: return { x: 0, y: 1 };
|
||||
case Direction.LEFT: return { x: -1, y: 0 };
|
||||
case Direction.RIGHT: return { x: 1, y: 0 };
|
||||
default: return { x: 0, y: 0 };
|
||||
}
|
||||
}
|
||||
|
||||
function isDanger(state: GameState, x: number, y: number): boolean {
|
||||
// Check wall
|
||||
if (x < 0 || x >= state.gridSize || y < 0 || y >= state.gridSize) {
|
||||
|
||||
@@ -12,7 +12,7 @@ export interface Network {
|
||||
|
||||
export function createNetwork(
|
||||
inputSize: number = 8,
|
||||
hiddenSize: number = 12,
|
||||
hiddenSize: number = 18,
|
||||
outputSize: number = 3
|
||||
): Network {
|
||||
return {
|
||||
@@ -53,7 +53,9 @@ export function forward(network: Network, inputs: number[]): number[] {
|
||||
for (let i = 0; i < network.inputSize; i++) {
|
||||
sum += inputs[i] * network.weightsIH[i][h];
|
||||
}
|
||||
hidden[h] = tanh(sum);
|
||||
// ReLU activation for hidden layer: f(x) = max(0, x)
|
||||
// Faster and solves vanishing gradient better than tanh
|
||||
hidden[h] = Math.max(0, sum);
|
||||
}
|
||||
|
||||
// Output layer activation
|
||||
|
||||
Reference in New Issue
Block a user