Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add hooks #3029

Open
wants to merge 4 commits into
base: dash-3.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from ._patch import Patch # noqa: F401,E402
from ._jupyter import jupyter_dash # noqa: F401,E402

from . import _hooks as hooks # noqa: F401,E402

ctx = callback_context


Expand Down
93 changes: 93 additions & 0 deletions dash/_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import typing as _t
from importlib import metadata as _importlib_metadata

import flask as _f

_ns = {
"setup": [],
"layout": [],
"routes": [],
"error": [],
"callback": [],
}


def layout(func):
"""
Run a function when serving the layout, the return value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it saves the function func but doesn't run it - is the comment incorrect?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is only run when serving the layout, the docstring are placeholder, we can improve the wording.

will be used as the layout.
"""
_ns["layout"].append(func)
return func


def setup(func):
"""
Can be used to get a reference to the app after it is instantiated.
"""
_ns["setup"].append(func)
return func


def route(name: _t.Optional[str] = None, methods: _t.Sequence[str] = ("GET",)):
"""
Add a route to the Dash server.
"""

def wrap(func: _t.Callable[[], _f.Response]):
_name = name or func.__name__
_ns["routes"].append((_name, func, methods))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm - so some of the _ns entries are functions and some are tuples, so the structure of entries in _ns["key"] depends on the value of "key"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _ns just saves the info for the different hooks, it isn't meant for public usage. For routes/callbacks it needs more arguments so it's saved in a tuple for easy unpacking.

return func

return wrap


def error(func: _t.Callable[[Exception], _t.Any]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're declaring the type of func here but you don't declare types in the previous registration functions - declare them all for consistency?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some issue with cyclic dependencies with the other types, I'll see what I can do.

"""Automatically add an error handler to the dash app."""
_ns["error"].append(func)
return func


def callback(*args, **kwargs):
"""
Add a callback to all the apps with the hook installed.
"""

def wrap(func):
_ns["callback"].append((list(args), dict(kwargs), func))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, it looks like the structure of values stored in _ns depends on the key, which feels like it should be documented somewhere in this file to help the next person reading the code.

return func

return wrap


class HooksManager:
_registered = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a comment explaining what _registered is for - in particular, why is it a class-level variable instead of an instance variable?


# pylint: disable=too-few-public-methods
class HookErrorHandler:
def __init__(self, original):
self.original = original

def __call__(self, err: Exception):
result = None
if self.original:
result = self.original(err)
hook_result = None
for hook in HooksManager.get_hooks("error"):
hook_result = hook(err)
return result or hook_result

@staticmethod
def get_hooks(hook: str):
return _ns.get(hook, []).copy()

@classmethod
def register_setuptools(cls):
if cls._registered:
return

for dist in _importlib_metadata.distributions():
for entry in dist.entry_points:
if entry.group != "dash":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm - hard-coded string embedded in the file - define as a constant at the top of the file to make it easier to find? and a comment here explaining what this filtering is doing would be welcome - I don't really understand it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just a filter on the entry_points of a setup.py, the entry_points are used to add cli and other functionalities to python packages. I usually don't define a variable if it's only used one time.

continue
entry.load()
26 changes: 26 additions & 0 deletions dash/dash.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,8 @@ def __init__( # pylint: disable=too-many-statements
for plugin in plugins:
plugin.plug(self)

self._setup_hooks()

# tracks internally if a function already handled at least one request.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need/want this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's the setup that load the plugins and add the setup/callback/error hooks.

self._got_first_request = {"pages": False, "setup_server": False}

Expand All @@ -588,6 +590,24 @@ def __init__( # pylint: disable=too-many-statements
)
self.setup_startup_routes()

def _setup_hooks(self):
# pylint: disable=import-outside-toplevel
from ._hooks import HooksManager

self._hooks = HooksManager
self._hooks.register_setuptools()

for setup in self._hooks.get_hooks("setup"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this where the difference in structure between the things that are stored in different _ns entries shows up? this loop is getting one value per list item; the loop below is getting three.

setup(self)

for callback_args, callback_kwargs, callback in self._hooks.get_hooks(
"callback"
):
self.callback(*callback_args, **callback_kwargs)(callback)

if self._hooks.get_hooks("error"):
self._on_error = self._hooks.HookErrorHandler(self._on_error)

def init_app(self, app=None, **kwargs):
"""Initialize the parts of Dash that require a flask app."""

Expand Down Expand Up @@ -688,6 +708,9 @@ def _setup_routes(self):
"_alive_" + jupyter_dash.alive_token, jupyter_dash.serve_alive
)

for name, func, methods in self._hooks.get_hooks("routes"):
self._add_url(name, func, methods)

# catch-all for front-end routes, used by dcc.Location
self._add_url("<path:path>", self.index)

Expand Down Expand Up @@ -754,6 +777,9 @@ def index_string(self, value):
def serve_layout(self):
layout = self._layout_value()

for hook in self._hooks.get_hooks("layout"):
layout = hook(layout)

# TODO - Set browser cache limit - pass hash into frontend
return flask.Response(
to_json(layout),
Expand Down
94 changes: 94 additions & 0 deletions tests/integration/test_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from flask import jsonify
import requests
import pytest

from dash import Dash, Input, Output, html, hooks, set_props


@pytest.fixture(scope="module", autouse=True)
def hook_cleanup():
yield
hooks._ns["layout"] = []
hooks._ns["setup"] = []
hooks._ns["route"] = []
hooks._ns["error"] = []
hooks._ns["callback"] = []


def test_hook001_layout(dash_duo):
@hooks.layout
def on_layout(layout):
return [html.Div("Header", id="header")] + layout

app = Dash()
app.layout = [html.Div("Body", id="body")]

dash_duo.start_server(app)

dash_duo.wait_for_text_to_equal("#header", "Header")
dash_duo.wait_for_text_to_equal("#body", "Body")


def test_hook002_setup():
setup_title = None

@hooks.setup
def on_setup(app: Dash):
nonlocal setup_title
setup_title = app.title

app = Dash(title="setup-test")
app.layout = html.Div("setup")

assert setup_title == "setup-test"


def test_hook003_route(dash_duo):
@hooks.route(methods=("POST",))
def hook_route():
return jsonify({"success": True})

app = Dash()
app.layout = html.Div("hook route")

dash_duo.start_server(app)
response = requests.post(f"{dash_duo.server_url}/hook_route")
data = response.json()
assert data["success"]


def test_hook004_error(dash_duo):
@hooks.error
def on_error(error):
set_props("error", {"children": str(error)})

app = Dash()
app.layout = [html.Button("start", id="start"), html.Div(id="error")]

@app.callback(Input("start", "n_clicks"), prevent_initial_call=True)
def on_click(_):
raise Exception("hook error")

dash_duo.start_server(app)
dash_duo.wait_for_element("#start").click()
dash_duo.wait_for_text_to_equal("#error", "hook error")


def test_hook005_callback(dash_duo):
@hooks.callback(
Output("output", "children"),
Input("start", "n_clicks"),
prevent_initial_call=True,
)
def on_hook_cb(n_clicks):
return f"clicked {n_clicks}"

app = Dash()
app.layout = [
html.Button("start", id="start"),
html.Div(id="output"),
]

dash_duo.start_server(app)
dash_duo.wait_for_element("#start").click()
dash_duo.wait_for_text_to_equal("#output", "clicked 1")