pytorch 如何在OOP中使用torch.save和torch.load进行RL?

ie3xauqp  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(118)

我是OOP和RL的初学者,我需要一些关于我的connect 4游戏的建议:slight_smile:
首先,如果你在我的代码中看到任何令人震惊的东西,请不要犹豫,啊哈,但最重要的是,我想知道在哪里保存我的训练数据记录,在哪里加载它,以便它是有效的?我有点麻烦。
在此先谢谢您!

import numpy as np
from colorama import Fore, Style
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
import torch.nn.functional as F
import pickle
import os

repo = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(repo, "modeleIA.pth")

import matplotlib.pyplot as plt

if torch.cuda.is_available():
    device = torch.device('cuda') 
else:
    device = torch.device('cpu')

class Plateau:
    """ Classe représentant le plateau de jeu """

    def __init__(self, rows, columns):
        """ Constructeur de la classe, initialise le plateau et ses dimensions """
        self.rows = rows
        self.columns = columns
        self.plato = np.zeros((rows, columns))

    def colonne_check(self, col):
        """ Vérifie si la colonne est pleine """
        for i in range(self.rows):
            if self.plato[i][col] == 0:
                return True
        return False

    def placement_jeton(self, col, joueur):
        """ Place le jeton du joueur dans la colonne qu'il sélectionne """
        for i in np.r_[:self.rows][::-1]:
            if self.plato[i][col] == 0:
                self.plato[i][col] = joueur
                return True
        return False

    def affichage(self, joueur):
        """ Affiche le plateau de jeu à l'instant t """
        couleur = Fore.RED if joueur == 1 else Fore.YELLOW
        print(couleur + str(self.plato) + Style.RESET_ALL)

    def check_victoire(self):
        """ Vérifie si un joueur a gagné """
        rows, columns = self.rows, self.columns

        # vérification en ligne
        for r in np.r_[:rows]:
            for d in np.r_[:columns-3]:
                f = d + 4
                s = np.prod(self.plato[r, d:f])
                if s == 1 or s == 16:
                    return True

        # vérification en colonne
        for c in np.r_[:columns]:
            for d in np.r_[:rows-3]:
                f = d + 4
                s = np.prod(self.plato[d:f, c])
                if s == 1 or s == 16:
                    return True

        # vérification en diagonale (bas gauche vers haut droite)
        for r in np.r_[:rows-3]:
            for c in np.r_[:columns-3]:
                f = c + 4
                s = np.prod([self.plato[r+i, c+i] for i in range(4)])
                if s == 1 or s == 16:
                    return True

        # vérification en diagonale (haut gauche vers bas droite)
        for r in np.r_[3:rows]:
            for c in np.r_[:columns-3]:
                f = c + 4
                s = np.prod([self.plato[r-i, c+i] for i in range(4)])
                if s == 1 or s == 16:
                    return True
        return False

    def get_etat(self):
        """ Obtient l'état actuel du plateau sous forme de tableau 1D """
        return self.plato.flatten()

    def get_actions(self):
        """ Obtient les actions possibles à partir de l'état actuel du plateau """
        return np.where(self.plato[0] == 0)[0]

class Joueur:
    """ Classe représentant un joueur humain """

    def __init__(self, numero, max_choix):
        """ Initialise le joueur et son numéro et le nombre de choix possibles """
        self.numero = numero
        self.max_choix = max_choix

    def jouer(self, state, actions):
        """ Demande au joueur de choisir une colonne """
        while True:
            try:
                choix = int(input(f'Joueur {self.numero}, à vous de jouer (entre 1 et {self.max_choix}): ')) - 1
                if choix in actions:
                    return choix
                else:
                    print("Choix invalide. Essayez à nouveau.")
            except ValueError:
                print("Ce n'est pas un nombre. Essayez encore.")

