MagmaDNN
1.0
c++NeuralNetworkFramework
gradientdescent.h
Go to the documentation of this file.
1
9
#pragma once
10
11
#include "
optimizer/optimizer.h
"
12
#include "
compute/gradtable.h
"
13
#include "
compute/gradients.h
"
14
#include "
math/add.h
"
15
#include "
optimizer/gradientdescent/gradientdescent_internal.h
"
16
17
namespace
magmadnn
{
18
namespace
optimizer {
19
20
template
<
typename
T>
21
class
GradientDescent
:
public
Optimizer
<T> {
22
public
:
23
GradientDescent
(T learning_rate);
24
25
virtual
void
minimize(
op::Operation<T>
*obj_func,
const
std::vector<
op::Operation<T>
*>& wrt);
26
27
void
set_learning_rate(T learning_rate) { this->learning_rate = learning_rate; }
28
T get_learning_rate() {
return
this->learning_rate; }
29
30
protected
:
31
virtual
void
update(
op::Operation<T>
*var,
Tensor<T>
*grad);
32
33
T learning_rate;
34
op::GradTable<T>
table;
35
};
36
37
}
// namespace optimizer
38
}
// namespace magmadnn
gradients.h
magmadnn
Definition:
addop.cpp:11
add.h
magmadnn::optimizer::GradientDescent
Definition:
gradientdescent.h:21
magmadnn::optimizer::Optimizer
Definition:
optimizer.h:19
magmadnn::Tensor
Definition:
tensor.h:34
gradtable.h
optimizer.h
magmadnn::op::GradTable
Definition:
gradtable.h:24
magmadnn::op::Operation
Definition:
operation.h:18
gradientdescent_internal.h
include
optimizer
gradientdescent
gradientdescent.h
Generated by
1.8.13