MagmaDNN  1.0
c++NeuralNetworkFramework
gradients.h
Go to the documentation of this file.
1 
10 #pragma once
11 
12 #include <vector>
13 #include "compute/operation.h"
14 #include "compute/variable.h"
15 #include "compute/gradtable.h"
16 #include "compute/variable.h"
17 #include "compute/add/addop.h"
18 #include "compute/sum/sumop.h"
19 #include "utilities_internal.h"
20 
21 namespace magmadnn {
22 namespace op {
23 
31 template <typename T>
32 magmadnn_error_t get_grad_table(const std::vector<Operation<T> *>& vars, Operation<T> *graph, GradTable<T> &table);
33 
34 } // namespace op
35 
36 // build_grad should only be used internally
37 namespace internal {
38 
46 template <typename T>
47 magmadnn_error_t build_grad(op::Operation<T>* var, op::Operation<T> *graph, op::GradTable<T> &table, Tensor<T> **grad);
48 
49 } // namespace internal
50 } // namespace magmadnn
Definition: addop.cpp:11
magmadnn_error_t get_grad_table(const std::vector< Operation< T > *> &vars, Operation< T > *graph, GradTable< T > &table)
Definition: gradients.cpp:16
magmadnn_error_t build_grad(op::Operation< T > *var, op::Operation< T > *graph, op::GradTable< T > &table, Tensor< T > **grad)
Definition: gradients.cpp:52