qlearning网络每次都会选择相同的操作,尽管有很重的负面回报

twh00eeo  于 2021-07-03  发布在  Java
关注(0)|答案(0)|浏览(400)

所以我插上了插头 QLearningDiscreteDense 变成我做的一个点和盒子的游戏。我创建了一个自定义 MDP 它的环境。问题是它每次都选择动作0,第一次起作用,但之后就不再是可用动作,所以这是非法移动。我给非法的动作一个 Integer.MIN_VALUE ,但它不会影响任何东西。这是你的答案 MDP 班级:

public class testEnv implements MDP<testState, Integer, DiscreteSpace> {
        final private int maxStep;
        DiscreteSpace actionSpace = new DiscreteSpace(Graph.getEdgeList().size());
        // takes amount of possible edges      ^
        ObservationSpace<testState> observationSpace = new ArrayObservationSpace(new int[] {1});
        private testState state = new testState(Graph.getMatrix(),0,0,0);
        private NeuralNetFetchable<IDQN> fetchable;
        boolean illegal=false;
        public testEnv(int maxStep){
            this.maxStep=maxStep;
        }
        @Override
        public ObservationSpace<testState> getObservationSpace() {
            return observationSpace;
        }

        @Override
        public DiscreteSpace getActionSpace() {
            return actionSpace;
        }

