diff --git a/src/libtorchaudio/cuctc/include/ctc_prefix_decoder.h b/src/libtorchaudio/cuctc/include/ctc_prefix_decoder.h index 084a5eeae5..6b5d4e6786 100644 --- a/src/libtorchaudio/cuctc/include/ctc_prefix_decoder.h +++ b/src/libtorchaudio/cuctc/include/ctc_prefix_decoder.h @@ -26,7 +26,6 @@ #ifndef __ctc_prefix_decoder_h_ #define __ctc_prefix_decoder_h_ -#include #include #include #include diff --git a/src/libtorchaudio/cuctc/include/ctc_prefix_decoder_host.h b/src/libtorchaudio/cuctc/include/ctc_prefix_decoder_host.h index a67ca7d01c..2d6574e36b 100644 --- a/src/libtorchaudio/cuctc/include/ctc_prefix_decoder_host.h +++ b/src/libtorchaudio/cuctc/include/ctc_prefix_decoder_host.h @@ -26,24 +26,6 @@ #ifndef __ctc_prefix_decoder_host_h_ #define __ctc_prefix_decoder_host_h_ -#include - -#define CUDA_CHECK(X) \ - do { \ - auto result = X; \ - if (result != cudaSuccess) { \ - const char* p_err_str = cudaGetErrorName(result); \ - fprintf( \ - stderr, \ - "File %s Line %d %s returned %s.\n", \ - __FILE__, \ - __LINE__, \ - #X, \ - p_err_str); \ - abort(); \ - } \ - } while (0) - #define CHECK(X, ERROR_INFO) \ do { \ auto result = (X); \ diff --git a/src/libtorchaudio/cuctc/src/bitonic_topk/bitonic_sort.cuh b/src/libtorchaudio/cuctc/src/bitonic_topk/bitonic_sort.cuh index 076804bb0b..f68c56f3cf 100644 --- a/src/libtorchaudio/cuctc/src/bitonic_topk/bitonic_sort.cuh +++ b/src/libtorchaudio/cuctc/src/bitonic_topk/bitonic_sort.cuh @@ -16,9 +16,13 @@ constexpr inline __host__ __device__ bool isPo2(IntType num) { } inline __device__ int laneId() { +#ifndef USE_ROCM int id; asm("mov.s32 %0, %%laneid;" : "=r"(id)); return id; +#else + return __lane_id(); +#endif } /** * @brief Shuffle the data inside a warp diff --git a/src/libtorchaudio/cuctc/src/bitonic_topk/pow2_utils.cuh b/src/libtorchaudio/cuctc/src/bitonic_topk/pow2_utils.cuh index d0c2754732..5f3f5a690b 100644 --- a/src/libtorchaudio/cuctc/src/bitonic_topk/pow2_utils.cuh +++ b/src/libtorchaudio/cuctc/src/bitonic_topk/pow2_utils.cuh @@ -12,7 +12,8 @@ namespace cu_ctc { * @tparam IntType data type (checked only for integers) */ template -constexpr __device__ IntType log2(IntType num, IntType ret = IntType(0)) { +constexpr __host__ __device__ IntType +log2(IntType num, IntType ret = IntType(0)) { return num <= IntType(1) ? ret : log2(num >> IntType(1), ++ret); } diff --git a/src/libtorchaudio/cuctc/src/bitonic_topk/warpsort_topk.cuh b/src/libtorchaudio/cuctc/src/bitonic_topk/warpsort_topk.cuh index f9bc880a41..d206fe9b0a 100644 --- a/src/libtorchaudio/cuctc/src/bitonic_topk/warpsort_topk.cuh +++ b/src/libtorchaudio/cuctc/src/bitonic_topk/warpsort_topk.cuh @@ -313,7 +313,7 @@ class warp_sort_filtered : public warp_sort { __device__ __forceinline__ void merge_buf_() { topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); - this->merge_in(val_buf_, idx_buf_); + this->template merge_in(val_buf_, idx_buf_); buf_len_ = 0; set_k_th_(); // contains warp sync #pragma unroll @@ -385,7 +385,7 @@ class warp_sort_immediate : public warp_sort { if (buf_len_ == kMaxArrLen) { topk::bitonic(!Ascending, kWarpWidth) .sort(val_buf_, idx_buf_); - this->merge_in(val_buf_, idx_buf_); + this->template merge_in(val_buf_, idx_buf_); #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { val_buf_[i] = kDummy; @@ -398,7 +398,7 @@ class warp_sort_immediate : public warp_sort { if (buf_len_ != 0) { topk::bitonic(!Ascending, kWarpWidth) .sort(val_buf_, idx_buf_); - this->merge_in(val_buf_, idx_buf_); + this->template merge_in(val_buf_, idx_buf_); } } @@ -421,7 +421,7 @@ constexpr inline __host__ __device__ IntType ceildiv(IntType a, IntType b) { return (a + b - 1) / b; } template -constexpr inline __device__ IntType roundUp256(IntType num) { +constexpr inline __host__ __device__ IntType roundUp256(IntType num) { // return (num + 255) / 256 * 256; constexpr int MASK = 255; return (num + MASK) & (~MASK); diff --git a/src/libtorchaudio/cuctc/src/ctc_prefix_decoder.cpp b/src/libtorchaudio/cuctc/src/ctc_prefix_decoder.cpp index 25057f9258..70fc801d0f 100644 --- a/src/libtorchaudio/cuctc/src/ctc_prefix_decoder.cpp +++ b/src/libtorchaudio/cuctc/src/ctc_prefix_decoder.cpp @@ -25,8 +25,8 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include -#include "include/ctc_prefix_decoder.h" -#include "include/ctc_prefix_decoder_host.h" +#include "../include/ctc_prefix_decoder.h" +#include "../include/ctc_prefix_decoder_host.h" #include "device_data_wrap.h" #include "device_log_prob.cuh" diff --git a/src/libtorchaudio/cuctc/src/ctc_prefix_decoder_kernel_v2.cu b/src/libtorchaudio/cuctc/src/ctc_prefix_decoder_kernel_v2.cu index 4ca8f1bf24..97a2742691 100644 --- a/src/libtorchaudio/cuctc/src/ctc_prefix_decoder_kernel_v2.cu +++ b/src/libtorchaudio/cuctc/src/ctc_prefix_decoder_kernel_v2.cu @@ -23,12 +23,13 @@ // OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#include #include +#include "../include/ctc_prefix_decoder_host.h" #include "ctc_fast_divmod.cuh" #include "cub/cub.cuh" #include "device_data_wrap.h" #include "device_log_prob.cuh" -#include "include/ctc_prefix_decoder_host.h" #include "bitonic_topk/warpsort_topk.cuh" @@ -630,7 +631,8 @@ int CTC_prob_first_step_V2( num_of_subwarp, beam)); int smem_size = block_sort_smem_size + beam * sizeof(float) + beam * sizeof(int); - FirstMatrixFuns[fun_idx]<<>>( + auto kernel = FirstMatrixFuns[fun_idx]; + kernel<<>>( (*log_prob_struct), step, pprev, @@ -766,7 +768,8 @@ int CTC_prob_topK_V2( int num_of_subwarp = threads_per_block0 / std::min(32, actual_capacity); int smem_size = cu_ctc::topk::calc_smem_size_for_block_wide( num_of_subwarp, beam); - BitonicTopkFuns[fun_idx]<<>>( + auto kernel = BitonicTopkFuns[fun_idx]; + kernel<<>>( (*log_prob_struct), step, ptable, diff --git a/src/libtorchaudio/cuctc/src/device_data_wrap.h b/src/libtorchaudio/cuctc/src/device_data_wrap.h index 2fee836943..0b2ca6f1e0 100644 --- a/src/libtorchaudio/cuctc/src/device_data_wrap.h +++ b/src/libtorchaudio/cuctc/src/device_data_wrap.h @@ -24,9 +24,26 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #pragma once +#include #include #include -#include "include/ctc_prefix_decoder_host.h" +#include "../include/ctc_prefix_decoder_host.h" + +#define CUDA_CHECK(X) \ + do { \ + auto result = X; \ + if (result != cudaSuccess) { \ + const char* p_err_str = cudaGetErrorName(result); \ + fprintf( \ + stderr, \ + "File %s Line %d %s returned %s.\n", \ + __FILE__, \ + __LINE__, \ + #X, \ + p_err_str); \ + abort(); \ + } \ + } while (0) namespace cu_ctc { constexpr size_t ALIGN_BYTES = 128;