|
| LinearForwardOp (Operation< T > *input, Operation< T > *weights, bool copy=true, bool needs_grad=true) |
|
| LinearForwardOp (Operation< T > *input, Operation< T > *weights, Operation< T > *bias, bool copy=true, bool needs_grad=true) |
|
std::string | to_string () |
|
| Operation () |
|
| Operation (std::vector< Operation< T > *> inputs, bool needs_grad=true) |
|
virtual std::vector< unsigned int > | get_output_shape () const |
|
virtual unsigned int | get_output_shape (unsigned int idx) const |
|
virtual unsigned int | get_output_size () const |
|
virtual memory_t | get_memory_type () const |
|
virtual Tensor< T > * | eval (bool recompute=true) |
|
virtual void | reset () |
|
virtual Tensor< T > * | grad (Operation< T > *consumer, Operation< T > *var, Tensor< T > *grad, bool recompute=true) |
|
virtual void | add_consumer (Operation< T > *consumer) |
|
virtual std::vector< Operation< T > * > | get_consumers () |
|
virtual std::vector< Operation< T > * > | get_inputs () |
|
virtual Tensor< T > * | get_output_tensor () |
|
virtual Tensor< T > * | get_grad_tensor (Operation< T > *wrt) |
|
virtual std::string | get_name () |
|
|
Operation< T > * | input |
|
Operation< T > * | weights |
|
Operation< T > * | bias |
|
Tensor< T > * | input_tensor |
|
Tensor< T > * | weights_tensor |
|
Tensor< T > * | bias_tensor |
|
Tensor< T > * | bias_ones |
|
bool | copy |
|
bool | use_bias |
|
std::vector< Operation< T > * > | inputs |
|
std::vector< Operation< T > * > | consumers |
|
std::vector< unsigned int > | output_shape |
|
memory_t | mem_type |
|
std::map< uintptr_t, Tensor< T > * > | _grad_cache |
|
std::string | name = "DefaultOpName" |
|
Tensor< T > * | output_tensor |
|
bool | needs_grad |
|
bool | has_been_computed |
|
bool | has_grad_been_computed |
|
◆ _eval()
◆ _grad()
Computes the gradient of this operation wrt the output of consumer.
- Parameters
-
consumer | |
var | |
grad | |
recompute | |
- Returns
- Tensor<T>*
Implements magmadnn::op::Operation< T >.
◆ to_string()
The documentation for this class was generated from the following files:
- include/compute/linearforward/linearforwardop.h
- src/compute/linearforward/linearforwardop.cpp