        @Override
        public testState reset() {
           // System.out.println("RESET");
            try {
                GameBoard r = new GameBoard(3,3);
            } catch (IOException e) {
                e.printStackTrace();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            return new testState(Graph.getMatrix(),0,0,0);
        }

        @Override
        public void close() { }

        @Override
        public StepReply<testState> step(Integer action) {
         //   System.out.println(Arrays.deepToString(Graph.getMatrix()));
            int reward=0;
            try {
                placeEdge(action);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            // change the getPlayer1 to whichever player the neural is
         //   System.out.println("step: "+state.step);
            if(!illegal) {
                System.out.println("Not Illegal");
                if (gameThread.checkFinished()) {
                    if (Graph.getPlayer1Score() > Graph.getPlayer2Score()) {
                        reward = 5;
                    } else {
                        reward = -5;
                    }
                }
                if (Graph.numOfMoves < 1) {
                    if (!isDone()) {
                        if (Graph.player1Turn) {
                            Graph.player1Turn = false;
                        } else {
                            Graph.player1Turn = true;
                        }
                        Graph.setNumOfMoves(1);
                        while (Graph.numOfMoves > 0) {
                     //       System.out.println(Arrays.deepToString(Graph.getMatrix()));
                            if (!isDone()) {
                                Graph.getRandomBot().placeRandomEdge();
                            } else {
                                Graph.numOfMoves = 0;
                            }
                        }
                        if (!isDone()) {
                            if (Graph.player1Turn) {
                                Graph.player1Turn = false;
                            } else {
                                Graph.player1Turn = true;
                            }
                            Graph.setNumOfMoves(1);
                        }
                    }
                }
            }else{
                reward=Integer.MIN_VALUE;
                illegal=false;
            }
            testState t = new testState(Graph.getMatrix(), Graph.getPlayer1Score(), Graph.getPlayer2Score(), state.step + 1);
            state=t;
            // perform action in game, get reward
            return new StepReply<>(t, reward, isDone(), null);
        }

        @Override
        public boolean isDone() {
            return gameThread.checkFinished();
        }

        @Override
        public MDP<testState, Integer, DiscreteSpace> newInstance() {
            testEnv test = new testEnv(maxStep);
            test.setFetchable(fetchable);
            return test;
        }
        public void setFetchable(NeuralNetFetchable<IDQN> fetchable) {
            this.fetchable = fetchable;
        }

        public void placeEdge(int index) throws InterruptedException {
            ELine line = Graph.getEdgeList().get(index).getEline();
            System.out.println("NChosen: "+line.vertices.get(0).id+"--"+line.vertices.get(1).id);
            if(!line.isActivated()) {
                line.setActivated(true);
                // make it black
                line.setBackground(Color.BLACK);
                line.repaint();
                // set the adjacency matrix to 2, 2==is a line, 1==is a possible line
                Graph.matrix[line.vertices.get(0).getID()][line.vertices.get(1).getID()] = 2;
                Graph.matrix[line.vertices.get(1).getID()][line.vertices.get(0).getID()] = 2;
                // gets an arrayList of each box the ELine creates. The box is an arrayList of 4 vertices.
                ArrayList<ArrayList<Vertex>> boxes = gameThread.checkBox(line);
                if (boxes != null) {
                    for (ArrayList<Vertex> box : boxes) {
                        // looks through the counterBoxes arrayList and sets the matching one visible.
                        gameThread.checkMatching(box);
                        // updates the score board
                        if (Graph.getPlayer1Turn()) {
                            Graph.setPlayer1Score(Graph.getPlayer1Score() + 1);
                            Graph.getScore1().setScore();
                        } else {
                            Graph.setPlayer2Score(Graph.getPlayer2Score() + 1);
                            Graph.getScore2().setScore();
                        }
                    }
                    // if every counterBox has been activated, the game is over
                } else {
                    Graph.setNumOfMoves(0);
                    // switches turn. If randomBot is active switches to their turn.
                }
            }else{
                System.out.println("ILLEGAL");
                illegal=true;
            }
        }
    }

以下是我为美国制作的课程:

public class testState implements Encodable {
        int[][] matrix;
        int playerScore;
        int otherPlayerScore;
        int step;
        public testState(int[][] m,int p,int op,int step){
            matrix=m;
            playerScore=p;
            otherPlayerScore=op;
            this.step=step;
        }
        @Override
        public double[] toArray() {
            double[] array = new double[matrix.length*matrix[0].length];
            int i=0;
            for(int a=0;a< matrix.length;a++){
                for(int b=0;b<matrix[0].length;b++){
                    array[i]= matrix[a][b];
                    i++;
                }
            }
            return array;
        }

        @Override
        public boolean isSkipped() {
            return false;
        }

        @Override
        public INDArray getData() {
            return null;
        }

        @Override
        public Encodable dup() {
            return null;
        }
    }

下面是我运行它的类:

public class testNeural {
        public static void main(String args[]) throws IOException, InterruptedException {
            GameBoard r = new GameBoard(3,3);
            DQNPolicy<testState> t = dots();
        }
        static QLearning.QLConfiguration DOTS_QL = QLearning.QLConfiguration.builder()
                .seed(123)                //Random seed (for reproducability)
                .maxEpochStep(10)        // Max step By epoch
                .maxStep(10000)           // Max step
                .expRepMaxSize(150000)    // Max size of experience replay
                .batchSize(128)            // size of batches
                .targetDqnUpdateFreq(500) // target update (hard)
                .updateStart(10)          // num step noop warmup
                .rewardFactor(0.01)       // reward scaling
                .gamma(0.99)              // gamma
                .errorClamp(1.0)          // /td-error clipping
                .minEpsilon(0.1f)         // min epsilon
                .epsilonNbStep(1000)      // num step for eps greedy anneal
                .doubleDQN(false)          // double DQN
                .build();
        static DQNFactoryStdDense.Configuration DOTS_NET =
                DQNFactoryStdDense.Configuration.builder()
                        .l2(0)
                        .updater(new RmsProp(0.000025))
                        .numHiddenNodes(300)
                        .numLayer(10)
                        .build();
        private static DQNPolicy<testState> dots() throws IOException {
            DataManager dataManager = new DataManager(true);

            // The neural network used by the agent. Note that there is no need to specify the number of inputs/outputs.
            // These will be read from the gym environment at the start of training.

            testEnv env = new testEnv(10000);
            QLearningDiscreteDense<testState> dql = new QLearningDiscreteDense<>(env, DOTS_NET, DOTS_QL,dataManager);
            dql.train();
            return dql.getPolicy();
        }
    }

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题