我是OOP和RL的初学者,我需要一些关于我的connect 4游戏的建议:slight_smile:
首先,如果你在我的代码中看到任何令人震惊的东西,请不要犹豫,啊哈,但最重要的是,我想知道在哪里保存我的训练数据记录,在哪里加载它,以便它是有效的?我有点麻烦。
在此先谢谢您!
import numpy as np
from colorama import Fore, Style
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
import torch.nn.functional as F
import pickle
import os
repo = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(repo, "modeleIA.pth")
import matplotlib.pyplot as plt
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
class Plateau:
""" Classe représentant le plateau de jeu """
def __init__(self, rows, columns):
""" Constructeur de la classe, initialise le plateau et ses dimensions """
self.rows = rows
self.columns = columns
self.plato = np.zeros((rows, columns))
def colonne_check(self, col):
""" Vérifie si la colonne est pleine """
for i in range(self.rows):
if self.plato[i][col] == 0:
return True
return False
def placement_jeton(self, col, joueur):
""" Place le jeton du joueur dans la colonne qu'il sélectionne """
for i in np.r_[:self.rows][::-1]:
if self.plato[i][col] == 0:
self.plato[i][col] = joueur
return True
return False
def affichage(self, joueur):
""" Affiche le plateau de jeu à l'instant t """
couleur = Fore.RED if joueur == 1 else Fore.YELLOW
print(couleur + str(self.plato) + Style.RESET_ALL)
def check_victoire(self):
""" Vérifie si un joueur a gagné """
rows, columns = self.rows, self.columns
# vérification en ligne
for r in np.r_[:rows]:
for d in np.r_[:columns-3]:
f = d + 4
s = np.prod(self.plato[r, d:f])
if s == 1 or s == 16:
return True
# vérification en colonne
for c in np.r_[:columns]:
for d in np.r_[:rows-3]:
f = d + 4
s = np.prod(self.plato[d:f, c])
if s == 1 or s == 16:
return True
# vérification en diagonale (bas gauche vers haut droite)
for r in np.r_[:rows-3]:
for c in np.r_[:columns-3]:
f = c + 4
s = np.prod([self.plato[r+i, c+i] for i in range(4)])
if s == 1 or s == 16:
return True
# vérification en diagonale (haut gauche vers bas droite)
for r in np.r_[3:rows]:
for c in np.r_[:columns-3]:
f = c + 4
s = np.prod([self.plato[r-i, c+i] for i in range(4)])
if s == 1 or s == 16:
return True
return False
def get_etat(self):
""" Obtient l'état actuel du plateau sous forme de tableau 1D """
return self.plato.flatten()
def get_actions(self):
""" Obtient les actions possibles à partir de l'état actuel du plateau """
return np.where(self.plato[0] == 0)[0]
class Joueur:
""" Classe représentant un joueur humain """
def __init__(self, numero, max_choix):
""" Initialise le joueur et son numéro et le nombre de choix possibles """
self.numero = numero
self.max_choix = max_choix
def jouer(self, state, actions):
""" Demande au joueur de choisir une colonne """
while True:
try:
choix = int(input(f'Joueur {self.numero}, à vous de jouer (entre 1 et {self.max_choix}): ')) - 1
if choix in actions:
return choix
else:
print("Choix invalide. Essayez à nouveau.")
except ValueError:
print("Ce n'est pas un nombre. Essayez encore.")
class DQNAgent:
# Dans le RL, l'agent DQN utilise une mémoire appelée "replay memory"
# pour stocker les XP passées (état/action/rec/prochain état etc)
# afin de les réutiliser lors de l'apprentissage
# Initialisation de l'agent DQN
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
self.memory = deque(maxlen=10000) #pareil que sur le github
self.gamma = 0.95 # facteur d'actualisation, équilibre recomp im et future
self.epsilon = 1.0 # taux d'exploration initial
self.epsilon_min = 0.01 # taux d'exploration minimum
self.epsilon_decay = 0.995 # taux de décroissance de l'exploration
self.lr = 0.001 # taux d'apprentissage
self.model = self._build_model() # Construire le modèle de réseau neuronal
self.batch_size = 64
self.update_every = 5
# réseau principal (d'évaluation) pour choisir les actions et réseau cible pour générer les Q-values cibles
self.dqn_network = self._build_model().to(device)
self.target_network = self._build_model().to(device)
#pour l'instant même dim d'entrée, cachée (64) et de sortie
self.optimizer = optim.Adam(self.dqn_network.parameters(), lr=self.lr) # Optimiseur courrament utilisé en RL
self.t_step = 0 # Compteur pour la mise à jour du réseau cible
def charger_modele(self, chemin):
self.modele.load_state_dict(torch.load(chemin))
self.modele.eval()
def _build_model(self):
'''modele de réseau de neurones pour l'apprentissage'''
model = nn.Sequential(
nn.Linear(self.state_size, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, self.action_size)
)
return model
# L'agent choisit l'action selon l'état et la politique epsilon-greedy
def act(self, state, eps=0.1):
state = torch.from_numpy(state).float().unsqueeze(0).to(device) # On convertit l'état en tenseur
self.dqn_network.eval() # mode évaluation càd pas de mise à jour des poids
with torch.no_grad():
action_values = self.dqn_network(state) # Calculer la Q-valeur pour chaque action
self.dqn_network.train() # Repasser le réseau en mode entrainement
# politique epsilon-greedy (exploration/exploitation), à améliorer car
#pour l'instant génère un nombre aléatoire
if random.random() > eps:
return np.argmax(action_values.cpu().data.numpy()) # Choix de l'action avec la plus grande Q-valeur (greedy)
else:
return random.choice(np.arange(self.action_size)) # Choix d'une action aléatoire
# Stock l'expérience dans la mémoire de remise en état
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
# Prend une action et apprendre à partir de l'expérience
def step(self, state, action, reward, next_state, done):
self.remember(state, action, reward, next_state, done) # Stocker l'expérience
self.epsilon *= self.epsilon_decay #on multiplie par le facteur de décroissance
self.learn() # Apprendre de l'expérience
# Apprentissage à partir de l'expérience (implémentation de l'équation de Belman)
def learn(self):
# mémoire de remise en état soit assez grande
if len(self.memory) < self.batch_size:
return
# on échantillonne un batch d'expériences de taille aléatoire
experiences = random.sample(self.memory, self.batch_size)
# dézip les élements de l'échantillon
states, actions, rewards, next_states, dones = zip(*experiences)
# Convertion des expériences numpy --> tenseur pytorch
states = torch.from_numpy(np.vstack(states)).float().to(device)
actions = torch.from_numpy(np.vstack(actions)).long().to(device)
rewards = torch.from_numpy(np.vstack(rewards)).float().to(device)
next_states = torch.from_numpy(np.vstack(next_states)).float().to(device)
dones = torch.from_numpy(np.vstack(dones).astype(np.int64)).float().to(device)
# Calcul les Q-valeurs cibles et attendues
#on obtient les qvaleur prédites à partir du modèle cible
Q_cible_next = self.target_network(next_states).detach().max(1)[0].unsqueeze(1) # Q-valeurs cibles pour les prochains états
#on calcule les Q cibles pour les états actuels
Q_cible = rewards + (self.gamma * Q_cible_next * (1 - dones)) # Q-valeurs cibles pour les états actuels
#on calcul les q attendus à partir du modèle
Q_expected = self.dqn_network(states).gather(1, actions) # Q-valeurs attendues pour les états actuels
# Calcul de la perte et rétropropagation de l'erreur
loss = F.mse_loss(Q_expected, Q_cible) # Calcul la perte w/ MSE
# On minimise la fonction de perte
self.optimizer.zero_grad() # Réinitialisation gradients
loss.backward() # Rétropropagation de l'erreur
self.optimizer.step() # Mise à jour des poids du réseau
# Mise à jour du réseau cible on le copie du réseau DQN
self.t_step = (self.t_step + 1) % self.update_every
if self.t_step == 0:
self.target_network.load_state_dict(self.dqn_network.state_dict())
def sauvegarder_modele(self, path):
# Sauvegarde du modèle PyTorch avec pickle
state = {
'dqn_network_state_dict': self.dqn_network.state_dict(),
'target_network_state_dict': self.target_network.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'memory': self.memory,
'epsilon': self.epsilon,
}
#with open(chemin, "wb") as fichier:
#pickle.dump(self.model, fichier)
def charger_modele(self, path):
# Chergament du modèle Pytorch avec pickle
#with open(chemin, "rb") as fichier:
# self.model = pickle.load(fichier)
if os.path.exists(path):
state = torch.load(path)
self.dqn_network.load_state_dict(state['dqn_network_state_dict'])
self.target_network.load_state_dict(state['target_network_state_dict'])
self.optimizer.load_state_dict(state['optimizer_state_dict'])
self.memory = state['memory']
self.epsilon = state['epsilon']
print(f"Modèle chargé depuis {path}")
else:
print(f"Aucun modèle trouvé à {path}")
class IA:
def __init__(self, numero, max_choix, state_size, agent=None):
self.numero = numero
self.max_choix = max_choix
if agent is None:
self.agent= DQNAgent(state_size, max_choix)
else:
self.agent= agent
def jouer(self, state, actions):
""" Choisit une colonne en utilisant l'agent IA """
proba_victoires = self.calculer_proba_victoires(state, actions)
# Trie en fonction des probabilités de victoire
tri_indice_action = np.argsort(proba_victoires)
# actions de la plus probable à la moins probable
for indice in reversed(tri_indice_action):
action = actions[indice]
# Si l'action est possible, on la retourne
if action in actions:
return action
#(ne devrait pas arriver), retourne une action aléatoire
return np.random.choice(actions)
def calculer_proba_victoires(self, state, actions):
proba_victoires = np.zeros(len(actions))
for i, action in enumerate(actions):
next_state = state.copy()
next_state[action] = self.numero
proba_victoires[i] = self.agent.act(next_state)
return proba_victoires
def apprendre(self, state, action, reward, next_state, done):
self.agent.step(state, action, reward, next_state, done)
class Jeu:
""" Classe représentant le jeu en lui-même """
def __init__(self, rows, columns, joueurs):
""" Initialise le jeu et les joueurs """
self.plato = Plateau(rows, columns)
self.joueurs = joueurs
def play(self):
""" Lance le jeu et vérifie si un joueur a gagné ou si la partie est nulle """
state = self.plato.get_etat()
while True:
for joueur in self.joueurs:
print(f"Joueur {joueur.numero}")
actions = self.plato.get_actions()
choix = joueur.jouer(state, actions)
self.plato.placement_jeton(choix, joueur.numero)
self.plato.affichage(joueur.numero)
if self.plato.check_victoire():
print(f"Joueur {joueur.numero} a gagné!")
self.plato.affichage(joueur.numero) # Afficher le plateau final
return joueur.numero
elif np.all(self.plato.plato != 0):
print("Match nul!")
return None
state = self.plato.get_etat()
class Entrainement:
def __init__(self, lignes, colonnes, episodes):
self.lignes = lignes
self.colonnes = colonnes
self.episodes = episodes
self.victoires = {1: 0, 2: 0, 'Nulles': 0}
### Liste vides pour enregistrer les données
self.episodes_list = []
self.victoires_joueur1 = []
self.victoires_joueur2 = []
self.parties_nulles = []
def commencer(self):
print("Choisissez le mode :")
print("1. Jouer contre l'IA 1")
print("2. Jouer contre l'IA 2")
print("3. IA 1 vs IA 2")
print("4. Jouer humain contre humain")
print("5. Entraîner les deux IA entre elles")
choix = int(input("Votre choix : "))
if choix == 1:
joueur_humain = Joueur(1, colonnes)
agent_IA1 = DQNAgent(colonnes * lignes, colonnes)
joueur_IA1 = IA(2, colonnes, colonnes * lignes, agent_IA1)
joueurs = [joueur_humain, joueur_IA1]
elif choix == 2:
joueur_humain = Joueur(1, colonnes)
agent_IA2 = DQNAgent(colonnes * lignes, colonnes)
joueur_IA2 = IA(2, colonnes, colonnes * lignes, agent_IA2)
joueurs = [joueur_humain, joueur_IA2]
elif choix == 3:
agent_IA1 = DQNAgent(colonnes * lignes, colonnes)
agent_IA2 = DQNAgent(colonnes * lignes, colonnes)
joueur_IA1 = IA(1, colonnes, colonnes * lignes, agent_IA1)
joueur_IA2 = IA(2, colonnes, colonnes * lignes, agent_IA2)
joueurs = [joueur_IA1, joueur_IA2]
elif choix == 4:
joueur_humain1 = Joueur(1, colonnes)
joueur_humain2 = Joueur(2, colonnes)
joueurs = [joueur_humain1, joueur_humain2]
elif choix == 5:
agent_IA1 = DQNAgent(colonnes * lignes, colonnes)
agent_IA2 = DQNAgent(colonnes * lignes, colonnes)
joueur_IA1 = IA(1, colonnes, colonnes * lignes, agent_IA1)
joueur_IA2 = IA(2, colonnes, colonnes * lignes, agent_IA2)
joueurs = [joueur_IA1, joueur_IA2]
self.entrainement_IA(joueurs)
return
else:
print("Mode invalide. Veuillez choisir 1, 2, 3, 4 ou 5.")
return
for i in range(self.episodes):
print(f"Épisode {i+1}/{self.episodes}")
jeu = Jeu(lignes, colonnes, joueurs)
vainqueur = jeu.play()
if vainqueur is not None:
self.victoires[vainqueur] += 1
else:
self.victoires['Nulles'] += 1
print(f"Taux de victoire Joueur 1 : {self.victoires[1]/(i+1):.2f}")
print(f"Taux de victoire Joueur 2 : {self.victoires[2]/(i+1):.2f}")
print(f"Parties nulles : {self.victoires['Nulles']/(i+1):.2f}")
def entrainement_IA(self, joueurs):
for i in range(self.episodes):
print(f"Épisode {i+1}/{self.episodes}")
jeu = Jeu(lignes, colonnes, joueurs)
vainqueur = jeu.play()
if vainqueur is not None:
self.victoires[vainqueur] += 1
else:
self.victoires['Nulles'] += 1
# Sauvefarde des données
self.episodes_list.append(i + 1)
self.victoires_joueur1.append(self.victoires[1] / (i + 1))
self.victoires_joueur2.append(self.victoires[2] / (i + 1))
self.parties_nulles.append(self.victoires['Nulles'] / (i + 1))
print(f"Taux de victoire Joueur 1 : {self.victoires[1]/(i+1):.2f}")
print(f"Taux de victoire Joueur 2 : {self.victoires[2]/(i+1):.2f}")
print(f"Parties nulles : {self.victoires['Nulles']/(i+1):.2f}")
# graphique inutile à changer + tard
plt.plot(self.episodes_list, self.victoires_joueur1, label='Taux de victoire Joueur 1')
plt.plot(self.episodes_list, self.victoires_joueur2, label='Taux de victoire Joueur 2')
plt.plot(self.episodes_list, self.parties_nulles, label='Parties nulles')
plt.xlabel('Épisodes')
plt.ylabel('Taux de victoire')
plt.legend()
plt.savefig('graphique_evolution.png')
plt.show()
#Initialise et démarre le jeu
if __name__ == "__main__":
lignes, colonnes = 6, 7 # dimensions standard du puissance 4
episodes = 100 # nombre d'épisodes pour l'entraînement
entrainement = Entrainement(lignes, colonnes, episodes)
entrainement.commencer()
字符串
1条答案
按热度按时间t3psigkw1#
确定这些函数调用的正确位置取决于您想要实现的目标(例如在多个会话、检查点中训练模型...)以及每个类应该负责什么。
保存模型应该在训练后的某个时间点(或在检查点的训练步骤之间)进行。
一种方法是让
DQNAgent
类负责自己保存的状态。在这种情况下,加载应该发生在方法的构造函数中。保存可以在learn()
函数结束时或在step()
中调用learn()
函数之后进行。另一种方法是在
IA
类中处理此功能。与前面的方法类似,加载可以在构造函数或IA
的方法中进行。在apprendre()
中调用self.agent.step()
后保存最有意义。创建/调用
IA
的类也可以进行加载并将加载的状态传递给IA
。