MagmaDNN  1.0
c++NeuralNetworkFramework
gradtable.h
Go to the documentation of this file.
1 
9 #pragma once
10 
11 #include <map>
12 #include <string>
13 #include <cstdint>
14 #include "compute/operation.h"
15 #include "compute/variable.h"
16 
17 namespace magmadnn {
18 namespace op {
19 
23 template <typename T>
24 class GradTable {
25 public:
28  GradTable();
29 
33  unsigned int get_size();
34 
39  Tensor<T>* get(Operation<T>* var);
40 
45  void set(Operation<T>* var, Tensor<T>* grad);
46 
49  void clear();
50 
51 protected:
52  std::map<Operation<T> *, Tensor<T>* > _table; // the underlying table to store data
53  typename std::map<Operation<T> *, Tensor<T> *>::iterator tmp_map_iterator;
54 
55 };
56 
57 } // namespace op
58 } // namespace magmadnn
void clear()
Definition: gradtable.cpp:46
Definition: addop.cpp:11
GradTable()
Definition: gradtable.cpp:16
Definition: tensor.h:34
unsigned int get_size()
Definition: gradtable.cpp:21
Definition: gradtable.h:24
Definition: operation.h:18
Variable< T > * var(std::string name, Tensor< T > *val)
Definition: variable.cpp:73