Skip to content

Commit

Permalink
feat: http handler support request forwarding. (#16637)
Browse files Browse the repository at this point in the history
* feat: http handler support request forwarding.

* small refactor.
  • Loading branch information
youngsofun authored Oct 18, 2024
1 parent 5ec5c02 commit 912267c
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 90 deletions.
1 change: 1 addition & 0 deletions src/common/base/src/headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub const HEADER_NODE_ID: &str = "X-DATABEND-NODE-ID";
pub const HEADER_QUERY_STATE: &str = "X-DATABEND-QUERY-STATE";
pub const HEADER_QUERY_PAGE_ROWS: &str = "X-DATABEND-QUERY-PAGE-ROWS";
pub const HEADER_VERSION: &str = "X-DATABEND-VERSION";
pub const HEADER_STICKY: &str = "X-DATABEND-STICKY-NODE";

pub const HEADER_SIGNATURE: &str = "X-DATABEND-SIGNATURE";
pub const HEADER_AUTH_METHOD: &str = "X-DATABEND-AUTH-METHOD";
33 changes: 32 additions & 1 deletion src/query/service/src/clusters/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ use futures::Future;
use futures::StreamExt;
use log::error;
use log::warn;
use parking_lot::RwLock;
use rand::thread_rng;
use rand::Rng;
use serde::Deserialize;
Expand All @@ -66,6 +67,7 @@ pub struct ClusterDiscovery {
cluster_id: String,
tenant_id: String,
flight_address: String,
cached_cluster: RwLock<Option<Arc<Cluster>>>,
}

// avoid leak FlightClient to common-xxx
Expand Down Expand Up @@ -200,6 +202,7 @@ impl ClusterDiscovery {
cluster_id: cfg.query.cluster_id.clone(),
tenant_id: cfg.query.tenant_id.tenant_name().to_string(),
flight_address: cfg.query.flight_api_address.clone(),
cached_cluster: Default::default(),
}))
}

Expand Down Expand Up @@ -261,11 +264,39 @@ impl ClusterDiscovery {
&self.flight_address,
cluster_nodes.len() as f64,
);
Ok(Cluster::create(res, self.local_id.clone()))
let res = Cluster::create(res, self.local_id.clone());
*self.cached_cluster.write() = Some(res.clone());
Ok(res)
}
}
}

fn cached_cluster(self: &Arc<Self>) -> Option<Arc<Cluster>> {
(*self.cached_cluster.read()).clone()
}

pub async fn find_node_by_id(
self: Arc<Self>,
id: &str,
config: &InnerConfig,
) -> Result<Option<Arc<NodeInfo>>> {
let (mut cluster, mut is_cached) = if let Some(cluster) = self.cached_cluster() {
(cluster, true)
} else {
(self.discover(config).await?, false)
};
while is_cached {
for node in cluster.get_nodes() {
if node.id == id {
return Ok(Some(node.clone()));
}
}
cluster = self.discover(config).await?;
is_cached = false;
}
Ok(None)
}

#[async_backtrace::framed]
async fn drop_invalid_nodes(self: &Arc<Self>, node_info: &NodeInfo) -> Result<()> {
let current_nodes_info = match self.api_provider.get_nodes().await {
Expand Down
72 changes: 1 addition & 71 deletions src/query/service/src/servers/http/http_services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,18 @@ use poem::listener::OpensslTlsConfig;
use poem::middleware::CatchPanic;
use poem::middleware::NormalizePath;
use poem::middleware::TrailingSlash;
use poem::post;
use poem::put;
use poem::Endpoint;
use poem::EndpointExt;
use poem::IntoEndpoint;
use poem::IntoResponse;
use poem::Route;

use super::v1::discovery_nodes;
use super::v1::logout_handler;
use super::v1::upload_to_stage;
use super::v1::HttpQueryContext;
use crate::servers::http::middleware::json_response;
use crate::servers::http::middleware::EndpointKind;
use crate::servers::http::middleware::HTTPSessionMiddleware;
use crate::servers::http::middleware::PanicHandler;
use crate::servers::http::v1::clickhouse_router;
use crate::servers::http::v1::list_suggestions;
use crate::servers::http::v1::login_handler;
use crate::servers::http::v1::query_route;
use crate::servers::http::v1::refresh_handler;
use crate::servers::Server;

#[derive(Copy, Clone)]
Expand Down Expand Up @@ -98,70 +89,9 @@ impl HttpHandler {
})
}

