Skip to content

Commit

Permalink
feat: process context in queries (#75)
Browse files Browse the repository at this point in the history
* add contxt

* wrap context in http restful api

* ls

* fix comment
  • Loading branch information
flaneur2020 authored Nov 2, 2023
1 parent 0582a8b commit 9808711
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 33 deletions.
4 changes: 2 additions & 2 deletions con_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (b *httpBatch) BatchInsert() error {
if err != nil {
return errors.Wrap(err, "upload to stage failed")
}
_, err = b.conn.rest.InsertWithStage(b.query, stage, nil, nil)
_, err = b.conn.rest.InsertWithStage(b.ctx, b.query, stage, nil, nil)
if err != nil {
return errors.Wrap(err, "insert with stage failed")
}
Expand Down Expand Up @@ -107,7 +107,7 @@ func (b *httpBatch) UploadToStage() (*StageLocation, error) {
Name: "~",
Path: fmt.Sprintf("batch/%d-%s", time.Now().Unix(), filepath.Base(b.batchFile)),
}
return stage, b.conn.rest.UploadToStage(stage, input, size)
return stage, b.conn.rest.UploadToStage(b.ctx, stage, input, size)
}

var _ ldriver.Batch = (*httpBatch)(nil)
4 changes: 2 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (dc *DatabendConn) exec(ctx context.Context, query string, args ...driver.V
respCh := make(chan QueryResponse)
errCh := make(chan error)
go func() {
err := dc.rest.QuerySync(query, args, respCh)
err := dc.rest.QuerySync(ctx, query, args, respCh)
errCh <- err
}()

Expand All @@ -64,7 +64,7 @@ func (dc *DatabendConn) query(ctx context.Context, query string, args ...driver.
var r0 *QueryResponse
err := retry.Do(
func() error {
r, err := dc.rest.DoQuery(query, args)
r, err := dc.rest.DoQuery(ctx, query, args)
if err != nil {
return err
}
Expand Down
55 changes: 31 additions & 24 deletions restful.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func initAccessTokenLoader(cfg *Config) AccessTokenLoader {
return nil
}

func (c *APIClient) doRequest(method, path string, req interface{}, resp interface{}) error {
func (c *APIClient) doRequest(ctx context.Context, method, path string, req interface{}, resp interface{}) error {
if c.doRequestFunc != nil {
return c.doRequestFunc(method, path, req, resp)
}
Expand All @@ -139,6 +139,7 @@ func (c *APIClient) doRequest(method, path string, req interface{}, resp interfa
if err != nil {
return errors.Wrap(err, "failed to create http request")
}
httpReq = httpReq.WithContext(ctx)

maxRetries := 2
for i := 1; i <= maxRetries; i++ {
Expand Down Expand Up @@ -270,7 +271,7 @@ func (c *APIClient) getSessionConfig() *SessionConfig {
}
}

func (c *APIClient) DoQuery(query string, args []driver.Value) (*QueryResponse, error) {
func (c *APIClient) DoQuery(ctx context.Context, query string, args []driver.Value) (*QueryResponse, error) {
q, err := buildQuery(query, args)
if err != nil {
return nil, err
Expand All @@ -283,7 +284,7 @@ func (c *APIClient) DoQuery(query string, args []driver.Value) (*QueryResponse,

path := "/v1/query"
var result QueryResponse
err = c.doRequest("POST", path, request, &result)
err = c.doRequest(ctx, "POST", path, request, &result)
if err != nil {
return nil, errors.Wrap(err, "failed to do query request")
}
Expand All @@ -309,15 +310,15 @@ func (c *APIClient) applySessionConfig(response *QueryResponse) {
}
}

func (c *APIClient) WaitForQuery(result *QueryResponse) (*QueryResponse, error) {
func (c *APIClient) WaitForQuery(ctx context.Context, result *QueryResponse) (*QueryResponse, error) {
if result.Error != nil {
return nil, errors.Wrap(result.Error, "query failed")
}
var err error
for result.NextURI != "" {
schema := result.Schema
data := result.Data
result, err = c.QueryPage(result.NextURI)
result, err = c.QueryPage(ctx, result.NextURI)
if err != nil {
return nil, errors.Wrap(err, "failed to query page")
}
Expand All @@ -332,13 +333,13 @@ func (c *APIClient) WaitForQuery(result *QueryResponse) (*QueryResponse, error)
return result, nil
}

func (c *APIClient) QuerySingle(query string, args []driver.Value) (*QueryResponse, error) {
result, err := c.DoQuery(query, args)
func (c *APIClient) QuerySingle(ctx context.Context, query string, args []driver.Value) (*QueryResponse, error) {
result, err := c.DoQuery(ctx, query, args)
if err != nil {
return nil, err
}
c.trackStats(result)
return c.WaitForQuery(result)
return c.WaitForQuery(ctx, result)
}

func buildQuery(query string, params []driver.Value) (string, error) {
Expand All @@ -352,12 +353,12 @@ func buildQuery(query string, params []driver.Value) (string, error) {
return query, nil
}

func (c *APIClient) QuerySync(query string, args []driver.Value, respCh chan QueryResponse) error {
func (c *APIClient) QuerySync(ctx context.Context, query string, args []driver.Value, respCh chan QueryResponse) error {
// fmt.Printf("query sync %s", query)
var r0 *QueryResponse
err := retry.Do(
func() error {
r, err := c.DoQuery(query, args)
r, err := c.DoQuery(ctx, query, args)
if err != nil {
return err
}
Expand Down Expand Up @@ -387,7 +388,7 @@ func (c *APIClient) QuerySync(query string, args []driver.Value, respCh chan Que
respCh <- *r0
nextUri := r0.NextURI
for len(nextUri) != 0 {
p, err := c.QueryPage(nextUri)
p, err := c.QueryPage(ctx, nextUri)
if err != nil {
return err
}
Expand All @@ -400,11 +401,11 @@ func (c *APIClient) QuerySync(query string, args []driver.Value, respCh chan Que
return nil
}

func (c *APIClient) QueryPage(nextURI string) (*QueryResponse, error) {
func (c *APIClient) QueryPage(ctx context.Context, nextURI string) (*QueryResponse, error) {
var result QueryResponse
err := retry.Do(
func() error {
return c.doRequest("GET", nextURI, nil, &result)
return c.doRequest(ctx, "GET", nextURI, nil, &result)
},
retry.RetryIf(func(err error) bool {
if err == nil {
Expand All @@ -426,7 +427,13 @@ func (c *APIClient) QueryPage(nextURI string) (*QueryResponse, error) {
return &result, nil
}

func (c *APIClient) InsertWithStage(sql string, stage *StageLocation, fileFormatOptions, copyOptions map[string]string) (*QueryResponse, error) {
func (c *APIClient) KillQuery(ctx context.Context, killURI string) error {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
return c.doRequest(ctx, "POST", killURI, nil, nil)
}

func (c *APIClient) InsertWithStage(ctx context.Context, sql string, stage *StageLocation, fileFormatOptions, copyOptions map[string]string) (*QueryResponse, error) {
if stage == nil {
return nil, errors.New("stage location required for insert with stage")
}
Expand All @@ -449,26 +456,26 @@ func (c *APIClient) InsertWithStage(sql string, stage *StageLocation, fileFormat

path := "/v1/query"
var result QueryResponse
err := c.doRequest("POST", path, request, &result)
err := c.doRequest(ctx, "POST", path, request, &result)
if err != nil {
return nil, errors.Wrap(err, "failed to insert with stage")
}
c.trackStats(&result)
return c.WaitForQuery(&result)
return c.WaitForQuery(ctx, &result)
}

func (c *APIClient) UploadToStage(stage *StageLocation, input *bufio.Reader, size int64) error {
func (c *APIClient) UploadToStage(ctx context.Context, stage *StageLocation, input *bufio.Reader, size int64) error {
if c.PresignedURLDisabled {
return c.UploadToStageByAPI(stage, input, size)
return c.UploadToStageByAPI(ctx, stage, input, size)
} else {
return c.UploadToStageByPresignURL(stage, input, size)
return c.UploadToStageByPresignURL(ctx, stage, input, size)
}
}

func (c *APIClient) GetPresignedURL(stage *StageLocation) (*PresignedResponse, error) {
func (c *APIClient) GetPresignedURL(ctx context.Context, stage *StageLocation) (*PresignedResponse, error) {
var headers string
presignUploadSQL := fmt.Sprintf("PRESIGN UPLOAD %s", stage)
resp, err := c.QuerySingle(presignUploadSQL, nil)
resp, err := c.QuerySingle(ctx, presignUploadSQL, nil)
if err != nil {
return nil, errors.Wrap(err, "failed to query presign url")
}
Expand All @@ -489,8 +496,8 @@ func (c *APIClient) GetPresignedURL(stage *StageLocation) (*PresignedResponse, e
return result, nil
}

func (c *APIClient) UploadToStageByPresignURL(stage *StageLocation, input *bufio.Reader, size int64) error {
presigned, err := c.GetPresignedURL(stage)
func (c *APIClient) UploadToStageByPresignURL(ctx context.Context, stage *StageLocation, input *bufio.Reader, size int64) error {
presigned, err := c.GetPresignedURL(ctx, stage)
if err != nil {
return errors.Wrap(err, "failed to get presigned url")
}
Expand Down Expand Up @@ -522,7 +529,7 @@ func (c *APIClient) UploadToStageByPresignURL(stage *StageLocation, input *bufio
return nil
}

func (c *APIClient) UploadToStageByAPI(stage *StageLocation, input *bufio.Reader, size int64) error {
func (c *APIClient) UploadToStageByAPI(ctx context.Context, stage *StageLocation, input *bufio.Reader, size int64) error {
body := new(bytes.Buffer)
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("upload", stage.Path)
Expand Down
3 changes: 2 additions & 1 deletion restful_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package godatabend

import (
"context"
"database/sql/driver"
"encoding/json"
"testing"
Expand Down Expand Up @@ -57,7 +58,7 @@ func TestDoQuery(t *testing.T) {
doRequestFunc: mockDoRequest,
statsTracker: statsTracker,
}
_, err := c.DoQuery("SELECT 1", []driver.Value{})
_, err := c.DoQuery(context.Background(), "SELECT 1", []driver.Value{})
assert.NoError(t, err)
assert.Equal(t, gotQueryID, "mockid1")
}
16 changes: 12 additions & 4 deletions rows.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package godatabend

import (
"context"
"database/sql/driver"
"errors"
"fmt"
"io"
"reflect"
Expand All @@ -20,13 +22,19 @@ func waitForQueryResult(dc *DatabendConn, result *QueryResponse) (*QueryResponse
if result.Error != nil {
return nil, result.Error
}
// save schema for final result
// save schema to use in the final response
schema := result.Schema
var err error
for result.NextURI != "" && len(result.Data) == 0 {
dc.log("wait for query result", result.NextURI)
result, err = dc.rest.QueryPage(result.NextURI)
if err != nil {
result, err = dc.rest.QueryPage(dc.ctx, result.NextURI)
if errors.Is(err, context.Canceled) {
// context might be canceled due to timeout or canceled. if it's canceled, we need call
// the kill url to tell the backend it's killed.
dc.log("query canceled", result.ID)
dc.rest.KillQuery(context.Background(), result.KillURI)
return nil, err
} else if err != nil {
return nil, err
}
if result.Error != nil {
Expand Down Expand Up @@ -80,7 +88,7 @@ func (r *nextRows) Columns() []string {

func (r *nextRows) Close() error {
if len(r.respData.NextURI) != 0 {
_, err := r.dc.rest.QueryPage(r.respData.NextURI)
_, err := r.dc.rest.QueryPage(r.dc.ctx, r.respData.NextURI)
if err != nil {
return err
}
Expand Down

0 comments on commit 9808711

Please sign in to comment.