Improve neural network and inputs
This commit is contained in:
@@ -135,29 +135,64 @@ export function getInputs(state: GameState): number[] {
|
|||||||
const head = state.snake[0];
|
const head = state.snake[0];
|
||||||
const food = state.food;
|
const food = state.food;
|
||||||
|
|
||||||
// 8 inputs for the neural network:
|
// Calculate relative direction vectors based on current direction
|
||||||
// 4 food direction indicators (normalized -1 to 1)
|
// If facing UP (0): Front=(0, -1), Left=(-1, 0), Right=(1, 0)
|
||||||
const foodDirX = (food.x - head.x) / state.gridSize;
|
// If facing RIGHT (1): Front=(1, 0), Left=(0, -1), Right=(0, 1)
|
||||||
const foodDirY = (food.y - head.y) / state.gridSize;
|
// ...and so on
|
||||||
|
|
||||||
|
const frontVec = getDirectionVector(state.direction);
|
||||||
|
const leftVec = getDirectionVector(((state.direction + 3) % 4) as Direction);
|
||||||
|
const rightVec = getDirectionVector(((state.direction + 1) % 4) as Direction);
|
||||||
|
|
||||||
// 4 danger sensors (1 if danger, 0 if safe)
|
// 1. Danger Sensors (Relative)
|
||||||
const dangerUp = isDanger(state, head.x, head.y - 1);
|
// Is there danger immediately to my Left, Front, or Right?
|
||||||
const dangerDown = isDanger(state, head.x, head.y + 1);
|
const dangerLeft = isDanger(state, head.x + leftVec.x, head.y + leftVec.y);
|
||||||
const dangerLeft = isDanger(state, head.x - 1, head.y);
|
const dangerFront = isDanger(state, head.x + frontVec.x, head.y + frontVec.y);
|
||||||
const dangerRight = isDanger(state, head.x + 1, head.y);
|
const dangerRight = isDanger(state, head.x + rightVec.x, head.y + rightVec.y);
|
||||||
|
|
||||||
|
// 2. Food Direction (Relative)
|
||||||
|
// We want to know if food is to our Left/Right or In Front/Behind relative to head
|
||||||
|
// We can use dot products or simple coordinate checks
|
||||||
|
const relFoodX = food.x - head.x;
|
||||||
|
const relFoodY = food.y - head.y;
|
||||||
|
|
||||||
|
// Dot product to project food vector onto our relative axes
|
||||||
|
const foodFront = relFoodX * frontVec.x + relFoodY * frontVec.y;
|
||||||
|
const foodSide = relFoodX * rightVec.x + relFoodY * rightVec.y;
|
||||||
|
// foodSide: Positive = Right, Negative = Left
|
||||||
|
|
||||||
return [
|
return [
|
||||||
foodDirX,
|
// Sensor 1: Danger Left
|
||||||
foodDirY,
|
|
||||||
dangerUp ? 1 : 0,
|
|
||||||
dangerDown ? 1 : 0,
|
|
||||||
dangerLeft ? 1 : 0,
|
dangerLeft ? 1 : 0,
|
||||||
|
// Sensor 2: Danger Front
|
||||||
|
dangerFront ? 1 : 0,
|
||||||
|
// Sensor 3: Danger Right
|
||||||
dangerRight ? 1 : 0,
|
dangerRight ? 1 : 0,
|
||||||
state.direction / 3, // Normalized direction
|
|
||||||
state.snake.length / (state.gridSize * state.gridSize), // Normalized length
|
// Sensor 4: Food is to the Left
|
||||||
|
foodSide < 0 ? 1 : 0,
|
||||||
|
// Sensor 5: Food is to the Right
|
||||||
|
foodSide > 0 ? 1 : 0,
|
||||||
|
// Sensor 6: Food is Ahead
|
||||||
|
foodFront > 0 ? 1 : 0,
|
||||||
|
// Sensor 7: Food is Behind
|
||||||
|
foodFront < 0 ? 1 : 0,
|
||||||
|
|
||||||
|
// Sensor 8: Normalized Length (Growth Sensor)
|
||||||
|
state.snake.length / (state.gridSize * state.gridSize)
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getDirectionVector(dir: Direction): Position {
|
||||||
|
switch (dir) {
|
||||||
|
case Direction.UP: return { x: 0, y: -1 };
|
||||||
|
case Direction.DOWN: return { x: 0, y: 1 };
|
||||||
|
case Direction.LEFT: return { x: -1, y: 0 };
|
||||||
|
case Direction.RIGHT: return { x: 1, y: 0 };
|
||||||
|
default: return { x: 0, y: 0 };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function isDanger(state: GameState, x: number, y: number): boolean {
|
function isDanger(state: GameState, x: number, y: number): boolean {
|
||||||
// Check wall
|
// Check wall
|
||||||
if (x < 0 || x >= state.gridSize || y < 0 || y >= state.gridSize) {
|
if (x < 0 || x >= state.gridSize || y < 0 || y >= state.gridSize) {
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ export interface Network {
|
|||||||
|
|
||||||
export function createNetwork(
|
export function createNetwork(
|
||||||
inputSize: number = 8,
|
inputSize: number = 8,
|
||||||
hiddenSize: number = 12,
|
hiddenSize: number = 18,
|
||||||
outputSize: number = 3
|
outputSize: number = 3
|
||||||
): Network {
|
): Network {
|
||||||
return {
|
return {
|
||||||
@@ -53,7 +53,9 @@ export function forward(network: Network, inputs: number[]): number[] {
|
|||||||
for (let i = 0; i < network.inputSize; i++) {
|
for (let i = 0; i < network.inputSize; i++) {
|
||||||
sum += inputs[i] * network.weightsIH[i][h];
|
sum += inputs[i] * network.weightsIH[i][h];
|
||||||
}
|
}
|
||||||
hidden[h] = tanh(sum);
|
// ReLU activation for hidden layer: f(x) = max(0, x)
|
||||||
|
// Faster and solves vanishing gradient better than tanh
|
||||||
|
hidden[h] = Math.max(0, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output layer activation
|
// Output layer activation
|
||||||
|
|||||||
Reference in New Issue
Block a user