Skip to content

Commit

Permalink
[CODEMOD][pytorch] replace uses of np.ndarray with npt.NDArray (pytor…
Browse files Browse the repository at this point in the history
…ch#3845)

Summary:
X-link: pytorch/opacus#680
X-link: pytorch/captum#1387
X-link: pytorch/botorch#2584

This replaces uses of `numpy.ndarray` in type annotations with `numpy.typing.NDArray`. In Numpy-1.24.0+ `numpy.ndarray` is annotated as generic type. Without template parameters it triggers static analysis errors: 
```counterexample
Generic type `ndarray` expects 2 type parameters.
```
`numpy.typing.NDArray` is an alias that provides default template parameters.

Differential Revision: D64619891
  • Loading branch information
igorsugak authored and facebook-github-bot committed Oct 18, 2024
1 parent 79047bf commit 6a59e92
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions test/torchaudio_unittest/prototype/functional/dsp_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import numpy.typing as npt


def oscillator_bank(
Expand Down Expand Up @@ -43,8 +44,8 @@ def freq_ir(magnitudes):


def exp_sigmoid(
input: np.ndarray, exponent: float = 10.0, max_value: float = 2.0, threshold: float = 1e-7
) -> np.ndarray:
input: npt.NDArray, exponent: float = 10.0, max_value: float = 2.0, threshold: float = 1e-7
) -> npt.NDArray:
"""Exponential Sigmoid pointwise nonlinearity (Numpy version).
Implements the equation:
``max_value`` * sigmoid(``input``) ** (log(``exponent``)) + ``threshold``
Expand Down

0 comments on commit 6a59e92

Please sign in to comment.