Skip to content

Commit

Permalink
feat: support nullable api response data (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
everpcpc authored Aug 7, 2024
1 parent 00ad41a commit 6675d5e
Show file tree
Hide file tree
Showing 18 changed files with 523 additions and 492 deletions.
19 changes: 11 additions & 8 deletions restful.go → client.go
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,6 @@ func (c *APIClient) UploadToStage(ctx context.Context, stage *StageLocation, inp
}

func (c *APIClient) GetPresignedURL(ctx context.Context, stage *StageLocation) (*PresignedResponse, error) {
var headers string
presignUploadSQL := fmt.Sprintf("PRESIGN UPLOAD %s", stage)
resp, err := c.QuerySync(ctx, presignUploadSQL, nil)
if err != nil {
Expand All @@ -611,17 +610,21 @@ func (c *APIClient) GetPresignedURL(ctx context.Context, stage *StageLocation) (
if len(resp.Data) < 1 || len(resp.Data[0]) < 2 {
return nil, errors.Errorf("generate presign url invalid response: %+v", resp.Data)
}

result := &PresignedResponse{
Method: resp.Data[0][0],
Headers: make(map[string]string),
URL: resp.Data[0][2],
if resp.Data[0][0] == nil || resp.Data[0][1] == nil || resp.Data[0][2] == nil {
return nil, errors.Errorf("generate presign url invalid response: %+v", resp.Data)
}
headers = resp.Data[0][1]
err = json.Unmarshal([]byte(headers), &result.Headers)
method := *resp.Data[0][0]
url := *resp.Data[0][2]
headers := map[string]string{}
err = json.Unmarshal([]byte(*resp.Data[0][1]), &headers)
if err != nil {
return nil, errors.Wrap(err, "failed to unmarshal headers")
}
result := &PresignedResponse{
Method: method,
Headers: headers,
URL: url,
}
return result, nil
}

Expand Down
4 changes: 2 additions & 2 deletions restful_test.go → client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ func TestDoQuery(t *testing.T) {
}

c := APIClient{
host: "tn3ftqihs--bl.ch.aws-us-east-2.default.databend.com",
tenant: "tn3ftqihs",
host: "tnxxxxxxx.gw.aws-us-east-2.default.databend.com",
tenant: "tnxxxxxxx",
accessTokenLoader: NewStaticAccessTokenLoader("abc123"),
warehouse: "small-abc",
doRequestFunc: mockDoRequest,
Expand Down
41 changes: 0 additions & 41 deletions helpers.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package godatabend

import (
"bytes"
"fmt"
"net/http"
"strings"
"time"

Expand Down Expand Up @@ -33,17 +31,6 @@ func formatDate(value time.Time) string {
return quote(value.Format(dateFormat))
}

func readResponse(response *http.Response) (result []byte, err error) {
if response.ContentLength > 0 {
result = make([]byte, 0, response.ContentLength)
}
buf := bytes.NewBuffer(result)
defer response.Body.Close()
_, err = buf.ReadFrom(response.Body)
result = buf.Bytes()
return
}

func getTableFromInsertQuery(query string) (string, error) {
if !strings.Contains(query, "insert") && !strings.Contains(query, "INSERT") {
return "", errors.New("wrong insert statement")
Expand All @@ -62,31 +49,3 @@ func generateDescTable(query string) (string, error) {
}
return fmt.Sprintf("DESC %s", table), nil
}

func databendParquetReflect(databendType string) string {

var parquetType string
switch databendType {
case "VARCHAR":
parquetType = "type=BYTE_ARRAY, convertedtype=UTF8, encoding=PLAIN_DICTIONARY"

case "BOOLEAN":
parquetType = "type=BOOLEAN"
case "TINYINT", "SMALLINT", "INT":
parquetType = "type=INT32"
case "BIGINT":
parquetType = "type=INT64"
case "FLOAT":
parquetType = "type=FLOAT"
case "DOUBLE":
parquetType = "type=DOUBLE"
case "DATE":
parquetType = "type=INT32, convertedtype=DATE"
case "TIMESTAMP":
parquetType = "type=INT64"
case "ARRAY":
parquetType = "type=LIST, convertedtype=LIST"

}
return parquetType
}
2 changes: 2 additions & 0 deletions log.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
rlog "github.com/sirupsen/logrus"
)

type contextKey string

// DBSessionIDKey is context key of session id
const DBSessionIDKey contextKey = "LOG_SESSION_ID"

Expand Down
2 changes: 1 addition & 1 deletion query.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type QueryResponse struct {
NodeID string `json:"node_id"`
Session *json.RawMessage `json:"session"`
Schema *[]DataField `json:"schema"`
Data [][]string `json:"data"`
Data [][]*string `json:"data"`
State string `json:"state"`
Error *QueryError `json:"error"`
Stats *QueryStats `json:"stats"`
Expand Down
18 changes: 11 additions & 7 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,12 @@ func (r *nextRows) Next(dest []driver.Value) error {
r.respData.Data = r.respData.Data[1:]

for j := range lineData {
reader := strings.NewReader(lineData[j])
val := lineData[j]
if val == nil {
dest[j] = nil
continue
}
reader := strings.NewReader(*val)
v, err := r.parsers[j].Parse(reader)
if err != nil {
r.dc.log("fail to parse field", j, ", error: ", err)
Expand All @@ -177,18 +182,17 @@ func (r *nextRows) ColumnTypeDatabaseTypeName(index int) string {
return r.types[index]
}

// ColumnTypeDatabaseTypeName implements the driver.RowsColumnTypeNullable
func (r *nextRows) ColumnTypeNullable(index int) (bool, bool) {
return r.parsers[index].Nullable(), true
}

// // ColumnTypeDatabaseTypeName implements the driver.RowsColumnTypeLength
// func (r *nextRows) ColumnTypeLength(index int) (int64, bool) {
// // TODO: implement this
// return 10, true
// }

// // ColumnTypeDatabaseTypeName implements the driver.RowsColumnTypeNullable
// func (r *nextRows) ColumnTypeNullable(index int) (bool, bool) {
// // TODO: implement this
// return true, true
// }

// // ColumnTypeDatabaseTypeName implements the driver.RowsColumnTypePrecisionScale
// func (r *nextRows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
// // TODO: implement this
Expand Down
9 changes: 8 additions & 1 deletion rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ import (
)

func TestTextRows(t *testing.T) {
ptr1 := strPtr("1")
ptr2 := strPtr("2")
ptr3 := strPtr("2")
rows, err := newNextRows(context.Background(), &DatabendConn{}, &QueryResponse{
Data: [][]string{{"1", "2", "3"}, {"3", "2", "1"}},
Data: [][]*string{{ptr1, ptr2, ptr3}, {ptr3, ptr2, ptr1}},
Schema: &[]DataField{
{Name: "age", Type: "Int32"},
{Name: "height", Type: "Int64"},
Expand All @@ -28,3 +31,7 @@ func TestTextRows(t *testing.T) {
assert.Equal(t, "Int32", rows.ColumnTypeDatabaseTypeName(0))
assert.Equal(t, "String", rows.ColumnTypeDatabaseTypeName(2))
}

func strPtr(s string) *string {
return &s
}
2 changes: 1 addition & 1 deletion tests/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ services:
volumes:
- ./data:/data
databend:
image: datafuselabs/databend
image: datafuselabs/databend:nightly
environment:
- QUERY_DEFAULT_USER=databend
- QUERY_DEFAULT_PASSWORD=databend
Expand Down
65 changes: 32 additions & 33 deletions tests/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ const (
)

var (
dsn = "http://databend:databend@localhost:8000?presigned_url_disabled=true"
dsn = "http://databend:databend@localhost:8000?presign=on"
)

func init() {
dsn = os.Getenv("TEST_DATABEND_DSN")
// databend default
// dsn = "http://root:@localhost:8000?presigned_url_disabled=true"
// dsn = "http://root:@localhost:8000?presign=on"

// add user databend by uncommenting corresponding [[query.users]] section scripts/ci/deploy/config/databend-query-node-1.toml
//dsn = "http://databend:databend@localhost:8000?presigned_url_disabled=true"
//dsn = "http://databend:databend@localhost:8000?presign=on"
}

func TestDatabendSuite(t *testing.T) {
Expand All @@ -65,13 +65,6 @@ func (s *DatabendTestSuite) SetupSuite() {

err = s.db.Ping()
s.Nil(err)

rows, err := s.db.Query("select version()")
s.Nil(err)
result, err := scanValues(rows)
s.Nil(err)

s.T().Logf("connected to databend: %s\n", result)
}

func (s *DatabendTestSuite) TearDownSuite() {
Expand Down Expand Up @@ -103,6 +96,14 @@ func (s *DatabendTestSuite) TearDownTest() {
s.r.Nil(err)
}

func (s *DatabendTestSuite) TestVersion() {
rows, err := s.db.Query("select version()")
s.Nil(err)
result, err := scanValues(rows)
s.Nil(err)
s.T().Logf("connected to databend: %s\n", result)
}

// For load balance test
func (s *DatabendTestSuite) TestCycleExec() {
rows, err := s.db.Query("SELECT number from numbers(200) order by number")
Expand All @@ -115,9 +116,11 @@ func (s *DatabendTestSuite) TestQuoteStringQuery() {
m := make(map[string]string, 0)
m["message"] = "this is action 'with quote string'"
x, err := json.Marshal(m)
s.r.Nil(err)
_, err = s.db.Exec(fmt.Sprintf("insert into %s values(?)", s.table2), string(x))
s.r.Nil(err)
rows, err := s.db.Query(fmt.Sprintf("select * from %s", s.table2))
s.r.Nil(err)
for rows.Next() {
var t string
_ = rows.Scan(&t)
Expand Down Expand Up @@ -272,17 +275,6 @@ func (s *DatabendTestSuite) TestServerError() {
s.Contains(err.Error(), "error")
}

func (s *DatabendTestSuite) TestQueryNull() {
rows, err := s.db.Query("SELECT NULL")
s.r.Nil(err)

result, err := scanValues(rows)
s.r.Nil(err)
s.r.Equal([][]interface{}{{"NULL"}}, result)

s.r.NoError(rows.Close())
}

func (s *DatabendTestSuite) TestTransactionCommit() {
tx, err := s.db.Begin()
s.r.Nil(err)
Expand All @@ -298,7 +290,7 @@ func (s *DatabendTestSuite) TestTransactionCommit() {

result, err := scanValues(rows)
s.r.Nil(err)
s.r.Equal([][]interface{}{{"1", "NULL", "NULL", "NULL", "NULL", "NULL", "NULL", "NULL", "NULL"}}, result)
s.r.Equal([][]interface{}{{int64(1), nil, nil, "NULL", "NULL", nil, nil, nil, nil}}, result)

s.r.NoError(rows.Close())
}
Expand Down Expand Up @@ -343,32 +335,39 @@ func (s *DatabendTestSuite) TestLongExec() {
}
}

func getNullableType(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
return t.Elem()
}
return t
}

func scanValues(rows *sql.Rows) (interface{}, error) {
var err error
var result [][]interface{}
ct, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
types := make([]reflect.Type, len(ct))
for i, v := range ct {
types[i] = v.ScanType()
}
ptrs := make([]interface{}, len(types))
vals := make([]any, len(ct))
for rows.Next() {
if err = rows.Err(); err != nil {
return nil, err
}
for i, t := range types {
ptrs[i] = reflect.New(t).Interface()
for i := range ct {
vals[i] = &dc.NullableValue{}
}
err = rows.Scan(ptrs...)
err = rows.Scan(vals...)
if err != nil {
return nil, err
}
values := make([]interface{}, len(types))
for i, p := range ptrs {
values[i] = reflect.ValueOf(p).Elem().Interface()
values := make([]interface{}, len(ct))
for i, p := range vals {
val, err := p.(*dc.NullableValue).Value()
if err != nil {
return nil, fmt.Errorf("failed to get value: %w", err)
}
values[i] = val
}
result = append(result, values)
}
Expand Down
55 changes: 55 additions & 0 deletions tests/nullable_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package tests

import (
"database/sql"
"fmt"
)

func (s *DatabendTestSuite) TestNullable() {
_, err := s.db.Exec(fmt.Sprintf("INSERT INTO %s (i64) VALUES (?)", s.table), int64(1))
s.r.Nil(err)

rows, err := s.db.Query(fmt.Sprintf("SELECT * FROM %s", s.table))
s.r.Nil(err)
result, err := scanValues(rows)
s.r.Nil(err)
s.r.Equal([][]interface{}{{int64(1), nil, nil, "NULL", "NULL", nil, nil, nil, nil}}, result)
s.r.NoError(rows.Close())

_, err = s.db.Exec("SET GLOBAL format_null_as_str=0")
s.r.Nil(err)

rows, err = s.db.Query(fmt.Sprintf("SELECT * FROM %s", s.table))
s.r.Nil(err)
result, err = scanValues(rows)
s.r.Nil(err)
s.r.Equal([][]interface{}{{int64(1), nil, nil, nil, nil, nil, nil, nil, nil}}, result)
s.r.NoError(rows.Close())

_, err = s.db.Exec("UNSET format_null_as_str")
s.r.Nil(err)
}

func (s *DatabendTestSuite) TestQueryNullAsStr() {
row := s.db.QueryRow("SELECT NULL")
var val sql.NullString
err := row.Scan(&val)
s.r.Nil(err)
s.r.True(val.Valid)
s.r.Equal("NULL", val.String)
}

func (s *DatabendTestSuite) TestQueryNull() {
_, err := s.db.Exec("SET GLOBAL format_null_as_str=0")
s.r.Nil(err)

row := s.db.QueryRow("SELECT NULL")
var val sql.NullString
err = row.Scan(&val)
s.r.Nil(err)
s.r.False(val.Valid)
s.r.Equal("", val.String)

_, err = s.db.Exec("UNSET format_null_as_str")
s.r.Nil(err)
}
Loading

0 comments on commit 6675d5e

Please sign in to comment.