bug描述 Describe the Bug
Paddle:
Device:
NVIDIA A100 40GB
Cuda:
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Tue_Mar__8_18:18:20_PST_2022
Cuda compilation tools, release 11.6, V11.6.124
Build cuda_11.6.r11.6/compiler.31057947_0
Bug:
按照官方文档,ParameterList和python内置List的行为一致
import paddle
class MyLayer(paddle.nn.Layer):
def __init__(self, num_stacked_param):
super().__init__()
# create ParameterList with iterable Parameters
self.params = paddle.nn.ParameterList(
[paddle.create_parameter(
shape=[2, 2], dtype='float32')] * num_stacked_param)
def forward(self, x):
for i, p in enumerate(self.params):
tmp = self._helper.create_variable_for_type_inference('float32')
self._helper.append_op(
type="mul",
inputs={"X": x,
"Y": p},
outputs={"Out": tmp},
attrs={"x_num_col_dims": 1,
"y_num_col_dims": 1})
x = tmp
return x
x = paddle.uniform(shape=[5, 2], dtype='float32')
num_stacked_param = 4
model = MyLayer(num_stacked_param)
print(len(model.params))
res = model(x)
print(res.shape)
replaced_param = paddle.create_parameter(shape=[2, 3], dtype='float32')
model.params[num_stacked_param - 1] = replaced_param # replace last param
res = model(x)
print(res.shape)
model.params.append(paddle.create_parameter(shape=[3, 4], dtype='float32')) # append param
print(len(model.params))
res = model(x)
print(res.shape)
# here
print(model.params[:])
Log
grep: warning: GREP_OPTIONS is deprecated; please use an alias or script
W0618 05:17:49.387017 142075 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 8.0, Driver API Version: 12.0, Runtime API Version: 11.6
W0618 05:17:49.393265 142075 gpu_resources.cc:164] device: 0, cuDNN Version: 8.4.
4
[5, 2]
[5, 3]
5
[5, 4]
Traceback (most recent call last):
File "/workspace/wangguan/gino_paddle_ahmed/gino/src/neuralop/models/test_PaddleParameterList.py", line 39, in <module>
print(model.params[:])
File "/root/miniconda3/envs/gino_conda/lib/python3.9/site-packages/paddle/nn/layer/container.py", line 365, in __getitem__
return self._parameters[str(idx)]
KeyError: 'slice(None, None, None)'
其他补充信息 Additional Supplementary Information
No response
2条答案
按热度按时间nwwlzxa71#
可以用这个https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/index_select_cn.html#index-select
6xfqseft2#
可以用这个https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/index_select_cn.html#index-select
感谢,另外发现,如果给slice套一层tuple
idea from : @lijialin03