1 #ifndef CAFFE_UTIL_MKL_ALTERNATE_H_
2 #define CAFFE_UTIL_MKL_ALTERNATE_H_
8 #else // If use MKL, simply include the MKL header
11 #include <Accelerate/Accelerate.h>
16 #endif // USE_ACCELERATE
24 #define DEFINE_VSL_UNARY_FUNC(name, operation) \
25 template<typename Dtype> \
26 void v##name(const int n, const Dtype* a, Dtype* y) { \
27 CHECK_GT(n, 0); CHECK(a); CHECK(y); \
28 for (int i = 0; i < n; ++i) { operation; } \
30 inline void vs##name( \
31 const int n, const float* a, float* y) { \
32 v##name<float>(n, a, y); \
34 inline void vd##name( \
35 const int n, const double* a, double* y) { \
36 v##name<double>(n, a, y); \
39 DEFINE_VSL_UNARY_FUNC(Sqr, y[i] = a[i] * a[i])
40 DEFINE_VSL_UNARY_FUNC(Sqrt, y[i] = sqrt(a[i]))
41 DEFINE_VSL_UNARY_FUNC(Exp, y[i] = exp(a[i]))
42 DEFINE_VSL_UNARY_FUNC(Ln, y[i] = log(a[i]))
43 DEFINE_VSL_UNARY_FUNC(Abs, y[i] = fabs(a[i]))
47 #define DEFINE_VSL_UNARY_FUNC_WITH_PARAM(name, operation) \
48 template<typename Dtype> \
49 void v##name(const int n, const Dtype* a, const Dtype b, Dtype* y) { \
50 CHECK_GT(n, 0); CHECK(a); CHECK(y); \
51 for (int i = 0; i < n; ++i) { operation; } \
53 inline void vs##name( \
54 const int n, const float* a, const float b, float* y) { \
55 v##name<float>(n, a, b, y); \
57 inline void vd##name( \
58 const int n, const double* a, const float b, double* y) { \
59 v##name<double>(n, a, b, y); \
62 DEFINE_VSL_UNARY_FUNC_WITH_PARAM(Powx, y[i] = pow(a[i], b))
66 #define DEFINE_VSL_BINARY_FUNC(name, operation) \
67 template<typename Dtype> \
68 void v##name(const int n, const Dtype* a, const Dtype* b, Dtype* y) { \
69 CHECK_GT(n, 0); CHECK(a); CHECK(b); CHECK(y); \
70 for (int i = 0; i < n; ++i) { operation; } \
72 inline void vs##name( \
73 const int n, const float* a, const float* b, float* y) { \
74 v##name<float>(n, a, b, y); \
76 inline void vd##name( \
77 const int n, const double* a, const double* b, double* y) { \
78 v##name<double>(n, a, b, y); \
81 DEFINE_VSL_BINARY_FUNC(Add, y[i] = a[i] + b[i])
82 DEFINE_VSL_BINARY_FUNC(Sub, y[i] = a[i] - b[i])
83 DEFINE_VSL_BINARY_FUNC(Mul, y[i] = a[i] * b[i])
84 DEFINE_VSL_BINARY_FUNC(Div, y[i] = a[i] / b[i])
89 inline void cblas_saxpby(
const int N,
const float alpha,
const float* X,
90 const int incX,
const float beta,
float* Y,
92 cblas_sscal(N, beta, Y, incY);
93 cblas_saxpy(N, alpha, X, incX, Y, incY);
95 inline void cblas_daxpby(
const int N,
const double alpha,
const double* X,
96 const int incX,
const double beta,
double* Y,
98 cblas_dscal(N, beta, Y, incY);
99 cblas_daxpy(N, alpha, X, incX, Y, incY);
103 #endif // CAFFE_UTIL_MKL_ALTERNATE_H_