MagmaDNN  1.0
c++NeuralNetworkFramework
gradients.h File Reference
#include <vector>
#include "compute/operation.h"
#include "compute/variable.h"
#include "compute/gradtable.h"
#include "compute/add/addop.h"
#include "compute/sum/sumop.h"
#include "utilities_internal.h"
Include dependency graph for gradients.h:
This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Functions

template<typename T >
magmadnn_error_t magmadnn::op::get_grad_table (const std::vector< Operation< T > *> &vars, Operation< T > *graph, GradTable< T > &table)
 
template<typename T >
magmadnn_error_t magmadnn::internal::build_grad (op::Operation< T > *var, op::Operation< T > *graph, op::GradTable< T > &table, Tensor< T > **grad)
 

Detailed Description

Author
Daniel Nichols
Version
0.1
Date
2019-05-17

Function Documentation

◆ build_grad()

template<typename T >
magmadnn_error_t magmadnn::internal::build_grad ( op::Operation< T > *  var,
op::Operation< T > *  graph,
op::GradTable< T > &  table,
Tensor< T > **  grad 
)

Sets the gradients for var.

Template Parameters
Tnumeric
Parameters
varVariable to compute gradients for
graphCompute graph that contains var
tableGradTable to put gradients in
Returns
magmadnn_error_t non-zero on error

◆ get_grad_table()

template<typename T >
magmadnn_error_t magmadnn::op::get_grad_table ( const std::vector< Operation< T > * > &  vars,
Operation< T > *  graph,
GradTable< T > &  table 
)

Given a list of vars and compute graph, fills in a GradTable.

Template Parameters
Tnumeric
Parameters
varsA list of variables whose gradients will be computed
graphHead node of compute graph that contains 'vars'
tableGradTable to be filled in
Returns
magmadnn_error_t non-zero on error