Skip to content

Commit

Permalink
[AMD] Hipify torchaudio_decoder
Browse files Browse the repository at this point in the history
Differential Revision: D64298970

Pull Request resolved: pytorch#3843
  • Loading branch information
xw285cornell authored Oct 17, 2024
1 parent b4a286a commit 79047bf
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 30 deletions.
1 change: 0 additions & 1 deletion src/libtorchaudio/cuctc/include/ctc_prefix_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#ifndef __ctc_prefix_decoder_h_
#define __ctc_prefix_decoder_h_

#include <cuda_runtime.h>
#include <cstdint>
#include <tuple>
#include <vector>
Expand Down
18 changes: 0 additions & 18 deletions src/libtorchaudio/cuctc/include/ctc_prefix_decoder_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,6 @@
#ifndef __ctc_prefix_decoder_host_h_
#define __ctc_prefix_decoder_host_h_

#include <cuda_runtime.h>

#define CUDA_CHECK(X) \
do { \
auto result = X; \
if (result != cudaSuccess) { \
const char* p_err_str = cudaGetErrorName(result); \
fprintf( \
stderr, \
"File %s Line %d %s returned %s.\n", \
__FILE__, \
__LINE__, \
#X, \
p_err_str); \
abort(); \
} \
} while (0)

#define CHECK(X, ERROR_INFO) \
do { \
auto result = (X); \
Expand Down
4 changes: 4 additions & 0 deletions src/libtorchaudio/cuctc/src/bitonic_topk/bitonic_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ constexpr inline __host__ __device__ bool isPo2(IntType num) {
}

inline __device__ int laneId() {
#ifndef USE_ROCM
int id;
asm("mov.s32 %0, %%laneid;" : "=r"(id));
return id;
#else
return __lane_id();
#endif
}
/**
* @brief Shuffle the data inside a warp
Expand Down
3 changes: 2 additions & 1 deletion src/libtorchaudio/cuctc/src/bitonic_topk/pow2_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ namespace cu_ctc {
* @tparam IntType data type (checked only for integers)
*/
template <typename IntType>
constexpr __device__ IntType log2(IntType num, IntType ret = IntType(0)) {
constexpr __host__ __device__ IntType
log2(IntType num, IntType ret = IntType(0)) {
return num <= IntType(1) ? ret : log2(num >> IntType(1), ++ret);
}

Expand Down
8 changes: 4 additions & 4 deletions src/libtorchaudio/cuctc/src/bitonic_topk/warpsort_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ class warp_sort_filtered : public warp_sort<Capacity, Ascending, T, IdxT> {

__device__ __forceinline__ void merge_buf_() {
topk::bitonic<kMaxBufLen>(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_);
this->merge_in<kMaxBufLen>(val_buf_, idx_buf_);
this->template merge_in<kMaxBufLen>(val_buf_, idx_buf_);
buf_len_ = 0;
set_k_th_(); // contains warp sync
#pragma unroll
Expand Down Expand Up @@ -385,7 +385,7 @@ class warp_sort_immediate : public warp_sort<Capacity, Ascending, T, IdxT> {
if (buf_len_ == kMaxArrLen) {
topk::bitonic<kMaxArrLen>(!Ascending, kWarpWidth)
.sort(val_buf_, idx_buf_);
this->merge_in<kMaxArrLen>(val_buf_, idx_buf_);
this->template merge_in<kMaxArrLen>(val_buf_, idx_buf_);
#pragma unroll
for (int i = 0; i < kMaxArrLen; i++) {
val_buf_[i] = kDummy;
Expand All @@ -398,7 +398,7 @@ class warp_sort_immediate : public warp_sort<Capacity, Ascending, T, IdxT> {
if (buf_len_ != 0) {
topk::bitonic<kMaxArrLen>(!Ascending, kWarpWidth)
.sort(val_buf_, idx_buf_);
this->merge_in<kMaxArrLen>(val_buf_, idx_buf_);
this->template merge_in<kMaxArrLen>(val_buf_, idx_buf_);
}
}

Expand All @@ -421,7 +421,7 @@ constexpr inline __host__ __device__ IntType ceildiv(IntType a, IntType b) {
return (a + b - 1) / b;
}
template <typename IntType>
constexpr inline __device__ IntType roundUp256(IntType num) {
constexpr inline __host__ __device__ IntType roundUp256(IntType num) {
// return (num + 255) / 256 * 256;
constexpr int MASK = 255;
return (num + MASK) & (~MASK);
Expand Down
4 changes: 2 additions & 2 deletions src/libtorchaudio/cuctc/src/ctc_prefix_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <cuda_runtime.h>

#include "include/ctc_prefix_decoder.h"
#include "include/ctc_prefix_decoder_host.h"
#include "../include/ctc_prefix_decoder.h"
#include "../include/ctc_prefix_decoder_host.h"

#include "device_data_wrap.h"
#include "device_log_prob.cuh"
Expand Down
9 changes: 6 additions & 3 deletions src/libtorchaudio/cuctc/src/ctc_prefix_decoder_kernel_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <float.h>
#include <algorithm>
#include "../include/ctc_prefix_decoder_host.h"
#include "ctc_fast_divmod.cuh"
#include "cub/cub.cuh"
#include "device_data_wrap.h"
#include "device_log_prob.cuh"
#include "include/ctc_prefix_decoder_host.h"

#include "bitonic_topk/warpsort_topk.cuh"

Expand Down Expand Up @@ -630,7 +631,8 @@ int CTC_prob_first_step_V2(
num_of_subwarp, beam));
int smem_size =
block_sort_smem_size + beam * sizeof(float) + beam * sizeof(int);
FirstMatrixFuns[fun_idx]<<<grid, threads_per_block, smem_size, stream>>>(
auto kernel = FirstMatrixFuns[fun_idx];
kernel<<<grid, threads_per_block, smem_size, stream>>>(
(*log_prob_struct),
step,
pprev,
Expand Down Expand Up @@ -766,7 +768,8 @@ int CTC_prob_topK_V2(
int num_of_subwarp = threads_per_block0 / std::min<int>(32, actual_capacity);
int smem_size = cu_ctc::topk::calc_smem_size_for_block_wide<float, int>(
num_of_subwarp, beam);
BitonicTopkFuns[fun_idx]<<<grid, block, smem_size, stream>>>(
auto kernel = BitonicTopkFuns[fun_idx];
kernel<<<grid, block, smem_size, stream>>>(
(*log_prob_struct),
step,
ptable,
Expand Down
19 changes: 18 additions & 1 deletion src/libtorchaudio/cuctc/src/device_data_wrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,26 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <cuda_runtime.h>
#include <iostream>
#include <vector>
#include "include/ctc_prefix_decoder_host.h"
#include "../include/ctc_prefix_decoder_host.h"

#define CUDA_CHECK(X) \
do { \
auto result = X; \
if (result != cudaSuccess) { \
const char* p_err_str = cudaGetErrorName(result); \
fprintf( \
stderr, \
"File %s Line %d %s returned %s.\n", \
__FILE__, \
__LINE__, \
#X, \
p_err_str); \
abort(); \
} \
} while (0)

namespace cu_ctc {
constexpr size_t ALIGN_BYTES = 128;
Expand Down

0 comments on commit 79047bf

Please sign in to comment.