tensorflow 运行时态融合转换器默认数据集形状错误

hgqdbh6s  于 2022-12-13  发布在  其他
关注(0)|答案(1)|浏览(117)

我在谷歌colab中运行了在github下载的时间融合转换器的默认代码。
克隆后,当我跑了第二步,就没有办法进行测试训练了。

python3 -m script_train_fixed_params volatility outputs yes

问题是形状误差在下面。

Computing best validation loss
Computing test loss
/usr/local/lib/python3.7/dist-packages/keras/engine/training_v1.py:2079: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  updates=self.state_updates,
Traceback (most recent call last):
  File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/content/drive/MyDrive/tft_tf2/script_train_fixed_params.py", line 239, in <module>
    use_testing_mode=True)  # Change to false to use original default params
  File "/content/drive/MyDrive/tft_tf2/script_train_fixed_params.py", line 156, in main
    targets = data_formatter.format_predictions(output_map["targets"])
  File "/content/drive/MyDrive/tft_tf2/data_formatters/volatility.py", line 183, in format_predictions
    output[col] = self._target_scaler.inverse_transform(predictions[col])
  File "/usr/local/lib/python3.7/dist-packages/sklearn/preprocessing/_data.py", line 1022, in inverse_transform
    force_all_finite="allow-nan",
  File "/usr/local/lib/python3.7/dist-packages/sklearn/utils/validation.py", line 773, in check_array
    "if it contains a single sample.".format(array)
ValueError: Expected 2D array, got 1D array instead:
array=[-1.43120418  1.58885804  0.28558148 ... -1.50945972 -0.16713021
 -0.57365613].
Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.

我已经尝试修改了format_predictions'中的“data_formatters/volatility.py“,第183行的预测 Dataframe 形式的代码,因为我猜这就是问题出现的地方。),但我无法处理。

bf1o4zei

bf1o4zei1#

你得换条线
183英寸volatitlity.py

output[col] = self._target_scaler.inverse_transform(predictions[col].values.reshape(-1, 1))

以及electricity.py中的行216

sliced_copy[col] = target_scaler.inverse_transform(sliced_copy[col].values.reshape(-1, 1))

之后,示例electricity工作得很好。我想这应该与volatility相同。

相关问题