diff --git a/client/bloodhound/client.go b/client/bloodhound/client.go index 807b4382..dfcb5d73 100644 --- a/client/bloodhound/client.go +++ b/client/bloodhound/client.go @@ -182,7 +182,7 @@ func (s *BHEClient) Ingest(ctx context.Context, in <-chan []any) bool { s.log.Error(err, unrecoverableErrMsg) return true } else { - req.Header.Set("User-Agent", constants.UserAgent()) + req.Header.Set("User-Agent", rest.UserAgent()) req.Header.Set("Accept", "application/json") req.Header.Set("Content-Encoding", "gzip") diff --git a/client/bloodhound/client_test.go b/client/bloodhound/client_test.go index a7e5c26d..4d1f6dd3 100644 --- a/client/bloodhound/client_test.go +++ b/client/bloodhound/client_test.go @@ -8,6 +8,8 @@ import ( "sync" "testing" + "github.com/bloodhoundad/azurehound/v2/config" + "github.com/bloodhoundad/azurehound/v2/constants" "github.com/go-logr/logr" "github.com/stretchr/testify/require" "golang.org/x/net/http2" @@ -141,4 +143,51 @@ func TestBHEClient_Ingest(t *testing.T) { require.True(t, hadErrors) }) + + t.Run("custom user agent applied", func(t *testing.T) { + const custom = "test-agent/9.9.9" + config.UserAgent.Set(custom) + t.Cleanup(func() { config.UserAgent.Set("") }) + + var got string + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got = r.Header.Get("User-Agent") + w.WriteHeader(http.StatusAccepted) + })) + defer testServer.Close() + + testUrl, _ := url.Parse(testServer.URL) + client, err := NewBHEClient(*testUrl, "tokenId", "token", "", 1, 1, logr.Logger{}) + require.NoError(t, err) + + data := make(chan []any, 1) + data <- []any{"test"} + close(data) + + require.False(t, client.Ingest(context.Background(), data)) + require.Equal(t, custom, got) + }) + + t.Run("default user agent when config empty", func(t *testing.T) { + config.UserAgent.Set("") + t.Cleanup(func() { config.UserAgent.Set("") }) + + var got string + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got = r.Header.Get("User-Agent") + w.WriteHeader(http.StatusAccepted) + })) + defer testServer.Close() + + testUrl, _ := url.Parse(testServer.URL) + client, err := NewBHEClient(*testUrl, "tokenId", "token", "", 1, 1, logr.Logger{}) + require.NoError(t, err) + + data := make(chan []any, 1) + data <- []any{"test"} + close(data) + + require.False(t, client.Ingest(context.Background(), data)) + require.Equal(t, constants.UserAgent(), got) + }) } diff --git a/client/rest/http.go b/client/rest/http.go index 52e8aba9..997db972 100644 --- a/client/rest/http.go +++ b/client/rest/http.go @@ -129,12 +129,17 @@ func NewRequest( } // set azurehound as user-agent, use custom if set in config - ua := config.UserAgent.Value() - if s, ok := ua.(string); ok && s != "" { - req.Header.Set("User-Agent", s) - } else { - req.Header.Set("User-Agent", constants.UserAgent()) - } + req.Header.Set("User-Agent", UserAgent()) return req, nil } } + +// UserAgent returns the configured User-Agent header value, falling back to +// the default azurehound/ string when the --user-agent flag is unset +// or empty. +func UserAgent() string { + if s, ok := config.UserAgent.Value().(string); ok && s != "" { + return s + } + return constants.UserAgent() +} diff --git a/client/rest/http_test.go b/client/rest/http_test.go new file mode 100644 index 00000000..9babea5d --- /dev/null +++ b/client/rest/http_test.go @@ -0,0 +1,75 @@ +// Copyright (C) 2026 Specter Ops, Inc. +// +// This file is part of AzureHound. +// +// AzureHound is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// AzureHound is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package rest + +import ( + "context" + "net/url" + "testing" + + "github.com/bloodhoundad/azurehound/v2/config" + "github.com/bloodhoundad/azurehound/v2/constants" +) + +func TestUserAgent_DefaultsWhenUnset(t *testing.T) { + config.UserAgent.Set("") + t.Cleanup(func() { config.UserAgent.Set("") }) + + if got, want := UserAgent(), constants.UserAgent(); got != want { + t.Fatalf("UserAgent() = %q, want default %q", got, want) + } +} + +func TestUserAgent_HonorsConfigValue(t *testing.T) { + const custom = "my-custom-agent/1.2.3" + config.UserAgent.Set(custom) + t.Cleanup(func() { config.UserAgent.Set("") }) + + if got := UserAgent(); got != custom { + t.Fatalf("UserAgent() = %q, want %q", got, custom) + } +} + +func TestNewRequest_AppliesCustomUserAgent(t *testing.T) { + const custom = "my-custom-agent/1.2.3" + config.UserAgent.Set(custom) + t.Cleanup(func() { config.UserAgent.Set("") }) + + endpoint, _ := url.Parse("http://example.com/") + req, err := NewRequest(context.Background(), "GET", endpoint, nil, nil, nil) + if err != nil { + t.Fatalf("NewRequest error: %v", err) + } + if got := req.Header.Get("User-Agent"); got != custom { + t.Fatalf("User-Agent header = %q, want %q", got, custom) + } +} + +func TestNewRequest_DefaultUserAgentWhenConfigEmpty(t *testing.T) { + config.UserAgent.Set("") + t.Cleanup(func() { config.UserAgent.Set("") }) + + endpoint, _ := url.Parse("http://example.com/") + req, err := NewRequest(context.Background(), "GET", endpoint, nil, nil, nil) + if err != nil { + t.Fatalf("NewRequest error: %v", err) + } + if got, want := req.Header.Get("User-Agent"), constants.UserAgent(); got != want { + t.Fatalf("User-Agent header = %q, want default %q", got, want) + } +}