diff --git a/libs/auth/credentials.go b/libs/auth/credentials.go new file mode 100644 index 0000000000..6adf0ab9c0 --- /dev/null +++ b/libs/auth/credentials.go @@ -0,0 +1,133 @@ +package auth + +import ( + "context" + "errors" + + "github.com/databricks/databricks-sdk-go/config" + "github.com/databricks/databricks-sdk-go/config/credentials" + "github.com/databricks/databricks-sdk-go/config/experimental/auth" + "github.com/databricks/databricks-sdk-go/config/experimental/auth/authconv" + "github.com/databricks/databricks-sdk-go/credentials/u2m" +) + +// The credentials chain used by the CLI. It is a custom implementation +// that differs from the SDK's default credentials chain. This guarantees +// that the CLI remain stable despite the evolution of the SDK while +// allowing the customization of some strategies such as "databricks-cli" +// which has a different behavior than the SDK. +// +// Modifying this order could break authentication for users whose +// environments are compatible with multiple strategies and who rely +// on the current priority for tie-breaking. +var credentialChain = []config.CredentialsStrategy{ + config.PatCredentials{}, + config.BasicCredentials{}, + config.M2mCredentials{}, + CLICredentials{}, // custom + config.MetadataServiceCredentials{}, + // OIDC Strategies. + config.GitHubOIDCCredentials{}, + config.AzureDevOpsOIDCCredentials{}, + config.EnvOIDCCredentials{}, + config.FileOIDCCredentials{}, + // Azure strategies. + config.AzureGithubOIDCCredentials{}, + config.AzureMsiCredentials{}, + config.AzureClientSecretCredentials{}, + config.AzureCliCredentials{}, + // Google strategies. + config.GoogleCredentials{}, + config.GoogleDefaultCredentials{}, +} + +func init() { + // Sets the credentials chain for the CLI. + config.DefaultCredentialStrategyProvider = func() config.CredentialsStrategy { + return &defaultCredentials{chain: config.NewCredentialsChain(credentialChain...)} + } +} + +// defaultCredentials wraps the CLI credential chain and provides "default" +// as the fallback name, matching the SDK's DefaultCredentials behavior. +type defaultCredentials struct { + chain config.CredentialsStrategy +} + +func (d *defaultCredentials) Name() string { + if name := d.chain.Name(); name != "" { + return name + } + return "default" +} + +func (d *defaultCredentials) Configure(ctx context.Context, cfg *config.Config) (credentials.CredentialsProvider, error) { + return d.chain.Configure(ctx, cfg) +} + +// CLICredentials is a credentials strategy that reads OAuth tokens directly +// from the local token store. It replaces the SDK's default "databricks-cli" +// strategy, which shells out to `databricks auth token` as a subprocess. +type CLICredentials struct { + // persistentAuth is a function to override the default implementation + // of the persistent auth client. It exists for testing purposes only. + persistentAuthFn func(ctx context.Context, opts ...u2m.PersistentAuthOption) (auth.TokenSource, error) +} + +// Name implements [config.CredentialsStrategy]. +func (c CLICredentials) Name() string { + return "databricks-cli" +} + +var errNoHost = errors.New("no host provided") + +// Configure implements [config.CredentialsStrategy]. +// +// IMPORTANT: This credentials strategy ignores the scopes specified in the +// config and purely relies on the scopes from the loaded CLI token. This can +// lead to mismatches if the token was obtained with different scopes than the +// ones configured in the current profile. This is a temporary limitation that +// will be addressed in a future release by adding support for dynamic token +// downscoping. +func (c CLICredentials) Configure(ctx context.Context, cfg *config.Config) (credentials.CredentialsProvider, error) { + if cfg.Host == "" { + return nil, errNoHost + } + oauthArg, err := authArgumentsFromConfig(cfg).ToOAuthArgument() + if err != nil { + return nil, err + } + ts, err := c.persistentAuth(ctx, u2m.WithOAuthArgument(oauthArg)) + if err != nil { + return nil, err + } + cp := credentials.NewOAuthCredentialsProviderFromTokenSource( + auth.NewCachedTokenSource(ts, auth.WithAsyncRefresh(!cfg.DisableOAuthRefreshToken)), + ) + return cp, nil +} + +// persistentAuth returns a token source. It is a convenience function that +// overrides the default implementation of the persistent auth client if +// an alternative implementation is provided for testing. +func (c CLICredentials) persistentAuth(ctx context.Context, opts ...u2m.PersistentAuthOption) (auth.TokenSource, error) { + if c.persistentAuthFn != nil { + return c.persistentAuthFn(ctx, opts...) + } + ts, err := u2m.NewPersistentAuth(ctx, opts...) + if err != nil { + return nil, err + } + return authconv.AuthTokenSource(ts), nil +} + +// authArgumentsFromConfig converts an SDK config to AuthArguments. +func authArgumentsFromConfig(cfg *config.Config) AuthArguments { + return AuthArguments{ + Host: cfg.Host, + AccountID: cfg.AccountID, + WorkspaceID: cfg.WorkspaceID, + IsUnifiedHost: cfg.Experimental_IsUnifiedHost, + Profile: cfg.Profile, + } +} diff --git a/libs/auth/credentials_test.go b/libs/auth/credentials_test.go new file mode 100644 index 0000000000..0d8973f853 --- /dev/null +++ b/libs/auth/credentials_test.go @@ -0,0 +1,192 @@ +package auth + +import ( + "context" + "errors" + "net/http" + "slices" + "testing" + + "github.com/databricks/databricks-sdk-go/config" + "github.com/databricks/databricks-sdk-go/config/experimental/auth" + "github.com/databricks/databricks-sdk-go/credentials/u2m" + "golang.org/x/oauth2" +) + +// TestCredentialChainOrder purely exists as an extra measure to catch +// accidental change in the ordering. +func TestCredentialChainOrder(t *testing.T) { + names := make([]string, len(credentialChain)) + for i, s := range credentialChain { + names[i] = s.Name() + } + want := []string{ + "pat", + "basic", + "oauth-m2m", + "databricks-cli", + "metadata-service", + "github-oidc", + "azure-devops-oidc", + "env-oidc", + "file-oidc", + "github-oidc-azure", + "azure-msi", + "azure-client-secret", + "azure-cli", + "google-credentials", + "google-id", + } + if !slices.Equal(names, want) { + t.Errorf("credential chain order: want %v, got %v", want, names) + } +} + +func TestCLICredentialsName(t *testing.T) { + c := CLICredentials{} + if got := c.Name(); got != "databricks-cli" { + t.Errorf("Name(): want %q, got %q", "databricks-cli", got) + } +} + +func TestAuthArgumentsFromConfig(t *testing.T) { + tests := []struct { + name string + cfg *config.Config + want AuthArguments + }{ + { + name: "empty config", + cfg: &config.Config{}, + want: AuthArguments{}, + }, + { + name: "workspace host only", + cfg: &config.Config{ + Host: "https://myworkspace.cloud.databricks.com", + }, + want: AuthArguments{ + Host: "https://myworkspace.cloud.databricks.com", + }, + }, + { + name: "account host with account ID", + cfg: &config.Config{ + Host: "https://accounts.cloud.databricks.com", + AccountID: "test-account-id", + }, + want: AuthArguments{ + Host: "https://accounts.cloud.databricks.com", + AccountID: "test-account-id", + }, + }, + { + name: "all fields", + cfg: &config.Config{ + Host: "https://myhost.com", + AccountID: "acc-123", + WorkspaceID: "ws-456", + Profile: "my-profile", + Experimental_IsUnifiedHost: true, + }, + want: AuthArguments{ + Host: "https://myhost.com", + AccountID: "acc-123", + WorkspaceID: "ws-456", + Profile: "my-profile", + IsUnifiedHost: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := authArgumentsFromConfig(tt.cfg) + if got != tt.want { + t.Errorf("want %v, got %v", tt.want, got) + } + }) + } +} + +func TestCLICredentialsConfigure(t *testing.T) { + testErr := errors.New("test error") + + tests := []struct { + name string + cfg *config.Config + persistentAuthFn func(ctx context.Context, opts ...u2m.PersistentAuthOption) (auth.TokenSource, error) + wantErr error + wantToken string + }{ + { + name: "empty host returns error", + cfg: &config.Config{}, + wantErr: errNoHost, + }, + { + name: "persistentAuthFn error is propagated", + cfg: &config.Config{ + Host: "https://myworkspace.cloud.databricks.com", + }, + persistentAuthFn: func(_ context.Context, _ ...u2m.PersistentAuthOption) (auth.TokenSource, error) { + return nil, testErr + }, + wantErr: testErr, + }, + { + name: "workspace host", + cfg: &config.Config{ + Host: "https://myworkspace.cloud.databricks.com", + }, + persistentAuthFn: func(_ context.Context, _ ...u2m.PersistentAuthOption) (auth.TokenSource, error) { + return auth.TokenSourceFn(func(_ context.Context) (*oauth2.Token, error) { + return &oauth2.Token{AccessToken: "workspace-token"}, nil + }), nil + }, + wantToken: "workspace-token", + }, + { + name: "account host", + cfg: &config.Config{ + Host: "https://accounts.cloud.databricks.com", + AccountID: "test-account-id", + }, + persistentAuthFn: func(_ context.Context, _ ...u2m.PersistentAuthOption) (auth.TokenSource, error) { + return auth.TokenSourceFn(func(_ context.Context) (*oauth2.Token, error) { + return &oauth2.Token{AccessToken: "account-token"}, nil + }), nil + }, + wantToken: "account-token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + c := CLICredentials{persistentAuthFn: tt.persistentAuthFn} + + got, err := c.Configure(ctx, tt.cfg) + + if !errors.Is(err, tt.wantErr) { + t.Fatalf("want error %v, got %v", tt.wantErr, err) + } + if tt.wantErr != nil { + return + } + + // Verify the credentials provider sets the correct Bearer token. + req, err := http.NewRequest("GET", tt.cfg.Host, nil) + if err != nil { + t.Fatalf("creating request: %v", err) + } + if err := got.SetHeaders(req); err != nil { + t.Fatalf("SetHeaders: want no error, got %v", err) + } + want := "Bearer " + tt.wantToken + if gotHeader := req.Header.Get("Authorization"); gotHeader != want { + t.Errorf("Authorization header: want %q, got %q", want, gotHeader) + } + }) + } +}