MagmaDNN
1.0
c++NeuralNetworkFramework
productop.h
Go to the documentation of this file.
1
9
#pragma once
10
11
#include <vector>
12
#include "
compute/operation.h
"
13
#include "
compute/variable.h
"
14
#include "
tensor/tensor.h
"
15
#include "
utilities_internal.h
"
16
#include "
compute/product/product_internal.h
"
17
#include "
math/scalar_tensor_product.h
"
18
19
namespace
magmadnn
{
20
21
namespace
internal {
22
enum
product_op_t {
23
SCALAR_PROD_TENSOR,
24
TENSOR_PROD_SCALAR,
25
TENSOR_PROD_TENSOR
26
};
27
}
// namespace internal
28
29
namespace
op {
30
31
template
<
typename
T>
32
class
ProductOp
:
public
Operation
<T> {
33
public
:
34
ProductOp
(T alpha,
Operation<T>
* a,
Operation<T>
* b,
bool
copy=
true
,
bool
needs_grad=
true
);
35
36
37
std::string
to_string
() {
return
"("
+ a->
to_string
() +
" * "
+ b->
to_string
() +
")"
; }
38
protected
:
39
Tensor<T>
*_eval(
bool
recompute=
true
);
40
Tensor<T>
*_grad(
Operation<T>
*consumer,
Operation<T>
*
var
,
Tensor<T>
*grad);
41
42
T alpha;
43
Operation<T>
*a;
44
Operation<T>
*b;
45
46
Tensor<T>
*a_tensor;
47
Tensor<T>
*b_tensor;
48
49
internal::product_op_t op_type;
50
51
bool
copy;
52
};
53
54
template
<
typename
T>
55
ProductOp<T>
* product(
Operation<T>
*a,
Operation<T>
*b,
bool
copy=
true
,
bool
needs_grad=
true
);
56
57
template
<
typename
T>
58
ProductOp<T>
* product(T alpha,
Operation<T>
*a,
Operation<T>
*b,
bool
copy=
true
,
bool
needs_grad=
true
);
59
60
}
// namespace op
61
}
// namespace magmadnn
product_internal.h
scalar_tensor_product.h
magmadnn
Definition:
addop.cpp:11
magmadnn::op::ProductOp::to_string
std::string to_string()
Definition:
productop.h:37
magmadnn::op::ProductOp
Definition:
productop.h:32
tensor.h
magmadnn::Tensor
Definition:
tensor.h:34
utilities_internal.h
variable.h
operation.h
magmadnn::op::Operation::to_string
virtual std::string to_string()=0
magmadnn::op::Operation
Definition:
operation.h:18
magmadnn::op::var
Variable< T > * var(std::string name, Tensor< T > *val)
Definition:
variable.cpp:73
include
compute
product
productop.h
Generated by
1.8.13