我是一名高中生,正在为我的cs研究班做一个项目(我很幸运有机会参加这样的班)!该项目是使人工智能学习流行的游戏,蛇,与多层感知器(mlp)学习,通过遗传算法(ga)。这个项目的灵感来源于我在youtube上看到的许多视频,它们完成了我刚才描述的内容,你可以在这里和这里看到。我使用javafx和一个名为neuroph的人工智能库编写了上述项目。
这就是我的程序目前的样子:
这个名字是不相关的,因为我有一个我用来生成它们的名词和形容词的列表(我想这会让它更有趣)。括号中的分数是该代中的最佳分数,因为一次只显示一条蛇。
在繁殖的时候,我把x%的蛇设定为父母(在这个例子中是20条)。孩子的数量然后平均分配给每对蛇的父母。在这种情况下,“基因”是mlp的权重。由于我的库不支持偏差,我在输入层添加了一个偏差神经元,并将其连接到每层中的所有其他神经元,以使其权重充当偏差(如这里的线程所述)。每一条蛇的孩子都有50%的机会获得父母中任何一方的基因。基因也有5%的几率发生突变,它被设置为-1.0到1.0之间的随机数。
每条蛇的mlp有3层:18个输入神经元、14个隐藏神经元和4个输出神经元(每个方向)。我输入的是头的x,头的y,食物的x,食物的y,还有剩下的步数。它还会朝4个方向看,并检查到食物、墙和自身的距离(如果看不到,则设置为-1.0)。还有一个偏向神经元,我说的是,在加上它之后,它会把数字变成18。
我计算一条蛇的分数的方法是通过我的适应度函数,它是(苹果)× 活了5秒以上/2)
这是我的gamlagent.java,所有mlp和ga的东西都发生在这里。
package agents;
import graphics.Snake;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Stream;
import javafx.scene.shape.Rectangle;
import org.neuroph.core.Layer;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.comp.neuron.BiasNeuron;
import org.neuroph.util.NeuralNetworkType;
import org.neuroph.util.TransferFunctionType;
import util.Direction;
/**
*
* @author Preston Tang
*
* GAMLPAgent stands for Genetic Algorithm Multi-Layer Perceptron Agent
*/
public class GAMLPAgent implements Comparable<GAMLPAgent> {
public Snake mask;
private final MultiLayerPerceptron mlp;
private final int width;
private final int height;
private final double size;
private final double mutationRate = 0.05;
public GAMLPAgent(Snake mask, int width, int height, double size) {
this.mask = mask;
this.width = width;
this.height = height;
this.size = size;
//Input: x of head, y of head, x of food, y of food, steps left
//Input: 4 directions, check for distance to food, wall, and self + 1 bias neuron (18 total)
//6 hidden perceptrons (2 hidden layer(s))
//Output: A direction, 4 possibilities
mlp = new MultiLayerPerceptron(TransferFunctionType.SIGMOID, 18, 14, 4);
//Adding connections
List<Layer> layers = mlp.getLayers();
for (int r = 0; r < layers.size(); r++) {
for (int c = 0; c < layers.get(r).getNeuronsCount(); c++) {
mlp.getInputNeurons().get(mlp.getInputsCount() - 1).addInputConnection(layers.get(r).getNeuronAt(c));
}
}
// System.out.println(mlp.getInputNeurons().get(17).getInputConnections() + " " + mlp.getInputNeurons().get(17).getOutConnections());
mlp.randomizeWeights();
// System.out.println(Arrays.toString(mlp.getInputNeurons().get(17).getWeights()));
}
public void compute() {
if (mask.isAlive()) {
Rectangle head = mask.getSnakeParts().get(0);
Rectangle food = mask.getFood();
double headX = head.getX();
double headY = head.getY();
double foodX = mask.getFood().getX();
double foodY = mask.getFood().getY();
int stepsLeft = mask.getSteps();
double foodL = -1.0, wallL, selfL = -1.0;
double foodR = -1.0, wallR, selfR = -1.0;
double foodU = -1.0, wallU, selfU = -1.0;
double foodD = -1.0, wallD, selfD = -1.0;
//The 4 directions
//Left Direction
if (head.getY() == food.getY() && head.getX() > food.getX()) {
foodL = head.getX() - food.getX();
}
wallL = head.getX() - size;
for (Rectangle part : mask.getSnakeParts()) {
if (head.getY() == part.getY() && head.getX() > part.getX()) {
selfL = head.getX() - part.getX();
break;
}
}
//Right Direction
if (head.getY() == food.getY() && head.getX() < food.getX()) {
foodR = food.getX() - head.getX();
}
wallR = size * width - head.getX();
for (Rectangle part : mask.getSnakeParts()) {
if (head.getY() == part.getY() && head.getX() < part.getX()) {
selfR = part.getX() - head.getX();
break;
}
}
//Up Direction
if (head.getX() == food.getX() && head.getY() < food.getY()) {
foodU = food.getY() - head.getY();
}
wallU = size * height - head.getY();
for (Rectangle part : mask.getSnakeParts()) {
if (head.getX() == part.getX() && head.getY() < part.getY()) {
selfU = part.getY() - head.getY();
break;
}
}
//Down Direction
if (head.getX() == food.getX() && head.getY() > food.getY()) {
foodD = head.getY() - food.getY();
}
wallD = head.getY() - size;
for (Rectangle part : mask.getSnakeParts()) {
if (head.getX() == part.getX() && head.getY() > part.getY()) {
selfD = head.getY() - food.getY();
break;
}
}
mlp.setInput(
headX, headY, foodX, foodY, stepsLeft,
foodL, wallL, selfL,
foodR, wallR, selfR,
foodU, wallU, selfU,
foodD, wallD, selfD, 1);
mlp.calculate();
if (getIndexOfLargest(mlp.getOutput()) == 0) {
mask.setDirection(Direction.UP);
} else if (getIndexOfLargest(mlp.getOutput()) == 1) {
mask.setDirection(Direction.DOWN);
} else if (getIndexOfLargest(mlp.getOutput()) == 2) {
mask.setDirection(Direction.LEFT);
} else if (getIndexOfLargest(mlp.getOutput()) == 3) {
mask.setDirection(Direction.RIGHT);
}
}
}
public double[][] breed(GAMLPAgent agent, int num) {
//Converts Double[] to double[]
//https://stackoverflow.com/questions/1109988/how-do-i-convert-double-to-double
double[] parent1 = Stream.of(mlp.getWeights()).mapToDouble(Double::doubleValue).toArray();
double[] parent2 = Stream.of(agent.getMLP().getWeights()).mapToDouble(Double::doubleValue).toArray();
double[][] childGenes = new double[num][parent1.length];
for (int r = 0; r < num; r++) {
for (int c = 0; c < childGenes[r].length; c++) {
if (new Random().nextInt(100) <= mutationRate * 100) {
childGenes[r][c] = ThreadLocalRandom.current().nextDouble(-1.0, 1.0);
//childGenes[r][c] += childGenes[r][c] * 0.1;
} else {
childGenes[r][c] = new Random().nextDouble() < 0.5 ? parent1[c] : parent2[c];
}
}
}
return childGenes;
}
public MultiLayerPerceptron getMLP() {
return mlp;
}
public void setMask(Snake mask) {
this.mask = mask;
}
public Snake getMask() {
return mask;
}
public int getIndexOfLargest(double[] array) {
if (array == null || array.length == 0) {
return -1; // null or empty
}
int largest = 0;
for (int i = 1; i < array.length; i++) {
if (array[i] > array[largest]) {
largest = i;
}
}
return largest; // position of the first largest found
}
@Override
public int compareTo(GAMLPAgent t) {
if (this.getMask().getScore() < t.getMask().getScore()) {
return -1;
} else if (t.getMask().getScore() < this.getMask().getScore()) {
return 1;
}
return 0;
}
public void debugLocation() {
Rectangle head = mask.getSnakeParts().get(0);
Rectangle food = mask.getFood();
System.out.println(head.getX() + " " + head.getY() + " " + food.getX() + " " + food.getY());
System.out.println(mask.getName() + ": " + Arrays.toString(mlp.getOutput()));
}
public void debugInput() {
String s = "";
for (int i = 0; i < mlp.getInputNeurons().size(); i++) {
s += mlp.getInputNeurons().get(i).getOutput() + " ";
}
System.out.println(s);
}
public double[] getOutput() {
return mlp.getOutput();
}
}
这是我代码的主要类geneticsnake2.java,游戏循环位于这里,我将基因分配给子蛇(我知道可以更干净地完成)。
package main;
import agents.GAMLPAgent;
import ui.InfoBar;
import graphics.Snake;
import graphics.SnakeGrid;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Random;
import java.util.Scanner;
import javafx.animation.AnimationTimer;
import javafx.application.Application;
import static javafx.application.Application.launch;
import javafx.scene.Scene;
import javafx.scene.control.Slider;
import javafx.scene.layout.Pane;
import javafx.scene.paint.Color;
import javafx.stage.Stage;
/**
*
* @author Preston Tang
*/
public class GeneticSnake2 extends Application {
private final int width = 45;
private final int height = 40;
private final double displaySize = 120;
private final double size = 12;
private final Color pathColor = Color.rgb(120, 120, 120);
private final Color wallColor = Color.rgb(50, 50, 50);
private final int initSnakeLength = 2;
private final int populationSize = 1000;
private int generation = 0;
private int initSteps = 100;
private int stepsIncrease = 50;
private double parentPercentage = 0.2;
private final ArrayList<Color> snakeColors = new ArrayList() {
{
add(Color.GREEN);
add(Color.RED);
add(Color.YELLOW);
add(Color.BLUE);
add(Color.MAGENTA);
add(Color.PINK);
add(Color.ORANGERED);
add(Color.BLACK);
add(Color.GOLDENROD);
add(Color.WHITE);
}
};
private final ArrayList<Snake> snakes = new ArrayList<>();
private final ArrayList<GAMLPAgent> agents = new ArrayList<>();
private long initTime = System.nanoTime();
@Override
public void start(Stage stage) {
Pane root = new Pane();
Pane graphics = new Pane();
graphics.setPrefHeight(height * size);
graphics.setPrefWidth(width * size);
graphics.setTranslateX(0);
graphics.setTranslateY(displaySize);
Pane display = new Pane();
display.setStyle("-fx-background-color: BLACK");
display.setPrefHeight(displaySize);
display.setPrefWidth(width * size);
display.setTranslateX(0);
display.setTranslateY(0);
root.getChildren().add(display);
SnakeGrid sg = new SnakeGrid(pathColor, wallColor, width, height, size);
//Parsing "adjectives.txt" and "nouns.txt" to form possible names
ArrayList<String> adjectives = new ArrayList<>(Arrays.asList(readFile(new File(getClass().getClassLoader().getResource("resources/adjectives.txt").getFile())).split("\n")));
ArrayList<String> nouns = new ArrayList<>(Arrays.asList(readFile(new File(getClass().getClassLoader().getResource("resources/nouns.txt").getFile())).split("\n")));
//Initializing the population
for (int i = 0; i < populationSize; i++) {
//Get random String from lists and capitalize first letter
String adj = adjectives.get(new Random().nextInt(adjectives.size()));
adj = adj.substring(0, 1).toUpperCase() + adj.substring(1);
String noun = nouns.get(new Random().nextInt(nouns.size()));
noun = noun.substring(0, 1).toUpperCase() + noun.substring(1);
Color color = snakeColors.get(new Random().nextInt(snakeColors.size()));
//We want to see the first snake
if (i == 0) {
InfoBar bar = new InfoBar();
bar.getStatusText().setText("Status: Alive");
bar.getStatusText().setFill(Color.GREENYELLOW);
bar.getSizeText().setText("Population Size: " + populationSize);
Snake snake = new Snake(bar, adj + " " + noun, width, height, size, initSnakeLength, color, initSteps, stepsIncrease);
bar.getNameText().setText("Name: " + snake.getName());
snakes.add(snake);
agents.add(new GAMLPAgent(snake, width, height, size));
} else {
Snake snake = new Snake(adj + " " + noun, width, height, size, initSnakeLength, color, initSteps, stepsIncrease);
snakes.add(snake);
agents.add(new GAMLPAgent(snake, width, height, size));
}
}
//Focused on original snake
display.getChildren().add(snakes.get(0).getInfoBar());
graphics.getChildren().addAll(sg);
graphics.getChildren().addAll(snakes.get(0));
root.getChildren().add(graphics);
//Add the speed controller (slider)
Slider slider = new Slider(1, 10, 10);
slider.setTranslateX(205);
slider.setTranslateY(75);
slider.setDisable(true);
root.getChildren().add(slider);
Scene scene = new Scene(root, width * size, height * size + displaySize);
stage.setScene(scene);
//Fixes the setResizable bug
//https://stackoverflow.com/questions/20732100/javafx-why-does-stage-setresizablefalse-cause-additional-margins
stage.setTitle("21-GeneticSnake2 Cause the First Version Got Deleted ;-; Started on 6/8/2020");
stage.setResizable(false);
stage.sizeToScene();
stage.show();
AnimationTimer timer = new AnimationTimer() {
private long lastUpdate = 0;
@Override
public void handle(long now) {
if (now - lastUpdate >= (10 - (int) slider.getValue()) * 50_000_000) {
lastUpdate = now;
int alive = populationSize;
for (int i = 0; i < snakes.size(); i++) {
Snake snake = snakes.get(i); //Current snake
if (i == 0) {
Collections.sort(agents);
snake.getInfoBar().getScoreText().setText("Score: " + snake.getScore() + " (" + agents.get(agents.size() - 1).getMask().getScore() + ")");
}
if (!snake.isAlive()) {
alive--;
//Update graphics for main snake
if (i == 0) {
snake.getInfoBar().getStatusText().setText("Status: Dead");
snake.getInfoBar().getStatusText().setFill(Color.RED);
graphics.getChildren().remove(snake);
}
} else {
//If out of steps
if (snake.getSteps() <= 0) {
snake.setAlive(false);
}
//Bounds Detection (left right up down)
if (snake.getSnakeParts().get(0).getX() >= width * size
|| snake.getSnakeParts().get(0).getX() <= 0
|| snake.getSnakeParts().get(0).getY() >= height * size
|| snake.getSnakeParts().get(0).getY() <= 0) {
snake.setAlive(false);
}
//Self-Collision Detection
for (int o = 1; o < snakes.get(o).getSnakeParts().size(); o++) {
if (snakes.get(o).getSnakeParts().get(0).getX() == snakes.get(o).getSnakeParts().get(o).getX()
&& snakes.get(o).getSnakeParts().get(0).getY() == snakes.get(o).getSnakeParts().get(o).getY()) {
snakes.get(o).setAlive(false);
}
}
int rate = (int) slider.getValue();
int seconds = (int) ((System.nanoTime() - initTime) * rate / 1_000_000_000);
agents.get(i).compute();
snake.manageMovement();
snake.setSecondsAlive(seconds);
// agents.get(0);
// System.out.println(Arrays.toString(agents.get(0).getOutput()));
//
// System.out.println("\n\n\n\n\n\n\n");
//Expression to calculate score
double exp = (snake.getConsumed() * 5 + snake.getSecondsAlive() / 2.0D);
//double exp = snake.getSteps() + (Math.pow(2, snake.getConsumed()) + Math.pow(snake.getConsumed(), 2.1) * 500)
// - (Math.pow(snake.getConsumed(), 1.2) * Math.pow(0.25 * snake.getSteps(), 1.3));
snake.setScore(Math.round(exp * 100.0) / 100.0);
//Update graphics for main snake
if (i == 0) {
snake.getInfoBar().getTimeText().setText("Time Survived: " + snake.getSecondsAlive() + "s");
snake.getInfoBar().getFoodText().setText("Food Consumed: " + snake.getConsumed());
snake.getInfoBar().getGenerationText().setText("Generation: " + generation);
snake.getInfoBar().getStepsText().setText("Steps Remaining: " + snake.getSteps());
}
}
}
//Reset and breed
if (alive == 0) {
//Ascending order
initTime = System.nanoTime();
generation++;
graphics.getChildren().clear();
graphics.getChildren().addAll(sg);
snakes.clear();
//x% of snakes are parents
int parentNum = (int) (populationSize * parentPercentage);
//Faster odd number check
if ((parentNum & 1) != 0) {
//If odd make even
parentNum += 1;
}
for (int i = 0; i < parentNum; i += 2) {
//Get the 2 parents, sorted by score
GAMLPAgent p1 = agents.get(populationSize - (i + 2));
GAMLPAgent p2 = agents.get(populationSize - (i + 1));
//Produce the next generation
double[][] childGenes = p1.breed(p2, ((populationSize - parentNum) / parentNum) * 2);
//Debugs Genes
// System.out.println(Arrays
// .stream(childGenes)
// .map(Arrays::toString)
// .collect(Collectors.joining(System.lineSeparator())));
//Soft copy
ArrayList<GAMLPAgent> temp = new ArrayList<>(agents);
for (int o = 0; o < childGenes.length; o++) {
temp.get(o).getMLP().setWeights(childGenes[o]);
}
//Add the genes of every pair of parents to the children
for (int o = 0; o < childGenes.length; o++) {
//Useful debug message
// System.out.println("ParentNum: " + parentNum
// + " ChildPerParent: " + (populationSize - parentNum) / parentNum
// + " Index: " + (o + (i / 2 * childGenes.length))
// + " ChildGenesNum: " + childGenes.length
// + " Var O: " + o);
//Adds the genes of the temp to the agents
agents.set((o + (i / 2 * childGenes.length)), temp.get(o));
}
// System.out.println("\n\n\n\n\n\n");
}
//Debugging the snakes' genes to a file
// String str = "";
// for (int i = 0; i < agents.size(); i++) {
// str += "Index: " + i + "\t" + Arrays.toString(agents.get(i).getMLP().getWeights())+ "\n\n\n";
// }
//
// printToFile(str, "gen" + generation);
for (int i = 0; i < populationSize; i++) {
//Get random String from lists and capitalize first letter
String adj = adjectives.get(new Random().nextInt(adjectives.size()));
adj = adj.substring(0, 1).toUpperCase() + adj.substring(1);
String noun = nouns.get(new Random().nextInt(nouns.size()));
noun = noun.substring(0, 1).toUpperCase() + noun.substring(1);
Color color = snakeColors.get(new Random().nextInt(snakeColors.size()));
//We want to see the first snake
if (i == 0) {
InfoBar bar = new InfoBar();
bar.getStatusText().setText("Status: Alive");
bar.getStatusText().setFill(Color.GREENYELLOW);
bar.getSizeText().setText("Population Size: " + populationSize);
Snake snake = new Snake(bar, adj + " " + noun, width, height, size, initSnakeLength, color, initSteps, stepsIncrease);
bar.getNameText().setText("Name: " + snake.getName());
snakes.add(snake);
agents.get(i).setMask(snake);
} else {
Snake snake = new Snake(adj + " " + noun, width, height, size, initSnakeLength, color, initSteps, stepsIncrease);
snakes.add(snake);
agents.get(i).setMask(snake);
}
}
graphics.getChildren().add(snakes.get(0));
display.getChildren().clear();
//Focused on original snake at first
display.getChildren().add(snakes.get(0).getInfoBar());
}
}
}
};
//Starts the infinite loop
timer.start();
}
public String readFile(File f) {
String content = "";
try {
content = new Scanner(f).useDelimiter("\\Z").next();
} catch (FileNotFoundException ex) {
System.err.println("Error: Unable to read " + f.getName());
}
return content;
}
public void printToFile(String str, String name) {
FileWriter fileWriter;
try {
fileWriter = new FileWriter(name + ".txt");
try (BufferedWriter bufferedWriter = new BufferedWriter(fileWriter)) {
bufferedWriter.write(str);
}
} catch (IOException ex) {
ex.printStackTrace();
}
}
public static void main(String[] args) {
launch(args);
}
}
主要的问题是,即使经过几千代,蛇仍然只是跳进墙自杀。在我上面链接的视频中,这些蛇在第5代的时候避开墙壁,获取食物。我怀疑问题出在我给出生的蛇分配基因的主类中。
事实上我已经被困在这几个星期了。以前,我怀疑的问题之一是缺乏投入,因为那时我的投入少得多。但现在,我认为情况已经不是这样了。如果需要的话,我可以试着在4个对角线方向上看,为蛇的mlp添加另外12个输入。我也曾到人工智能不和组织寻求帮助,但还没有真正找到解决办法。
如果需要,我愿意发送我的全部代码,这样你就可以自己运行模拟了。
如果你读到这里,谢谢你抽出时间来帮助我!我非常感激。
1条答案
按热度按时间jckbn6z71#
我不奇怪你的蛇会死。
让我们后退一步。人工智能到底是什么?嗯,这是个搜索问题。我们在一些参数空间中搜索,以找到在给定游戏当前状态下解snake的参数集。你可以想象一个具有全局最小值的参数空间:最好的蛇,犯最少错误的蛇。
所有的学习算法都从这个参数空间的某个点开始,并试图找到随时间变化的全局最大值。首先,让我们考虑一下MLP。MLP通过尝试一组权值,计算损失函数,然后朝着进一步最小化损失的方向迈出一步(梯度下降)来学习。很明显,一个mlp会找到一个最小值,但是它是否能找到一个足够好的最小值是一个问题,有很多训练技巧可以提高这个机会。
另一方面,遗传算法的收敛性很差。首先,我们不要再叫这些遗传算法了。让我们把这些叫做自助餐算法。一个自助餐算法从两个父对象获取两组参数,将它们混合,然后产生一个新的自助餐。是什么让你觉得这个自助餐比这两个都好?你在这儿干什么?你怎么知道它正在接近更好的结果?如果你附加一个损失函数,你怎么知道你所处的空间实际上可以最小化?
我想说的是,遗传算法是无原则的,与自然不同。大自然不只是把密码子放进搅拌机里制造新的dna链,而这正是遗传算法所做的。有一些技术可以增加爬山的时间,但是遗传算法仍然有很多问题。
关键是,不要被冠以这样的名字。遗传算法就是简单的自助餐算法。我的观点是,你的方法行不通,因为gas不能保证在无限迭代后收敛,mlp也不能保证收敛到一个好的全局最小值。
怎么办?好吧,更好的方法是使用适合你问题的学习模式。更好的方法是使用强化学习。佐治亚理工学院有一门很好的关于这个问题的课程。