18 #if defined(_HAS_CUDA_) 22 #include <cuda_runtime_api.h> 24 #define cudaErrchk(ans) { cudaAssert((ans), __FILE__, __LINE__); } 25 inline void cudaAssert(cudaError_t code,
const char *file,
int line,
bool abort=
true) {
26 if (code != cudaSuccess) {
27 fprintf(stderr,
"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
28 if (abort) exit(code);
32 #define cudnnErrchk(ans) { cudnnAssert((ans), __FILE__, __LINE__); } 33 inline void cudnnAssert(cudnnStatus_t code,
const char *file,
int line,
bool abort=
true) {
34 if (code != CUDNN_STATUS_SUCCESS) {
35 fprintf(stderr,
"CuDNNassert: %s %s %d\n", cudnnGetErrorString(code), file, line);
36 if (abort) exit(code);
40 #define curandErrchk(ans) { curandAssert((ans), __FILE__, __LINE__); } 41 inline void curandAssert(curandStatus_t code,
const char *file,
int line,
bool abort=
true) {
42 if (code != CURAND_STATUS_SUCCESS) {
43 fprintf(stderr,
"CuRandAssert: %d %s %d\n", code, file, line);
51 #define T_IS_SCALAR(tensor_ptr) ((tensor_ptr)->get_size() == 1) 52 #define T_IS_VECTOR(tensor_ptr) ((tensor_ptr)->get_size() != 1 && ((tensor_ptr)->get_shape().size() == 1)) 53 #define T_IS_MATRIX(tensor_ptr) ((tensor_ptr)->get_shape().size() == 2) 54 #define T_IS_N_DIMENSIONAL(tensor_ptr, N) ((tensor_ptr)->get_shape().size() == N) 55 #define OP_IS_SCALAR(op_ptr) ((op_ptr)->get_output_size() == 1) 56 #define OP_IS_VECTOR(op_ptr) (((op_ptr)->get_output_size() != 1) && ((op_ptr)->get_output_shape().size() == 1)) 57 #define OP_IS_MATRIX(op_ptr) ((op_ptr)->get_output_shape().size() == 2) 58 #define OP_IS_N_DIMENSIONAL(op_ptr, N) ((op_ptr)->get_output_shape().size() == N) 60 #define T_IS_SAME_MEMORY_TYPE(x_ptr,y_ptr) ((x_ptr)->get_memory_type() == (y_ptr)->get_memory_type()) 61 #define OP_IS_SAME_MEMORY_TYPE(x_ptr, y_ptr) ((x_ptr)->get_memory_type() == (y_ptr)->get_memory_type()) 71 int debugf(
const char *fmt, ...);
74 void print_vector(
const std::vector<unsigned int>& vec,
bool debug=
true,
char begin=
'{',
char end=
'}',
char delim=
',');
84 #if defined(_HAS_CUDA_) 86 cudnnDataType_t get_cudnn_data_type(T val);
int debugf(const char *fmt,...)
Definition: utilities_internal.cpp:14