NNFS
Neural network library from scratch
Loading...
Searching...
No Matches
SGD.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <Eigen/Dense>
4#include "Optimizer.hpp"
5
6namespace NNFS
7{
13 class SGD : public Optimizer
14 {
15 public:
23 SGD(double lr, double decay = 0.0, double momentum = 0.0) : Optimizer(lr, decay),
24 _momentum(momentum) {}
25
31 void update_params(std::shared_ptr<Dense> &layer)
32 {
33 Eigen::MatrixXd weights = layer->weights();
34 Eigen::MatrixXd biases = layer->biases();
35 Eigen::MatrixXd dweights = layer->dweights();
36 Eigen::MatrixXd dbiases = layer->dbiases();
37
38 Eigen::MatrixXd weights_updates;
39 Eigen::MatrixXd bias_updates;
40 if (_momentum > 0)
41 {
42 Eigen::MatrixXd weights_momentums = layer->weights_optimizer();
43 Eigen::MatrixXd biases_momentums = layer->biases_optimizer();
44
45 weights_updates = _momentum * weights_momentums - _current_lr * dweights;
46
47 bias_updates = _momentum * biases_momentums - _current_lr * dbiases;
48
49 layer->weights_optimizer(weights_updates);
50 layer->biases_optimizer(bias_updates);
51 }
52 else
53 {
54 weights_updates = -_current_lr * dweights;
55 bias_updates = -_current_lr * dbiases;
56 }
57
58 weights += weights_updates;
59 biases += bias_updates;
60
61 layer->weights(weights);
62 layer->biases(biases);
63 }
64
65 private:
66 double _momentum; // Momentum
67 };
68} // namespace NNFS
Base class for all optimizers.
Definition Optimizer.hpp:15
double _current_lr
Definition Optimizer.hpp:78
Stochastic Gradient Descent optimizer.
Definition SGD.hpp:14
SGD(double lr, double decay=0.0, double momentum=0.0)
Construct a new SGD object.
Definition SGD.hpp:23
void update_params(std::shared_ptr< Dense > &layer)
Update the parameters of the layer.
Definition SGD.hpp:31
Definition Activation.hpp:6