在PyTorch代码库中,真正实现交叉熵损失的代码在哪里?
从www.example.com开始loss.py,我在PyTorch中追踪交叉熵损失的源代码到loss. h,但它只包含以下内容:
struct TORCH_API CrossEntropyLossImpl : public Cloneable<CrossEntropyLossImpl> {
explicit CrossEntropyLossImpl(const CrossEntropyLossOptions& options_ = {});
void reset() override;
/// Pretty prints the `CrossEntropyLoss` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
Tensor forward(const Tensor& input, const Tensor& target);
/// The options with which this `Module` was constructed.
CrossEntropyLossOptions options;
/// A manual rescaling weight given to to each class.
Tensor weight;
};
/// A `ModuleHolder` subclass for `CrossEntropyLossImpl`.
/// See the documentation for `CrossEntropyLossImpl` class to learn what methods
/// it provides, and examples of how to use `CrossEntropyLoss` with
/// `torch::nn::CrossEntropyLossOptions`. See the documentation for
/// `ModuleHolder` to learn about PyTorch's module storage semantics.
TORCH_MODULE(CrossEntropyLoss);
作为一个C++新手,看过ModuleHolder模板类之后,我有点迷惑。
有人能帮我建立一个准确的心理模型吗?
1条答案
按热度按时间cczfrluj1#
正如luk2302在评论中所说的,实现在src文件夹中,而不是在include中:https://github.com/pytorch/pytorch/blob/30fb2c4abaaaa966999eab11674f25b18460e609/torch/csrc/api/src/nn/modules/loss.cpp .