Issue
I'm a high school senior who's working on a project for my CS research class (I'm very lucky to have the opportunity to be in such a class)! The project is to make an AI learn the popular game, Snake, with a Multilayer Perceptron (MLP) that learns through Genetic Algorithm (GA). This project is heavily inspired by many videos I've seen on Youtube accomplishing what I've just described, as you can see here and here. I've written the project described above using JavaFX and an AI library called Neuroph.
This is what my program looks like currently:
The name is irrelevant, as I have a list of nouns and adjectives I used to generate them from (I thought it would make it more interesting). The number in the parenthesis for Score is the best score in that generation, since only 1 snake is shown at a time.
When breeding, I set x% of the snakes to be parents (in this case, 20). The number of children is then divided up evenly for each pair of snake parents. The "genes" in this case, are the weights of the MLP. Since my library doesn't really support biases, I added a bias neuron to the input layer and connected it to all of the other neurons in every layer for its weights to act as biases instead (as described in a thread here). Each of the snake's children has a 50, 50 chance of getting either parents' gene for every gene. There is also a 5% chance for a gene to mutate, where it's set to a random number between -1.0 and 1.0.
Each snake's MLP has 3 layers: 18 input neurons, 14 hidden ones, and 4 output neurons (for each direction). The inputs I feed it are the x of head, y of head, x of food, y of food, and steps left. It also looks in 4 directions, and check for the distance to food, wall, and itself (if it doesn't see it, it's set to -1.0). There's also the bias neuron I talked about which brings the number to 18 after adding it.
The way I calculate a snake's score is through my fitness function, which is (apples consumed × 5 + seconds alive / 2)
Here is my GAMLPAgent.java, where all the MLP and GA stuff happens.
package agents;
import graphics.Snake;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Stream;
import javafx.scene.shape.Rectangle;
import org.neuroph.core.Layer;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.comp.neuron.BiasNeuron;
import org.neuroph.util.NeuralNetworkType;
import org.neuroph.util.TransferFunctionType;
import util.Direction;
/**
*
* @author Preston Tang
*
* GAMLPAgent stands for Genetic Algorithm Multi-Layer Perceptron Agent
*/
public class GAMLPAgent implements Comparable<GAMLPAgent> {
public Snake mask;
private final MultiLayerPerceptron mlp;
private final int width;
private final int height;
private final double size;
private final double mutationRate = 0.05;
public GAMLPAgent(Snake mask, int width, int height, double size) {
this.mask = mask;
this.width = width;
this.height = height;
this.size = size;
//Input: x of head, y of head, x of food, y of food, steps left
//Input: 4 directions, check for distance to food, wall, and self + 1 bias neuron (18 total)
//6 hidden perceptrons (2 hidden layer(s))
//Output: A direction, 4 possibilities
mlp = new MultiLayerPerceptron(TransferFunctionType.SIGMOID, 18, 14, 4);
//Adding connections
List<Layer> layers = mlp.getLayers();
for (int r = 0; r < layers.size(); r++) {
for (int c = 0; c < layers.get(r).getNeuronsCount(); c++) {
mlp.getInputNeurons().get(mlp.getInputsCount() - 1).addInputConnection(layers.get(r).getNeuronAt(c));
}
}
// System.out.println(mlp.getInputNeurons().get(17).getInputConnections() + " " + mlp.getInputNeurons().get(17).getOutConnections());
mlp.randomizeWeights();
// System.out.println(Arrays.toString(mlp.getInputNeurons().get(17).getWeights()));
}
public void compute() {
if (mask.isAlive()) {
Rectangle head = mask.getSnakeParts().get(0);
Rectangle food = mask.getFood();
double headX = head.getX();
double headY = head.getY();
double foodX = mask.getFood().getX();
double foodY = mask.getFood().getY();
int stepsLeft = mask.getSteps();
double foodL = -1.0, wallL, selfL = -1.0;
double foodR = -1.0, wallR, selfR = -1.0;
double foodU = -1.0, wallU, selfU = -1.0;
double foodD = -1.0, wallD, selfD = -1.0;
//The 4 directions
//Left Direction
if (head.getY() == food.getY() && head.getX() > food.getX()) {
foodL = head.getX() - food.getX();
}
wallL = head.getX() - size;
for (Rectangle part : mask.getSnakeParts()) {
if (head.getY() == part.getY() && head.getX() > part.getX()) {
selfL = head.getX() - part.getX();
break;
}
}
//Right Direction
if (head.getY() == food.getY() && head.getX() < food.getX()) {
foodR = food.getX() - head.getX();
}
wallR = size * width - head.getX();
for (Rectangle part : mask.getSnakeParts()) {
if (head.getY() == part.getY() && head.getX() < part.getX()) {
selfR = part.getX() - head.getX();
break;
}
}
//Up Direction
if (head.getX() == food.getX() && head.getY() < food.getY()) {
foodU = food.getY() - head.getY();
}
wallU = size * height - head.getY();
for (Rectangle part : mask.getSnakeParts()) {
if (head.getX() == part.getX() && head.getY() < part.getY()) {
selfU = part.getY() - head.getY();
break;
}
}
//Down Direction
if (head.getX() == food.getX() && head.getY() > food.getY()) {
foodD = head.getY() - food.getY();
}
wallD = head.getY() - size;
for (Rectangle part : mask.getSnakeParts()) {
if (head.getX() == part.getX() && head.getY() > part.getY()) {
selfD = head.getY() - food.getY();
break;
}
}
mlp.setInput(
headX, headY, foodX, foodY, stepsLeft,
foodL, wallL, selfL,
foodR, wallR, selfR,
foodU, wallU, selfU,
foodD, wallD, selfD, 1);
mlp.calculate();
if (getIndexOfLargest(mlp.getOutput()) == 0) {
mask.setDirection(Direction.UP);
} else if (getIndexOfLargest(mlp.getOutput()) == 1) {
mask.setDirection(Direction.DOWN);
} else if (getIndexOfLargest(mlp.getOutput()) == 2) {
mask.setDirection(Direction.LEFT);
} else if (getIndexOfLargest(mlp.getOutput()) == 3) {
mask.setDirection(Direction.RIGHT);
}
}
}
public double[][] breed(GAMLPAgent agent, int num) {
//Converts Double[] to double[]
//https://stackoverflow.com/questions/1109988/how-do-i-convert-double-to-double
double[] parent1 = Stream.of(mlp.getWeights()).mapToDouble(Double::doubleValue).toArray();
double[] parent2 = Stream.of(agent.getMLP().getWeights()).mapToDouble(Double::doubleValue).toArray();
double[][] childGenes = new double[num][parent1.length];
for (int r = 0; r < num; r++) {
for (int c = 0; c < childGenes[r].length; c++) {
if (new Random().nextInt(100) <= mutationRate * 100) {
childGenes[r][c] = ThreadLocalRandom.current().nextDouble(-1.0, 1.0);
//childGenes[r][c] += childGenes[r][c] * 0.1;
} else {
childGenes[r][c] = new Random().nextDouble() < 0.5 ? parent1[c] : parent2[c];
}
}
}
return childGenes;
}
public MultiLayerPerceptron getMLP() {
return mlp;
}
public void setMask(Snake mask) {
this.mask = mask;
}
public Snake getMask() {
return mask;
}
public int getIndexOfLargest(double[] array) {
if (array == null || array.length == 0) {
return -1; // null or empty
}
int largest = 0;
for (int i = 1; i < array.length; i++) {
if (array[i] > array[largest]) {
largest = i;
}
}
return largest; // position of the first largest found
}
@Override
public int compareTo(GAMLPAgent t) {
if (this.getMask().getScore() < t.getMask().getScore()) {
return -1;
} else if (t.getMask().getScore() < this.getMask().getScore()) {
return 1;
}
return 0;
}
public void debugLocation() {
Rectangle head = mask.getSnakeParts().get(0);
Rectangle food = mask.getFood();
System.out.println(head.getX() + " " + head.getY() + " " + food.getX() + " " + food.getY());
System.out.println(mask.getName() + ": " + Arrays.toString(mlp.getOutput()));
}
public void debugInput() {
String s = "";
for (int i = 0; i < mlp.getInputNeurons().size(); i++) {
s += mlp.getInputNeurons().get(i).getOutput() + " ";
}
System.out.println(s);
}
public double[] getOutput() {
return mlp.getOutput();
}
}
This is the main class of my code, GeneticSnake2.java, where the game loop is located, and where I assign genes to the child snakes (I know that it could be done more cleanly).
package main;
import agents.GAMLPAgent;
import ui.InfoBar;
import graphics.Snake;
import graphics.SnakeGrid;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Random;
import java.util.Scanner;
import javafx.animation.AnimationTimer;
import javafx.application.Application;
import static javafx.application.Application.launch;
import javafx.scene.Scene;
import javafx.scene.control.Slider;
import javafx.scene.layout.Pane;
import javafx.scene.paint.Color;
import javafx.stage.Stage;
/**
*
* @author Preston Tang
*/
public class GeneticSnake2 extends Application {
private final int width = 45;
private final int height = 40;
private final double displaySize = 120;
private final double size = 12;
private final Color pathColor = Color.rgb(120, 120, 120);
private final Color wallColor = Color.rgb(50, 50, 50);
private final int initSnakeLength = 2;
private final int populationSize = 1000;
private int generation = 0;
private int initSteps = 100;
private int stepsIncrease = 50;
private double parentPercentage = 0.2;
private final ArrayList<Color> snakeColors = new ArrayList() {
{
add(Color.GREEN);
add(Color.RED);
add(Color.YELLOW);
add(Color.BLUE);
add(Color.MAGENTA);
add(Color.PINK);
add(Color.ORANGERED);
add(Color.BLACK);
add(Color.GOLDENROD);
add(Color.WHITE);
}
};
private final ArrayList<Snake> snakes = new ArrayList<>();
private final ArrayList<GAMLPAgent> agents = new ArrayList<>();
private long initTime = System.nanoTime();
@Override
public void start(Stage stage) {
Pane root = new Pane();
Pane graphics = new Pane();
graphics.setPrefHeight(height * size);
graphics.setPrefWidth(width * size);
graphics.setTranslateX(0);
graphics.setTranslateY(displaySize);
Pane display = new Pane();
display.setStyle("-fx-background-color: BLACK");
display.setPrefHeight(displaySize);
display.setPrefWidth(width * size);
display.setTranslateX(0);
display.setTranslateY(0);
root.getChildren().add(display);
SnakeGrid sg = new SnakeGrid(pathColor, wallColor, width, height, size);
//Parsing "adjectives.txt" and "nouns.txt" to form possible names
ArrayList<String> adjectives = new ArrayList<>(Arrays.asList(readFile(new File(getClass().getClassLoader().getResource("resources/adjectives.txt").getFile())).split("\n")));
ArrayList<String> nouns = new ArrayList<>(Arrays.asList(readFile(new File(getClass().getClassLoader().getResource("resources/nouns.txt").getFile())).split("\n")));
//Initializing the population
for (int i = 0; i < populationSize; i++) {
//Get random String from lists and capitalize first letter
String adj = adjectives.get(new Random().nextInt(adjectives.size()));
adj = adj.substring(0, 1).toUpperCase() + adj.substring(1);
String noun = nouns.get(new Random().nextInt(nouns.size()));
noun = noun.substring(0, 1).toUpperCase() + noun.substring(1);
Color color = snakeColors.get(new Random().nextInt(snakeColors.size()));
//We want to see the first snake
if (i == 0) {
InfoBar bar = new InfoBar();
bar.getStatusText().setText("Status: Alive");
bar.getStatusText().setFill(Color.GREENYELLOW);
bar.getSizeText().setText("Population Size: " + populationSize);
Snake snake = new Snake(bar, adj + " " + noun, width, height, size, initSnakeLength, color, initSteps, stepsIncrease);
bar.getNameText().setText("Name: " + snake.getName());
snakes.add(snake);
agents.add(new GAMLPAgent(snake, width, height, size));
} else {
Snake snake = new Snake(adj + " " + noun, width, height, size, initSnakeLength, color, initSteps, stepsIncrease);
snakes.add(snake);
agents.add(new GAMLPAgent(snake, width, height, size));
}
}
//Focused on original snake
display.getChildren().add(snakes.get(0).getInfoBar());
graphics.getChildren().addAll(sg);
graphics.getChildren().addAll(snakes.get(0));
root.getChildren().add(graphics);
//Add the speed controller (slider)
Slider slider = new Slider(1, 10, 10);
slider.setTranslateX(205);
slider.setTranslateY(75);
slider.setDisable(true);
root.getChildren().add(slider);
Scene scene = new Scene(root, width * size, height * size + displaySize);
stage.setScene(scene);
//Fixes the setResizable bug
//https://stackoverflow.com/questions/20732100/javafx-why-does-stage-setresizablefalse-cause-additional-margins
stage.setTitle("21-GeneticSnake2 Cause the First Version Got Deleted ;-; Started on 6/8/2020");
stage.setResizable(false);
stage.sizeToScene();
stage.show();
AnimationTimer timer = new AnimationTimer() {
private long lastUpdate = 0;
@Override
public void handle(long now) {
if (now - lastUpdate >= (10 - (int) slider.getValue()) * 50_000_000) {
lastUpdate = now;
int alive = populationSize;
for (int i = 0; i < snakes.size(); i++) {
Snake snake = snakes.get(i); //Current snake
if (i == 0) {
Collections.sort(agents);
snake.getInfoBar().getScoreText().setText("Score: " + snake.getScore() + " (" + agents.get(agents.size() - 1).getMask().getScore() + ")");
}
if (!snake.isAlive()) {
alive--;
//Update graphics for main snake
if (i == 0) {
snake.getInfoBar().getStatusText().setText("Status: Dead");
snake.getInfoBar().getStatusText().setFill(Color.RED);
graphics.getChildren().remove(snake);
}
} else {
//If out of steps
if (snake.getSteps() <= 0) {
snake.setAlive(false);
}
//Bounds Detection (left right up down)
if (snake.getSnakeParts().get(0).getX() >= width * size
|| snake.getSnakeParts().get(0).getX() <= 0
|| snake.getSnakeParts().get(0).getY() >= height * size
|| snake.getSnakeParts().get(0).getY() <= 0) {
snake.setAlive(false);
}
//Self-Collision Detection
for (int o = 1; o < snakes.get(o).getSnakeParts().size(); o++) {
if (snakes.get(o).getSnakeParts().get(0).getX() == snakes.get(o).getSnakeParts().get(o).getX()
&& snakes.get(o).getSnakeParts().get(0).getY() == snakes.get(o).getSnakeParts().get(o).getY()) {
snakes.get(o).setAlive(false);
}
}
int rate = (int) slider.getValue();
int seconds = (int) ((System.nanoTime() - initTime) * rate / 1_000_000_000);
agents.get(i).compute();
snake.manageMovement();
snake.setSecondsAlive(seconds);
// agents.get(0);
// System.out.println(Arrays.toString(agents.get(0).getOutput()));
//
// System.out.println("\n\n\n\n\n\n\n");
//Expression to calculate score
double exp = (snake.getConsumed() * 5 + snake.getSecondsAlive() / 2.0D);
//double exp = snake.getSteps() + (Math.pow(2, snake.getConsumed()) + Math.pow(snake.getConsumed(), 2.1) * 500)
// - (Math.pow(snake.getConsumed(), 1.2) * Math.pow(0.25 * snake.getSteps(), 1.3));
snake.setScore(Math.round(exp * 100.0) / 100.0);
//Update graphics for main snake
if (i == 0) {
snake.getInfoBar().getTimeText().setText("Time Survived: " + snake.getSecondsAlive() + "s");
snake.getInfoBar().getFoodText().setText("Food Consumed: " + snake.getConsumed());
snake.getInfoBar().getGenerationText().setText("Generation: " + generation);
snake.getInfoBar().getStepsText().setText("Steps Remaining: " + snake.getSteps());
}
}
}
//Reset and breed
if (alive == 0) {
//Ascending order
initTime = System.nanoTime();
generation++;
graphics.getChildren().clear();
graphics.getChildren().addAll(sg);
snakes.clear();
//x% of snakes are parents
int parentNum = (int) (populationSize * parentPercentage);
//Faster odd number check
if ((parentNum & 1) != 0) {
//If odd make even
parentNum += 1;
}
for (int i = 0; i < parentNum; i += 2) {
//Get the 2 parents, sorted by score
GAMLPAgent p1 = agents.get(populationSize - (i + 2));
GAMLPAgent p2 = agents.get(populationSize - (i + 1));
//Produce the next generation
double[][] childGenes = p1.breed(p2, ((populationSize - parentNum) / parentNum) * 2);
//Debugs Genes
// System.out.println(Arrays
// .stream(childGenes)
// .map(Arrays::toString)
// .collect(Collectors.joining(System.lineSeparator())));
//Soft copy
ArrayList<GAMLPAgent> temp = new ArrayList<>(agents);
for (int o = 0; o < childGenes.length; o++) {
temp.get(o).getMLP().setWeights(childGenes[o]);
}
//Add the genes of every pair of parents to the children
for (int o = 0; o < childGenes.length; o++) {
//Useful debug message
// System.out.println("ParentNum: " + parentNum
// + " ChildPerParent: " + (populationSize - parentNum) / parentNum
// + " Index: " + (o + (i / 2 * childGenes.length))
// + " ChildGenesNum: " + childGenes.length
// + " Var O: " + o);
//Adds the genes of the temp to the agents
agents.set((o + (i / 2 * childGenes.length)), temp.get(o));
}
// System.out.println("\n\n\n\n\n\n");
}
//Debugging the snakes' genes to a file
// String str = "";
// for (int i = 0; i < agents.size(); i++) {
// str += "Index: " + i + "\t" + Arrays.toString(agents.get(i).getMLP().getWeights())+ "\n\n\n";
// }
//
// printToFile(str, "gen" + generation);
for (int i = 0; i < populationSize; i++) {
//Get random String from lists and capitalize first letter
String adj = adjectives.get(new Random().nextInt(adjectives.size()));
adj = adj.substring(0, 1).toUpperCase() + adj.substring(1);
String noun = nouns.get(new Random().nextInt(nouns.size()));
noun = noun.substring(0, 1).toUpperCase() + noun.substring(1);
Color color = snakeColors.get(new Random().nextInt(snakeColors.size()));
//We want to see the first snake
if (i == 0) {
InfoBar bar = new InfoBar();
bar.getStatusText().setText("Status: Alive");
bar.getStatusText().setFill(Color.GREENYELLOW);
bar.getSizeText().setText("Population Size: " + populationSize);
Snake snake = new Snake(bar, adj + " " + noun, width, height, size, initSnakeLength, color, initSteps, stepsIncrease);
bar.getNameText().setText("Name: " + snake.getName());
snakes.add(snake);
agents.get(i).setMask(snake);
} else {
Snake snake = new Snake(adj + " " + noun, width, height, size, initSnakeLength, color, initSteps, stepsIncrease);
snakes.add(snake);
agents.get(i).setMask(snake);
}
}
graphics.getChildren().add(snakes.get(0));
display.getChildren().clear();
//Focused on original snake at first
display.getChildren().add(snakes.get(0).getInfoBar());
}
}
}
};
//Starts the infinite loop
timer.start();
}
public String readFile(File f) {
String content = "";
try {
content = new Scanner(f).useDelimiter("\\Z").next();
} catch (FileNotFoundException ex) {
System.err.println("Error: Unable to read " + f.getName());
}
return content;
}
public void printToFile(String str, String name) {
FileWriter fileWriter;
try {
fileWriter = new FileWriter(name + ".txt");
try (BufferedWriter bufferedWriter = new BufferedWriter(fileWriter)) {
bufferedWriter.write(str);
}
} catch (IOException ex) {
ex.printStackTrace();
}
}
public static void main(String[] args) {
launch(args);
}
}
The main problem is that even after a few thousand generations, the snakes are still simply suiciding into the wall. In the videos I linked above, the snakes were avoiding walls and getting food at like generation 5. I suspect the problem is located in the main class where I'm assigning genes to the snakes that have been born.
I've actually been stuck on this for a few weeks. Before, one of the problems I suspected was a lack of inputs, since I had way less back then. But now, I think that it is no longer the case. If needed, I can try to look in the 4 diagonal directions to add another 12 inputs to the snake's MLP. I've also went to the Artificial Intelligence Discord to ask for help, but a solution hasn't really been found.
If needed, I'm willing to send my entire code so you could run the simulation for yourself.
If you've read up to here, thank you for taking time out of your day to help me! I greatly appreciate it.
Solution
I'm not surprised your snakes are dying.
Let's take a step back. What is AI exactly? Well, it's a search problem. We're searching through some parameter space to find the set of parameters that solve snake given the current state of the game. You can imagine a space of parameters that has a global minimum: the best possible snake, the snake that makes the fewest mistakes.
All learning algorithms start at some point in this parameters space and attempt to find that global maximum over time. First, let's think about MLPs. MLPs learn by trying a set of weights, computing a loss function, and then taking a step in the direction that would further minimize the loss (gradient descent). It's fairly obvious that an MLP will find a minimum, but whether it can find a good enough minimum is a question and there are a lot of training techniques that exist to improve that chance.
Genetic algorithms, on the other hand, have very poor convergence characteristics. First, let's stop calling these genetic algorithms. Let's call these smorgasbord algorithms instead. A smorgasbord algorithm takes two sets of parameters from two parents, jumbles them, and then yields a new smorgasbord. What makes you think this would be a better smorgasbord than either of the two? What are you minimizing here? How do you know it's approaching anything better? If you attach a loss function, how do you know you're in a space that can actually be minimized?
The point I'm trying to make is that genetic algorithms are unprincipled, unlike nature. Nature does not just put codons in a blender to make a new strand of DNA, but that's exactly what genetic algorithms do. There are techniques to add some time of hill climbing, but still genetic algorithms have tons of problems.
Point is, don't get swept up in the name. Genetic algorithms are simply smorgasbord algorithms. My view is that your approach doesn't work because GAs have no guarantees of converging after infinite iterations and MLPs have no guarantees of converging to a good global minimum.
What to do? Well, a better approach would be to use a learning paradigm that fits your problem. That better approach would be to use reinforcement learning. There's a very good course on Udacity from Georgia Tech on the subject.
Answered By - fny