MagmaDNN  1.0
c++NeuralNetworkFramework
model.h
Go to the documentation of this file.
1 
9 #pragma once
10 
11 
12 #include <string>
13 #include "types.h"
14 #include "tensor/tensor.h"
15 #include "compute/operation.h"
16 #include "optimizer/optimizer.h"
17 
18 
19 namespace magmadnn {
20 namespace model {
21 
22 struct metric_t {
23  double accuracy;
24  double loss;
25  double training_time;
26 };
27 
28 template <typename T>
29 class Model {
30 public:
33  Model() {}
34 
42  virtual magmadnn_error_t fit(Tensor<T> *x, Tensor<T> *y, metric_t& metric_out, bool verbose=false) = 0;
43 
48  virtual Tensor<T> *predict(Tensor<T> *sample) = 0;
49 
54  virtual unsigned int predict_class(Tensor<T> *sample) = 0;
55 
59  virtual double get_accuracy() { return _last_training_metric.accuracy; }
60 
64  virtual double get_loss() { return _last_training_metric.loss; }
65 
69  virtual double get_training_time() { return _last_training_metric.training_time; }
70 
74  virtual std::string get_name() { return this->_name; }
75 
76 protected:
77  std::string _name;
78  metric_t _last_training_metric;
79 
80 };
81 
82 } // namespace model
83 } // namespace magmadnn
84 
virtual double get_training_time()
Definition: model.h:69
Definition: model.h:29
Model()
Definition: model.h:33
virtual std::string get_name()
Definition: model.h:74
virtual double get_accuracy()
Definition: model.h:59
Definition: addop.cpp:11
double training_time
Definition: model.h:25
Definition: tensor.h:34
double accuracy
Definition: model.h:23
Definition: model.h:22
double loss
Definition: model.h:24
virtual double get_loss()
Definition: model.h:64