From 7ca341e99a3bc300313addd18abfc5a213909f20 Mon Sep 17 00:00:00 2001 From: Tanmay Munjal Date: Sun, 14 Apr 2024 01:39:36 -0600 Subject: [PATCH 1/2] Fixed failing test test_jax_svd --- .../test_frontends/test_jax/test_numpy/test_linalg.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py index fb081e1cd406a..0a7389468272a 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py @@ -268,8 +268,8 @@ def test_jax_eig( ): dtype, x = dtype_and_x x = np.array(x[0], dtype=dtype[0]) - """Make symmetric positive-definite since ivy does not support complex data - dtypes currently.""" + """Make symmetric positive-definite since ivy does not support complex data dtypes + currently.""" x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 ret, frontend_ret = helpers.test_frontend_function( @@ -899,7 +899,7 @@ def test_jax_svd( if compute_uv: with BackendHandler.update_backend(backend_fw) as ivy_backend: - ret = [ivy_backend.to_numpy(x) for x in ret] + ret = [ivy_backend.to_numpy(x).astype(np.float64) for x in ret] frontend_ret = [np.asarray(x) for x in frontend_ret] u, s, vh = ret @@ -915,10 +915,11 @@ def test_jax_svd( ) else: with BackendHandler.update_backend(backend_fw) as ivy_backend: - ret = ivy_backend.to_numpy(ret) + ret = ivy_backend.to_numpy(ret).astype(np.float64) + frontend_ret = np.asarray(frontend_ret) assert_all_close( ret_np=ret, - ret_from_gt_np=np.asarray(frontend_ret[0]), + ret_from_gt_np=frontend_ret, rtol=1e-2, atol=1e-2, backend=backend_fw, From 765cae863ff2ef2930e4a18059613a8628db39e2 Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Sun, 14 Apr 2024 07:59:46 +0000 Subject: [PATCH 2/2] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_frontends/test_jax/test_numpy/test_linalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py index 0a7389468272a..8657cc660b96c 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py @@ -268,8 +268,8 @@ def test_jax_eig( ): dtype, x = dtype_and_x x = np.array(x[0], dtype=dtype[0]) - """Make symmetric positive-definite since ivy does not support complex data dtypes - currently.""" + """Make symmetric positive-definite since ivy does not support complex data + dtypes currently.""" x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 ret, frontend_ret = helpers.test_frontend_function(