diff --git a/restful.go b/client.go similarity index 98% rename from restful.go rename to client.go index 361c293..54bf561 100644 --- a/restful.go +++ b/client.go @@ -602,7 +602,6 @@ func (c *APIClient) UploadToStage(ctx context.Context, stage *StageLocation, inp } func (c *APIClient) GetPresignedURL(ctx context.Context, stage *StageLocation) (*PresignedResponse, error) { - var headers string presignUploadSQL := fmt.Sprintf("PRESIGN UPLOAD %s", stage) resp, err := c.QuerySync(ctx, presignUploadSQL, nil) if err != nil { @@ -611,17 +610,21 @@ func (c *APIClient) GetPresignedURL(ctx context.Context, stage *StageLocation) ( if len(resp.Data) < 1 || len(resp.Data[0]) < 2 { return nil, errors.Errorf("generate presign url invalid response: %+v", resp.Data) } - - result := &PresignedResponse{ - Method: resp.Data[0][0], - Headers: make(map[string]string), - URL: resp.Data[0][2], + if resp.Data[0][0] == nil || resp.Data[0][1] == nil || resp.Data[0][2] == nil { + return nil, errors.Errorf("generate presign url invalid response: %+v", resp.Data) } - headers = resp.Data[0][1] - err = json.Unmarshal([]byte(headers), &result.Headers) + method := *resp.Data[0][0] + url := *resp.Data[0][2] + headers := map[string]string{} + err = json.Unmarshal([]byte(*resp.Data[0][1]), &headers) if err != nil { return nil, errors.Wrap(err, "failed to unmarshal headers") } + result := &PresignedResponse{ + Method: method, + Headers: headers, + URL: url, + } return result, nil } diff --git a/restful_test.go b/client_test.go similarity index 96% rename from restful_test.go rename to client_test.go index f86ba37..803cd59 100644 --- a/restful_test.go +++ b/client_test.go @@ -87,8 +87,8 @@ func TestDoQuery(t *testing.T) { } c := APIClient{ - host: "tn3ftqihs--bl.ch.aws-us-east-2.default.databend.com", - tenant: "tn3ftqihs", + host: "tnxxxxxxx.gw.aws-us-east-2.default.databend.com", + tenant: "tnxxxxxxx", accessTokenLoader: NewStaticAccessTokenLoader("abc123"), warehouse: "small-abc", doRequestFunc: mockDoRequest, diff --git a/helpers.go b/helpers.go index c57667d..e37ffac 100644 --- a/helpers.go +++ b/helpers.go @@ -1,9 +1,7 @@ package godatabend import ( - "bytes" "fmt" - "net/http" "strings" "time" @@ -33,17 +31,6 @@ func formatDate(value time.Time) string { return quote(value.Format(dateFormat)) } -func readResponse(response *http.Response) (result []byte, err error) { - if response.ContentLength > 0 { - result = make([]byte, 0, response.ContentLength) - } - buf := bytes.NewBuffer(result) - defer response.Body.Close() - _, err = buf.ReadFrom(response.Body) - result = buf.Bytes() - return -} - func getTableFromInsertQuery(query string) (string, error) { if !strings.Contains(query, "insert") && !strings.Contains(query, "INSERT") { return "", errors.New("wrong insert statement") @@ -62,31 +49,3 @@ func generateDescTable(query string) (string, error) { } return fmt.Sprintf("DESC %s", table), nil } - -func databendParquetReflect(databendType string) string { - - var parquetType string - switch databendType { - case "VARCHAR": - parquetType = "type=BYTE_ARRAY, convertedtype=UTF8, encoding=PLAIN_DICTIONARY" - - case "BOOLEAN": - parquetType = "type=BOOLEAN" - case "TINYINT", "SMALLINT", "INT": - parquetType = "type=INT32" - case "BIGINT": - parquetType = "type=INT64" - case "FLOAT": - parquetType = "type=FLOAT" - case "DOUBLE": - parquetType = "type=DOUBLE" - case "DATE": - parquetType = "type=INT32, convertedtype=DATE" - case "TIMESTAMP": - parquetType = "type=INT64" - case "ARRAY": - parquetType = "type=LIST, convertedtype=LIST" - - } - return parquetType -} diff --git a/log.go b/log.go index 74d0097..f68d5f2 100644 --- a/log.go +++ b/log.go @@ -11,6 +11,8 @@ import ( rlog "github.com/sirupsen/logrus" ) +type contextKey string + // DBSessionIDKey is context key of session id const DBSessionIDKey contextKey = "LOG_SESSION_ID" diff --git a/query.go b/query.go index ef3a54d..7433abd 100644 --- a/query.go +++ b/query.go @@ -33,7 +33,7 @@ type QueryResponse struct { NodeID string `json:"node_id"` Session *json.RawMessage `json:"session"` Schema *[]DataField `json:"schema"` - Data [][]string `json:"data"` + Data [][]*string `json:"data"` State string `json:"state"` Error *QueryError `json:"error"` Stats *QueryStats `json:"stats"` diff --git a/rows.go b/rows.go index b798372..03ce2b3 100644 --- a/rows.go +++ b/rows.go @@ -156,7 +156,12 @@ func (r *nextRows) Next(dest []driver.Value) error { r.respData.Data = r.respData.Data[1:] for j := range lineData { - reader := strings.NewReader(lineData[j]) + val := lineData[j] + if val == nil { + dest[j] = nil + continue + } + reader := strings.NewReader(*val) v, err := r.parsers[j].Parse(reader) if err != nil { r.dc.log("fail to parse field", j, ", error: ", err) @@ -177,18 +182,17 @@ func (r *nextRows) ColumnTypeDatabaseTypeName(index int) string { return r.types[index] } +// ColumnTypeDatabaseTypeName implements the driver.RowsColumnTypeNullable +func (r *nextRows) ColumnTypeNullable(index int) (bool, bool) { + return r.parsers[index].Nullable(), true +} + // // ColumnTypeDatabaseTypeName implements the driver.RowsColumnTypeLength // func (r *nextRows) ColumnTypeLength(index int) (int64, bool) { // // TODO: implement this // return 10, true // } -// // ColumnTypeDatabaseTypeName implements the driver.RowsColumnTypeNullable -// func (r *nextRows) ColumnTypeNullable(index int) (bool, bool) { -// // TODO: implement this -// return true, true -// } - // // ColumnTypeDatabaseTypeName implements the driver.RowsColumnTypePrecisionScale // func (r *nextRows) ColumnTypePrecisionScale(index int) (int64, int64, bool) { // // TODO: implement this diff --git a/rows_test.go b/rows_test.go index 8196740..683f1f5 100644 --- a/rows_test.go +++ b/rows_test.go @@ -9,8 +9,11 @@ import ( ) func TestTextRows(t *testing.T) { + ptr1 := strPtr("1") + ptr2 := strPtr("2") + ptr3 := strPtr("2") rows, err := newNextRows(context.Background(), &DatabendConn{}, &QueryResponse{ - Data: [][]string{{"1", "2", "3"}, {"3", "2", "1"}}, + Data: [][]*string{{ptr1, ptr2, ptr3}, {ptr3, ptr2, ptr1}}, Schema: &[]DataField{ {Name: "age", Type: "Int32"}, {Name: "height", Type: "Int64"}, @@ -28,3 +31,7 @@ func TestTextRows(t *testing.T) { assert.Equal(t, "Int32", rows.ColumnTypeDatabaseTypeName(0)) assert.Equal(t, "String", rows.ColumnTypeDatabaseTypeName(2)) } + +func strPtr(s string) *string { + return &s +} diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 1493207..d3ede21 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -6,7 +6,7 @@ services: volumes: - ./data:/data databend: - image: datafuselabs/databend + image: datafuselabs/databend:nightly environment: - QUERY_DEFAULT_USER=databend - QUERY_DEFAULT_PASSWORD=databend diff --git a/tests/main_test.go b/tests/main_test.go index 082f9c6..d0f0407 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -32,16 +32,16 @@ const ( ) var ( - dsn = "http://databend:databend@localhost:8000?presigned_url_disabled=true" + dsn = "http://databend:databend@localhost:8000?presign=on" ) func init() { dsn = os.Getenv("TEST_DATABEND_DSN") // databend default - // dsn = "http://root:@localhost:8000?presigned_url_disabled=true" + // dsn = "http://root:@localhost:8000?presign=on" // add user databend by uncommenting corresponding [[query.users]] section scripts/ci/deploy/config/databend-query-node-1.toml - //dsn = "http://databend:databend@localhost:8000?presigned_url_disabled=true" + //dsn = "http://databend:databend@localhost:8000?presign=on" } func TestDatabendSuite(t *testing.T) { @@ -65,13 +65,6 @@ func (s *DatabendTestSuite) SetupSuite() { err = s.db.Ping() s.Nil(err) - - rows, err := s.db.Query("select version()") - s.Nil(err) - result, err := scanValues(rows) - s.Nil(err) - - s.T().Logf("connected to databend: %s\n", result) } func (s *DatabendTestSuite) TearDownSuite() { @@ -103,6 +96,14 @@ func (s *DatabendTestSuite) TearDownTest() { s.r.Nil(err) } +func (s *DatabendTestSuite) TestVersion() { + rows, err := s.db.Query("select version()") + s.Nil(err) + result, err := scanValues(rows) + s.Nil(err) + s.T().Logf("connected to databend: %s\n", result) +} + // For load balance test func (s *DatabendTestSuite) TestCycleExec() { rows, err := s.db.Query("SELECT number from numbers(200) order by number") @@ -115,9 +116,11 @@ func (s *DatabendTestSuite) TestQuoteStringQuery() { m := make(map[string]string, 0) m["message"] = "this is action 'with quote string'" x, err := json.Marshal(m) + s.r.Nil(err) _, err = s.db.Exec(fmt.Sprintf("insert into %s values(?)", s.table2), string(x)) s.r.Nil(err) rows, err := s.db.Query(fmt.Sprintf("select * from %s", s.table2)) + s.r.Nil(err) for rows.Next() { var t string _ = rows.Scan(&t) @@ -272,17 +275,6 @@ func (s *DatabendTestSuite) TestServerError() { s.Contains(err.Error(), "error") } -func (s *DatabendTestSuite) TestQueryNull() { - rows, err := s.db.Query("SELECT NULL") - s.r.Nil(err) - - result, err := scanValues(rows) - s.r.Nil(err) - s.r.Equal([][]interface{}{{"NULL"}}, result) - - s.r.NoError(rows.Close()) -} - func (s *DatabendTestSuite) TestTransactionCommit() { tx, err := s.db.Begin() s.r.Nil(err) @@ -298,7 +290,7 @@ func (s *DatabendTestSuite) TestTransactionCommit() { result, err := scanValues(rows) s.r.Nil(err) - s.r.Equal([][]interface{}{{"1", "NULL", "NULL", "NULL", "NULL", "NULL", "NULL", "NULL", "NULL"}}, result) + s.r.Equal([][]interface{}{{int64(1), nil, nil, "NULL", "NULL", nil, nil, nil, nil}}, result) s.r.NoError(rows.Close()) } @@ -343,6 +335,13 @@ func (s *DatabendTestSuite) TestLongExec() { } } +func getNullableType(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + return t.Elem() + } + return t +} + func scanValues(rows *sql.Rows) (interface{}, error) { var err error var result [][]interface{} @@ -350,25 +349,25 @@ func scanValues(rows *sql.Rows) (interface{}, error) { if err != nil { return nil, err } - types := make([]reflect.Type, len(ct)) - for i, v := range ct { - types[i] = v.ScanType() - } - ptrs := make([]interface{}, len(types)) + vals := make([]any, len(ct)) for rows.Next() { if err = rows.Err(); err != nil { return nil, err } - for i, t := range types { - ptrs[i] = reflect.New(t).Interface() + for i := range ct { + vals[i] = &dc.NullableValue{} } - err = rows.Scan(ptrs...) + err = rows.Scan(vals...) if err != nil { return nil, err } - values := make([]interface{}, len(types)) - for i, p := range ptrs { - values[i] = reflect.ValueOf(p).Elem().Interface() + values := make([]interface{}, len(ct)) + for i, p := range vals { + val, err := p.(*dc.NullableValue).Value() + if err != nil { + return nil, fmt.Errorf("failed to get value: %w", err) + } + values[i] = val } result = append(result, values) } diff --git a/tests/nullable_test.go b/tests/nullable_test.go new file mode 100644 index 0000000..0c0b1cc --- /dev/null +++ b/tests/nullable_test.go @@ -0,0 +1,55 @@ +package tests + +import ( + "database/sql" + "fmt" +) + +func (s *DatabendTestSuite) TestNullable() { + _, err := s.db.Exec(fmt.Sprintf("INSERT INTO %s (i64) VALUES (?)", s.table), int64(1)) + s.r.Nil(err) + + rows, err := s.db.Query(fmt.Sprintf("SELECT * FROM %s", s.table)) + s.r.Nil(err) + result, err := scanValues(rows) + s.r.Nil(err) + s.r.Equal([][]interface{}{{int64(1), nil, nil, "NULL", "NULL", nil, nil, nil, nil}}, result) + s.r.NoError(rows.Close()) + + _, err = s.db.Exec("SET GLOBAL format_null_as_str=0") + s.r.Nil(err) + + rows, err = s.db.Query(fmt.Sprintf("SELECT * FROM %s", s.table)) + s.r.Nil(err) + result, err = scanValues(rows) + s.r.Nil(err) + s.r.Equal([][]interface{}{{int64(1), nil, nil, nil, nil, nil, nil, nil, nil}}, result) + s.r.NoError(rows.Close()) + + _, err = s.db.Exec("UNSET format_null_as_str") + s.r.Nil(err) +} + +func (s *DatabendTestSuite) TestQueryNullAsStr() { + row := s.db.QueryRow("SELECT NULL") + var val sql.NullString + err := row.Scan(&val) + s.r.Nil(err) + s.r.True(val.Valid) + s.r.Equal("NULL", val.String) +} + +func (s *DatabendTestSuite) TestQueryNull() { + _, err := s.db.Exec("SET GLOBAL format_null_as_str=0") + s.r.Nil(err) + + row := s.db.QueryRow("SELECT NULL") + var val sql.NullString + err = row.Scan(&val) + s.r.Nil(err) + s.r.False(val.Valid) + s.r.Equal("", val.String) + + _, err = s.db.Exec("UNSET format_null_as_str") + s.r.Nil(err) +} diff --git a/tests/txn_test.go b/tests/txn_test.go index 1bc6f8c..8296845 100644 --- a/tests/txn_test.go +++ b/tests/txn_test.go @@ -53,13 +53,13 @@ func TestTnx(t *testing.T) { assert.NoError(t, err) if rows1 != nil { res1, _ := scanValues(rows1) - assert.Equal(t, [][]interface{}{{"2"}}, res1) + assert.Equal(t, [][]interface{}{{int32(2)}}, res1) } rows2, err = db2.Query(selectT) assert.NoError(t, err) if rows2 != nil { res2, _ := scanValues(rows2) - assert.Equal(t, [][]interface{}{{"2"}}, res2) + assert.Equal(t, [][]interface{}{{int32(2)}}, res2) } // test rollback diff --git a/tokenizer.go b/tokenizer.go index e5e0480..4bf54d3 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -4,32 +4,12 @@ import ( "bytes" "fmt" "io" - "strings" ) const ( eof = rune(0) ) -type token struct { - kind rune - data string -} - -func skipWhiteSpace(s io.RuneScanner) { - for { - r := read(s) - switch r { - case ' ', '\t', '\n': - continue - case eof: - return - } - _ = s.UnreadRune() - return - } -} - func read(s io.RuneScanner) rune { r, _, err := s.ReadRune() if err != nil { @@ -75,91 +55,3 @@ func readRaw(s io.RuneScanner) *bytes.Buffer { return &data } - -func readQuoted(s io.RuneScanner) (*token, error) { - var data bytes.Buffer - -loop: - for { - r := read(s) - - switch r { - case eof: - return nil, fmt.Errorf("unexpected eof inside quoted string") - case '\\': - escaped, err := readEscaped(s) - if err != nil { - return nil, fmt.Errorf("incorrect escaping in quoted string: %v", err) - } - r = escaped - case '\'': - break loop - } - - data.WriteRune(r) - } - - return &token{'q', data.String()}, nil -} - -func readNumberOrID(s io.RuneScanner) *token { - var data bytes.Buffer - -loop: - for { - r := read(s) - - switch r { - case eof, ' ', '\t', '\n': - break loop - case '(', ')', ',': - _ = s.UnreadRune() - break loop - default: - data.WriteRune(r) - } - } - - return &token{'s', data.String()} -} - -func tokenize(s io.RuneScanner) ([]*token, error) { - var tokens []*token - -loop: - for { - var t *token - var err error - - switch read(s) { - case eof: - break loop - case ' ', '\t', '\n': - skipWhiteSpace(s) - continue - case '(': - t = &token{kind: '('} - case ')': - t = &token{kind: ')'} - case ',': - t = &token{kind: ','} - case '\'': - t, err = readQuoted(s) - if err != nil { - return nil, err - } - default: - _ = s.UnreadRune() - t = readNumberOrID(s) - } - - tokens = append(tokens, t) - } - - tokens = append(tokens, &token{kind: eof}) - return tokens, nil -} - -func tokenizeString(s string) ([]*token, error) { - return tokenize(strings.NewReader(s)) -} diff --git a/tokenizer_test.go b/tokenizer_test.go deleted file mode 100644 index 45d7ff3..0000000 --- a/tokenizer_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package godatabend - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestTokenize(t *testing.T) { - type testCase struct { - name string - input string - output []*token - fail bool - } - testCases := []*testCase{ - { - name: "empty", - input: "", - output: []*token{{eof, ""}}, - }, - { - name: "only whitespace", - input: "", - output: []*token{{eof, ""}}, - }, - { - name: "whitespace all over the place", - input: " \t\nhello \t \n world \n", - output: []*token{ - {'s', "hello"}, - {'s', "world"}, - {eof, ""}, - }, - }, - { - name: "complex with quotes and escaping", - input: `Array(Tuple(FixedString(5), Float32, 'hello, \') world'))`, - output: []*token{ - {'s', "Array"}, - {'(', ""}, - {'s', "Tuple"}, - {'(', ""}, - {'s', "FixedString"}, - {'(', ""}, - {'s', "5"}, - {')', ""}, - {',', ""}, - {'s', "Float32"}, - {',', ""}, - {'q', `hello, ') world`}, - {')', ""}, - {')', ""}, - {eof, ""}, - }, - }, - { - name: "unclosed quote", - input: "Array(')", - fail: true, - }, - { - name: "unfinished escape", - input: `Array('\`, - fail: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(tt *testing.T) { - output, err := tokenizeString(tc.input) - if tc.fail { - assert.Error(tt, err) - } else { - assert.NoError(tt, err) - assert.Equal(tt, tc.output, output) - } - }) - } -} diff --git a/typeparser.go b/typeparser.go index a30279b..a385781 100644 --- a/typeparser.go +++ b/typeparser.go @@ -1,96 +1,80 @@ package godatabend -import ( - "fmt" -) +import "fmt" // TypeDesc describes a (possibly nested) data type returned by Databend. type TypeDesc struct { - Name string - Args []*TypeDesc + Name string + Nullable bool + Args []*TypeDesc } -func parseTypeDesc(tokens []*token) (*TypeDesc, []*token, error) { - var name string - if tokens[0].kind == 's' || tokens[0].kind == 'q' { - name = tokens[0].data - tokens = tokens[1:] - } else { - return nil, nil, fmt.Errorf("failed to parse type name: wrong token type '%c'", tokens[0].kind) - } - - desc := TypeDesc{Name: name} - if tokens[0].kind != '(' { - return &desc, tokens, nil - } - - tokens = tokens[1:] - - if tokens[0].kind == ')' { - return &desc, tokens[1:], nil - } +func ParseTypeDesc(s string) (*TypeDesc, error) { + var ( + name = "" + args = []*TypeDesc{} + depth = 0 + start = 0 + nullable = false + ) - if name == "Enum8" || name == "Enum16" { - // TODO: an Enum's arguments get completely ignored - for i := range tokens { - if tokens[i].kind == ')' { - return &desc, tokens[i+1:], nil + for i, c := range s { + switch c { + case '(': + if depth == 0 { + name = s[start:i] + start = i + 1 } - } - return nil, nil, fmt.Errorf("unfinished enum type description") - } - - for { - var arg *TypeDesc - var err error - - arg, tokens, err = parseTypeDesc(tokens) - if err != nil { - return nil, nil, fmt.Errorf("failed to parse subtype: %v", err) - } - desc.Args = append(desc.Args, arg) - - switch tokens[0].kind { - case ',': - tokens = tokens[1:] - continue + depth++ case ')': - return &desc, tokens[1:], nil + depth-- + if depth == 0 { + s := s[start:i] + if s != "" { + desc, err := ParseTypeDesc(s) + if err != nil { + return nil, err + } + args = append(args, desc) + } + start = i + 1 + } + case ',': + if depth == 1 { + s := s[start:i] + if s != "" { + desc, err := ParseTypeDesc(s) + if err != nil { + return nil, err + } + args = append(args, desc) + } + start = i + 1 + } + case ' ': + if depth == 0 { + s := s[start:i] + if s != "" { + name = s + } + start = i + 1 + } } } -} - -// ParseTypeDesc parses the type description that Databend provides. -// -// The grammar is quite simple: -// -// desc -// name -// name() -// name(args) -// args -// desc -// desc, args -// -// Examples: -// -// String -// Nullable(Nothing) -// Array(Tuple(Tuple(String, String), Tuple(String, UInt64))) -func ParseTypeDesc(s string) (*TypeDesc, error) { - tokens, err := tokenizeString(s) - if err != nil { - return nil, fmt.Errorf("failed to tokenize type description: %v", err) + if depth != 0 { + return nil, fmt.Errorf("invalid type desc: %s", s) } - - desc, tail, err := parseTypeDesc(tokens) - if err != nil { - return nil, fmt.Errorf("failed to parse type description: %v", err) - } - - if len(tail) != 1 || tail[0].kind != eof { - return nil, fmt.Errorf("unexpected tail after type description") + if start < len(s) { + s := s[start:] + if s != "" { + if name == "" { + name = s + } else if s == "NULL" { + nullable = true + } else { + return nil, fmt.Errorf("invalid type arg for %s: %s", name, s) + } + } } - - return desc, nil + return &TypeDesc{Name: name, Nullable: nullable, Args: args}, nil } diff --git a/typeparser_test.go b/typeparser_test.go index c2959c5..8999a50 100644 --- a/typeparser_test.go +++ b/typeparser_test.go @@ -6,85 +6,125 @@ import ( "github.com/stretchr/testify/assert" ) +type typeparserTestCase struct { + desc string + input string + output *TypeDesc + fail bool +} + func TestParseTypeDesc(t *testing.T) { - type testCase struct { - name string - input string - output *TypeDesc - fail bool - } - testCases := []*testCase{ - { - name: "plain type", - input: "String", - output: &TypeDesc{Name: "String"}, - }, + testCases := []*typeparserTestCase{ { - name: "nullable type", - input: "Nullable(Nothing)", + desc: "plain type", + input: "String", output: &TypeDesc{ - Name: "Nullable", - Args: []*TypeDesc{{Name: "Nothing"}}, + Name: "String", + Nullable: false, + Args: []*TypeDesc{}, }, }, { - name: "empty arg", - input: "DateTime()", - output: &TypeDesc{Name: "DateTime"}, - }, - { - name: "numeric arg", - input: "FixedString(42)", + desc: "decimal type", + input: "Decimal(42, 42)", output: &TypeDesc{ - Name: "FixedString", - Args: []*TypeDesc{{Name: "42"}}, + Name: "Decimal", + Nullable: false, + Args: []*TypeDesc{ + { + Name: "42", + Nullable: false, + Args: []*TypeDesc{}, + }, + { + Name: "42", + Nullable: false, + Args: []*TypeDesc{}, + }, + }, }, }, { - name: "args are ignored for Enum", - input: "Enum8(you can = put, 'whatever' here)", - output: &TypeDesc{Name: "Enum8"}, - }, - { - name: "quoted arg", - input: "DateTime('UTC')", + desc: "nullable type", + input: "Nullable(Nothing)", output: &TypeDesc{ - Name: "DateTime", - Args: []*TypeDesc{{Name: "UTC"}}, + Name: "Nullable", + Nullable: false, + Args: []*TypeDesc{ + { + Name: "Nothing", + Nullable: false, + Args: []*TypeDesc{}, + }, + }, }, }, { - name: "decimal", - input: "Decimal(9,4)", + desc: "empty arg", + input: "DateTime()", output: &TypeDesc{ - Name: "Decimal", - Args: []*TypeDesc{{Name: "9"}, {Name: "4"}}, + Name: "DateTime", + Nullable: false, + Args: []*TypeDesc{}, }, }, { - name: "quoted escaped arg", - input: `DateTime('UTC\b\r\n\'\f\t\0')`, + desc: "numeric arg", + input: "FixedString(42)", output: &TypeDesc{ - Name: "DateTime", - Args: []*TypeDesc{{Name: "UTC\b\r\n'\f\t\x00"}}, + Name: "FixedString", + Nullable: false, + Args: []*TypeDesc{ + { + Name: "42", + Nullable: false, + Args: []*TypeDesc{}, + }, + }, }, }, { - name: "nested args", + desc: "multiple args", input: "Array(Tuple(Tuple(String, String), Tuple(String, UInt64)))", output: &TypeDesc{ - Name: "Array", + Name: "Array", + Nullable: false, Args: []*TypeDesc{ { - Name: "Tuple", + Name: "Tuple", + Nullable: false, Args: []*TypeDesc{ { - Name: "Tuple", - Args: []*TypeDesc{{Name: "String"}, {Name: "String"}}, + Name: "Tuple", + Nullable: false, + Args: []*TypeDesc{ + { + Name: "String", + Nullable: false, + Args: []*TypeDesc{}, + }, + { + Name: "String", + Nullable: false, + Args: []*TypeDesc{}, + }, + }, }, { - Name: "Tuple", - Args: []*TypeDesc{{Name: "String"}, {Name: "UInt64"}}, + Name: "Tuple", + Nullable: false, + Args: []*TypeDesc{ + { + Name: "String", + Nullable: false, + Args: []*TypeDesc{}, + }, + { + Name: "UInt64", + Nullable: false, + Args: []*TypeDesc{}, + }, + }, }, }, }, @@ -92,50 +132,126 @@ func TestParseTypeDesc(t *testing.T) { }, }, { - name: "map args", + desc: "map args", input: "Map(String, Array(Int64))", output: &TypeDesc{ - Name: "Map", + Name: "Map", + Nullable: false, Args: []*TypeDesc{ { - Name: "String", + Name: "String", + Nullable: false, + Args: []*TypeDesc{}, }, { - Name: "Array", - Args: []*TypeDesc{{Name: "Int64"}}, + Name: "Array", + Nullable: false, + Args: []*TypeDesc{ + { + Name: "Int64", + Nullable: false, + Args: []*TypeDesc{}, + }, + }, }, }, }, }, { - name: "unfinished arg list", - input: "Array(Tuple(Tuple(String, String), Tuple(String, UInt64))", - fail: true, - }, - { - name: "left paren without name", - input: "(", - fail: true, - }, - { - name: "unfinished quote", - input: "Array(')", - fail: true, - }, - { - name: "unfinished escape", - input: `Array(\`, - fail: true, + desc: "map nullable value args", + input: "Map(String, String NULL)", + output: &TypeDesc{ + Name: "Map", + Nullable: false, + Args: []*TypeDesc{ + { + Name: "String", + Nullable: false, + Args: []*TypeDesc{}, + }, + { + Name: "String", + Nullable: true, + Args: []*TypeDesc{}, + }, + }, + }, }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(tt *testing.T) { + output, err := ParseTypeDesc(tc.input) + if tc.fail { + assert.Error(tt, err) + } else { + assert.NoError(tt, err) + } + assert.Equal(tt, tc.output, output) + }) + } +} + +func TestParseComplexTypeWithNull(t *testing.T) { + testCases := []*typeparserTestCase{ { - name: "stuff after end", - input: `Array() String`, - fail: true, + desc: "complex nullable type", + input: "Nullable(Tuple(String NULL, Array(Tuple(Array(Int32 NULL) NULL, Array(String NULL) NULL) NULL) NULL))", + output: &TypeDesc{ + Name: "Nullable", + Nullable: false, + Args: []*TypeDesc{ + { + Name: "Tuple", + Nullable: false, + Args: []*TypeDesc{ + { + Name: "String", + Nullable: true, + Args: []*TypeDesc{}, + }, + { + Name: "Array", + Nullable: true, + Args: []*TypeDesc{ + { + Name: "Tuple", + Nullable: true, + Args: []*TypeDesc{ + { + Name: "Array", + Nullable: true, + Args: []*TypeDesc{ + { + Name: "Int32", + Nullable: true, + Args: []*TypeDesc{}, + }, + }, + }, + { + Name: "Array", + Nullable: true, + Args: []*TypeDesc{ + { + Name: "String", + Nullable: true, + Args: []*TypeDesc{}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, }, } - for _, tc := range testCases { - t.Run(tc.name, func(tt *testing.T) { + t.Run(tc.desc, func(tt *testing.T) { output, err := ParseTypeDesc(tc.input) if tc.fail { assert.Error(tt, err) diff --git a/types.go b/types.go index 16e5e46..c6e8537 100644 --- a/types.go +++ b/types.go @@ -112,3 +112,18 @@ type tuple struct { func (t tuple) Value() (driver.Value, error) { return textEncode.Encode(t) } + +type NullableValue struct { + val any +} + +// Scan implements the [Scanner] interface. +func (nv *NullableValue) Scan(value any) error { + nv.val = value + return nil +} + +// Value implements the [driver.Valuer] interface. +func (nv NullableValue) Value() (driver.Value, error) { + return nv.val, nil +} diff --git a/util.go b/util.go deleted file mode 100644 index 3b5390b..0000000 --- a/util.go +++ /dev/null @@ -1,3 +0,0 @@ -package godatabend - -type contextKey string diff --git a/dataparser.go b/valueparser.go similarity index 88% rename from dataparser.go rename to valueparser.go index 5dd7068..4e32478 100644 --- a/dataparser.go +++ b/valueparser.go @@ -100,17 +100,62 @@ func readString(s io.RuneScanner, length int, unquote bool) (string, error) { return str, nil } +func peakNull(s io.RuneScanner) bool { + r := read(s) + if r != 'N' { + _ = s.UnreadRune() + return false + } + + r = read(s) + if r != 'U' { + _ = s.UnreadRune() + _ = s.UnreadRune() + return false + } + + r = read(s) + if r != 'L' { + _ = s.UnreadRune() + _ = s.UnreadRune() + _ = s.UnreadRune() + return false + } + + r = read(s) + if r != 'L' { + _ = s.UnreadRune() + _ = s.UnreadRune() + _ = s.UnreadRune() + _ = s.UnreadRune() + return false + } + + r = read(s) + if r != eof { + _ = s.UnreadRune() + _ = s.UnreadRune() + _ = s.UnreadRune() + _ = s.UnreadRune() + _ = s.UnreadRune() + return false + } + + return true +} + // DataParser implements parsing of a driver value and reporting its type. type DataParser interface { Parse(io.RuneScanner) (driver.Value, error) + Nullable() bool Type() reflect.Type } -type nullableParser struct { +type nullParser struct { DataParser } -func (p *nullableParser) Parse(s io.RuneScanner) (driver.Value, error) { +func (p *nullParser) Parse(s io.RuneScanner) (driver.Value, error) { var dB *bytes.Buffer dType := p.DataParser.Type() @@ -238,6 +283,10 @@ func (p *stringParser) Type() reflect.Type { return reflectTypeString } +func (p *stringParser) Nullable() bool { + return false +} + type booleanParser struct { length int } @@ -254,6 +303,10 @@ func (p *booleanParser) Type() reflect.Type { return reflectTypeBool } +func (p *booleanParser) Nullable() bool { + return false +} + type dateTimeParser struct { unquote bool format string @@ -290,6 +343,10 @@ func (p *dateTimeParser) Type() reflect.Type { return reflectTypeTime } +func (p *dateTimeParser) Nullable() bool { + return false +} + type tupleParser struct { args []DataParser } @@ -303,6 +360,10 @@ func (p *tupleParser) Type() reflect.Type { return reflect.StructOf(fields) } +func (p *tupleParser) Nullable() bool { + return false +} + func (p *tupleParser) Parse(s io.RuneScanner) (driver.Value, error) { r := read(s) if r != '(' { @@ -342,6 +403,10 @@ func (p *arrayParser) Type() reflect.Type { return reflect.SliceOf(p.arg.Type()) } +func (p *arrayParser) Nullable() bool { + return false +} + func (p *arrayParser) Parse(s io.RuneScanner) (driver.Value, error) { r := read(s) if r != '[' { @@ -362,7 +427,7 @@ func (p *arrayParser) Parse(s io.RuneScanner) (driver.Value, error) { } if v == nil { - if reflect.TypeOf(p.arg) != reflect.TypeOf(&nullableParser{}) { + if reflect.TypeOf(p.arg) != reflect.TypeOf(&nullParser{}) { //need check if v is nil: panic otherwise return nil, fmt.Errorf("unexpected nil element") } @@ -393,6 +458,10 @@ func (p *mapParser) Type() reflect.Type { return reflect.MapOf(p.key.Type(), p.value.Type()) } +func (p *mapParser) Nullable() bool { + return false +} + func (p *mapParser) Parse(s io.RuneScanner) (driver.Value, error) { r := read(s) if r != '{' { @@ -438,30 +507,6 @@ func (p *mapParser) Parse(s io.RuneScanner) (driver.Value, error) { return m.Interface(), nil } -type lowCardinalityParser struct { - arg DataParser -} - -func (p *lowCardinalityParser) Type() reflect.Type { - return p.arg.Type() -} - -func (p *lowCardinalityParser) Parse(s io.RuneScanner) (driver.Value, error) { - return p.arg.Parse(s) -} - -type simpleAggregateFunctionParser struct { - arg DataParser -} - -func (p *simpleAggregateFunctionParser) Type() reflect.Type { - return p.arg.Type() -} - -func (p *simpleAggregateFunctionParser) Parse(s io.RuneScanner) (driver.Value, error) { - return p.arg.Parse(s) -} - func newDateTimeParser(format string, loc *time.Location, precision int, unquote bool) (DataParser, error) { return &dateTimeParser{ unquote: unquote, @@ -545,6 +590,10 @@ func (p *intParser) Type() reflect.Type { } } +func (p *intParser) Nullable() bool { + return false +} + type floatParser struct { bitSize int } @@ -577,6 +626,10 @@ func (p *floatParser) Type() reflect.Type { } } +func (p *floatParser) Nullable() bool { + return false +} + type nothingParser struct{} func (p *nothingParser) Parse(s io.RuneScanner) (driver.Value, error) { @@ -587,6 +640,10 @@ func (p *nothingParser) Type() reflect.Type { return reflectTypeEmptyStruct } +func (p *nothingParser) Nullable() bool { + return true +} + // DataParserOptions describes DataParser options. // Ex.: Fields Location and UseDBLocation specify timezone options. type DataParserOptions struct { @@ -602,14 +659,53 @@ func NewDataParser(t *TypeDesc, opt *DataParserOptions) (DataParser, error) { return newDataParser(t, false, opt) } +type nullableParser struct { + innerParser DataParser + innerType string +} + +func (p *nullableParser) Parse(s io.RuneScanner) (driver.Value, error) { + switch p.innerType { + case "String": + return p.innerParser.Parse(s) + default: + // for compatibility with old databend versions + if peakNull(s) { + return nil, nil + } + return p.innerParser.Parse(s) + } +} + +func (p *nullableParser) Type() reflect.Type { + return p.innerParser.Type() +} + +func (p *nullableParser) Nullable() bool { + return true +} + func newDataParser(t *TypeDesc, unquote bool, opt *DataParserOptions) (DataParser, error) { + if t.Nullable { + t.Nullable = false + inner, err := newDataParser(t, unquote, opt) + if err != nil { + return nil, err + } + return &nullableParser{innerParser: inner, innerType: t.Name}, nil + } switch t.Name { case "Nothing": return ¬hingParser{}, nil case "Nullable": - return &stringParser{unquote: unquote}, nil + inner, err := newDataParser(t.Args[0], unquote, opt) + if err != nil { + return nil, err + } + return &nullableParser{innerParser: inner, innerType: t.Args[0].Name}, nil case "NULL": - return &stringParser{unquote: unquote}, nil + inner := &stringParser{unquote: unquote} + return &nullableParser{innerParser: inner, innerType: "String"}, nil case "Date": loc := time.UTC if opt != nil && opt.Location != nil { @@ -717,24 +813,6 @@ func newDataParser(t *TypeDesc, unquote bool, opt *DataParserOptions) (DataParse subParsers[i] = subParser } return &tupleParser{subParsers}, nil - case "LowCardinality": - if len(t.Args) != 1 { - return nil, fmt.Errorf("element type not specified for LowCardinality") - } - subParser, err := newDataParser(t.Args[0], unquote, opt) - if err != nil { - return nil, fmt.Errorf("failed to create parser for LowCardinality elements: %v", err) - } - return &lowCardinalityParser{subParser}, nil - case "SimpleAggregateFunction": - if len(t.Args) != 2 { - return nil, fmt.Errorf("incorrect number of arguments for SimpleAggregateFunction") - } - subParser, err := newDataParser(t.Args[1], unquote, opt) - if err != nil { - return nil, fmt.Errorf("failed to create parser for SimpleAggregateFunction element: %v", err) - } - return &simpleAggregateFunctionParser{subParser}, nil case "Map": if len(t.Args) != 2 { return nil, fmt.Errorf("incorrect number of arguments for Map")