MagmaDNN  1.0
c++NeuralNetworkFramework
layer.h
Go to the documentation of this file.
1 
9 #pragma once
10 #include <vector>
11 #include <string>
12 #include "compute/operation.h"
13 
14 namespace magmadnn {
15 namespace layer {
16 
17 template <typename T>
18 class Layer {
19 public:
20 
21  virtual std::vector<op::Operation<T> *> get_weights() = 0;
22 
23  virtual op::Operation<T>* out() {
24  return output;
25  }
26 
30  op::Operation<T> *get_input() { return input; }
34  op::Operation<T> *get_output() { return output; }
35 
39  std::vector<unsigned int> get_input_shape() const { return input_shape; }
43  std::vector<unsigned int> get_output_shape() const { return this->output->get_output_shape(); }
44 
49  unsigned int get_input_shape(unsigned int i) const {
50  assert( i < input_shape.size() );
51  return input_shape[i];
52  }
53 
58  unsigned int get_output_shape(unsigned int i) const {
59  assert( i < output_shape.size() );
60  return output_shape[i];
61  }
62 
66  void set_name(std::string name) { this->name = name; }
67 
72  std::string get_name() const { return this->name; }
73 
74 protected:
75  Layer(std::vector<unsigned int> input_shape, op::Operation<T> *input) :
76  input_shape(input_shape), input(input) {}
77 
78  std::vector<unsigned int> input_shape;
79  std::vector<unsigned int> output_shape;
80 
81  op::Operation<T> *input;
82  op::Operation<T> *output;
83 
84  std::string name;
85 
86 };
87 
88 } // namespace layer
89 } // namespace magmadnn
unsigned int get_input_shape(unsigned int i) const
Definition: layer.h:49
std::vector< unsigned int > get_input_shape() const
Definition: layer.h:39
Definition: addop.cpp:11
void set_name(std::string name)
Definition: layer.h:66
std::string get_name() const
Definition: layer.h:72
unsigned int get_output_shape(unsigned int i) const
Definition: layer.h:58
std::vector< unsigned int > get_output_shape() const
Definition: layer.h:43
Definition: layer.h:18
op::Operation< T > * get_output()
Definition: layer.h:34
Definition: operation.h:18
op::Operation< T > * get_input()
Definition: layer.h:30