我有一个模块列表,我想使用另一个索引列表将其编入索引
import torch
import torchvision.transforms as T
transforms = torch.nn.ModuleList([T.ColorJitter(), T.ColorJitter()])
order = [1,0]
字符串
I cannot do transforms[order[0]]
。但是,我可以迭代ModuleList:for t in transforms:
和for i,t in enumerate(transforms):
工作
如何通过迭代或枚举有效地索引到ModuleList中?
我尝试了以下方法,但它们不起作用
# Permute/Change ordering of the ModuleList using a ModuleDict and then iterate the ModuleDict
permuted_transforms = torch.nn.ModuleDict({order[i]:t for i,t in enumerate(transforms)})
给出FrontendError: Cannot instantiate class 'ModuleDict' in a script function
*
# Permute/Change ordering of the ModuleList using torch.take
permuted_transforms = torch.take(self.transforms, order)
但是torch.take只对torch.tensors有效,而对ModuleLists无效
*
# Permute/Change ordering of the ModuleList using map
permuted_transforms = map(self.transforms.__getitem__, order)
型
*
# Permute/Change ordering of the ModuleList using sorted
permuted_transforms = sorted(self.transforms, key=order.__getitem__)
型
*
# Have 2 for loops work but is extremely ineffecient
for o in order:
for i,t in enumerate(transforms):
if i==o: apply(t)
型
1条答案
按热度按时间fkvaft9z1#
如果我没理解错的话,这应该就是你要找的:
第一个月