diff --git a/index.go b/index.go index 7c2cb5d1..1fc396d9 100644 --- a/index.go +++ b/index.go @@ -13,6 +13,7 @@ import ( "github.com/blevesearch/bleve/document" "github.com/blevesearch/bleve/index" "github.com/blevesearch/bleve/index/store" + "golang.org/x/net/context" ) // A Batch groups together multiple Index and Delete @@ -167,6 +168,7 @@ type Index interface { DocCount() (uint64, error) Search(req *SearchRequest) (*SearchResult, error) + SearchInContext(ctx context.Context, req *SearchRequest) (*SearchResult, error) Fields() ([]string, error) diff --git a/index_alias_impl.go b/index_alias_impl.go index bd6efac0..7df91855 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -14,6 +14,8 @@ import ( "sync" "time" + "golang.org/x/net/context" + "github.com/blevesearch/bleve/document" "github.com/blevesearch/bleve/index" "github.com/blevesearch/bleve/index/store" @@ -132,6 +134,10 @@ func (i *indexAliasImpl) DocCount() (uint64, error) { } func (i *indexAliasImpl) Search(req *SearchRequest) (*SearchResult, error) { + return i.SearchInContext(context.Background(), req) +} + +func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest) (*SearchResult, error) { i.mutex.RLock() defer i.mutex.RUnlock() @@ -145,10 +151,10 @@ func (i *indexAliasImpl) Search(req *SearchRequest) (*SearchResult, error) { // short circuit the simple case if len(i.indexes) == 1 { - return i.indexes[0].Search(req) + return i.indexes[0].SearchInContext(ctx, req) } - return MultiSearch(req, i.indexes...) + return MultiSearch(ctx, req, i.indexes...) } func (i *indexAliasImpl) Fields() ([]string, error) { @@ -456,70 +462,81 @@ func createChildSearchRequest(req *SearchRequest) *SearchRequest { return &rv } -type errWrap struct { - Name string - Err error +type asyncSearchResult struct { + Name string + Result *SearchResult + Err error +} + +func wrapSearch(ctx context.Context, in Index, req *SearchRequest) *asyncSearchResult { + rv := asyncSearchResult{Name: in.Name()} + rv.Result, rv.Err = in.SearchInContext(ctx, req) + return &rv +} + +func wrapSearchTimeout(ctx context.Context, in Index, req *SearchRequest) *asyncSearchResult { + reschan := make(chan *asyncSearchResult) + go func() { reschan <- wrapSearch(ctx, in, req) }() + select { + case res := <-reschan: + return res + case <-ctx.Done(): + return &asyncSearchResult{Name: in.Name(), Err: ctx.Err()} + } } // MultiSearch executes a SearchRequest across multiple // Index objects, then merges the results. -func MultiSearch(req *SearchRequest, indexes ...Index) (*SearchResult, error) { +func MultiSearch(ctx context.Context, req *SearchRequest, indexes ...Index) (*SearchResult, error) { + searchStart := time.Now() - results := make(chan *SearchResult) - errs := make(chan *errWrap) + asyncResults := make(chan *asyncSearchResult) // run search on each index in separate go routine var waitGroup sync.WaitGroup - var searchChildIndex = func(waitGroup *sync.WaitGroup, in Index, results chan *SearchResult, errs chan *errWrap) { - go func() { - defer waitGroup.Done() - childReq := createChildSearchRequest(req) - searchResult, err := in.Search(childReq) - if err != nil { - errs <- &errWrap{ - Name: in.Name(), - Err: err, - } - } else { - results <- searchResult - } - }() + var searchChildIndex = func(waitGroup *sync.WaitGroup, in Index, asyncResults chan *asyncSearchResult) { + childReq := createChildSearchRequest(req) + if ia, ok := in.(IndexAlias); ok { + // if the child index is another alias, trust it returns promptly on timeout/cancel + go func() { + defer waitGroup.Done() + asyncResults <- wrapSearch(ctx, ia, childReq) + }() + } else { + // if the child index is not an alias, enforce timeout here + go func() { + defer waitGroup.Done() + asyncResults <- wrapSearchTimeout(ctx, in, childReq) + }() + } } for _, in := range indexes { waitGroup.Add(1) - searchChildIndex(&waitGroup, in, results, errs) + searchChildIndex(&waitGroup, in, asyncResults) } // on another go routine, close after finished go func() { waitGroup.Wait() - close(results) - close(errs) + close(asyncResults) }() var sr *SearchResult - var ew *errWrap - var result *SearchResult indexErrors := make(map[string]error) - ok := true - for ok { - select { - case result, ok = <-results: - if ok { - if sr == nil { - // first result - sr = result - } else { - // merge with previous - sr.Merge(result) - } - } - case ew, ok = <-errs: - if ok { - indexErrors[ew.Name] = ew.Err + + for asr := range asyncResults { + if asr.Err == nil { + if sr == nil { + // first result + sr = asr.Result + } else { + // merge with previous + sr.Merge(asr.Result) } + } else { + indexErrors[asr.Name] = asr.Err } } diff --git a/index_alias_impl_test.go b/index_alias_impl_test.go index 95dd71d5..fd07bc55 100644 --- a/index_alias_impl_test.go +++ b/index_alias_impl_test.go @@ -6,6 +6,8 @@ import ( "testing" "time" + "golang.org/x/net/context" + "github.com/blevesearch/bleve/document" "github.com/blevesearch/bleve/index" "github.com/blevesearch/bleve/index/store" @@ -650,7 +652,7 @@ func TestMultiSearchNoError(t *testing.T) { MaxScore: 2.0, } - results, err := MultiSearch(sr, ei1, ei2) + results, err := MultiSearch(context.Background(), sr, ei1, ei2) if err != nil { t.Error(err) } @@ -681,7 +683,7 @@ func TestMultiSearchSomeError(t *testing.T) { }} ei2 := &stubIndex{name: "ei2", err: fmt.Errorf("deliberate error")} sr := NewSearchRequest(NewTermQuery("test")) - res, err := MultiSearch(sr, ei1, ei2) + res, err := MultiSearch(context.Background(), sr, ei1, ei2) if err != nil { t.Errorf("expected no error, got %v", err) } @@ -708,7 +710,7 @@ func TestMultiSearchAllError(t *testing.T) { ei1 := &stubIndex{name: "ei1", err: fmt.Errorf("deliberate error")} ei2 := &stubIndex{name: "ei2", err: fmt.Errorf("deliberate error")} sr := NewSearchRequest(NewTermQuery("test")) - res, err := MultiSearch(sr, ei1, ei2) + res, err := MultiSearch(context.Background(), sr, ei1, ei2) if err != nil { t.Errorf("expected no error, got %v", err) } @@ -764,13 +766,383 @@ func TestMultiSearchSecondPage(t *testing.T) { checkRequest: checkRequest, } sr := NewSearchRequestOptions(NewTermQuery("test"), 10, 10, false) - _, err := MultiSearch(sr, ei1, ei2) + _, err := MultiSearch(context.Background(), sr, ei1, ei2) if err != nil { t.Errorf("unexpected error %v", err) } } +// TestMultiSearchTimeout tests simple timeout cases +// 1. all searches finish successfully before timeout +// 2. no searchers finish before the timeout +// 3. no searches finish before cancellation +func TestMultiSearchTimeout(t *testing.T) { + ei1 := &stubIndex{ + name: "ei1", + checkRequest: func(req *SearchRequest) error { + time.Sleep(50 * time.Millisecond) + return nil + }, + err: nil, + searchResult: &SearchResult{ + Status: &SearchStatus{ + Total: 1, + Successful: 1, + Errors: make(map[string]error), + }, + Total: 1, + Hits: []*search.DocumentMatch{ + &search.DocumentMatch{ + Index: "1", + ID: "a", + Score: 1.0, + }, + }, + MaxScore: 1.0, + }} + ei2 := &stubIndex{ + name: "ei2", + checkRequest: func(req *SearchRequest) error { + time.Sleep(50 * time.Millisecond) + return nil + }, + err: nil, + searchResult: &SearchResult{ + Status: &SearchStatus{ + Total: 1, + Successful: 1, + Errors: make(map[string]error), + }, + Total: 1, + Hits: []*search.DocumentMatch{ + &search.DocumentMatch{ + Index: "2", + ID: "b", + Score: 2.0, + }, + }, + MaxScore: 2.0, + }} + + // first run with absurdly long time out, should succeed + ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + query := NewTermQuery("test") + sr := NewSearchRequest(query) + res, err := MultiSearch(ctx, sr, ei1, ei2) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if res.Status.Total != 2 { + t.Errorf("expected 2 total, got %d", res.Status.Failed) + } + if res.Status.Successful != 2 { + t.Errorf("expected 0 success, got %d", res.Status.Successful) + } + if res.Status.Failed != 0 { + t.Errorf("expected 2 failed, got %d", res.Status.Failed) + } + if len(res.Status.Errors) != 0 { + t.Errorf("expected 0 errors, got %v", res.Status.Errors) + } + + // now run a search again with an absurdly low timeout (should timeout) + ctx, _ = context.WithTimeout(context.Background(), 1*time.Microsecond) + res, err = MultiSearch(ctx, sr, ei1, ei2) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if res.Status.Total != 2 { + t.Errorf("expected 2 failed, got %d", res.Status.Failed) + } + if res.Status.Successful != 0 { + t.Errorf("expected 0 success, got %d", res.Status.Successful) + } + if res.Status.Failed != 2 { + t.Errorf("expected 2 failed, got %d", res.Status.Failed) + } + if len(res.Status.Errors) != 2 { + t.Errorf("expected 2 errors, got %v", res.Status.Errors) + } else { + if res.Status.Errors["ei1"].Error() != context.DeadlineExceeded.Error() { + t.Errorf("expected err for 'ei1' to be '%s' got '%s'", context.DeadlineExceeded.Error(), res.Status.Errors["ei1"]) + } + if res.Status.Errors["ei2"].Error() != context.DeadlineExceeded.Error() { + t.Errorf("expected err for 'ei2' to be '%s' got '%s'", context.DeadlineExceeded.Error(), res.Status.Errors["ei2"]) + } + } + + // now run a search again with a normal timeout, but cancel it first + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + cancel() + res, err = MultiSearch(ctx, sr, ei1, ei2) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if res.Status.Total != 2 { + t.Errorf("expected 2 failed, got %d", res.Status.Failed) + } + if res.Status.Successful != 0 { + t.Errorf("expected 0 success, got %d", res.Status.Successful) + } + if res.Status.Failed != 2 { + t.Errorf("expected 2 failed, got %d", res.Status.Failed) + } + if len(res.Status.Errors) != 2 { + t.Errorf("expected 2 errors, got %v", res.Status.Errors) + } else { + if res.Status.Errors["ei1"].Error() != context.Canceled.Error() { + t.Errorf("expected err for 'ei1' to be '%s' got '%s'", context.Canceled.Error(), res.Status.Errors["ei1"]) + } + if res.Status.Errors["ei2"].Error() != context.Canceled.Error() { + t.Errorf("expected err for 'ei2' to be '%s' got '%s'", context.Canceled.Error(), res.Status.Errors["ei2"]) + } + } +} + +// TestMultiSearchTimeoutPartial tests the case where some indexes exceed +// the timeout, while others complete successfully +func TestMultiSearchTimeoutPartial(t *testing.T) { + ei1 := &stubIndex{ + name: "ei1", + err: nil, + searchResult: &SearchResult{ + Status: &SearchStatus{ + Total: 1, + Successful: 1, + Errors: make(map[string]error), + }, + Total: 1, + Hits: []*search.DocumentMatch{ + &search.DocumentMatch{ + Index: "1", + ID: "a", + Score: 1.0, + }, + }, + MaxScore: 1.0, + }} + ei2 := &stubIndex{ + name: "ei2", + err: nil, + searchResult: &SearchResult{ + Status: &SearchStatus{ + Total: 1, + Successful: 1, + Errors: make(map[string]error), + }, + Total: 1, + Hits: []*search.DocumentMatch{ + &search.DocumentMatch{ + Index: "2", + ID: "b", + Score: 2.0, + }, + }, + MaxScore: 2.0, + }} + + ei3 := &stubIndex{ + name: "ei3", + checkRequest: func(req *SearchRequest) error { + time.Sleep(50 * time.Millisecond) + return nil + }, + err: nil, + searchResult: &SearchResult{ + Status: &SearchStatus{ + Total: 1, + Successful: 1, + Errors: make(map[string]error), + }, + Total: 1, + Hits: []*search.DocumentMatch{ + &search.DocumentMatch{ + Index: "3", + ID: "c", + Score: 3.0, + }, + }, + MaxScore: 3.0, + }} + + // ei3 is set to take >50ms, so run search with timeout less than + // this, this should return partial results + ctx, _ := context.WithTimeout(context.Background(), 25*time.Millisecond) + query := NewTermQuery("test") + sr := NewSearchRequest(query) + expected := &SearchResult{ + Status: &SearchStatus{ + Total: 3, + Successful: 2, + Failed: 1, + Errors: map[string]error{ + "ei3": context.DeadlineExceeded, + }, + }, + Request: sr, + Total: 2, + Hits: search.DocumentMatchCollection{ + &search.DocumentMatch{ + Index: "2", + ID: "b", + Score: 2.0, + }, + &search.DocumentMatch{ + Index: "1", + ID: "a", + Score: 1.0, + }, + }, + MaxScore: 2.0, + } + + res, err := MultiSearch(ctx, sr, ei1, ei2, ei3) + if err != nil { + t.Fatalf("expected no err, got %v", err) + } + expected.Took = res.Took + if !reflect.DeepEqual(res, expected) { + t.Errorf("expected %#v, got %#v", expected, res) + } +} + +func TestIndexAliasMultipleLayer(t *testing.T) { + ei1 := &stubIndex{ + name: "ei1", + err: nil, + searchResult: &SearchResult{ + Status: &SearchStatus{ + Total: 1, + Successful: 1, + Errors: make(map[string]error), + }, + Total: 1, + Hits: []*search.DocumentMatch{ + &search.DocumentMatch{ + Index: "1", + ID: "a", + Score: 1.0, + }, + }, + MaxScore: 1.0, + }} + ei2 := &stubIndex{ + name: "ei2", + checkRequest: func(req *SearchRequest) error { + time.Sleep(50 * time.Millisecond) + return nil + }, + err: nil, + searchResult: &SearchResult{ + Status: &SearchStatus{ + Total: 1, + Successful: 1, + Errors: make(map[string]error), + }, + Total: 1, + Hits: []*search.DocumentMatch{ + &search.DocumentMatch{ + Index: "2", + ID: "b", + Score: 2.0, + }, + }, + MaxScore: 2.0, + }} + + ei3 := &stubIndex{ + name: "ei3", + checkRequest: func(req *SearchRequest) error { + time.Sleep(50 * time.Millisecond) + return nil + }, + err: nil, + searchResult: &SearchResult{ + Status: &SearchStatus{ + Total: 1, + Successful: 1, + Errors: make(map[string]error), + }, + Total: 1, + Hits: []*search.DocumentMatch{ + &search.DocumentMatch{ + Index: "3", + ID: "c", + Score: 3.0, + }, + }, + MaxScore: 3.0, + }} + + ei4 := &stubIndex{ + name: "ei4", + err: nil, + searchResult: &SearchResult{ + Status: &SearchStatus{ + Total: 1, + Successful: 1, + Errors: make(map[string]error), + }, + Total: 1, + Hits: []*search.DocumentMatch{ + &search.DocumentMatch{ + Index: "4", + ID: "d", + Score: 4.0, + }, + }, + MaxScore: 4.0, + }} + + alias1 := NewIndexAlias(ei1, ei2) + alias2 := NewIndexAlias(ei3, ei4) + aliasTop := NewIndexAlias(alias1, alias2) + + // ei2 and ei3 have 50ms delay + // search across aliasTop should still get results from ei1 and ei4 + // total should still be 4 + + ctx, _ := context.WithTimeout(context.Background(), 25*time.Millisecond) + query := NewTermQuery("test") + sr := NewSearchRequest(query) + expected := &SearchResult{ + Status: &SearchStatus{ + Total: 4, + Successful: 2, + Failed: 2, + Errors: map[string]error{ + "ei2": context.DeadlineExceeded, + "ei3": context.DeadlineExceeded, + }, + }, + Request: sr, + Total: 2, + Hits: search.DocumentMatchCollection{ + &search.DocumentMatch{ + Index: "4", + ID: "d", + Score: 4.0, + }, + &search.DocumentMatch{ + Index: "1", + ID: "a", + Score: 1.0, + }, + }, + MaxScore: 4.0, + } + + res, err := aliasTop.SearchInContext(ctx, sr) + if err != nil { + t.Fatalf("expected no err, got %v", err) + } + expected.Took = res.Took + if !reflect.DeepEqual(res, expected) { + t.Errorf("expected %#v, got %#v", expected, res) + } +} + // stubIndex is an Index impl for which all operations // return the configured error value, unless the // corresponding operation result value has been @@ -811,6 +1183,10 @@ func (i *stubIndex) DocCount() (uint64, error) { } func (i *stubIndex) Search(req *SearchRequest) (*SearchResult, error) { + return i.SearchInContext(context.Background(), req) +} + +func (i *stubIndex) SearchInContext(ctx context.Context, req *SearchRequest) (*SearchResult, error) { if i.checkRequest != nil { err := i.checkRequest(req) if err != nil { diff --git a/index_impl.go b/index_impl.go index 57832376..7c1ad5d5 100644 --- a/index_impl.go +++ b/index_impl.go @@ -17,6 +17,8 @@ import ( "sync/atomic" "time" + "golang.org/x/net/context" + "github.com/blevesearch/bleve/document" "github.com/blevesearch/bleve/index" "github.com/blevesearch/bleve/index/store" @@ -364,6 +366,12 @@ func (i *indexImpl) DocCount() (uint64, error) { // Search executes a search request operation. // Returns a SearchResult object or an error. func (i *indexImpl) Search(req *SearchRequest) (sr *SearchResult, err error) { + return i.SearchInContext(context.Background(), req) +} + +// SearchInContext executes a search request operation within the provided +// Context. Returns a SearchResult object or an error. +func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr *SearchResult, err error) { i.mutex.RLock() defer i.mutex.RUnlock() @@ -424,7 +432,7 @@ func (i *indexImpl) Search(req *SearchRequest) (sr *SearchResult, err error) { collector.SetFacetsBuilder(facetsBuilder) } - err = collector.Collect(searcher) + err = collector.Collect(ctx, searcher) if err != nil { return nil, err } diff --git a/index_test.go b/index_test.go index 7e9c0739..19d9b9fa 100644 --- a/index_test.go +++ b/index_test.go @@ -21,6 +21,8 @@ import ( "testing" "time" + "golang.org/x/net/context" + "encoding/json" "strconv" @@ -1440,3 +1442,57 @@ func TestBooleanFieldMappingIssue109(t *testing.T) { t.Fatal(err) } } + +func TestSearchTimeout(t *testing.T) { + defer func() { + err := os.RemoveAll("testidx") + if err != nil { + t.Fatal(err) + } + }() + + index, err := New("testidx", NewIndexMapping()) + if err != nil { + t.Fatal(err) + } + defer func() { + err := index.Close() + if err != nil { + t.Fatal(err) + } + }() + + // first run a search with an absurdly long timeout (should succeeed) + ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + query := NewTermQuery("water") + req := NewSearchRequest(query) + _, err = index.SearchInContext(ctx, req) + if err != nil { + t.Fatal(err) + } + + // now run a search again with an absurdly low timeout (should timeout) + ctx, _ = context.WithTimeout(context.Background(), 1*time.Microsecond) + sq := &slowQuery{ + actual: query, + delay: 50 * time.Millisecond, // on Windows timer resolution is 15ms + } + req.Query = sq + _, err = index.SearchInContext(ctx, req) + if err != context.DeadlineExceeded { + t.Fatalf("exected %v, got: %v", context.DeadlineExceeded, err) + } + + // now run a search with a long timeout, but with a long query, and cancel it + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + sq = &slowQuery{ + actual: query, + delay: 100 * time.Millisecond, // on Windows timer resolution is 15ms + } + req = NewSearchRequest(sq) + cancel() + _, err = index.SearchInContext(ctx, req) + if err != context.Canceled { + t.Fatalf("exected %v, got: %v", context.Canceled, err) + } +} diff --git a/search/collector.go b/search/collector.go index 86e901af..773c8d55 100644 --- a/search/collector.go +++ b/search/collector.go @@ -11,10 +11,12 @@ package search import ( "time" + + "golang.org/x/net/context" ) type Collector interface { - Collect(searcher Searcher) error + Collect(ctx context.Context, searcher Searcher) error Results() DocumentMatchCollection Total() uint64 MaxScore() float64 diff --git a/search/collectors/collector_top_score.go b/search/collectors/collector_top_score.go index 5a31c3e9..2f00d133 100644 --- a/search/collectors/collector_top_score.go +++ b/search/collectors/collector_top_score.go @@ -13,6 +13,8 @@ import ( "container/list" "time" + "golang.org/x/net/context" + "github.com/blevesearch/bleve/search" ) @@ -54,19 +56,31 @@ func (tksc *TopScoreCollector) Took() time.Duration { return tksc.took } -func (tksc *TopScoreCollector) Collect(searcher search.Searcher) error { +func (tksc *TopScoreCollector) Collect(ctx context.Context, searcher search.Searcher) error { startTime := time.Now() - next, err := searcher.Next() - for err == nil && next != nil { - tksc.collectSingle(next) - if tksc.facetsBuilder != nil { - err = tksc.facetsBuilder.Update(next) - if err != nil { - break - } - } + var err error + var next *search.DocumentMatch + select { + case <-ctx.Done(): + return ctx.Err() + default: next, err = searcher.Next() } + for err == nil && next != nil { + select { + case <-ctx.Done(): + return ctx.Err() + default: + tksc.collectSingle(next) + if tksc.facetsBuilder != nil { + err = tksc.facetsBuilder.Update(next) + if err != nil { + break + } + } + next, err = searcher.Next() + } + } // compute search duration tksc.took = time.Since(startTime) if err != nil { diff --git a/search/collectors/collector_top_score_test.go b/search/collectors/collector_top_score_test.go index ee21140b..d37140ae 100644 --- a/search/collectors/collector_top_score_test.go +++ b/search/collectors/collector_top_score_test.go @@ -14,6 +14,8 @@ import ( "strconv" "testing" + "golang.org/x/net/context" + "github.com/blevesearch/bleve/search" ) @@ -84,7 +86,7 @@ func TestTop10Scores(t *testing.T) { } collector := NewTopScorerCollector(10) - err := collector.Collect(searcher) + err := collector.Collect(context.Background(), searcher) if err != nil { t.Fatal(err) } @@ -192,7 +194,7 @@ func TestTop10ScoresSkip10(t *testing.T) { } collector := NewTopScorerSkipCollector(10, 10) - err := collector.Collect(searcher) + err := collector.Collect(context.Background(), searcher) if err != nil { t.Fatal(err) } @@ -238,7 +240,7 @@ func BenchmarkTop10of100000Scores(b *testing.B) { collector := NewTopScorerCollector(10) b.ResetTimer() - err := collector.Collect(searcher) + err := collector.Collect(context.Background(), searcher) if err != nil { b.Fatal(err) }