MagmaDNN  1.0
c++NeuralNetworkFramework
flattenop.h
1 
2 #pragma once
3 
4 #include "compute/operation.h"
5 #include "tensor/tensor.h"
6 
7 namespace magmadnn {
8 namespace op {
9 
10 template <typename T>
11 class FlattenOp : public Operation<T> {
12 public:
13  FlattenOp(Operation<T> *input, bool copy=true, bool needs_grad=true);
14 
15  std::string to_string() { return "Flatten(" + input->to_string() + ")"; }
16 
17 protected:
18  Tensor<T> *_eval(bool recompute);
20 
21  Operation<T> *input;
22  Tensor<T> *input_tensor;
23 
24  bool copy;
25 
26 };
27 
28 template <typename T>
29 FlattenOp<T>* flatten(Operation<T> *input, bool copy=true, bool needs_grad=true);
30 
31 } // namespace op
32 } // namespace magmadnn
std::string to_string()
Definition: flattenop.h:15
Definition: flattenop.h:11
Definition: addop.cpp:11
Definition: tensor.h:34
Tensor< T > * _eval(bool recompute)
Definition: flattenop.cpp:24
virtual Tensor< T > * grad(Operation< T > *consumer, Operation< T > *var, Tensor< T > *grad, bool recompute=true)
Definition: operation.h:93
Tensor< T > * _grad(Operation< T > *consumer, Operation< T > *var, Tensor< T > *grad)
Definition: flattenop.cpp:35
virtual std::string to_string()=0
Definition: operation.h:18
Variable< T > * var(std::string name, Tensor< T > *val)
Definition: variable.cpp:73