Skip to content
Merged
133 changes: 133 additions & 0 deletions libs/auth/credentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package auth

import (
"context"
"errors"

"github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/config/credentials"
"github.com/databricks/databricks-sdk-go/config/experimental/auth"
"github.com/databricks/databricks-sdk-go/config/experimental/auth/authconv"
"github.com/databricks/databricks-sdk-go/credentials/u2m"
)

// The credentials chain used by the CLI. It is a custom implementation
// that differs from the SDK's default credentials chain. This guarantees
// that the CLI remain stable despite the evolution of the SDK while
// allowing the customization of some strategies such as "databricks-cli"
// which has a different behavior than the SDK.
//
// Modifying this order could break authentication for users whose
// environments are compatible with multiple strategies and who rely
// on the current priority for tie-breaking.
var credentialChain = []config.CredentialsStrategy{
config.PatCredentials{},
config.BasicCredentials{},
config.M2mCredentials{},
CLICredentials{}, // custom
config.MetadataServiceCredentials{},
// OIDC Strategies.
config.GitHubOIDCCredentials{},
config.AzureDevOpsOIDCCredentials{},
config.EnvOIDCCredentials{},
config.FileOIDCCredentials{},
// Azure strategies.
config.AzureGithubOIDCCredentials{},
config.AzureMsiCredentials{},
config.AzureClientSecretCredentials{},
config.AzureCliCredentials{},
// Google strategies.
config.GoogleCredentials{},
config.GoogleDefaultCredentials{},
}

func init() {
// Sets the credentials chain for the CLI.
config.DefaultCredentialStrategyProvider = func() config.CredentialsStrategy {
return &defaultCredentials{chain: config.NewCredentialsChain(credentialChain...)}
}
}

// defaultCredentials wraps the CLI credential chain and provides "default"
// as the fallback name, matching the SDK's DefaultCredentials behavior.
type defaultCredentials struct {
chain config.CredentialsStrategy
}

func (d *defaultCredentials) Name() string {
if name := d.chain.Name(); name != "" {
return name
}
return "default"
}

func (d *defaultCredentials) Configure(ctx context.Context, cfg *config.Config) (credentials.CredentialsProvider, error) {
return d.chain.Configure(ctx, cfg)
}

// CLICredentials is a credentials strategy that reads OAuth tokens directly
// from the local token store. It replaces the SDK's default "databricks-cli"
// strategy, which shells out to `databricks auth token` as a subprocess.
type CLICredentials struct {
// persistentAuth is a function to override the default implementation
// of the persistent auth client. It exists for testing purposes only.
persistentAuthFn func(ctx context.Context, opts ...u2m.PersistentAuthOption) (auth.TokenSource, error)
}

// Name implements [config.CredentialsStrategy].
func (c CLICredentials) Name() string {
return "databricks-cli"
}

var errNoHost = errors.New("no host provided")

// Configure implements [config.CredentialsStrategy].
//
// IMPORTANT: This credentials strategy ignores the scopes specified in the
// config and purely relies on the scopes from the loaded CLI token. This can
// lead to mismatches if the token was obtained with different scopes than the
// ones configured in the current profile. This is a temporary limitation that
// will be addressed in a future release by adding support for dynamic token
// downscoping.
func (c CLICredentials) Configure(ctx context.Context, cfg *config.Config) (credentials.CredentialsProvider, error) {
if cfg.Host == "" {
return nil, errNoHost
}
oauthArg, err := authArgumentsFromConfig(cfg).ToOAuthArgument()
if err != nil {
return nil, err
}
ts, err := c.persistentAuth(ctx, u2m.WithOAuthArgument(oauthArg))
if err != nil {
return nil, err
}
cp := credentials.NewOAuthCredentialsProviderFromTokenSource(
auth.NewCachedTokenSource(ts, auth.WithAsyncRefresh(!cfg.DisableOAuthRefreshToken)),
)
return cp, nil
}

// persistentAuth returns a token source. It is a convenience function that
// overrides the default implementation of the persistent auth client if
// an alternative implementation is provided for testing.
func (c CLICredentials) persistentAuth(ctx context.Context, opts ...u2m.PersistentAuthOption) (auth.TokenSource, error) {
if c.persistentAuthFn != nil {
return c.persistentAuthFn(ctx, opts...)
}
ts, err := u2m.NewPersistentAuth(ctx, opts...)
if err != nil {
return nil, err
}
return authconv.AuthTokenSource(ts), nil
}

// authArgumentsFromConfig converts an SDK config to AuthArguments.
func authArgumentsFromConfig(cfg *config.Config) AuthArguments {
return AuthArguments{
Host: cfg.Host,
AccountID: cfg.AccountID,
WorkspaceID: cfg.WorkspaceID,
IsUnifiedHost: cfg.Experimental_IsUnifiedHost,
Profile: cfg.Profile,
}
}
192 changes: 192 additions & 0 deletions libs/auth/credentials_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
package auth