class DQNAgent:
    # Dans le RL, l'agent DQN utilise une mémoire appelée "replay memory"
    # pour stocker les XP passées (état/action/rec/prochain état etc)
    # afin de les réutiliser lors de l'apprentissage

    # Initialisation de l'agent DQN
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=10000) #pareil que sur le github
        self.gamma = 0.95  # facteur d'actualisation, équilibre recomp im et future
        self.epsilon = 1.0  # taux d'exploration initial
        self.epsilon_min = 0.01  # taux d'exploration minimum
        self.epsilon_decay = 0.995  # taux de décroissance de l'exploration
        self.lr = 0.001  # taux d'apprentissage
        self.model = self._build_model()  # Construire le modèle de réseau neuronal
        self.batch_size = 64
        self.update_every = 5
       
        # réseau principal (d'évaluation) pour choisir les actions et réseau cible pour générer les Q-values cibles
        self.dqn_network = self._build_model().to(device)
        self.target_network = self._build_model().to(device)
        #pour l'instant même dim d'entrée, cachée (64) et de sortie

        self.optimizer = optim.Adam(self.dqn_network.parameters(), lr=self.lr)  # Optimiseur courrament utilisé en RL
        self.t_step = 0  # Compteur pour la mise à jour du réseau cible

    def charger_modele(self, chemin):
            self.modele.load_state_dict(torch.load(chemin))
            self.modele.eval()

    def _build_model(self):
            '''modele de réseau de neurones pour l'apprentissage'''
            model = nn.Sequential(
                nn.Linear(self.state_size, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, self.action_size)
            )
            return model

    # L'agent choisit l'action selon l'état et la politique epsilon-greedy
    def act(self, state, eps=0.1):
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)  # On convertit l'état en tenseur
        self.dqn_network.eval()  # mode évaluation càd pas de mise à jour des poids

        with torch.no_grad():
            action_values = self.dqn_network(state)  # Calculer la Q-valeur pour chaque action
        self.dqn_network.train()  # Repasser le réseau en mode entrainement

        # politique epsilon-greedy (exploration/exploitation), à améliorer car
        #pour l'instant génère un nombre aléatoire
        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())  # Choix de l'action avec la plus grande Q-valeur (greedy)
        else:
            return random.choice(np.arange(self.action_size))  # Choix d'une action aléatoire

    # Stock l'expérience dans la mémoire de remise en état
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    # Prend une action et apprendre à partir de l'expérience
    def step(self, state, action, reward, next_state, done):
        self.remember(state, action, reward, next_state, done)  # Stocker l'expérience
        self.epsilon *= self.epsilon_decay #on multiplie par le facteur de décroissance
        self.learn()  # Apprendre de l'expérience

    # Apprentissage à partir de l'expérience (implémentation de l'équation de Belman)
    def learn(self):
        # mémoire de remise en état soit assez grande
        if len(self.memory) < self.batch_size:
            return

        # on échantillonne un batch d'expériences de taille aléatoire
        experiences = random.sample(self.memory, self.batch_size)
        # dézip les élements de l'échantillon 
        states, actions, rewards, next_states, dones = zip(*experiences)

        # Convertion des expériences numpy --> tenseur pytorch
        states = torch.from_numpy(np.vstack(states)).float().to(device)
        actions = torch.from_numpy(np.vstack(actions)).long().to(device)
        rewards = torch.from_numpy(np.vstack(rewards)).float().to(device)
        next_states = torch.from_numpy(np.vstack(next_states)).float().to(device)
        dones = torch.from_numpy(np.vstack(dones).astype(np.int64)).float().to(device)

        # Calcul les Q-valeurs cibles et attendues

        #on obtient les qvaleur prédites à partir du modèle cible
        Q_cible_next = self.target_network(next_states).detach().max(1)[0].unsqueeze(1)  # Q-valeurs cibles pour les prochains états
        
        #on calcule les Q cibles pour les états actuels
        Q_cible = rewards + (self.gamma * Q_cible_next * (1 - dones))  # Q-valeurs cibles pour les états actuels
        
        #on calcul les q attendus à partir du modèle
        Q_expected = self.dqn_network(states).gather(1, actions)  # Q-valeurs attendues pour les états actuels
        
        # Calcul de la perte et rétropropagation de l'erreur
        loss = F.mse_loss(Q_expected, Q_cible)  # Calcul la perte w/ MSE
        # On minimise la fonction de perte
        self.optimizer.zero_grad()  # Réinitialisation gradients
        loss.backward()  # Rétropropagation de l'erreur
        self.optimizer.step()  # Mise à jour des poids du réseau

        # Mise à jour du réseau cible on le copie du réseau DQN
        self.t_step = (self.t_step + 1) % self.update_every
        if self.t_step == 0:
            self.target_network.load_state_dict(self.dqn_network.state_dict())

    def sauvegarder_modele(self, path):
        # Sauvegarde du modèle PyTorch avec pickle
        state = {
                'dqn_network_state_dict': self.dqn_network.state_dict(),
                'target_network_state_dict': self.target_network.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'memory': self.memory,
                'epsilon': self.epsilon,
            }

        #with open(chemin, "wb") as fichier:
           #pickle.dump(self.model, fichier)

    def charger_modele(self, path):
        # Chergament du modèle Pytorch avec pickle

        #with open(chemin, "rb") as fichier:
        #    self.model = pickle.load(fichier)
        if os.path.exists(path):
            state = torch.load(path)
            self.dqn_network.load_state_dict(state['dqn_network_state_dict'])
            self.target_network.load_state_dict(state['target_network_state_dict'])
            self.optimizer.load_state_dict(state['optimizer_state_dict'])
            self.memory = state['memory']
            self.epsilon = state['epsilon']
            print(f"Modèle chargé depuis {path}")
        else:
            print(f"Aucun modèle trouvé à {path}")

