NNFS
Neural network library from scratch
Loading...
Searching...
No Matches
NNFS::RMSProp Class Reference

Root Mean Square Propagation optimizer. More...

#include <RMSProp.hpp>

Inheritance diagram for NNFS::RMSProp:
[legend]

Public Member Functions

 RMSProp (double lr=1e-3, double decay=1e-3, double epsilon=1e-7, double rho=.9)
 Construct a new RMSProp object.
 
void update_params (std::shared_ptr< Dense > &layer)
 Update the parameters of the layer.
 
- Public Member Functions inherited from NNFS::Optimizer
 Optimizer (double lr, double decay)
 Construct a new Optimizer object.
 
virtual ~Optimizer ()=default
 Basic destructor.
 
virtual void update_params (std::shared_ptr< Dense > &layer)=0
 Update the parameters of the layer.
 
void pre_update_params ()
 Pre-update parameters (e.g. learning rate decay)
 
void post_update_params ()
 Post-update parameters (e.g. increase iteration count)
 
double & current_lr ()
 Get the current learning rate.
 
int & iterations ()
 Get current iteration count.
 

Additional Inherited Members

- Protected Attributes inherited from NNFS::Optimizer
const double _lr
 
double _current_lr
 
int _iterations
 
double _decay
 

Detailed Description

Root Mean Square Propagation optimizer.

This class implements the Root Mean Square Propagation (RMSProp) optimizer.

Constructor & Destructor Documentation

◆ RMSProp()

NNFS::RMSProp::RMSProp ( double  lr = 1e-3,
double  decay = 1e-3,
double  epsilon = 1e-7,
double  rho = .9 
)
inline

Construct a new RMSProp object.

Parameters
lrLearning rate (default: 1e-3)
decayLearning rate decay (default: 1e-3)
epsilonEpsilon - to avoid division by zero (default: 1e-7)
rhoRMSProp uses "rho" to calculate an exponentially weighted average over the square of the gradients. (default: .9)

Member Function Documentation

◆ update_params()

void NNFS::RMSProp::update_params ( std::shared_ptr< Dense > &  layer)
inlinevirtual

Update the parameters of the layer.

Parameters
[in,out]layerLayer to update

Implements NNFS::Optimizer.


The documentation for this class was generated from the following file: