MagmaDNN  1.0
c++NeuralNetworkFramework
productop.h
Go to the documentation of this file.
1 
9 #pragma once
10 
11 #include <vector>
12 #include "compute/operation.h"
13 #include "compute/variable.h"
14 #include "tensor/tensor.h"
15 #include "utilities_internal.h"
18 
19 namespace magmadnn {
20 
21 namespace internal {
22 enum product_op_t {
23  SCALAR_PROD_TENSOR,
24  TENSOR_PROD_SCALAR,
25  TENSOR_PROD_TENSOR
26 };
27 } // namespace internal
28 
29 namespace op {
30 
31 template <typename T>
32 class ProductOp : public Operation<T> {
33 public:
34  ProductOp(T alpha, Operation<T>* a, Operation<T>* b, bool copy=true, bool needs_grad=true);
35 
36 
37  std::string to_string() { return "(" + a->to_string() + " * " + b->to_string() + ")"; }
38 protected:
39  Tensor<T> *_eval(bool recompute=true);
40  Tensor<T> *_grad(Operation<T> *consumer, Operation<T> *var, Tensor<T> *grad);
41 
42  T alpha;
43  Operation<T> *a;
44  Operation<T> *b;
45 
46  Tensor<T> *a_tensor;
47  Tensor<T> *b_tensor;
48 
49  internal::product_op_t op_type;
50 
51  bool copy;
52 };
53 
54 template <typename T>
55 ProductOp<T>* product(Operation<T> *a, Operation<T> *b, bool copy=true, bool needs_grad=true);
56 
57 template <typename T>
58 ProductOp<T>* product(T alpha, Operation<T> *a, Operation<T> *b, bool copy=true, bool needs_grad=true);
59 
60 } // namespace op
61 } // namespace magmadnn
Definition: addop.cpp:11
std::string to_string()
Definition: productop.h:37
Definition: productop.h:32
Definition: tensor.h:34
virtual std::string to_string()=0
Definition: operation.h:18
Variable< T > * var(std::string name, Tensor< T > *val)
Definition: variable.cpp:73