class IA:
    def __init__(self, numero, max_choix, state_size, agent=None):
        self.numero = numero
        self.max_choix = max_choix
        if agent is None:
            self.agent= DQNAgent(state_size, max_choix)
        else:
            self.agent= agent

    def jouer(self, state, actions):
        """ Choisit une colonne en utilisant l'agent IA """
        proba_victoires = self.calculer_proba_victoires(state, actions)
        # Trie en fonction des probabilités de victoire
        tri_indice_action = np.argsort(proba_victoires)
        # actions de la plus probable à la moins probable
        for indice in reversed(tri_indice_action):
            action = actions[indice]
            # Si l'action est possible, on la retourne
            if action in actions:
                return action
        #(ne devrait pas arriver), retourne une action aléatoire
        return np.random.choice(actions)

    def calculer_proba_victoires(self, state, actions):
        proba_victoires = np.zeros(len(actions))
        for i, action in enumerate(actions):
            next_state = state.copy()
            next_state[action] = self.numero
            proba_victoires[i] = self.agent.act(next_state)
        return proba_victoires

    def apprendre(self, state, action, reward, next_state, done):
        self.agent.step(state, action, reward, next_state, done)

class Jeu:
    """ Classe représentant le jeu en lui-même """

    def __init__(self, rows, columns, joueurs):
        """ Initialise le jeu et les joueurs """
        self.plato = Plateau(rows, columns)
        self.joueurs = joueurs

    def play(self):
            """ Lance le jeu et vérifie si un joueur a gagné ou si la partie est nulle """
            state = self.plato.get_etat()
            while True:
                for joueur in self.joueurs:
                    print(f"Joueur {joueur.numero}")
                    actions = self.plato.get_actions()
                    choix = joueur.jouer(state, actions)
                    self.plato.placement_jeton(choix, joueur.numero)
                    self.plato.affichage(joueur.numero)

                    if self.plato.check_victoire():
                        print(f"Joueur {joueur.numero} a gagné!")
                        self.plato.affichage(joueur.numero)  # Afficher le plateau final
                        return joueur.numero
                    elif np.all(self.plato.plato != 0):
                        print("Match nul!")
                        return None

                    state = self.plato.get_etat()