import (
"context"
"errors"
"net/http"
"slices"
"testing"

"github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/config/experimental/auth"
"github.com/databricks/databricks-sdk-go/credentials/u2m"
"golang.org/x/oauth2"
)

// TestCredentialChainOrder purely exists as an extra measure to catch
// accidental change in the ordering.
func TestCredentialChainOrder(t *testing.T) {
names := make([]string, len(credentialChain))
for i, s := range credentialChain {
names[i] = s.Name()
}
want := []string{
"pat",
"basic",
"oauth-m2m",
"databricks-cli",
"metadata-service",
"github-oidc",
"azure-devops-oidc",
"env-oidc",
"file-oidc",
"github-oidc-azure",
"azure-msi",
"azure-client-secret",
"azure-cli",
"google-credentials",
"google-id",
}
if !slices.Equal(names, want) {
t.Errorf("credential chain order: want %v, got %v", want, names)
}
}

func TestCLICredentialsName(t *testing.T) {
c := CLICredentials{}
if got := c.Name(); got != "databricks-cli" {
t.Errorf("Name(): want %q, got %q", "databricks-cli", got)
}
}

func TestAuthArgumentsFromConfig(t *testing.T) {
tests := []struct {
name string
cfg *config.Config
want AuthArguments
}{
{
name: "empty config",
cfg: &config.Config{},
want: AuthArguments{},
},
{
name: "workspace host only",
cfg: &config.Config{
Host: "https://myworkspace.cloud.databricks.com",
},
want: AuthArguments{
Host: "https://myworkspace.cloud.databricks.com",
},
},
{
name: "account host with account ID",
cfg: &config.Config{
Host: "https://accounts.cloud.databricks.com",
AccountID: "test-account-id",
},
want: AuthArguments{
Host: "https://accounts.cloud.databricks.com",
AccountID: "test-account-id",
},
},
{
name: "all fields",
cfg: &config.Config{
Host: "https://myhost.com",
AccountID: "acc-123",
WorkspaceID: "ws-456",
Profile: "my-profile",
Experimental_IsUnifiedHost: true,
},
want: AuthArguments{
Host: "https://myhost.com",
AccountID: "acc-123",
WorkspaceID: "ws-456",
Profile: "my-profile",
IsUnifiedHost: true,
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := authArgumentsFromConfig(tt.cfg)
if got != tt.want {
t.Errorf("want %v, got %v", tt.want, got)
}
})
}
}

func TestCLICredentialsConfigure(t *testing.T) {
testErr := errors.New("test error")

tests := []struct {
name string
cfg *config.Config
persistentAuthFn func(ctx context.Context, opts ...u2m.PersistentAuthOption) (auth.TokenSource, error)
wantErr error
wantToken string
}{
{
name: "empty host returns error",
cfg: &config.Config{},
wantErr: errNoHost,
},
{
name: "persistentAuthFn error is propagated",
cfg: &config.Config{
Host: "https://myworkspace.cloud.databricks.com",
},
persistentAuthFn: func(_ context.Context, _ ...u2m.PersistentAuthOption) (auth.TokenSource, error) {
return nil, testErr
},
wantErr: testErr,
},
{
name: "workspace host",
cfg: &config.Config{
Host: "https://myworkspace.cloud.databricks.com",
},
persistentAuthFn: func(_ context.Context, _ ...u2m.PersistentAuthOption) (auth.TokenSource, error) {
return auth.TokenSourceFn(func(_ context.Context) (*oauth2.Token, error) {
return &oauth2.Token{AccessToken: "workspace-token"}, nil
}), nil
},
wantToken: "workspace-token",
},
{
name: "account host",
cfg: &config.Config{
Host: "https://accounts.cloud.databricks.com",
AccountID: "test-account-id",
},
persistentAuthFn: func(_ context.Context, _ ...u2m.PersistentAuthOption) (auth.TokenSource, error) {
return auth.TokenSourceFn(func(_ context.Context) (*oauth2.Token, error) {
return &oauth2.Token{AccessToken: "account-token"}, nil
}), nil
},
wantToken: "account-token",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
c := CLICredentials{persistentAuthFn: tt.persistentAuthFn}

got, err := c.Configure(ctx, tt.cfg)

if !errors.Is(err, tt.wantErr) {
t.Fatalf("want error %v, got %v", tt.wantErr, err)
}
if tt.wantErr != nil {
return
}

// Verify the credentials provider sets the correct Bearer token.
req, err := http.NewRequest("GET", tt.cfg.Host, nil)
if err != nil {
t.Fatalf("creating request: %v", err)
}
if err := got.SetHeaders(req); err != nil {
t.Fatalf("SetHeaders: want no error, got %v", err)
}
want := "Bearer " + tt.wantToken
if gotHeader := req.Header.Get("Authorization"); gotHeader != want {
t.Errorf("Authorization header: want %q, got %q", want, gotHeader)
}
})
}
}