Go to the source code of this file.
|
template<typename T > |
magmadnn_error_t | magmadnn::op::get_grad_table (const std::vector< Operation< T > *> &vars, Operation< T > *graph, GradTable< T > &table) |
|
template<typename T > |
magmadnn_error_t | magmadnn::internal::build_grad (op::Operation< T > *var, op::Operation< T > *graph, op::GradTable< T > &table, Tensor< T > **grad) |
|
- Author
- Daniel Nichols
- Version
- 0.1
- Date
- 2019-05-17
- Copyright
- Copyright (c) 2019
◆ build_grad()
Sets the gradients for var.
- Template Parameters
-
- Parameters
-
var | Variable to compute gradients for |
graph | Compute graph that contains var |
table | GradTable to put gradients in |
- Returns
- magmadnn_error_t non-zero on error
◆ get_grad_table()
template<typename T >
magmadnn_error_t magmadnn::op::get_grad_table |
( |
const std::vector< Operation< T > * > & |
vars, |
|
|
Operation< T > * |
graph, |
|
|
GradTable< T > & |
table |
|
) |
| |
Given a list of vars and compute graph, fills in a GradTable.
- Template Parameters
-
- Parameters
-
vars | A list of variables whose gradients will be computed |
graph | Head node of compute graph that contains 'vars' |
table | GradTable to be filled in |
- Returns
- magmadnn_error_t non-zero on error