diff --git a/dsn.go b/dsn.go index ec2d56c..276eec6 100644 --- a/dsn.go +++ b/dsn.go @@ -23,6 +23,8 @@ type Config struct { Password string // Password (requires User) Database string // Database name + Role string // Role is the databend role you want to use for the current connection + AccessToken string AccessTokenFile string // path to file containing access token, it can be used to rotate access token AccessTokenLoader AccessTokenLoader @@ -84,6 +86,10 @@ func (cfg *Config) FormatDSN() string { if cfg.Warehouse != "" { query.Set("warehouse", cfg.Warehouse) } + + if len(cfg.Role) > 0 { + query.Set("role", cfg.Role) + } if cfg.AccessToken != "" { query.Set("access_token", cfg.AccessToken) } @@ -151,6 +157,8 @@ func (cfg *Config) AddParams(params map[string]string) (err error) { cfg.Tenant = v case "warehouse": cfg.Warehouse = v + case "role": + cfg.Role = v case "access_token": cfg.AccessToken = v case "access_token_file": diff --git a/dsn_test.go b/dsn_test.go index b96f213..6b7883d 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -9,7 +9,7 @@ import ( ) func TestFormatDSN(t *testing.T) { - dsn := "databend+https://username:password@tn3ftqihs.ch.aws-us-east-2.default.databend.com/test?timeout=1s&wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&tls_config=tls-settings&warehouse=wh" + dsn := "databend+https://username:password@tn3ftqihs.ch.aws-us-east-2.default.databend.com/test?role=test_role&timeout=1s&wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&tls_config=tls-settings&warehouse=wh" cfg, err := ParseDSN(dsn) require.Nil(t, err) @@ -21,6 +21,7 @@ func TestFormatDSN(t *testing.T) { assert.Equal(t, int64(10000), cfg.MaxRowsPerPage) assert.Equal(t, int64(10), cfg.WaitTimeSecs) assert.Equal(t, int64(5000000), cfg.MaxRowsInBuffer) + assert.Equal(t, "test_role", cfg.Role) dsn1 := cfg.FormatDSN() cfg1, err := ParseDSN(dsn1) diff --git a/query.go b/query.go index d56f76c..410f774 100644 --- a/query.go +++ b/query.go @@ -80,6 +80,7 @@ type PaginationConfig struct { type SessionConfig struct { Database string `json:"database,omitempty"` + Role string `json:"role,omitempty"` // Since we use client session, this should not be used // KeepServerSessionSecs uint64 `json:"keep_server_session_secs,omitempty"` diff --git a/restful.go b/restful.go index 00ab28f..5a2f00d 100644 --- a/restful.go +++ b/restful.go @@ -66,6 +66,7 @@ type APIClient struct { database string user string password string + role string accessTokenLoader AccessTokenLoader sessionSettings map[string]string statsTracker QueryStatsTracker @@ -98,6 +99,7 @@ func NewAPIClientFromConfig(cfg *Config) *APIClient { database: cfg.Database, user: cfg.User, password: cfg.Password, + role: cfg.Role, accessTokenLoader: initAccessTokenLoader(cfg), sessionSettings: cfg.Params, statsTracker: cfg.StatsTracker, @@ -267,6 +269,7 @@ func (c *APIClient) getPagenationConfig() *PaginationConfig { func (c *APIClient) getSessionConfig() *SessionConfig { return &SessionConfig{ Database: c.database, + Role: c.role, Settings: c.sessionSettings, } } @@ -303,6 +306,9 @@ func (c *APIClient) applySessionConfig(response *QueryResponse) { if response.Session.Database != "" { c.database = response.Session.Database } + if len(response.Session.Role) > 0 { + c.role = response.Session.Role + } if response.Session.Settings != nil { for k, v := range response.Session.Settings { c.sessionSettings[k] = v diff --git a/restful_test.go b/restful_test.go index 19922ca..11675a6 100644 --- a/restful_test.go +++ b/restful_test.go @@ -15,11 +15,14 @@ func TestMakeHeadersUserPassword(t *testing.T) { password: "root", host: "localhost:8000", tenant: "default", + role: "role1", } headers, err := c.makeHeaders() assert.Nil(t, err) assert.Equal(t, headers["Authorization"], []string{"Basic cm9vdDpyb290"}) assert.Equal(t, headers["X-Databend-Tenant"], []string{"default"}) + session := c.getSessionConfig() + assert.Equal(t, session.Role, "role1") } func TestMakeHeadersAccessToken(t *testing.T) { diff --git a/tests/main_test.go b/tests/main_test.go index 1f458d3..b867cc0 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -43,7 +43,6 @@ func (s *DatabendTestSuite) SetupSuite() { dsn := os.Getenv("TEST_DATABEND_DSN") s.NotEmpty(dsn) - s.db, err = sql.Open("databend", dsn) s.Nil(err) diff --git a/tests/session_test.go b/tests/session_test.go index 9694ad8..5f93e11 100644 --- a/tests/session_test.go +++ b/tests/session_test.go @@ -1,6 +1,11 @@ package tests -import "github.com/stretchr/testify/require" +import ( + "database/sql" + "fmt" + "github.com/stretchr/testify/require" + "os" +) func (s *DatabendTestSuite) TestChangeDatabase() { r := require.New(s.T()) @@ -19,6 +24,43 @@ func (s *DatabendTestSuite) TestChangeDatabase() { r.Equal("default", result) } +func (s *DatabendTestSuite) TestChangeRole() { + r := require.New(s.T()) + var result string + err := s.db.QueryRow("select version()").Scan(&result) + r.Nil(err) + println(result) + _, err = s.db.Exec("create role if not exists test_role") + r.Nil(err) + dsn := os.Getenv("TEST_DATABEND_DSN") + s.NotEmpty(dsn) + dsn = fmt.Sprintf("%s&role=test_role", dsn) + s.db, err = sql.Open("databend", dsn) + s.Nil(err) + + err = s.db.QueryRow("select current_role()").Scan(&result) + r.Nil(err) + r.Equal("test_role", result) + + dsn = os.Getenv("TEST_DATABEND_DSN") + s.NotEmpty(dsn) + s.db, err = sql.Open("databend", dsn) + s.Nil(err) + // + //defer s.db.Exec("drop role if exists test_role") + //_, err = s.db.Exec("set role 'test_role'") + //r.Nil(err) + // + //_, err = s.db.Exec("create role if not exists test_role_2") + //r.Nil(err) + //defer s.db.Exec("drop role if exists test_role_2") + //_, err = s.db.Exec("set role 'test_role_2'") + //r.Nil(err) + //err = s.db.QueryRow("select current_role()").Scan(&result) + //r.Nil(err) + //r.Equal("test_role_2", result) +} + func (s *DatabendTestSuite) TestSessionConfig() { r := require.New(s.T())