c++ PyTorch的交叉熵损失是如何实现的,在哪里实现的?

6ie5vjzr  于 2022-12-15  发布在  其他
关注(0)|答案(1)|浏览(147)

在PyTorch代码库中,真正实现交叉熵损失的代码在哪里?
从www.example.com开始loss.py,我在PyTorch中追踪交叉熵损失的源代码到loss. h,但它只包含以下内容:

struct TORCH_API CrossEntropyLossImpl : public Cloneable<CrossEntropyLossImpl> {
  explicit CrossEntropyLossImpl(const CrossEntropyLossOptions& options_ = {});

  void reset() override;

  /// Pretty prints the `CrossEntropyLoss` module into the given `stream`.
  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  /// The options with which this `Module` was constructed.
  CrossEntropyLossOptions options;

  /// A manual rescaling weight given to to each class.
  Tensor weight;
};

/// A `ModuleHolder` subclass for `CrossEntropyLossImpl`.
/// See the documentation for `CrossEntropyLossImpl` class to learn what methods
/// it provides, and examples of how to use `CrossEntropyLoss` with
/// `torch::nn::CrossEntropyLossOptions`. See the documentation for
/// `ModuleHolder` to learn about PyTorch's module storage semantics.
TORCH_MODULE(CrossEntropyLoss);

作为一个C++新手,看过ModuleHolder模板类之后,我有点迷惑。
有人能帮我建立一个准确的心理模型吗?

相关问题