Skip to content

Commit

Permalink
data_kind: Add more tests to demonstrate the data kind of various dat…
Browse files Browse the repository at this point in the history
…a types (#3480)

Co-authored-by: Wei Ji <[email protected]>
  • Loading branch information
seisman and weiji14 authored Oct 3, 2024
1 parent 2d1a8cc commit 68a17a0
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 48 deletions.
84 changes: 62 additions & 22 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,11 @@ def data_kind(
Parameters
----------
data : str, pathlib.PurePath, None, bool, xarray.DataArray or {table-like}
Pass in either a file name or :class:`pathlib.Path` to an ASCII data
table, an :class:`xarray.DataArray`, a 1-D/2-D
{table-classes} or an option argument.
data
The data to be passed to a GMT module.
required
Set to True when 'data' is required, or False when dealing with
optional virtual files. [Default is True].
Whether 'data' is required. Set to ``False`` when dealing with optional virtual
files.
Returns
-------
Expand All @@ -222,30 +220,72 @@ def data_kind(
Examples
--------
>>> import io
>>> from pathlib import Path
>>> import numpy as np
>>> import pandas as pd
>>> import xarray as xr
>>> import pathlib
>>> import io
>>> data_kind(data=None)
'vectors'
>>> data_kind(data=np.arange(10).reshape((5, 2)))
'matrix'
>>> data_kind(data="my-data-file.txt")
'file'
>>> data_kind(data=pathlib.Path("my-data-file.txt"))
'file'
The "arg" kind:
>>> [data_kind(data=data, required=False) for data in (2, 2.0, True, False)]
['arg', 'arg', 'arg', 'arg']
>>> data_kind(data=None, required=False)
'arg'
>>> data_kind(data=2.0, required=False)
'arg'
>>> data_kind(data=True, required=False)
'arg'
>>> data_kind(data=xr.DataArray(np.random.rand(4, 3)))
The "file" kind:
>>> [data_kind(data=data) for data in ("file.txt", ("file1.txt", "file2.txt"))]
['file', 'file']
>>> data_kind(data=Path("file.txt"))
'file'
>>> data_kind(data=(Path("file1.txt"), Path("file2.txt")))
'file'
The "grid" kind:
>>> data_kind(data=xr.DataArray(np.random.rand(4, 3))) # 2-D xarray.DataArray
'grid'
>>> data_kind(data=xr.DataArray(np.arange(12))) # 1-D xarray.DataArray
'grid'
>>> data_kind(data=xr.DataArray(np.random.rand(2, 3, 4, 5))) # 4-D xarray.DataArray
'grid'
>>> data_kind(data=xr.DataArray(np.random.rand(3, 4, 5)))
The "image" kind:
>>> data_kind(data=xr.DataArray(np.random.rand(3, 4, 5))) # 3-D xarray.DataArray
'image'
The "stringio"`` kind:
>>> data_kind(data=io.StringIO("TEXT1\nTEXT23\n"))
'stringio'
The "matrix"`` kind:
>>> data_kind(data=np.arange(10)) # 1-D numpy.ndarray
'matrix'
>>> data_kind(data=np.arange(10).reshape((5, 2))) # 2-D numpy.ndarray
'matrix'
>>> data_kind(data=np.arange(60).reshape((3, 4, 5))) # 3-D numpy.ndarray
'matrix'
>>> data_kind(xr.DataArray(np.arange(12), name="x").to_dataset()) # xarray.Dataset
'matrix'
>>> data_kind(data=[1, 2, 3]) # 1-D sequence
'matrix'
>>> data_kind(data=[[1, 2, 3], [4, 5, 6]]) # sequence of sequences
'matrix'
>>> data_kind(data={"x": [1, 2, 3], "y": [4, 5, 6]}) # dictionary
'matrix'
>>> data_kind(data=pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) # pd.DataFrame
'matrix'
>>> data_kind(data=pd.Series([1, 2, 3], name="x")) # pd.Series
'matrix'
The "vectors" kind:
>>> data_kind(data=None)
'vectors'
"""
kind: Literal[
"arg", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
Expand Down
3 changes: 1 addition & 2 deletions pygmt/tests/test_blockm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pygmt import blockmean, blockmode
from pygmt.datasets import load_sample_data
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import GMTTempFile, data_kind
from pygmt.helpers import GMTTempFile


@pytest.fixture(scope="module", name="dataframe")
Expand Down Expand Up @@ -68,7 +68,6 @@ def test_blockmean_wrong_kind_of_input_table_grid(dataframe):
Run blockmean using table input that is not a pandas.DataFrame or file but a grid.
"""
invalid_table = dataframe.bathymetry.to_xarray()
assert data_kind(invalid_table) == "grid"
with pytest.raises(GMTInvalidInput):
blockmean(data=invalid_table, spacing="5m", region=[245, 255, 20, 30])

Expand Down
3 changes: 1 addition & 2 deletions pygmt/tests/test_blockmedian.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pygmt import blockmedian
from pygmt.datasets import load_sample_data
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import GMTTempFile, data_kind
from pygmt.helpers import GMTTempFile


@pytest.fixture(scope="module", name="dataframe")
Expand Down Expand Up @@ -65,7 +65,6 @@ def test_blockmedian_wrong_kind_of_input_table_grid(dataframe):
Run blockmedian using table input that is not a pandas.DataFrame or file but a grid.
"""
invalid_table = dataframe.bathymetry.to_xarray()
assert data_kind(invalid_table) == "grid"
with pytest.raises(GMTInvalidInput):
blockmedian(data=invalid_table, spacing="5m", region=[245, 255, 20, 30])

Expand Down
16 changes: 16 additions & 0 deletions pygmt/tests/test_geopandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import pytest
from pygmt import Figure, info, makecpt, which
from pygmt.helpers import data_kind
from pygmt.helpers.testing import skip_if_no

gpd = pytest.importorskip("geopandas")
Expand Down Expand Up @@ -243,3 +244,18 @@ def test_geopandas_plot_int64_as_float(gdf_ridge):
makecpt(cmap="lisbon", series=[10, 60, 10], continuous=True)
fig.colorbar()
return fig


def test_geopandas_data_kind_geopandas(gdf):
"""
Check if geopandas.GeoDataFrame object is recognized as a "geojson" kind.
"""
assert data_kind(data=gdf) == "geojson"


def test_geopandas_data_kind_shapely():
"""
Check if shapely.geometry object is recognized as a "geojson" kind.
"""
polygon = shapely.geometry.Polygon([(20, 10), (23, 10), (23, 14), (20, 14)])
assert data_kind(data=polygon) == "geojson"
10 changes: 3 additions & 7 deletions pygmt/tests/test_grdtrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest
from pygmt import grdtrack
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import GMTTempFile, data_kind
from pygmt.helpers import GMTTempFile
from pygmt.helpers.testing import load_static_earth_relief

POINTS_DATA = Path(__file__).parent / "data" / "track.txt"
Expand Down Expand Up @@ -126,22 +126,18 @@ def test_grdtrack_profile(dataarray):

def test_grdtrack_wrong_kind_of_points_input(dataarray, dataframe):
"""
Run grdtrack using points input that is not a pandas.DataFrame (matrix) or file.
Run grdtrack using points input that is not a pandas.DataFrame or file.
"""
invalid_points = dataframe.longitude.to_xarray()

assert data_kind(invalid_points) == "grid"
with pytest.raises(GMTInvalidInput):
grdtrack(points=invalid_points, grid=dataarray, newcolname="bathymetry")


def test_grdtrack_wrong_kind_of_grid_input(dataarray, dataframe):
"""
Run grdtrack using grid input that is not as xarray.DataArray (grid) or file.
Run grdtrack using grid input that is not an xarray.DataArray or file.
"""
invalid_grid = dataarray.to_dataset()

assert data_kind(invalid_grid) == "matrix"
with pytest.raises(GMTInvalidInput):
grdtrack(points=dataframe, grid=invalid_grid, newcolname="bathymetry")

Expand Down
6 changes: 1 addition & 5 deletions pygmt/tests/test_grdview.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from pygmt import Figure, grdcut
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import GMTTempFile, data_kind
from pygmt.helpers import GMTTempFile
from pygmt.helpers.testing import load_static_earth_relief


Expand Down Expand Up @@ -58,8 +58,6 @@ def test_grdview_wrong_kind_of_grid(xrgrid):
Run grdview using grid input that is not an xarray.DataArray or file.
"""
dataset = xrgrid.to_dataset() # convert xarray.DataArray to xarray.Dataset
assert data_kind(dataset) == "matrix"

fig = Figure()
with pytest.raises(GMTInvalidInput):
fig.grdview(grid=dataset)
Expand Down Expand Up @@ -238,8 +236,6 @@ def test_grdview_wrong_kind_of_drapegrid(xrgrid):
Run grdview using drapegrid input that is not an xarray.DataArray or file.
"""
dataset = xrgrid.to_dataset() # convert xarray.DataArray to xarray.Dataset
assert data_kind(dataset) == "matrix"

fig = Figure()
with pytest.raises(GMTInvalidInput):
fig.grdview(grid=xrgrid, drapegrid=dataset)
3 changes: 1 addition & 2 deletions pygmt/tests/test_nearneighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pygmt import nearneighbor
from pygmt.datasets import load_sample_data
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import GMTTempFile, data_kind
from pygmt.helpers import GMTTempFile


@pytest.fixture(scope="module", name="ship_data")
Expand Down Expand Up @@ -61,7 +61,6 @@ def test_nearneighbor_wrong_kind_of_input(ship_data):
Run nearneighbor using grid input that is not file/matrix/vectors.
"""
data = ship_data.bathymetry.to_xarray() # convert pandas.Series to xarray.DataArray
assert data_kind(data) == "grid"
with pytest.raises(GMTInvalidInput):
nearneighbor(
data=data, spacing="5m", region=[245, 255, 20, 30], search_radius="10m"
Expand Down
3 changes: 1 addition & 2 deletions pygmt/tests/test_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import xarray as xr
from pygmt import surface, which
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import GMTTempFile, data_kind
from pygmt.helpers import GMTTempFile


@pytest.fixture(scope="module", name="data")
Expand Down Expand Up @@ -125,7 +125,6 @@ def test_surface_wrong_kind_of_input(data, region, spacing):
Run surface using grid input that is not file/matrix/vectors.
"""
data = data.z.to_xarray() # convert pandas.Series to xarray.DataArray
assert data_kind(data) == "grid"
with pytest.raises(GMTInvalidInput):
surface(data=data, spacing=spacing, region=region)

Expand Down
3 changes: 1 addition & 2 deletions pygmt/tests/test_triangulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import xarray as xr
from pygmt import triangulate, which
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import GMTTempFile, data_kind
from pygmt.helpers import GMTTempFile


@pytest.fixture(scope="module", name="dataframe")
Expand Down Expand Up @@ -93,7 +93,6 @@ def test_delaunay_triples_wrong_kind_of_input(dataframe):
Run triangulate.delaunay_triples using grid input that is not file/matrix/vectors.
"""
data = dataframe.z.to_xarray() # convert pandas.Series to xarray.DataArray
assert data_kind(data) == "grid"
with pytest.raises(GMTInvalidInput):
triangulate.delaunay_triples(data=data)

Expand Down
6 changes: 2 additions & 4 deletions pygmt/tests/test_x2sys_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from pygmt.clib import __gmt_version__
from pygmt.datasets import load_sample_data
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import data_kind


@pytest.fixture(name="mock_x2sys_home")
Expand Down Expand Up @@ -237,11 +236,10 @@ def test_x2sys_cross_input_two_filenames():

def test_x2sys_cross_invalid_tracks_input_type(tracks):
"""
Run x2sys_cross using tracks input that is not a pandas.DataFrame (matrix) or str
(file) type, which would raise a GMTInvalidInput error.
Run x2sys_cross using tracks input that is not a pandas.DataFrame or str type,
which would raise a GMTInvalidInput error.
"""
invalid_tracks = tracks[0].to_xarray().z
assert data_kind(invalid_tracks) == "grid"
with pytest.raises(GMTInvalidInput):
x2sys_cross(tracks=[invalid_tracks])

Expand Down

0 comments on commit 68a17a0

Please sign in to comment.