MagmaDNN  1.0
c++NeuralNetworkFramework
variable.h
Go to the documentation of this file.
1 
9 #pragma once
10 #include <string>
11 #include <tensor/tensor.h>
12 #include "operation.h"
13 
14 namespace magmadnn {
15 namespace op {
16 
20 template <typename T>
21 class Variable : public Operation<T> {
22 public:
23  Variable (std::string name, std::vector<unsigned int> shape, tensor_filler_t<T> filler, memory_t mem_type);
24  Variable (std::string name, Tensor<T> *val);
25  ~Variable();
26 
27  std::string to_string() { return name; }
28  std::string get_name() { return name; }
29 
30 protected:
31  Tensor<T> *_eval(bool recompute=true);
33 
34  std::string name;
35  Tensor<T> *val;
36  bool delete_tensor;
37 
38 };
39 
46 template <typename T>
47 Variable<T>* var(std::string name, Tensor<T>* val);
48 
57 template <typename T>
58 Variable<T>* var(std::string name, std::vector<unsigned int> shape, tensor_filler_t<T> filler, memory_t mem_type);
59 
67 template <typename T>
68 Variable<T> *scalar(std::string name, T val, memory_t mem_type);
69 
70 } // namespace op
71 } // namespace magmadnn
std::string get_name()
Definition: variable.h:28
Definition: types.h:64
Definition: addop.cpp:11
Definition: variable.h:21
Tensor< T > * _eval(bool recompute=true)
Definition: variable.cpp:44
Definition: tensor.h:34
Tensor< T > * _grad(Operation< T > *consumer, Operation< T > *var, Tensor< T > *grad)
Definition: variable.cpp:49
virtual Tensor< T > * grad(Operation< T > *consumer, Operation< T > *var, Tensor< T > *grad, bool recompute=true)
Definition: operation.h:93
Definition: operation.h:18
Variable< T > * var(std::string name, Tensor< T > *val)
Definition: variable.cpp:73
Variable< T > * scalar(std::string name, T val, memory_t mem_type)
Definition: variable.cpp:90
std::string to_string()
Definition: variable.h:27