pytorch 在我的python discort bot中加载torch保存的模型

0s7z1bwu  于 12个月前  发布在  Python
关注(0)|答案(1)|浏览(126)

我想将两个torch模型加载到我的discord应用程序中。当我启动python main.py命令时,我得到错误AttributeError: Can't get attribute 'Generator' on <module '__main__' from '/home/mle/Devel/gan-discord-bot/main.py'>
但是当我运行bot.py脚本时,加载工作正常。
以下是我使用的脚本

main.py

import bot

if __name__=='__main__':
    bot.run_discord_bot()

字符串

bot.py

import discord
import torch
from torchvision.utils import save_image

from model import *

# the models have been trained on gpus
netG = torch.load(netG_name, map_location=torch.device('cpu'))
netD = torch.load(netD_name, map_location=torch.device('cpu'))

netG.eval()
netD.eval()

async def generate_image(user_message, is_private):
    try: 
        z = torch.randn(1, nz, 1, 1)
        torch_image = netG(z)
        save_image(torch_image, 'tmp.png') 
        image = discord.File('tmp.png')
        await user_message.author.send(file=image) if is_private else await user_message.channel.send(file=image)

    except Exception as e:
        print(e)

def run_discord_bot():
    TOKEN = 'MyToken'
    intents = discord.Intents.default()
    intents.message_content = True
    client = discord.Client(intents=intents)

    @client.event
    async def on_ready():
        print(f'{client.user} is now running!')

    @client.event
    async def on_message(message):
        if message.author == client.user:
            return 
        username = str(message.author)
        user_message = str(message.content)
        channel = str(message.channel)

        if str.find(user_message, '! generate') != -1:
            await generate_image(user_message, is_private=False)

    client.run(TOKEN)

model.py

我要加载的两个网络是GAN的Generator和Discriminator。使用的参数在以下脚本中:

from torch import nn

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 128

# Size of feature maps in discriminator
ndf = 128

# saved models
netG_name = "netG128.pt"
netD_name = "netD128.pt"

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. ``(ngf*8) x 4 x 4``
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            # state size. ``(nc) x 64 x 64``
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

p_dropout = 0.5
bias_discriminator = False

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is ``(nc) x 64 x 64``
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=bias_discriminator),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf) x 32 x 32``
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=bias_discriminator),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(p=p_dropout),
            # state size. ``(ndf*2) x 16 x 16``
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=bias_discriminator),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(p=p_dropout),
            # state size. ``(ndf*4) x 8 x 8``
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=bias_discriminator),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(p=p_dropout),
            # state size. ``(ndf*8) x 4 x 4``
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=bias_discriminator),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

减少pip冻结输出

discord==2.3.2
discord.py==2.3.2
torch==2.1.0
torchvision==0.16.0

juud5qan

juud5qan1#

通过在generate_image_function中加载模型,它可以工作。但这不是最佳的。这个错误对我来说很奇怪。我愿意接受任何关于这个解决方案的错误或改进的解释。
以下是更新后的bot.py

import discord
import torch
import torchvision
from torchvision.utils import save_image
from torchvision.transforms import InterpolationMode
import torchvision.transforms as transforms
from model import *

async def generate_image(message, user_message, is_private):
    try: 

        netG = torch.load(netG_name, map_location=torch.device('cpu'))
        netD = torch.load(netD_name, map_location=torch.device('cpu'))
        netG.eval()
        netD.eval()

        z = torch.randn(1, nz, 1, 1)
        torch_image = netG(z)

        """
        transformations = transforms.Compose([
            torchvision.transforms.Resize(
                size=256, 
                antialias = True, 
                interpolation=InterpolationMode.BILINEAR
            ),
            transforms.ToPILImage()
        ])

        torch_image = transformations(torch_image)
        torch_image.save("tmp.png")
        """

        resize = torchvision.transforms.Resize(
                size=256, 
                antialias = True, 
                interpolation=InterpolationMode.BILINEAR
            )
        torch_image = resize(torch_image)

        save_image(torch_image, 'tmp.png', normalize=True) 
        image = discord.File('tmp.png')

        await message.author.send(file=image) if is_private else await message.channel.send(file=image)

    except Exception as e:
        print(e)

def run_discord_bot():
    TOKEN = 'MTE4MzQ0MTI1ODczNjI2MzIxOQ.G1Xv0k.igv_KI-1oHRKaY5631Mi8ANBczWwMxeqDFXecU'
    intents = discord.Intents.default()
    intents.message_content = True
    client = discord.Client(intents=intents)

    @client.event
    async def on_ready():
        print(f'{client.user} is now running!')

    @client.event
    async def on_message(message):
        if message.author == client.user:
            return 
        username = str(message.author)
        user_message = str(message.content)
        channel = str(message.channel)

        if str.find(user_message, '! generate') != -1:
            await generate_image(message, user_message, is_private=False)

    client.run(TOKEN)

字符串

编辑

Karl评论中提出的解决方案:
bot.py

import discord
import torch
import torchvision
from torchvision.utils import save_image
from torchvision.transforms import InterpolationMode
import torchvision.transforms as transforms
from model import *

async def generate_image(message, user_message, is_private, netG):
    try: 

        #netG = torch.load(netG_name, map_location=torch.device('cpu'))
        #netD = torch.load(netD_name, map_location=torch.device('cpu'))
        #netG.eval()
        #netD.eval()

        z = torch.randn(1, nz, 1, 1)
        torch_image = netG(z)

        """
        transformations = transforms.Compose([
            torchvision.transforms.Resize(
                size=256, 
                antialias = True, 
                interpolation=InterpolationMode.BILINEAR
            ),
            transforms.ToPILImage()
        ])

        torch_image = transformations(torch_image)
        torch_image.save("tmp.png")
        """

        resize = torchvision.transforms.Resize(
                size=256, 
                antialias = True, 
                interpolation=InterpolationMode.BILINEAR
            )
        torch_image = resize(torch_image)

        save_image(torch_image, 'tmp.png', normalize=True) 
        image = discord.File('tmp.png')

        await message.author.send(file=image) if is_private else await message.channel.send(file=image)

    except Exception as e:
        print(e)

def run_discord_bot():
    TOKEN = 'MTE4MzQ0MTI1ODczNjI2MzIxOQ.G1Xv0k.igv_KI-1oHRKaY5631Mi8ANBczWwMxeqDFXecU'
    intents = discord.Intents.default()
    intents.message_content = True
    client = discord.Client(intents=intents)

    netG = torch.load(netG_name, map_location=torch.device('cpu'))
    netG.eval()

    @client.event
    async def on_ready():
        print(f'{client.user} is now running!')

    @client.event
    async def on_message(message):
        if message.author == client.user:
            return 
        username = str(message.author)
        user_message = str(message.content)
        channel = str(message.channel)

        if str.find(user_message, '! generate') != -1:
            await generate_image(message, user_message, is_private=False, netG=netG)

    client.run(TOKEN)

相关问题