如何在Pytorch中可视化一个网络?

pod7payv  于 2022-11-23  发布在  其他
关注(0)|答案(5)|浏览(239)
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot

batch_size = 3
learning_rate =0.0002
epoch = 50

resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)

我想从pytorch模型中可视化resnet。我该怎么做呢?我尝试使用torchviz,但它给出了一个错误:

'ResNet' object has no attribute 'grad_fn'
4szc88ey

4szc88ey1#

以下是使用不同工具的三种不同图形可视化。
为了生成示例可视化,我将使用一个简单的RNN来执行从online tutorial

class RNN(nn.Module):

    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):

        super().__init__()
        self.embedding  = nn.Embedding(input_dim, embedding_dim)
        self.rnn        = nn.RNN(embedding_dim, hidden_dim)
        self.fc         = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):

        embedding       = self.embedding(text)
        output, hidden  = self.rnn(embedding)

        return self.fc(hidden.squeeze(0))

下面是print()模型的输出。

RNN(
  (embedding): Embedding(25002, 100)
  (rnn): RNN(100, 256)
  (fc): Linear(in_features=256, out_features=1, bias=True)
)

以下是三种不同可视化工具的结果。
对于所有这些方法,您都需要一个可以通过模型的forward()方法传递的虚拟输入。获取此输入的一个简单方法是从Dataloader中检索一个批处理,如下所示:

batch = next(iter(dataloader_train))
yhat = model(batch.text) # Give dummy batch to forward().

Torch 维茨
https://github.com/szagoruyko/pytorchviz
我相信这个工具使用向后传递来生成它的图形,所以所有的盒子都使用PyTorch组件进行向后传播。

from torchviz import make_dot

make_dot(yhat, params=dict(list(model.named_parameters()))).render("rnn_torchviz", format="png")

此工具生成以下输出文件:

这是唯一一个清楚地提到我的模型中的三个层embeddingrnnfc的输出。

隐藏图层

https://github.com/waleedka/hiddenlayer
我相信这个工具使用的是正向传球。

import hiddenlayer as hl

transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph.

graph = hl.build_graph(model, batch.text, transforms=transforms)
graph.theme = hl.graph.THEMES['blue'].copy()
graph.save('rnn_hiddenlayer', format='png')

这是输出。我喜欢蓝色的阴影。

我发现输出有太多的细节,混淆了我的架构。例如,为什么unsqueeze被提到这么多次?
耐特龙
https://github.com/lutzroeder/netron
此工具是一个适用于Mac、Windows和Linux的桌面应用程序。它依赖于首先将模型导出到ONNX format。然后应用程序读取ONNX文件并渲染它。然后有一个选项将模型导出到图像文件。

input_names = ['Sentence']
output_names = ['yhat']
torch.onnx.export(model, batch.text, 'rnn.onnx', input_names=input_names, output_names=output_names)

下面是这个模型在应用程序中的样子。我认为这个工具非常灵活:你可以缩放和平移,你可以钻到层和操作符。唯一的缺点是,我发现它只做垂直布局。

xlpyo6sf

xlpyo6sf2#

make_dot需要一个变量(即grad_fn的Tensor),而不是模型本身。
尝试:

x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
out = resnet(x)
make_dot(out)  # plot graph of variable, not of a nn.Module
5w9g7ksd

5w9g7ksd3#

您可以查看PyTorchViz(https://github.com/szagoruyko/pytorchviz),“一个创建PyTorch执行图和跟踪可视化的小包。”

woobm2wo

woobm2wo4#

如果你想保存图像,下面是如何使用torchviz

# http://www.bnikolic.co.uk/blog/pytorch-detach.html

import torch
from torchviz import make_dot

x=torch.ones(10, requires_grad=True)
weights = {'x':x}

y=x**2
z=x**3
r=(y+z).sum()

make_dot(r).render("attached", format="png")

您获得图像的屏幕截图:

来源:http://www.bnikolic.co.uk/blog/pytorch-detach.html

kuuvgm7e

kuuvgm7e5#

这可能是一个迟来的答案。但是,特别是开发了__torch_function__之后,有可能获得更好的可视化。您可以在这里尝试我的项目torchview
对于您的resnet 50示例,您可以查看colab笔记本here,在这里我演示了resnet 18模型的可视化。
它还接受广泛的输出/输入类型(例如列表、字典)

相关问题