NNFS
Neural network library from scratch
Loading...
Searching...
No Matches
CCE.hpp
Go to the documentation of this file.
1#pragma once
2
3#include "Loss.hpp"
4
5namespace NNFS
6{
12 class CCE : public Loss
13 {
14 public:
19
27 void forward(Eigen::MatrixXd &sample_losses, const Eigen::MatrixXd &predictions, const Eigen::MatrixXd &labels) const
28 {
29 Eigen::MatrixXd clipped_predictions = predictions.array().max(1e-7).min(1 - 1e-7); // clip data to prevent division by zero
30 Eigen::MatrixXd correct_confidences = (labels.array() * clipped_predictions.array()).rowwise().sum();
31 sample_losses = -(correct_confidences.array().log());
32 }
33
41 void backward(Eigen::MatrixXd &out, const Eigen::MatrixXd &predictions, const Eigen::MatrixXd &labels) const
42 {
43 int m = labels.rows();
44 out = -labels.array() / predictions.array();
45 out /= m;
46 }
47 };
48} // namespace NNFS
Cross-entropy loss function.
Definition CCE.hpp:13
void backward(Eigen::MatrixXd &out, const Eigen::MatrixXd &predictions, const Eigen::MatrixXd &labels) const
Backward pass of the CCE loss function.
Definition CCE.hpp:41
CCE()
Construct a new CCE object.
Definition CCE.hpp:18
void forward(Eigen::MatrixXd &sample_losses, const Eigen::MatrixXd &predictions, const Eigen::MatrixXd &labels) const
Forward pass of the CCE loss function.
Definition CCE.hpp:27
Base class for all loss functions.
Definition Loss.hpp:24
Definition Activation.hpp:6
LossType
Enum class for loss types.
Definition Loss.hpp:13