From 9808711e8380313d97cece136989618134a9bb53 Mon Sep 17 00:00:00 2001 From: flaneur Date: Thu, 2 Nov 2023 16:32:09 +0800 Subject: [PATCH] feat: process context in queries (#75) * add contxt * wrap context in http restful api * ls * fix comment --- con_batch.go | 4 ++-- connection.go | 4 ++-- restful.go | 55 ++++++++++++++++++++++++++++--------------------- restful_test.go | 3 ++- rows.go | 16 ++++++++++---- 5 files changed, 49 insertions(+), 33 deletions(-) diff --git a/con_batch.go b/con_batch.go index 33e47a2..a7154ed 100644 --- a/con_batch.go +++ b/con_batch.go @@ -62,7 +62,7 @@ func (b *httpBatch) BatchInsert() error { if err != nil { return errors.Wrap(err, "upload to stage failed") } - _, err = b.conn.rest.InsertWithStage(b.query, stage, nil, nil) + _, err = b.conn.rest.InsertWithStage(b.ctx, b.query, stage, nil, nil) if err != nil { return errors.Wrap(err, "insert with stage failed") } @@ -107,7 +107,7 @@ func (b *httpBatch) UploadToStage() (*StageLocation, error) { Name: "~", Path: fmt.Sprintf("batch/%d-%s", time.Now().Unix(), filepath.Base(b.batchFile)), } - return stage, b.conn.rest.UploadToStage(stage, input, size) + return stage, b.conn.rest.UploadToStage(b.ctx, stage, input, size) } var _ ldriver.Batch = (*httpBatch)(nil) diff --git a/connection.go b/connection.go index db1e1c9..97f9c88 100644 --- a/connection.go +++ b/connection.go @@ -38,7 +38,7 @@ func (dc *DatabendConn) exec(ctx context.Context, query string, args ...driver.V respCh := make(chan QueryResponse) errCh := make(chan error) go func() { - err := dc.rest.QuerySync(query, args, respCh) + err := dc.rest.QuerySync(ctx, query, args, respCh) errCh <- err }() @@ -64,7 +64,7 @@ func (dc *DatabendConn) query(ctx context.Context, query string, args ...driver. var r0 *QueryResponse err := retry.Do( func() error { - r, err := dc.rest.DoQuery(query, args) + r, err := dc.rest.DoQuery(ctx, query, args) if err != nil { return err } diff --git a/restful.go b/restful.go index 13a272f..00ab28f 100644 --- a/restful.go +++ b/restful.go @@ -120,7 +120,7 @@ func initAccessTokenLoader(cfg *Config) AccessTokenLoader { return nil } -func (c *APIClient) doRequest(method, path string, req interface{}, resp interface{}) error { +func (c *APIClient) doRequest(ctx context.Context, method, path string, req interface{}, resp interface{}) error { if c.doRequestFunc != nil { return c.doRequestFunc(method, path, req, resp) } @@ -139,6 +139,7 @@ func (c *APIClient) doRequest(method, path string, req interface{}, resp interfa if err != nil { return errors.Wrap(err, "failed to create http request") } + httpReq = httpReq.WithContext(ctx) maxRetries := 2 for i := 1; i <= maxRetries; i++ { @@ -270,7 +271,7 @@ func (c *APIClient) getSessionConfig() *SessionConfig { } } -func (c *APIClient) DoQuery(query string, args []driver.Value) (*QueryResponse, error) { +func (c *APIClient) DoQuery(ctx context.Context, query string, args []driver.Value) (*QueryResponse, error) { q, err := buildQuery(query, args) if err != nil { return nil, err @@ -283,7 +284,7 @@ func (c *APIClient) DoQuery(query string, args []driver.Value) (*QueryResponse, path := "/v1/query" var result QueryResponse - err = c.doRequest("POST", path, request, &result) + err = c.doRequest(ctx, "POST", path, request, &result) if err != nil { return nil, errors.Wrap(err, "failed to do query request") } @@ -309,7 +310,7 @@ func (c *APIClient) applySessionConfig(response *QueryResponse) { } } -func (c *APIClient) WaitForQuery(result *QueryResponse) (*QueryResponse, error) { +func (c *APIClient) WaitForQuery(ctx context.Context, result *QueryResponse) (*QueryResponse, error) { if result.Error != nil { return nil, errors.Wrap(result.Error, "query failed") } @@ -317,7 +318,7 @@ func (c *APIClient) WaitForQuery(result *QueryResponse) (*QueryResponse, error) for result.NextURI != "" { schema := result.Schema data := result.Data - result, err = c.QueryPage(result.NextURI) + result, err = c.QueryPage(ctx, result.NextURI) if err != nil { return nil, errors.Wrap(err, "failed to query page") } @@ -332,13 +333,13 @@ func (c *APIClient) WaitForQuery(result *QueryResponse) (*QueryResponse, error) return result, nil } -func (c *APIClient) QuerySingle(query string, args []driver.Value) (*QueryResponse, error) { - result, err := c.DoQuery(query, args) +func (c *APIClient) QuerySingle(ctx context.Context, query string, args []driver.Value) (*QueryResponse, error) { + result, err := c.DoQuery(ctx, query, args) if err != nil { return nil, err } c.trackStats(result) - return c.WaitForQuery(result) + return c.WaitForQuery(ctx, result) } func buildQuery(query string, params []driver.Value) (string, error) { @@ -352,12 +353,12 @@ func buildQuery(query string, params []driver.Value) (string, error) { return query, nil } -func (c *APIClient) QuerySync(query string, args []driver.Value, respCh chan QueryResponse) error { +func (c *APIClient) QuerySync(ctx context.Context, query string, args []driver.Value, respCh chan QueryResponse) error { // fmt.Printf("query sync %s", query) var r0 *QueryResponse err := retry.Do( func() error { - r, err := c.DoQuery(query, args) + r, err := c.DoQuery(ctx, query, args) if err != nil { return err } @@ -387,7 +388,7 @@ func (c *APIClient) QuerySync(query string, args []driver.Value, respCh chan Que respCh <- *r0 nextUri := r0.NextURI for len(nextUri) != 0 { - p, err := c.QueryPage(nextUri) + p, err := c.QueryPage(ctx, nextUri) if err != nil { return err } @@ -400,11 +401,11 @@ func (c *APIClient) QuerySync(query string, args []driver.Value, respCh chan Que return nil } -func (c *APIClient) QueryPage(nextURI string) (*QueryResponse, error) { +func (c *APIClient) QueryPage(ctx context.Context, nextURI string) (*QueryResponse, error) { var result QueryResponse err := retry.Do( func() error { - return c.doRequest("GET", nextURI, nil, &result) + return c.doRequest(ctx, "GET", nextURI, nil, &result) }, retry.RetryIf(func(err error) bool { if err == nil { @@ -426,7 +427,13 @@ func (c *APIClient) QueryPage(nextURI string) (*QueryResponse, error) { return &result, nil } -func (c *APIClient) InsertWithStage(sql string, stage *StageLocation, fileFormatOptions, copyOptions map[string]string) (*QueryResponse, error) { +func (c *APIClient) KillQuery(ctx context.Context, killURI string) error { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + return c.doRequest(ctx, "POST", killURI, nil, nil) +} + +func (c *APIClient) InsertWithStage(ctx context.Context, sql string, stage *StageLocation, fileFormatOptions, copyOptions map[string]string) (*QueryResponse, error) { if stage == nil { return nil, errors.New("stage location required for insert with stage") } @@ -449,26 +456,26 @@ func (c *APIClient) InsertWithStage(sql string, stage *StageLocation, fileFormat path := "/v1/query" var result QueryResponse - err := c.doRequest("POST", path, request, &result) + err := c.doRequest(ctx, "POST", path, request, &result) if err != nil { return nil, errors.Wrap(err, "failed to insert with stage") } c.trackStats(&result) - return c.WaitForQuery(&result) + return c.WaitForQuery(ctx, &result) } -func (c *APIClient) UploadToStage(stage *StageLocation, input *bufio.Reader, size int64) error { +func (c *APIClient) UploadToStage(ctx context.Context, stage *StageLocation, input *bufio.Reader, size int64) error { if c.PresignedURLDisabled { - return c.UploadToStageByAPI(stage, input, size) + return c.UploadToStageByAPI(ctx, stage, input, size) } else { - return c.UploadToStageByPresignURL(stage, input, size) + return c.UploadToStageByPresignURL(ctx, stage, input, size) } } -func (c *APIClient) GetPresignedURL(stage *StageLocation) (*PresignedResponse, error) { +func (c *APIClient) GetPresignedURL(ctx context.Context, stage *StageLocation) (*PresignedResponse, error) { var headers string presignUploadSQL := fmt.Sprintf("PRESIGN UPLOAD %s", stage) - resp, err := c.QuerySingle(presignUploadSQL, nil) + resp, err := c.QuerySingle(ctx, presignUploadSQL, nil) if err != nil { return nil, errors.Wrap(err, "failed to query presign url") } @@ -489,8 +496,8 @@ func (c *APIClient) GetPresignedURL(stage *StageLocation) (*PresignedResponse, e return result, nil } -func (c *APIClient) UploadToStageByPresignURL(stage *StageLocation, input *bufio.Reader, size int64) error { - presigned, err := c.GetPresignedURL(stage) +func (c *APIClient) UploadToStageByPresignURL(ctx context.Context, stage *StageLocation, input *bufio.Reader, size int64) error { + presigned, err := c.GetPresignedURL(ctx, stage) if err != nil { return errors.Wrap(err, "failed to get presigned url") } @@ -522,7 +529,7 @@ func (c *APIClient) UploadToStageByPresignURL(stage *StageLocation, input *bufio return nil } -func (c *APIClient) UploadToStageByAPI(stage *StageLocation, input *bufio.Reader, size int64) error { +func (c *APIClient) UploadToStageByAPI(ctx context.Context, stage *StageLocation, input *bufio.Reader, size int64) error { body := new(bytes.Buffer) writer := multipart.NewWriter(body) part, err := writer.CreateFormFile("upload", stage.Path) diff --git a/restful_test.go b/restful_test.go index d62ec56..19922ca 100644 --- a/restful_test.go +++ b/restful_test.go @@ -1,6 +1,7 @@ package godatabend import ( + "context" "database/sql/driver" "encoding/json" "testing" @@ -57,7 +58,7 @@ func TestDoQuery(t *testing.T) { doRequestFunc: mockDoRequest, statsTracker: statsTracker, } - _, err := c.DoQuery("SELECT 1", []driver.Value{}) + _, err := c.DoQuery(context.Background(), "SELECT 1", []driver.Value{}) assert.NoError(t, err) assert.Equal(t, gotQueryID, "mockid1") } diff --git a/rows.go b/rows.go index 22396c3..c96e9d8 100644 --- a/rows.go +++ b/rows.go @@ -1,7 +1,9 @@ package godatabend import ( + "context" "database/sql/driver" + "errors" "fmt" "io" "reflect" @@ -20,13 +22,19 @@ func waitForQueryResult(dc *DatabendConn, result *QueryResponse) (*QueryResponse if result.Error != nil { return nil, result.Error } - // save schema for final result + // save schema to use in the final response schema := result.Schema var err error for result.NextURI != "" && len(result.Data) == 0 { dc.log("wait for query result", result.NextURI) - result, err = dc.rest.QueryPage(result.NextURI) - if err != nil { + result, err = dc.rest.QueryPage(dc.ctx, result.NextURI) + if errors.Is(err, context.Canceled) { + // context might be canceled due to timeout or canceled. if it's canceled, we need call + // the kill url to tell the backend it's killed. + dc.log("query canceled", result.ID) + dc.rest.KillQuery(context.Background(), result.KillURI) + return nil, err + } else if err != nil { return nil, err } if result.Error != nil { @@ -80,7 +88,7 @@ func (r *nextRows) Columns() []string { func (r *nextRows) Close() error { if len(r.respData.NextURI) != 0 { - _, err := r.dc.rest.QueryPage(r.respData.NextURI) + _, err := r.dc.rest.QueryPage(r.dc.ctx, r.respData.NextURI) if err != nil { return err }