MagmaDNN  1.0
c++NeuralNetworkFramework
gradientdescent.h
Go to the documentation of this file.
1 
9 #pragma once
10 
11 #include "optimizer/optimizer.h"
12 #include "compute/gradtable.h"
13 #include "compute/gradients.h"
14 #include "math/add.h"
16 
17 namespace magmadnn {
18 namespace optimizer {
19 
20 template <typename T>
21 class GradientDescent : public Optimizer<T> {
22 public:
23  GradientDescent(T learning_rate);
24 
25  virtual void minimize(op::Operation<T> *obj_func, const std::vector<op::Operation<T> *>& wrt);
26 
27  void set_learning_rate(T learning_rate) { this->learning_rate = learning_rate; }
28  T get_learning_rate() { return this->learning_rate; }
29 
30 protected:
31  virtual void update(op::Operation<T> *var, Tensor<T> *grad);
32 
33  T learning_rate;
34  op::GradTable<T> table;
35 };
36 
37 } // namespace optimizer
38 } // namespace magmadnn
Definition: addop.cpp:11
Definition: gradientdescent.h:21
Definition: optimizer.h:19
Definition: tensor.h:34
Definition: gradtable.h:24
Definition: operation.h:18