pytorch Assert错误:环境必须指定操作空间

798qvoo8  于 2022-11-23  发布在  其他
关注(0)|答案(1)|浏览(700)

我正在使用OpenAI健身库。我已经下载了一个新的环境,但是我得到了这个错误。

AssertionError: The environment must specify an action space. https://www.gymlibrary.dev/content/environment_creation/

它说我没有指定动作空间。但我指定了。

class MoleculeEnv(gym.Env):
    metadata = {'render.modes': ['human']}
    def init(self, docking_config=dict(), data_type='zinc',ratios=dict(),reward_step_total=1,is_normalize=0,reward_type='crystal',reward_target=0.5,has_scaffold=False, has_feature=False,is_conditional=False,conditional='low',max_action=128,min_action=20,force_final=False):
        '''
        own init function, since gym does not support passing argument
        '''
        self.is_normalize = bool(is_normalize)
        self.has_feature = has_feature

        # init smi
        self.starting_smi = 'c1([*:1])c([*:2])ccc([*:3])c1' # for hit (benzene ring)
        # self.starting_smi = '[*:1]c1ccc2[nH]c(-c3cc([*:2])cc(-c4cccc([*:3])c4)c3O)cc2c1' # fa7 scaffold
        # self.starting_smi = 'O=C(c1cccc(Cc2c([*:1])[nH]c(=O)c3cc([*:2])c([*:3])n23)c1)N1CCN([*:4])CC1' # parp1 scaffold
        # self.starting_smi = 'C1=C([*:1])C2=NC=C(CC([*:2])C[NH+]3CCC([*:3])C([*:4])C3)[C@H]2C=C1[*:5]' # 5ht1b scaffold
        self.smi = self.starting_smi

        self.mol = Chem.MolFromSmiles(self.smi)
        self.smile_list = []
        self.smile_old_list = []

        possible_atoms = ATOM_VOCAB
        possible_motifs = FRAG_VOCAB
        possible_bonds = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
                          Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
        self.atom_type_num = len(possible_atoms)
        self.motif_type_num = len(possible_motifs)
        self.possible_atom_types = np.array(possible_atoms)
        self.possible_motif_types = np.array(possible_motifs)

        self.possible_bond_types = np.array(possible_bonds, dtype=object)

        self.d_n = len(self.possible_atom_types)+18 

        self.max_action = max_action
        self.min_action = min_action

        self.max_atom = 150
        self.action_space = gym.spaces.MultiDiscrete([20, len(FRAG_VOCAB), 20])

        self.counter = 0
        self.level = 0 # for curriculum learning, level starts with 0, and increase afterwards

        self.predictor = DockingVina(docking_config)

        self.attach_point = Chem.MolFromSmiles('*')
        self.Na = Chem.MolFromSmiles('[Na+]')
        self.K = Chem.MolFromSmiles('[K+]')
        self.H = Chem.MolFromSmiles('[H]')

    def seed(self,seed):
        np.random.seed(seed=seed)
        random.seed(seed)

    def normalize_adj(self,adj):
        degrees = np.sum(adj,axis=2)

        D = np.zeros((adj.shape[0],adj.shape[1],adj.shape[2]))
        for i in range(D.shape[0]):
            D[i,:,:] = np.diag(np.power(degrees[i,:],-0.5))
        adj_normal = D@adj@D
        adj_normal[np.isnan(adj_normal)]=0
        return adj_normal
    
    def reset_batch(self):
        self.smile_list = []
    
    def reward_batch(self):
        reward = []
        print('smiles list', self.smile_list)
        return reward_vina(self.smile_list, self.predictor)

    def reward_single(self, smile_list):
        reward = []
        print('smiles list', smile_list)
        return reward_vina(smile_list, self.predictor)

    def step(self, ac):
        """
        Perform a given action
        :param action:
        :param action_type:
        :return: reward of 1 if resulting molecule graph does not exceed valency,
        -1 if otherwise
        """
        ac = ac[0]
         
        ### init
        info = {}  # info we care about
        self.mol_old = copy.deepcopy(self.mol) # keep old mol
        
        stop = False    
        new = False
        
        if (self.counter >= self.max_action) or get_att_points(self.mol) == []:
            new = True
        else:
            self._add_motif(ac) # problems here

        reward_step = 0.05
        if self.mol.GetNumAtoms() > self.mol_old.GetNumAtoms():
            reward_step += 0.005
        self.counter += 1

        if new:            
            reward = 0
            # Only store for obs if attachment point exists in o2
            if get_att_points(self.mol) != []:
                mol_no_att = self.get_final_mol() 
                Chem.SanitizeMol(mol_no_att, sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
                smi_no_att = Chem.MolToSmiles(mol_no_att)
                info['smile'] = smi_no_att
                print("smi:", smi_no_att)
                self.smile_list.append(smi_no_att)

                # Info for old mol
                mol_old_no_att = self.get_final_mol_ob(self.mol_old)
                Chem.SanitizeMol(mol_old_no_att, sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
                smi_old_no_att = Chem.MolToSmiles(mol_no_att)
                info['old_smi'] = smi_old_no_att
                self.smile_old_list.append(smi_old_no_att)

                stop = True
            else:
                stop = False
            self.counter = 0      

        ### use stepwise reward
        else:
            reward = reward_step

        info['stop'] = stop

        # get observation
        ob = self.get_observation()
        return ob,reward,new,info

    def reset(self,smile=None):
        '''
        to avoid error, assume an atom already exists
        :return: ob
        '''
        if smile is not None:
            self.mol = Chem.RWMol(Chem.MolFromSmiles(smile))
            Chem.SanitizeMol(self.mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
        else:
            # init smi
            self.smi = self.starting_smi
            self.mol = Chem.MolFromSmiles(self.smi) 
        # self.smile_list = [] # only for single motif
        self.counter = 0
        ob = self.get_observation()
        return ob

    def render(self, mode='human', close=False):
        return

    def sample_motif(self):
        go_on = True
        while go_on:
            cur_mol_atts = get_att_points(self.mol)
            ac1 = np.random.randint(len(cur_mol_atts))
            ac2 = np.random.randint(self.motif_type_num)
            motif = FRAG_VOCAB_MOL[ac2]
            ac3 = np.random.randint(len(FRAG_VOCAB_ATT[ac2]))
            a = self.action_space.sample()
            
            a[0] = ac1
            a[1] = ac2
            a[2] = ac3

            go_on = False

        return a

    def _add_motif(self, ac): 
        
        cur_mol = Chem.ReplaceSubstructs(self.mol, self.attach_point, self.Na)[ac[0]]
        motif = FRAG_VOCAB_MOL[ac[1]]
        att_point = FRAG_VOCAB_ATT[ac[1]]
        motif_atom = map_idx(ac[2], att_point, motif) 
        motif = Chem.ReplaceSubstructs(motif, self.attach_point, self.K)[ac[2]]
        motif = Chem.DeleteSubstructs(motif, self.K)
        next_mol = Chem.ReplaceSubstructs(cur_mol, self.Na, motif, replacementConnectionPoint=motif_atom)[0]
        self.mol = next_mol

    def get_final_smiles_mol(self):
        """
        Returns a SMILES of the final molecule. Converts any radical
        electrons into hydrogens. Works only if molecule is valid
        :return: SMILES
        """
        m = Chem.DeleteSubstructs(self.mol, Chem.MolFromSmiles("*"))
        m = convert_radical_electrons_to_hydrogens(m)
        return m, Chem.MolToSmiles(m, isomericSmiles=True)

    def get_final_mol(self):
        """
        Returns a rdkit mol object of the final molecule. Converts any radical
        electrons into hydrogens. Works only if molecule is valid
        :return: SMILES
        """
        m = Chem.DeleteSubstructs(self.mol, Chem.MolFromSmiles("*"))
        return m
    
    def get_final_mol_ob(self, mol):
        m = Chem.DeleteSubstructs(mol, Chem.MolFromSmiles("*"))
        return m

    def get_observation(self, expert_smi=None):
        """
        ob['adj']:d_e*n*n --- 'E'
        ob['node']:1*n*d_n --- 'F'
        n = atom_num + atom_type_num
        """
        ob = {}

        if expert_smi:
            mol = Chem.MolFromSmiles(expert_smi)
        else:
            ob['att'] = get_att_points(self.mol)
            mol = copy.deepcopy(self.mol)
        
        try:
            Chem.SanitizeMol(mol)
        except:
            pass

        smi = Chem.MolToSmiles(mol)

        n = mol.GetNumAtoms()
        F = np.zeros((1, self.max_atom, self.d_n))

        for a in mol.GetAtoms():
            atom_idx = a.GetIdx()
            
            atom_symbol = a.GetSymbol()
            if self.has_feature:
                float_array = atom_feature(a, use_atom_meta=True)
            else:
                float_array = (atom_symbol == self.possible_atom_types).astype(float)

            F[0, atom_idx, :] = float_array

        d_e = len(self.possible_bond_types)
        E = np.zeros((d_e, self.max_atom, self.max_atom))

        for b in mol.GetBonds(): 
            begin_idx = b.GetBeginAtomIdx()
            end_idx = b.GetEndAtomIdx()
            bond_type = b.GetBondType()
            float_array = (bond_type == self.possible_bond_types).astype(float)
            try:
                assert float_array.sum() != 0
            except:
                print('error',bond_type)
            E[:, begin_idx, end_idx] = float_array
        
        if self.is_normalize:
            E = self.normalize_adj(E)
        
        ob_adj = adj2sparse(E.squeeze())
        ob_node = torch.Tensor(F)
        g = dgl.DGLGraph()

        ob_len = torch.sum(torch.sum(ob_node, dim=-1).bool().float().squeeze(-2), dim=-1)
        g.add_nodes(ob_len)
        if ob_adj is not None and len(ob_adj[0])>0 :
            g.add_edges(ob_adj[0][0], ob_adj[0][1], {'x': ob_adj[1]})
        g.ndata['x'] = ob_node[:, :int(ob_len),:].squeeze(0)
        
        ob['g'] = g
        ob['smi'] = smi
        
        return ob

    def get_observation_mol(self,mol):
        """
        ob['adj']:d_e*n*n --- 'E'
        ob['node']:1*n*d_n --- 'F'
        n = atom_num + atom_type_num
        """
        ob = {}

        ob['att'] = get_att_points(mol)
        
        try:
            Chem.SanitizeMol(mol)
        except:
            pass

        smi = Chem.MolToSmiles(mol)

        n = mol.GetNumAtoms()
        F = np.zeros((1, self.max_atom, self.d_n))

        for a in mol.GetAtoms():
            atom_idx = a.GetIdx()
            
            atom_symbol = a.GetSymbol()
            if self.has_feature:
                float_array = atom_feature(a, use_atom_meta=True)
            else:
                float_array = (atom_symbol == self.possible_atom_types).astype(float)

            F[0, atom_idx, :] = float_array

        d_e = len(self.possible_bond_types)
        E = np.zeros((d_e, self.max_atom, self.max_atom))

        for b in mol.GetBonds(): 

            begin_idx = b.GetBeginAtomIdx()
            end_idx = b.GetEndAtomIdx()
            bond_type = b.GetBondType()
            float_array = (bond_type == self.possible_bond_types).astype(float)

            try:
                assert float_array.sum() != 0
            except:
                print('error',bond_type)
            E[:, begin_idx, end_idx] = float_array
        
        if self.is_normalize:
            E = self.normalize_adj(E)
        
        ob_adj = adj2sparse(E.squeeze())
        ob_node = torch.Tensor(F)
        g = dgl.DGLGraph()

        ob_len = torch.sum(torch.sum(ob_node, dim=-1).bool().float().squeeze(-2), dim=-1)
        g.add_nodes(ob_len)
        if ob_adj is not None and len(ob_adj[0])>0 :
            g.add_edges(ob_adj[0][0], ob_adj[0][1], {'x': ob_adj[1]})
        g.ndata['x'] = ob_node[:, :int(ob_len),:].squeeze(0)
        
        ob['g'] = g
        ob['smi'] = smi
        return ob

这是我的env。我指定动作空间:[20,len(FRAG_VOCAB),20])中的一个或多个变量的值。
我查看了附加到错误的URL,但我仍然不知道如何解决此错误。
如何修复错误?

jchrr9hc

jchrr9hc1#

我把我的gym版本降低到0.18.0。这修正了我的错误。

相关问题