Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Mypy & JSON models friends #18

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
101 changes: 101 additions & 0 deletions jsonmodels/mypy_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import Callable, List, Type
import mypy
from mypy.plugin import Plugin, AttributeContext, FunctionContext
from mypy.types import Type as MypyType

class JSONModelsPlugin(Plugin):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool beans 👀

def get_function_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
if fullname == "jsonmodels.fields.StringField":
return self._string_field_callback
if fullname == "jsonmodels.fields.IntField":
return self._int_field_callback
if fullname == "jsonmodels.fields.FloatField":
return self._float_field_callback
if fullname == "jsonmodels.fields.BoolField":
return self._bool_field_callback
if fullname == "jsonmodels.fields.TimeField":
return self._time_field_callback
if fullname == "jsonmodels.fields.DateField":
return self._date_field_callback
if fullname == "jsonmodels.fields.DateTimeField":
return self._datetime_field_callback
if fullname == "jsonmodels.fields.EmbeddedField":
return self._embedded_field_callback
if fullname == "jsonmodels.fields.ListField":
return self._list_field_callback
if fullname == "jsonmodels.fields.DerivedListField":
return self._list_field_callback

return None

def _wrap_nullable(self, ctx: FunctionContext, core_type: MypyType) -> MypyType:
try:
nullable_index = ctx.callee_arg_names.index("nullable")
except ValueError:
return core_type

arg_value = ctx.args[nullable_index]
if len(arg_value) == 0:
return core_type

nullable_value = arg_value[0]
if isinstance(nullable_value, mypy.nodes.NameExpr) and nullable_value.fullname == "builtins.True":
return mypy.types.UnionType([core_type, mypy.types.NoneType()])

return core_type

def _string_field_callback(self, ctx: FunctionContext) -> MypyType:
return self._wrap_nullable(ctx, ctx.api.named_type("builtins.str"))

def _int_field_callback(self, ctx: FunctionContext) -> MypyType:
return self._wrap_nullable(ctx, ctx.api.named_type("builtins.int"))

def _float_field_callback(self, ctx: FunctionContext) -> MypyType:
return self._wrap_nullable(ctx, ctx.api.named_type("builtins.float"))

def _bool_field_callback(self, ctx: FunctionContext) -> MypyType:
return self._wrap_nullable(ctx, ctx.api.named_type("builtins.bool"))

def _time_field_callback(self, ctx: FunctionContext) -> MypyType:
return self._wrap_nullable(ctx, ctx.api.named_type("datetime.time"))

def _date_field_callback(self, ctx: FunctionContext) -> MypyType:
return self._wrap_nullable(ctx, ctx.api.named_type("datetime.date"))

def _datetime_field_callback(self, ctx: FunctionContext) -> MypyType:
return self._wrap_nullable(ctx, ctx.api.named_type("datetime.datetime"))

def _get_type_from_arg(self, ctx: FunctionContext, arg_name: str) -> MypyType:
try:
model_types_index = ctx.callee_arg_names.index(arg_name)
except ValueError:
return mypy.types.NoneType()

arg_value = ctx.args[model_types_index]
if len(arg_value) == 0:
return mypy.types.NoneType()

model_types_value = arg_value[0]

if isinstance(model_types_value, mypy.nodes.NameExpr):
return ctx.api.named_type(model_types_value.fullname)

if isinstance(model_types_value, mypy.nodes.TupleExpr):
accepted_types: List[MypyType] = []
for item in model_types_value.items:
if isinstance(item, mypy.nodes.NameExpr):
accepted_types.append(ctx.api.named_type(item.fullname))
return mypy.types.UnionType(accepted_types)

return mypy.types.NoneType()

def _embedded_field_callback(self, ctx: FunctionContext) -> MypyType:
return self._wrap_nullable(ctx, self._get_type_from_arg(ctx, "model_types"))

def _list_field_callback(self, ctx: FunctionContext) -> MypyType:
item_type = self._get_type_from_arg(ctx, "items_types")
list_type = ctx.api.named_generic_type("list", [item_type])
return self._wrap_nullable(ctx, list_type)

