24 Operation(std::vector<
Operation<T> *> inputs,
bool needs_grad=
true) : inputs(inputs), needs_grad(needs_grad) {
25 for (
typename std::vector<
Operation<T> *>::iterator vit = inputs.begin(); vit != inputs.end(); vit++) {
27 (*vit)->add_consumer(
this);
29 this->_grad_cache.insert( std::make_pair((uintptr_t) (*vit), (
Tensor<T> *) NULL) );
33 for (
unsigned int i = 0; i < inputs.size(); i++)
46 virtual std::vector<unsigned int>
get_output_shape()
const {
return this->output_shape; }
53 assert( idx < this->output_shape.size() );
54 return this->output_shape[idx];
61 unsigned int size = 1;
62 for (
unsigned int i = 0; i < this->output_shape.size(); i++) size *= this->output_shape[i];
76 if (!recompute && this->has_been_computed && this->output_tensor != NULL) {
77 return this->output_tensor;
79 this->has_been_computed =
true;
80 return _eval(recompute);
86 virtual void reset() { this->has_been_computed =
false; this->has_grad_been_computed =
false; }
96 ret = this->_grad_cache[(uintptr_t)var];
101 return _grad(consumer, var, grad);
104 return _grad(consumer, var, grad);
116 virtual std::vector<Operation<T> *>
get_consumers() {
return this->consumers; }
121 virtual std::vector<Operation<T> *>
get_inputs() {
return this->inputs; }
143 virtual std::string
get_name() {
return this->name; }
160 std::vector<Operation<T>*> inputs;
161 std::vector<Operation<T>*> consumers;
162 std::vector<unsigned int> output_shape;
164 std::map<uintptr_t, Tensor<T> *> _grad_cache;
165 std::string name =
"DefaultOpName";
170 bool has_been_computed;
171 bool has_grad_been_computed;
virtual Tensor< T > * get_output_tensor()
Definition: operation.h:127
virtual unsigned int get_output_size() const
Definition: operation.h:60
virtual memory_t get_memory_type() const
Definition: operation.h:69
virtual std::vector< unsigned int > get_output_shape() const
Definition: operation.h:46
virtual void add_consumer(Operation< T > *consumer)
Definition: operation.h:111
Operation()
Definition: operation.h:23
virtual void reset()
Definition: operation.h:86
virtual Tensor< T > * get_grad_tensor(Operation< T > *wrt)
Definition: operation.h:133
virtual std::vector< Operation< T > * > get_consumers()
Definition: operation.h:116
virtual std::string get_name()
Definition: operation.h:143
virtual Tensor< T > * _eval(bool recompute=true)=0
virtual Tensor< T > * grad(Operation< T > *consumer, Operation< T > *var, Tensor< T > *grad, bool recompute=true)
Definition: operation.h:93
virtual std::vector< Operation< T > * > get_inputs()
Definition: operation.h:121
virtual std::string to_string()=0
virtual unsigned int get_output_shape(unsigned int idx) const
Definition: operation.h:52
Definition: operation.h:18
Variable< T > * var(std::string name, Tensor< T > *val)
Definition: variable.cpp:73
virtual Tensor< T > * eval(bool recompute=true)
Definition: operation.h:75
virtual Tensor< T > * _grad(Operation< T > *consumer, Operation< T > *var, Tensor< T > *grad)=0