#include "compute/matmul/gemm_internal.h"
Functions | |
template<typename T > | |
bool | magmadnn::internal::gemm_check (Tensor< T > *A, Tensor< T > *B, Tensor< T > *C, unsigned int &M, unsigned int &N, unsigned int &K) |
template<> | |
void | magmadnn::internal::gemm_full (int alpha, Tensor< int > *A, Tensor< int > *B, int beta, Tensor< int > *C) |
template<> | |
void | magmadnn::internal::gemm_full (float alpha, Tensor< float > *A, Tensor< float > *B, float beta, Tensor< float > *C) |
template<> | |
void | magmadnn::internal::gemm_full (double alpha, Tensor< double > *A, Tensor< double > *B, double beta, Tensor< double > *C) |
bool magmadnn::internal::gemm_check | ( | Tensor< T > * | A, |
Tensor< T > * | B, | ||
Tensor< T > * | C, | ||
unsigned int & | M, | ||
unsigned int & | N, | ||
unsigned int & | K | ||
) |
Returns true if A, B, C are valid parameters for gemm_full. It also sets M, N, K to A.get_shape(0), B.get_shape(1), and A.get_shape(1), respectively.
T |
A | |
B | |
C | |
M | |
N | |
K |