Skip to content

Commit

Permalink
[AMD] hipify torchaudio
Browse files Browse the repository at this point in the history
Differential Revision: D64184710

Pull Request resolved: pytorch#3840
  • Loading branch information
xw285cornell authored Oct 16, 2024
1 parent 3f05699 commit b4a286a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ __global__ void ReduceMax2D(

CAST_DTYPE shf;
for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) {
#ifndef USE_ROCM
shf = __shfl_down_sync(0xFFFFFFFF, val, stride);
#else
shf = __shfl_down(val, stride);
#endif
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
if (shf > val) {
val = shf;
Expand Down Expand Up @@ -81,7 +85,11 @@ __global__ void ReduceLogSumExpGivenMax2D(

CAST_DTYPE shf;
for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) {
#ifndef USE_ROCM
shf = __shfl_down_sync(0xFFFFFFFF, val, stride);
#else
shf = __shfl_down(val, stride);
#endif
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
val = val + shf;
}
Expand Down
16 changes: 16 additions & 0 deletions src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,11 @@ __device__ void ComputeAlphas(

#pragma unroll
for (int i = 1; i < warpSize; i <<= 1) {
#ifndef USE_ROCM
val = __shfl_up_sync(0xffffffff, skip_prob, i);
#else
val = __shfl_up(skip_prob, i);
#endif
if (i <= threadIdx.x) {
skip_prob = skip_prob + val;
}
Expand All @@ -150,7 +154,11 @@ __device__ void ComputeAlphas(
CAST_DTYPE out = val;

for (int i = 1; i < warpSize; ++i) {
#ifndef USE_ROCM
val = __shfl_up_sync(0xffffffff, val, 1);
#else
val = __shfl_up(val, 1);
#endif
if (i == threadIdx.x) {
val = math::lse(val + skip_prob, emit);
out = val;
Expand Down Expand Up @@ -225,7 +233,11 @@ __device__ void ComputeBetasCosts(

#pragma unroll
for (int i = 1; i < warpSize; i <<= 1) {
#ifndef USE_ROCM
val = __shfl_up_sync(0xffffffff, skip_prob, i);
#else
val = __shfl_up(skip_prob, i);
#endif
if (i <= threadIdx.x) {
skip_prob = skip_prob + val;
}
Expand All @@ -248,7 +260,11 @@ __device__ void ComputeBetasCosts(
CAST_DTYPE out = val;

for (int i = 1; i < warpSize; ++i) {
#ifndef USE_ROCM
val = __shfl_up_sync(0xffffffff, val, 1);
#else
val = __shfl_up(val, 1);
#endif
if (i == threadIdx.x) {
val = math::lse(val + skip_prob, emit);
out = val;
Expand Down

0 comments on commit b4a286a

Please sign in to comment.