def plugin(version: str):
return JSONModelsPlugin
9 changes: 9 additions & 0 deletions mypy_plugin.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[mypy]
disallow_untyped_defs = True
disallow_any_unimported = True
no_implicit_optional = True
warn_return_any = True
warn_unused_configs = True
warn_unused_ignores = True
show_error_codes = True
plugins = jsonmodels/mypy_plugin.py
Empty file added tests_mypy/__init__.py
Empty file.
4 changes: 4 additions & 0 deletions tests_mypy/case_date.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from models import person

reveal_type(person.dob)
# expect: datetime.date
4 changes: 4 additions & 0 deletions tests_mypy/case_datetime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from models import person

reveal_type(person.last_update)
# expect: datetime.datetime
4 changes: 4 additions & 0 deletions tests_mypy/case_derivedlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from models import person

reveal_type(person.nicknames)
# expect: builtins.list[builtins.str]
7 changes: 7 additions & 0 deletions tests_mypy/case_embedded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from models import person

reveal_type(person.address)
# expect: models.Address

reveal_type(person.transport)
# expect: Union[models.Car, models.Boat]
4 changes: 4 additions & 0 deletions tests_mypy/case_int.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from models import person

reveal_type(person.age)
# expect: builtins.int
4 changes: 4 additions & 0 deletions tests_mypy/case_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from models import person

reveal_type(person.pet_names)
# expect: builtins.list[builtins.str]
7 changes: 7 additions & 0 deletions tests_mypy/case_nullable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from models import address

reveal_type(address.line_1)
# expect: builtins.str

reveal_type(address.line_2)
# expect: Union[builtins.str, None]
5 changes: 5 additions & 0 deletions tests_mypy/case_str.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

from models import person

reveal_type(person.name)
# expect: builtins.str
27 changes: 27 additions & 0 deletions tests_mypy/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from jsonmodels import models, fields

class Address(models.Base):
line_1 = fields.StringField()
line_2 = fields.StringField(nullable=True)
city = fields.StringField()

class Car(models.Base):
registration = fields.StringField()

class Boat(models.Base):
name = fields.StringField()

class Person(models.Base):
name = fields.StringField()
surname = fields.StringField()
age = fields.IntField()
dob = fields.DateField()
alive = fields.BoolField()
last_update = fields.DateTimeField()
address = fields.EmbeddedField(model_types=Address)
transport = fields.EmbeddedField(model_types=(Car, Boat))
pet_names = fields.ListField(items_types=str)
nicknames = fields.DerivedListField(fields.StringField())

person = Person()
address = Address()
49 changes: 49 additions & 0 deletions tests_mypy/test_mypy_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from mypy import api
import os


EXPECT_LINE = "# expect: "
EXPECT_LINE_OUTPUT = "Revealed type is "


def test_file(directory: str, file_name: str) -> bool:
expected: list[str] = []
file_path = os.path.join(directory, file_name)
with open(file_path, 'r') as f:
lines = f.readlines()
for line in lines:
if line.startswith(EXPECT_LINE):
expected.append(line[len(EXPECT_LINE):].strip())

result = api.run([
"--config-file=../mypy_plugin.ini",
"--show-traceback",
file_path])

output_expected: list[str] = []
for output_line in result[0].splitlines():
index = output_line.find(EXPECT_LINE_OUTPUT)
if index > 0:
output_expected.append(output_line[index + len(EXPECT_LINE_OUTPUT):].strip().strip('"'))

if expected == output_expected:
print(f"PASS {file_name}")
return True
else:
print(f"FAIL {file_name}\n")
print(f"Expected: {repr(expected)}")
print(f"Received: {repr(output_expected)}")
print("STDOUT----------------")
print(result[0])
print(result[1])
print("----------------------")
return False

def main() -> None:
directory = '.'
files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and f.startswith("case_")]

for file_name in files:
test_file(directory, file_name)

main()