4#include "../Utilities/clue.hpp"
5#include "../Layer/Dense.hpp"
48 virtual void forward(Eigen::MatrixXd &sample_losses,
const Eigen::MatrixXd &predictions,
const Eigen::MatrixXd &labels)
const = 0;
57 virtual void backward(Eigen::MatrixXd &out,
const Eigen::MatrixXd &predictions,
const Eigen::MatrixXd &labels)
const = 0;
66 void calculate(
double &loss,
const Eigen::MatrixXd &predictions,
const Eigen::MatrixXd &labels)
68 Eigen::MatrixXd sample_losses;
69 forward(sample_losses, predictions, labels);
70 loss = sample_losses.mean();
83 const double weight_regularizer_l1 = layer->l1_weights_regularizer();
84 const double weight_regularizer_l2 = layer->l2_weights_regularizer();
85 const double bias_regularizer_l1 = layer->l1_biases_regularizer();
86 const double bias_regularizer_l2 = layer->l2_biases_regularizer();
88 if (weight_regularizer_l1 > 0)
93 if (weight_regularizer_l2 > 0)
95 regularization_loss += weight_regularizer_l2 * (layer->weights().array() * layer->weights().array()).sum();
98 if (bias_regularizer_l1 > 0)
103 if (bias_regularizer_l2 > 0)
105 regularization_loss += bias_regularizer_l2 * (layer->biases().array() * layer->biases().array()).sum();
Cross-entropy loss function.
Definition CCE.hpp:13
Base class for all loss functions.
Definition Loss.hpp:24
virtual void forward(Eigen::MatrixXd &sample_losses, const Eigen::MatrixXd &predictions, const Eigen::MatrixXd &labels) const =0
Forward pass of the loss function.
void calculate(double &loss, const Eigen::MatrixXd &predictions, const Eigen::MatrixXd &labels)
Calculate the loss.
Definition Loss.hpp:66
virtual void backward(Eigen::MatrixXd &out, const Eigen::MatrixXd &predictions, const Eigen::MatrixXd &labels) const =0
Backward pass of the loss function.
virtual ~Loss()=default
Basic destructor.
Loss(LossType type)
Construct a new Loss object.
Definition Loss.hpp:34
double regularization_loss(const std::shared_ptr< Dense > &layer)
Calculate l1 and l2 regularization loss.
Definition Loss.hpp:80
LossType type
Definition Loss.hpp:26
Definition Activation.hpp:6
LossType
Enum class for loss types.
Definition Loss.hpp:13