MagmaDNN  1.0
c++NeuralNetworkFramework
linearforwardop.h
1 
2 #pragma once
3 
4 #include "compute/operation.h"
5 #include "tensor/tensor.h"
6 #include "math/matmul.h"
7 #include "math/bias_add.h"
8 #include "math/reduce_sum.h"
9 
10 namespace magmadnn {
11 namespace op {
12 
13 template <typename T>
14 class LinearForwardOp : public Operation<T> {
15 public:
16  LinearForwardOp(Operation<T> *input, Operation<T> *weights, bool copy=true, bool needs_grad=true);
17  LinearForwardOp(Operation<T> *input, Operation<T> *weights, Operation<T> *bias, bool copy=true, bool needs_grad=true);
18  virtual ~LinearForwardOp();
19 
20  std::string to_string() { return "LinearForward(" + input->to_string() + ", " + weights->to_string() + ")"; }
21 protected:
22  Tensor<T> *_eval(bool recompute);
24 
25  void init_bias_settings(); /* init ones and bias_reduce_settings */
26 
27  Operation<T> *input, *weights, *bias;
28  Tensor<T> *input_tensor, *weights_tensor, *bias_tensor, *bias_ones;
29 
30  bool copy;
31  bool use_bias;
32 
33  #if defined(_HAS_CUDA_)
34  math::reduce_sum_cudnn_settings_t bias_reduce_settings;
35  #endif
36 
37 };
38 
47 template <typename T>
48 LinearForwardOp<T>* linearforward(Operation<T> *input, Operation<T> *weights, bool copy=true, bool needs_grad=true);
49 
59 template <typename T>
60 LinearForwardOp<T>* linearforward(Operation<T> *input, Operation<T> *weights, Operation<T> *bias, bool copy=true, bool needs_grad=true);
61 
62 } // namespace op
63 } // namespace magmadnn
Definition: linearforwardop.h:14
Definition: addop.cpp:11
Definition: tensor.h:34
virtual Tensor< T > * grad(Operation< T > *consumer, Operation< T > *var, Tensor< T > *grad, bool recompute=true)
Definition: operation.h:93
Tensor< T > * _grad(Operation< T > *consumer, Operation< T > *var, Tensor< T > *grad)
Definition: linearforwardop.cpp:58
virtual std::string to_string()=0
Tensor< T > * _eval(bool recompute)
Definition: linearforwardop.cpp:42
Definition: operation.h:18
Variable< T > * var(std::string name, Tensor< T > *val)
Definition: variable.cpp:73
std::string to_string()
Definition: linearforwardop.h:20