diff --git a/flow/retriever/parent/parent.go b/flow/retriever/parent/parent.go index 979afeb4f..19ed6734c 100644 --- a/flow/retriever/parent/parent.go +++ b/flow/retriever/parent/parent.go @@ -74,6 +74,9 @@ func NewRetriever(ctx context.Context, config *Config) (retriever.Retriever, err if config.OrigDocGetter == nil { return nil, fmt.Errorf("orig doc getter is required") } + if config.ParentIDKey == "" { + return nil, fmt.Errorf("parent id key is required") + } return &parentRetriever{ retriever: config.Retriever, parentIDKey: config.ParentIDKey, diff --git a/flow/retriever/parent/parent_test.go b/flow/retriever/parent/parent_test.go index 6487e25fa..10321ddfd 100644 --- a/flow/retriever/parent/parent_test.go +++ b/flow/retriever/parent/parent_test.go @@ -41,6 +41,27 @@ func (t *testRetriever) Retrieve(ctx context.Context, query string, opts ...retr return ret, nil } +func TestNewRetrieverValidation(t *testing.T) { + ctx := context.Background() + tr := &testRetriever{} + getter := func(ctx context.Context, ids []string) ([]*schema.Document, error) { return nil, nil } + + _, err := NewRetriever(ctx, &Config{Retriever: nil, ParentIDKey: "k", OrigDocGetter: getter}) + if err == nil { + t.Error("expected error when Retriever is nil") + } + + _, err = NewRetriever(ctx, &Config{Retriever: tr, ParentIDKey: "k", OrigDocGetter: nil}) + if err == nil { + t.Error("expected error when OrigDocGetter is nil") + } + + _, err = NewRetriever(ctx, &Config{Retriever: tr, ParentIDKey: "", OrigDocGetter: getter}) + if err == nil { + t.Error("expected error when ParentIDKey is empty") + } +} + func TestParentRetriever(t *testing.T) { tests := []struct { name string