MagmaDNN  1.0
c++NeuralNetworkFramework
dataloader.h
Go to the documentation of this file.
1 
10 #pragma once
11 #include "tensor/tensor.h"
12 
13 namespace magmadnn {
14 namespace dataloader {
15 
16 template <typename T>
17 class DataLoader {
18 public:
24  DataLoader(Tensor<T> *x, Tensor<T> *y, unsigned int batch_size): x(x), y(y), batch_size(batch_size) {
25  num_batches = unsigned(x->get_shape(0) / batch_size);
26  assert(num_batches > 0);
27  assert(num_batches == unsigned(y->get_shape(0) / batch_size));
28 
29  sample_size_x = x->get_size() / x->get_shape(0);
30  sample_size_y = y->get_size() / y->get_shape(0);
31  }
32 
37  virtual void next(Tensor<T> *x_batch, Tensor<T> *y_batch) = 0;
38 
41  virtual void reset() = 0;
42 
43  virtual unsigned int get_batch_size() const {return batch_size;}
44  virtual void set_batch_size(unsigned int size) {
45  batch_size = size;
46  num_batches = unsigned(x->get_shape(0) / batch_size);
47  assert(num_batches > 0);
48  assert(num_batches == unsigned(y->get_shape(0) / batch_size));
49  }
50  virtual unsigned int get_num_batches() const {return num_batches;}
51 
52 protected:
53  Tensor<T> *x;
54  Tensor<T> *y;
55  unsigned int batch_size;
56  unsigned int sample_size_x;
57  unsigned int sample_size_y;
58  unsigned int num_batches;
59 };
60 
61 } // namespace dataloader
62 } // namespace magmadnn
std::vector< unsigned int > get_shape() const
Definition: tensor.h:181
unsigned int get_size() const
Definition: tensor.h:192
Definition: addop.cpp:11
Definition: tensor.h:34
Definition: dataloader.h:17
virtual void next(Tensor< T > *x_batch, Tensor< T > *y_batch)=0
DataLoader(Tensor< T > *x, Tensor< T > *y, unsigned int batch_size)
Definition: dataloader.h:24