diff --git a/index_alias_impl.go b/index_alias_impl.go index 94d67a8b..ebce2034 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -443,53 +443,26 @@ type asyncSearchResult struct { Err error } -func wrapSearch(ctx context.Context, in Index, req *SearchRequest) *asyncSearchResult { - rv := asyncSearchResult{Name: in.Name()} - rv.Result, rv.Err = in.SearchInContext(ctx, req) - return &rv -} - -func wrapSearchTimeout(ctx context.Context, in Index, req *SearchRequest) *asyncSearchResult { - reschan := make(chan *asyncSearchResult) - go func() { reschan <- wrapSearch(ctx, in, req) }() - select { - case res := <-reschan: - return res - case <-ctx.Done(): - return &asyncSearchResult{Name: in.Name(), Err: ctx.Err()} - } -} - -// MultiSearch executes a SearchRequest across multiple -// Index objects, then merges the results. +// MultiSearch executes a SearchRequest across multiple Index objects, +// then merges the results. The indexes must honor any ctx deadline. func MultiSearch(ctx context.Context, req *SearchRequest, indexes ...Index) (*SearchResult, error) { searchStart := time.Now() - asyncResults := make(chan *asyncSearchResult) + asyncResults := make(chan *asyncSearchResult, len(indexes)) // run search on each index in separate go routine var waitGroup sync.WaitGroup - var searchChildIndex = func(waitGroup *sync.WaitGroup, in Index, asyncResults chan *asyncSearchResult) { - childReq := createChildSearchRequest(req) - if ia, ok := in.(IndexAlias); ok { - // if the child index is another alias, trust it returns promptly on timeout/cancel - go func() { - defer waitGroup.Done() - asyncResults <- wrapSearch(ctx, ia, childReq) - }() - } else { - // if the child index is not an alias, enforce timeout here - go func() { - defer waitGroup.Done() - asyncResults <- wrapSearchTimeout(ctx, in, childReq) - }() - } + var searchChildIndex = func(in Index, childReq *SearchRequest) { + rv := asyncSearchResult{Name: in.Name()} + rv.Result, rv.Err = in.SearchInContext(ctx, childReq) + asyncResults <- &rv + waitGroup.Done() } + waitGroup.Add(len(indexes)) for _, in := range indexes { - waitGroup.Add(1) - searchChildIndex(&waitGroup, in, asyncResults) + go searchChildIndex(in, createChildSearchRequest(req)) } // on another go routine, close after finished diff --git a/index_alias_impl_test.go b/index_alias_impl_test.go index 8393513e..a5940664 100644 --- a/index_alias_impl_test.go +++ b/index_alias_impl_test.go @@ -724,11 +724,16 @@ func TestMultiSearchSecondPage(t *testing.T) { func TestMultiSearchTimeout(t *testing.T) { score1, _ := numeric.NewPrefixCodedInt64(numeric.Float64ToInt64(1.0), 0) score2, _ := numeric.NewPrefixCodedInt64(numeric.Float64ToInt64(2.0), 0) + var ctx context.Context ei1 := &stubIndex{ name: "ei1", checkRequest: func(req *SearchRequest) error { - time.Sleep(50 * time.Millisecond) - return nil + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(50 * time.Millisecond): + return nil + } }, err: nil, searchResult: &SearchResult{ @@ -751,8 +756,12 @@ func TestMultiSearchTimeout(t *testing.T) { ei2 := &stubIndex{ name: "ei2", checkRequest: func(req *SearchRequest) error { - time.Sleep(50 * time.Millisecond) - return nil + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(50 * time.Millisecond): + return nil + } }, err: nil, searchResult: &SearchResult{ @@ -774,7 +783,7 @@ func TestMultiSearchTimeout(t *testing.T) { }} // first run with absurdly long time out, should succeed - ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + ctx, _ = context.WithTimeout(context.Background(), 10*time.Second) query := NewTermQuery("test") sr := NewSearchRequest(query) res, err := MultiSearch(ctx, sr, ei1, ei2) @@ -821,7 +830,8 @@ func TestMultiSearchTimeout(t *testing.T) { } // now run a search again with a normal timeout, but cancel it first - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) cancel() res, err = MultiSearch(ctx, sr, ei1, ei2) if err != nil { @@ -854,6 +864,7 @@ func TestMultiSearchTimeoutPartial(t *testing.T) { score1, _ := numeric.NewPrefixCodedInt64(numeric.Float64ToInt64(1.0), 0) score2, _ := numeric.NewPrefixCodedInt64(numeric.Float64ToInt64(2.0), 0) score3, _ := numeric.NewPrefixCodedInt64(numeric.Float64ToInt64(3.0), 0) + var ctx context.Context ei1 := &stubIndex{ name: "ei1", err: nil, @@ -898,8 +909,12 @@ func TestMultiSearchTimeoutPartial(t *testing.T) { ei3 := &stubIndex{ name: "ei3", checkRequest: func(req *SearchRequest) error { - time.Sleep(50 * time.Millisecond) - return nil + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(50 * time.Millisecond): + return nil + } }, err: nil, searchResult: &SearchResult{ @@ -922,7 +937,7 @@ func TestMultiSearchTimeoutPartial(t *testing.T) { // ei3 is set to take >50ms, so run search with timeout less than // this, this should return partial results - ctx, _ := context.WithTimeout(context.Background(), 25*time.Millisecond) + ctx, _ = context.WithTimeout(context.Background(), 25*time.Millisecond) query := NewTermQuery("test") sr := NewSearchRequest(query) expected := &SearchResult{ @@ -968,6 +983,7 @@ func TestIndexAliasMultipleLayer(t *testing.T) { score2, _ := numeric.NewPrefixCodedInt64(numeric.Float64ToInt64(2.0), 0) score3, _ := numeric.NewPrefixCodedInt64(numeric.Float64ToInt64(3.0), 0) score4, _ := numeric.NewPrefixCodedInt64(numeric.Float64ToInt64(4.0), 0) + var ctx context.Context ei1 := &stubIndex{ name: "ei1", err: nil, @@ -991,8 +1007,12 @@ func TestIndexAliasMultipleLayer(t *testing.T) { ei2 := &stubIndex{ name: "ei2", checkRequest: func(req *SearchRequest) error { - time.Sleep(50 * time.Millisecond) - return nil + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(50 * time.Millisecond): + return nil + } }, err: nil, searchResult: &SearchResult{ @@ -1016,8 +1036,12 @@ func TestIndexAliasMultipleLayer(t *testing.T) { ei3 := &stubIndex{ name: "ei3", checkRequest: func(req *SearchRequest) error { - time.Sleep(50 * time.Millisecond) - return nil + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(50 * time.Millisecond): + return nil + } }, err: nil, searchResult: &SearchResult{ @@ -1067,7 +1091,7 @@ func TestIndexAliasMultipleLayer(t *testing.T) { // search across aliasTop should still get results from ei1 and ei4 // total should still be 4 - ctx, _ := context.WithTimeout(context.Background(), 25*time.Millisecond) + ctx, _ = context.WithTimeout(context.Background(), 25*time.Millisecond) query := NewTermQuery("test") sr := NewSearchRequest(query) expected := &SearchResult{