MagmaDNN  1.0
c++NeuralNetworkFramework
operation.h
Go to the documentation of this file.
1 
9 #pragma once
10 #include <string>
11 #include <map>
12 #include "tensor/tensor.h"
13 
14 namespace magmadnn {
15 namespace op {
16 
17 template <typename T>
18 class Operation {
19 public:
23  Operation() : has_been_computed(false) {}
24  Operation(std::vector<Operation<T> *> inputs, bool needs_grad=true) : inputs(inputs), needs_grad(needs_grad) {
25  for (typename std::vector<Operation<T> *>::iterator vit = inputs.begin(); vit != inputs.end(); vit++) {
26  if (needs_grad) { /* TODO : verify this is necessary */
27  (*vit)->add_consumer(this);
28  }
29  this->_grad_cache.insert( std::make_pair((uintptr_t) (*vit), (Tensor<T> *) NULL) );
30  }
31  }
32  virtual ~Operation() {
33  for (unsigned int i = 0; i < inputs.size(); i++)
34  delete inputs[i];
35 
36  /* TODO : figure out why this peice of code caused SEGFAULTS
37  if (this->output_tensor != NULL) {
38  delete this->output_tensor;
39  }
40  */
41  }
42 
46  virtual std::vector<unsigned int> get_output_shape() const { return this->output_shape; }
47 
52  virtual unsigned int get_output_shape(unsigned int idx) const {
53  assert( idx < this->output_shape.size() );
54  return this->output_shape[idx];
55  }
56 
60  virtual unsigned int get_output_size() const {
61  unsigned int size = 1;
62  for (unsigned int i = 0; i < this->output_shape.size(); i++) size *= this->output_shape[i];
63  return size;
64  }
65 
69  virtual memory_t get_memory_type() const { return this->mem_type; }
70 
75  virtual Tensor<T>* eval(bool recompute=true) {
76  if (!recompute && this->has_been_computed && this->output_tensor != NULL) {
77  return this->output_tensor;
78  } else {
79  this->has_been_computed = true;
80  return _eval(recompute);
81  }
82  }
83 
86  virtual void reset() { this->has_been_computed = false; this->has_grad_been_computed = false; }
87 
93  virtual Tensor<T>* grad(Operation<T> *consumer, Operation<T> *var, Tensor<T> *grad, bool recompute=true) {
94  if (!recompute) {
95  Tensor<T> *ret;
96  ret = this->_grad_cache[(uintptr_t)var];
97 
98  if (ret != NULL) {
99  return ret;
100  } else {
101  return _grad(consumer, var, grad);
102  }
103  } else {
104  return _grad(consumer, var, grad);
105  }
106  }
107 
111  virtual void add_consumer(Operation<T> *consumer) { this->consumers.push_back(consumer); }
112 
116  virtual std::vector<Operation<T> *> get_consumers() { return this->consumers; }
117 
121  virtual std::vector<Operation<T> *> get_inputs() { return this->inputs; }
122 
123 
127  virtual Tensor<T> *get_output_tensor() { return this->output_tensor; }
128 
133  virtual Tensor<T> *get_grad_tensor(Operation<T> *wrt) { return this->_grad_cache.find((uintptr_t)wrt)->second; }
134 
138  virtual std::string to_string() = 0;
139 
143  virtual std::string get_name() { return this->name; }
144 
145 protected:
149  virtual Tensor<T> *_eval(bool recompute=true) = 0;
150 
158  virtual Tensor<T> *_grad(Operation<T> *consumer, Operation<T> *var, Tensor<T> *grad) = 0;
159 
160  std::vector<Operation<T>*> inputs;
161  std::vector<Operation<T>*> consumers;
162  std::vector<unsigned int> output_shape;
163  memory_t mem_type;
164  std::map<uintptr_t, Tensor<T> *> _grad_cache; /* this will cache the tensors for the gradient computation */
165  std::string name = "DefaultOpName";
166 
167  Tensor<T> *output_tensor; /* the return tensor */
168 
169  bool needs_grad;
170  bool has_been_computed;
171  bool has_grad_been_computed;
172 };
173 
174 } // namespace op
175 } // namespace magmadnn
virtual Tensor< T > * get_output_tensor()
Definition: operation.h:127
virtual unsigned int get_output_size() const
Definition: operation.h:60
virtual memory_t get_memory_type() const
Definition: operation.h:69
virtual std::vector< unsigned int > get_output_shape() const
Definition: operation.h:46
Definition: addop.cpp:11
virtual void add_consumer(Operation< T > *consumer)
Definition: operation.h:111
Operation()
Definition: operation.h:23
virtual void reset()
Definition: operation.h:86
Definition: tensor.h:34
virtual Tensor< T > * get_grad_tensor(Operation< T > *wrt)
Definition: operation.h:133
virtual std::vector< Operation< T > * > get_consumers()
Definition: operation.h:116
virtual std::string get_name()
Definition: operation.h:143
virtual Tensor< T > * _eval(bool recompute=true)=0
virtual Tensor< T > * grad(Operation< T > *consumer, Operation< T > *var, Tensor< T > *grad, bool recompute=true)
Definition: operation.h:93
virtual std::vector< Operation< T > * > get_inputs()
Definition: operation.h:121
virtual std::string to_string()=0
virtual unsigned int get_output_shape(unsigned int idx) const
Definition: operation.h:52
Definition: operation.h:18
Variable< T > * var(std::string name, Tensor< T > *val)
Definition: variable.cpp:73
virtual Tensor< T > * eval(bool recompute=true)
Definition: operation.h:75
virtual Tensor< T > * _grad(Operation< T > *consumer, Operation< T > *var, Tensor< T > *grad)=0