pub fn wrap_auth<E>(&self, ep: E, auth_type: EndpointKind) -> impl Endpoint
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
let session_middleware = HTTPSessionMiddleware::create(self.kind, auth_type);
ep.with(session_middleware).boxed()
}

#[allow(clippy::let_with_type_underscore)]
#[async_backtrace::framed]
async fn build_router(&self, sock: SocketAddr) -> impl Endpoint {
let ep_v1 = Route::new()
.nest("/query", query_route(self.kind))
.at(
"/session/login",
post(login_handler).with(HTTPSessionMiddleware::create(
self.kind,
EndpointKind::Login,
)),
)
.at(
"/session/logout",
post(logout_handler).with(HTTPSessionMiddleware::create(
self.kind,
EndpointKind::Logout,
)),
)
.at(
"/session/refresh",
post(refresh_handler).with(HTTPSessionMiddleware::create(
self.kind,
EndpointKind::Refresh,
)),
)
.at(
"/auth/verify",
get(verify_handler).with(HTTPSessionMiddleware::create(
self.kind,
EndpointKind::Verify,
)),
)
.at(
"/upload_to_stage",
put(upload_to_stage).with(HTTPSessionMiddleware::create(
self.kind,
EndpointKind::StartQuery,
)),
)
.at(
"/suggested_background_tasks",
get(list_suggestions).with(HTTPSessionMiddleware::create(
self.kind,
EndpointKind::StartQuery,
)),
)
.at(
"/discovery_nodes",
get(discovery_nodes).with(HTTPSessionMiddleware::create(
self.kind,
EndpointKind::NoAuth,
)),
);

let ep_clickhouse =
Route::new()
.nest("/", clickhouse_router())
Expand All @@ -182,7 +112,7 @@ impl HttpHandler {
HttpHandlerKind::Query => Route::new()
.at("/", ep_usage)
.nest("/health", ep_health)
.nest("/v1", ep_v1)
.nest("/v1", query_route())
.nest("/clickhouse", ep_clickhouse),
HttpHandlerKind::Clickhouse => Route::new()
.nest("/", ep_clickhouse)
Expand Down
92 changes: 91 additions & 1 deletion src/query/service/src/servers/http/middleware/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use databend_common_base::headers::HEADER_DEDUPLICATE_LABEL;
use databend_common_base::headers::HEADER_NODE_ID;
use databend_common_base::headers::HEADER_QUERY_ID;
use databend_common_base::headers::HEADER_SESSION_ID;
use databend_common_base::headers::HEADER_STICKY;
use databend_common_base::headers::HEADER_TENANT;
use databend_common_base::headers::HEADER_VERSION;
use databend_common_base::runtime::ThreadTracker;
Expand All @@ -28,6 +29,7 @@ use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_meta_app::principal::user_token::TokenType;
use databend_common_meta_app::tenant::Tenant;
use databend_common_meta_types::NodeInfo;
use fastrace::func_name;
use headers::authorization::Basic;
use headers::authorization::Bearer;
Expand All @@ -47,6 +49,7 @@ use poem::error::Result as PoemResult;
use poem::web::Json;
use poem::Addr;
use poem::Endpoint;
use poem::Error;
use poem::IntoResponse;
use poem::Middleware;
use poem::Request;
Expand All @@ -55,6 +58,7 @@ use uuid::Uuid;

use crate::auth::AuthMgr;
use crate::auth::Credential;
use crate::clusters::ClusterDiscovery;
use crate::servers::http::error::HttpErrorCode;
use crate::servers::http::error::JsonErrorOnly;
use crate::servers::http::error::QueryError;
Expand Down Expand Up @@ -82,6 +86,12 @@ impl EndpointKind {
pub fn need_user_info(&self) -> bool {
!matches!(self, EndpointKind::NoAuth | EndpointKind::PollQuery)
}
pub fn may_need_sticky(&self) -> bool {
matches!(
self,
EndpointKind::StartQuery | EndpointKind::PollQuery | EndpointKind::Logout
)
}
pub fn require_databend_token_type(&self) -> Result<Option<TokenType>> {
match self {
EndpointKind::Verify => Ok(None),
Expand Down Expand Up @@ -372,14 +382,94 @@ impl<E> HTTPSessionEndpoint<E> {
}
}

async fn forward_request(mut req: Request, node: Arc<NodeInfo>) -> PoemResult<Response> {
let addr = node.http_address.clone();
let config = GlobalConfig::instance();
let scheme = if config.query.http_handler_tls_server_key.is_empty()
|| config.query.http_handler_tls_server_cert.is_empty()
{
"http"
} else {
"https"
};
let url = format!("{scheme}://{addr}/v1{}", req.uri());

let client = reqwest::Client::new();
let reqwest_request = client
.request(req.method().clone(), &url)
.headers(req.headers().clone())
.body(req.take_body().into_bytes().await?)
.build()
.map_err(|e| {
HttpErrorCode::bad_request(ErrorCode::BadArguments(format!(
"fail to build forward request: {e}"
)))
})?;

let response = client.execute(reqwest_request).await.map_err(|e| {
HttpErrorCode::server_error(ErrorCode::Internal(format!(
"fail to send forward request: {e}",
)))
})?;

let status = StatusCode::from_u16(response.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let headers = response.headers().clone();
let body = response.bytes().await.map_err(|e| {
HttpErrorCode::server_error(ErrorCode::Internal(format!(
"fail to send forward request: {e}",
)))
})?;
let mut poem_resp = Response::builder().status(status).body(body);
let headers_ref = poem_resp.headers_mut();
for (key, value) in headers.iter() {
headers_ref.insert(key, value.to_owned());
}
Ok(poem_resp)
}

impl<E: Endpoint> Endpoint for HTTPSessionEndpoint<E> {
type Output = Response;

#[async_backtrace::framed]
async fn call(&self, mut req: Request) -> PoemResult<Self::Output> {
let headers = req.headers().clone();

if self.endpoint_kind.may_need_sticky()
&& let Some(sticky_node_id) = headers.get(HEADER_STICKY)
{
let sticky_node_id = sticky_node_id
.to_str()
.map_err(|e| {
HttpErrorCode::bad_request(ErrorCode::BadArguments(format!(
"Invalid Header ({HEADER_STICKY}: {sticky_node_id:?}): {e}"
)))
})?
.to_string();
let local_id = GlobalConfig::instance().query.node_id.clone();
if local_id != sticky_node_id {
let config = GlobalConfig::instance();
return if let Some(node) = ClusterDiscovery::instance()
.find_node_by_id(&sticky_node_id, &config)
.await
.map_err(HttpErrorCode::server_error)?
{
log::info!(
"forwarding {} from {local_id} to {sticky_node_id}",
req.uri()
);
forward_request(req, node).await
} else {
let msg = format!("sticky_node_id '{sticky_node_id}' not found in cluster",);
warn!("{}", msg);
Err(Error::from(HttpErrorCode::bad_request(
ErrorCode::BadArguments(msg),
)))
};
}
}
let method = req.method().clone();
let uri = req.uri().clone();
let headers = req.headers().clone();

let query_id = req
.headers()
Expand Down
Loading

0 comments on commit 912267c

Please sign in to comment.