MagmaDNN  1.0
c++NeuralNetworkFramework
conv2dforwardop.h
1 
2 #pragma once
3 
4 #include "compute/operation.h"
5 #include "tensor/tensor.h"
6 #include "compute/conv2dforward/conv2dforward_internal.h"
7 #include "math/conv2d.h"
8 
9 namespace magmadnn {
10 namespace op {
11 
12 template <typename T>
13 class Conv2DForwardOp : public Operation<T> {
14 public:
15  Conv2DForwardOp(Operation<T> *input, Operation<T> *filter, int pad_h=0, int pad_w=0, int vertical_stride=1, int horizontal_stride=1, int dilation_h=1, int dilation_w=1, bool use_cross_correlation=true, bool needs_grad=true);
16  ~Conv2DForwardOp();
17 
18 
19  std::string to_string() { return "Conv2DForward(" + input->to_string() + ")"; }
20 protected:
21  Tensor<T> *_eval(bool recompute);
23 
24 
25  void init_settings();
26  void calculate_and_set_output_shape();
27 
28  Operation<T> *input, *filter;
29  Tensor<T> *input_tensor, *filter_tensor;
30 
31  int pad_h, pad_w, vertical_stride, horizontal_stride, dilation_h, dilation_w;
32  bool use_cross_correlation;
33 
34  #if defined(_HAS_CUDA_)
35  ::magmadnn::math::conv2d_cudnn_settings cudnn_settings;
36  #endif
37 
38 };
39 
40 template <typename T>
41 Conv2DForwardOp<T>* conv2dforward(Operation<T> *input, Operation<T> *filter, int pad_h=0, int pad_w=0, int vertical_stride=1, int horizontal_stride=1, int dilation_h=1, int dilation_w=1, bool use_cross_correlation=true, bool needs_grad=true);
42 
43 } // namespace op
44 } // namespace magmadnn
std::string to_string()
Definition: conv2dforwardop.h:19
Definition: addop.cpp:11
Definition: tensor.h:34
Tensor< T > * _eval(bool recompute)
Definition: conv2dforwardop.cpp:41
virtual Tensor< T > * grad(Operation< T > *consumer, Operation< T > *var, Tensor< T > *grad, bool recompute=true)
Definition: operation.h:93
virtual std::string to_string()=0
Tensor< T > * _grad(Operation< T > *consumer, Operation< T > *var, Tensor< T > *grad)
Definition: conv2dforwardop.cpp:59
Definition: conv2dforwardop.h:13
Definition: operation.h:18
Variable< T > * var(std::string name, Tensor< T > *val)
Definition: variable.cpp:73