diff --git a/benchmark/reference.cpp b/benchmark/reference.cpp index a74318a..13dd108 100644 --- a/benchmark/reference.cpp +++ b/benchmark/reference.cpp @@ -13,6 +13,35 @@ #include "defines.h" +// Check if 'floatType' is complex +template +class is_complex { + template < + template typename _floatType, typename T, + typename = typename std::enable_if< + std::is_same<_floatType, std::complex>::value, bool>::type> + bool static constexpr _is_complex(_floatType) { + return true; + } + template + bool static constexpr _is_complex(_floatType) { + return false; + } + + public: + static constexpr auto value = + _is_complex(typename std::remove_reference< + typename std::remove_cv::type>::type()); +}; + + +// Re-define 'conj' to make sure it return a floating point if +// argument is a floating point +template ::value, bool>::type> +floatType conj(floatType &&x) { + return x; +} template void transpose_ref( uint32_t *size, uint32_t *perm, int dim, @@ -57,13 +86,13 @@ void transpose_ref( uint32_t *size, uint32_t *perm, int dim, if( beta == (floatType) 0 ) for(int i=0; i < sizeInner; ++i) if( conjA ) - B_[i] = alpha * std::conj(A_[i * strideAinner]); + B_[i] = alpha * conj(A_[i * strideAinner]); else B_[i] = alpha * A_[i * strideAinner]; else for(int i=0; i < sizeInner; ++i) if( conjA ) - B_[i] = alpha * std::conj(A_[i * strideAinner]) + beta * B_[i]; + B_[i] = alpha * conj(A_[i * strideAinner]) + beta * B_[i]; else B_[i] = alpha * A_[i * strideAinner] + beta * B_[i]; } diff --git a/pythonAPI/hptt/hptt.py b/pythonAPI/hptt/hptt.py index f6f4171..a0df69c 100644 --- a/pythonAPI/hptt/hptt.py +++ b/pythonAPI/hptt/hptt.py @@ -115,10 +115,16 @@ def tensorTransposeAndUpdate(perm, alpha, A, beta, B, numThreads=-1): raise ValueError("Unsupported dtype: {}.".format(A.dtype)) # tranpose! - tranpose_fn(permc, ctypes.c_int32(A.ndim), - scalar_fn(alpha), dataA, sizeA, outerSizeA, - scalar_fn(beta), dataB, outerSizeB, - ctypes.c_int32(numThreads), ctypes.c_int32(useRowMajor)) + if 'float' in str(A.dtype): + tranpose_fn(permc, ctypes.c_int32(A.ndim), + scalar_fn(alpha), dataA, sizeA, outerSizeA, + scalar_fn(beta), dataB, outerSizeB, + ctypes.c_int32(numThreads), ctypes.c_int32(useRowMajor)) + else: + tranpose_fn(permc, ctypes.c_int32(A.ndim), + scalar_fn(alpha), False, dataA, sizeA, outerSizeA, + scalar_fn(beta), dataB, outerSizeB, + ctypes.c_int32(numThreads), ctypes.c_int32(useRowMajor)) def tensorTranspose(perm, alpha, A, numThreads=-1): @@ -168,7 +174,7 @@ def transpose(a, axes=None): ``a`` with its axes permuted. """ if axes is None: - axes = reversed(range(a.ndim)) + axes = list(reversed(range(a.ndim))) return tensorTranspose(axes, 1.0, a) diff --git a/src/hptt.cpp b/src/hptt.cpp index ea761c8..41c4551 100644 --- a/src/hptt.cpp +++ b/src/hptt.cpp @@ -171,8 +171,8 @@ void dTensorTranspose( const int *perm, const int dim, } void cTensorTranspose( const int *perm, const int dim, - const float _Complex alpha, bool conjA, const float _Complex *A, const int *sizeA, const int *outerSizeA, - const float _Complex beta, float _Complex *B, const int *outerSizeB, + const float alpha, bool conjA, const float *A, const int *sizeA, const int *outerSizeA, + const float beta, float *B, const int *outerSizeB, const int numThreads, const int useRowMajor) { auto plan(std::make_shared >(sizeA, perm, outerSizeA, outerSizeB, dim, @@ -182,8 +182,8 @@ void cTensorTranspose( const int *perm, const int dim, } void zTensorTranspose( const int *perm, const int dim, - const double _Complex alpha, bool conjA, const double _Complex *A, const int *sizeA, const int *outerSizeA, - const double _Complex beta, double _Complex *B, const int *outerSizeB, + const double alpha, bool conjA, const double *A, const int *sizeA, const int *outerSizeA, + const double beta, double *B, const int *outerSizeB, const int numThreads, const int useRowMajor) { auto plan(std::make_shared >(sizeA, perm, outerSizeA, outerSizeB, dim,