Paddle 2.0.0rc1版本中model.fit()高级API获取Tensor错误的问题

d5vmydt9  于 2021-11-29  发布在  Java
关注(0)|答案(2)|浏览(599)

在自定义网络中调用model.fit(train_X, epochs=5, batch_size=64, verbose=2)进行训练报错。
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/5
---------------------------------------------------------------------------ValueError Traceback (most recent call last) in
14 model.prepare(optimizer=paddle.optimizer.Adam(parameters=model.parameters()),loss=paddle.nn.CrossEntropyLoss(),metrics=paddle.metric.Accuracy())
15
---> 16 model.fit(train_X, epochs=5, batch_size=64, verbose=2)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/hapi/model.py in fit(self, train_data, eval_data, batch_size, epochs, eval_freq, log_freq, save_dir, save_freq, verbose, drop_last, shuffle, num_workers, callbacks)
1490 for epoch in range(epochs):
1491 cbks.on_epoch_begin(epoch)
-> 1492 logs = self._run_one_epoch(train_loader, cbks, 'train')
1493 cbks.on_epoch_end(epoch, logs)
1494
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/hapi/model.py in _run_one_epoch(self, data_loader, callbacks, mode, logs)
1797 if mode != 'predict':
1798 outs = getattr(self, mode + '_batch')(data[:len(self._inputs)],
-> 1799 data[len(self._inputs):])
1800 if self._metrics and self._loss:
1801 metrics = l[0] for l in outs[0]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/hapi/model.py in train_batch(self, inputs, labels)
938 print(loss)
939 """
--> 940 loss = self._adapter.train_batch(inputs, labels)
941 if fluid.in_dygraph_mode() and self._input_info is None:
942 self._update_inputs()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/hapi/model.py in train_batch(self, inputs, labels)
652 else:
653 outputs = self.model.network.forward(
--> 654 * [to_variable(x) for x in inputs])
655
656 losses = self.model._loss(*(to_list(outputs) + labels))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/container.py in forward(self, input)
84 def forward(self, input):
85 for layer in self._sub_layers.values():
---> 86 input = layer(input)
87 return input
88
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py incall(self, *inputs,**kwargs)
882 self._built = True
883
--> 884 outputs = self.forward(*inputs,**kwargs)
885
886 for forward_post_hook in self._forward_post_hooks.values():
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/rnn.py in forward(self, inputs, initial_states, sequence_length)
1033 if self.could_use_cudnn:
1034 # Add CPU kernel and dispatch in backend later
-> 1035 return self._cudnn_impl(inputs, initial_states, sequence_length)
1036
1037 states = split_states(initial_states, self.num_directions == 2,
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/rnn.py in _cudnn_impl(self, inputs, initial_states, sequence_length)
977 def _cudnn_impl(self, inputs, initial_states, sequence_length):
978 if not self.time_major:
--> 979 inputs = paddle.tensor.transpose(inputs, [1, 0, 2])
980 out = self._helper.create_variable_for_type_inference(inputs.dtype)
981 state = [
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/nn.py in transpose(x, perm, name)
5469 """
5470 if in_dygraph_mode():
-> 5471 out, _ = core.ops.transpose2(x, 'axis', perm)
5472 return out
5473
ValueError: (InvalidArgument) The input tensor's dimension should be equal to the axis's size. But received input tensor's dimension is 2, axis's size is 3
[Hint: Expected x_rank == axis_size, but received x_rank:2 != axis_size:3.] (at /paddle/paddle/fluid/operators/transpose_op.cc:47)
[Hint: If you need C++ stacktraces for debugging, please set FLAGS_call_stack_level=2.]
[operator < transpose2 > error]

附Tensor类形状
(965, 1, 6)

mznpcxlj

mznpcxlj1#

您好,我们已经收到了您的问题,会安排技术人员尽快解答您的问题,请耐心等待。请您再次检查是否提供了清晰的问题描述、复现代码、环境&版本、报错信息等。同时,您也可以通过查看官网API文档常见问题历史IssueAI社区来寻求解答。祝您生活愉快~

Hi! We've received your issue and please be patient to get responded. We will arrange technicians to answer your questions as soon as possible. Please make sure that you have posted enough message to demo your request. You may also check out the APIFAQGithub Issue and AI community to get the answer.Have a nice day!

mwg9r5ms

mwg9r5ms2#

看报错好像transpose的输入rank是2,可以检查一下

相关问题