class Entrainement:
    
    def __init__(self, lignes, colonnes, episodes):
        self.lignes = lignes
        self.colonnes = colonnes
        self.episodes = episodes
        self.victoires = {1: 0, 2: 0, 'Nulles': 0}

        ### Liste vides pour enregistrer les données

        self.episodes_list = []
        self.victoires_joueur1 = []
        self.victoires_joueur2 = []
        self.parties_nulles = []

    def commencer(self):
        print("Choisissez le mode :")
        print("1. Jouer contre l'IA 1")
        print("2. Jouer contre l'IA 2")
        print("3. IA 1 vs IA 2")
        print("4. Jouer humain contre humain")
        print("5. Entraîner les deux IA entre elles")
        choix = int(input("Votre choix : "))

        if choix == 1:
            joueur_humain = Joueur(1, colonnes)
            agent_IA1 = DQNAgent(colonnes * lignes, colonnes)
            joueur_IA1 = IA(2, colonnes, colonnes * lignes, agent_IA1)
            joueurs = [joueur_humain, joueur_IA1]
        elif choix == 2:
            joueur_humain = Joueur(1, colonnes)
            agent_IA2 = DQNAgent(colonnes * lignes, colonnes)
            joueur_IA2 = IA(2, colonnes, colonnes * lignes, agent_IA2)
            joueurs = [joueur_humain, joueur_IA2]
        elif choix == 3:
            agent_IA1 = DQNAgent(colonnes * lignes, colonnes)
            agent_IA2 = DQNAgent(colonnes * lignes, colonnes)
            joueur_IA1 = IA(1, colonnes, colonnes * lignes, agent_IA1)
            joueur_IA2 = IA(2, colonnes, colonnes * lignes, agent_IA2)
            joueurs = [joueur_IA1, joueur_IA2]
        elif choix == 4:
            joueur_humain1 = Joueur(1, colonnes)
            joueur_humain2 = Joueur(2, colonnes)
            joueurs = [joueur_humain1, joueur_humain2]
        elif choix == 5:
            agent_IA1 = DQNAgent(colonnes * lignes, colonnes)
            agent_IA2 = DQNAgent(colonnes * lignes, colonnes)
            joueur_IA1 = IA(1, colonnes, colonnes * lignes, agent_IA1)
            joueur_IA2 = IA(2, colonnes, colonnes * lignes, agent_IA2)
            joueurs = [joueur_IA1, joueur_IA2]
            self.entrainement_IA(joueurs)
            return
        else:
            print("Mode invalide. Veuillez choisir 1, 2, 3, 4 ou 5.")
            return

        for i in range(self.episodes):
            print(f"Épisode {i+1}/{self.episodes}")
            jeu = Jeu(lignes, colonnes, joueurs)
            vainqueur = jeu.play()
            if vainqueur is not None:
                self.victoires[vainqueur] += 1
            else:
                self.victoires['Nulles'] += 1
            print(f"Taux de victoire Joueur 1 : {self.victoires[1]/(i+1):.2f}")
            print(f"Taux de victoire Joueur 2 : {self.victoires[2]/(i+1):.2f}")
            print(f"Parties nulles : {self.victoires['Nulles']/(i+1):.2f}")

    def entrainement_IA(self, joueurs):
    
        for i in range(self.episodes):
            print(f"Épisode {i+1}/{self.episodes}")
            jeu = Jeu(lignes, colonnes, joueurs)
            vainqueur = jeu.play()
            if vainqueur is not None:
                self.victoires[vainqueur] += 1
            else:
                self.victoires['Nulles'] += 1

            # Sauvefarde des données
            self.episodes_list.append(i + 1)
            self.victoires_joueur1.append(self.victoires[1] / (i + 1))
            self.victoires_joueur2.append(self.victoires[2] / (i + 1))
            self.parties_nulles.append(self.victoires['Nulles'] / (i + 1))

            print(f"Taux de victoire Joueur 1 : {self.victoires[1]/(i+1):.2f}")
            print(f"Taux de victoire Joueur 2 : {self.victoires[2]/(i+1):.2f}")
            print(f"Parties nulles : {self.victoires['Nulles']/(i+1):.2f}")

    # graphique inutile à changer + tard
        plt.plot(self.episodes_list, self.victoires_joueur1, label='Taux de victoire Joueur 1')
        plt.plot(self.episodes_list, self.victoires_joueur2, label='Taux de victoire Joueur 2')
        plt.plot(self.episodes_list, self.parties_nulles, label='Parties nulles')
        plt.xlabel('Épisodes')
        plt.ylabel('Taux de victoire')
        plt.legend()
        plt.savefig('graphique_evolution.png')
        plt.show()

#Initialise et démarre le jeu
if __name__ == "__main__":
    lignes, colonnes = 6, 7  # dimensions standard du puissance 4
    episodes = 100  # nombre d'épisodes pour l'entraînement
    entrainement = Entrainement(lignes, colonnes, episodes)
    entrainement.commencer()

字符串

t3psigkw

t3psigkw1#

确定这些函数调用的正确位置取决于您想要实现的目标(例如在多个会话、检查点中训练模型...)以及每个类应该负责什么。
保存模型应该在训练后的某个时间点(或在检查点的训练步骤之间)进行。
一种方法是让DQNAgent类负责自己保存的状态。在这种情况下,加载应该发生在方法的构造函数中。保存可以在learn()函数结束时或在step()中调用learn()函数之后进行。
另一种方法是在IA类中处理此功能。与前面的方法类似,加载可以在构造函数或IA的方法中进行。在apprendre()中调用self.agent.step()后保存最有意义。
创建/调用IA的类也可以进行加载并将加载的状态传递给IA

相关问题