From 96071c085cbdf0beb7d7a74f483b1b2c53f4bd97 Mon Sep 17 00:00:00 2001 From: abhinavdangeti Date: Mon, 5 Mar 2018 16:49:55 -0800 Subject: [PATCH] MB-28163: Register a callback with context to estimate RAM for search This callback if registered with context will invoke the api to estimate the memory needed to execute a search query. The callback defined at the client side will be responsible for determining whether to continue with the search or abort based on the threshold settings. --- index_impl.go | 13 +++++++++++++ index_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/index_impl.go b/index_impl.go index df6e748d..1036aef2 100644 --- a/index_impl.go +++ b/index_impl.go @@ -50,6 +50,10 @@ const storePath = "store" var mappingInternalKey = []byte("_mapping") +const SearchMemCheckCallbackKey = "_search_mem_callback_key" + +type SearchMemCheckCallbackFn func(size uint64) error + func indexStorePath(path string) string { return path + string(os.PathSeparator) + storePath } @@ -479,6 +483,15 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr collector.SetFacetsBuilder(facetsBuilder) } + if memCb := ctx.Value(SearchMemCheckCallbackKey); memCb != nil { + if memCbFn, ok := memCb.(SearchMemCheckCallbackFn); ok { + err = memCbFn(memNeededForSearch(req, searcher, collector)) + } + } + if err != nil { + return nil, err + } + err = collector.Collect(ctx, searcher, indexReader) if err != nil { return nil, err diff --git a/index_test.go b/index_test.go index f1e53647..57429dcb 100644 --- a/index_test.go +++ b/index_test.go @@ -1870,3 +1870,50 @@ func BenchmarkUpsidedownSearchOverhead(b *testing.B) { func BenchmarkScorchSearchOverhead(b *testing.B) { benchmarkSearchOverhead(scorch.Name, b) } + +func TestSearchMemCheckCallback(t *testing.T) { + defer func() { + err := os.RemoveAll("testidx") + if err != nil { + t.Fatal(err) + } + }() + + index, err := New("testidx", NewIndexMapping()) + if err != nil { + t.Fatal(err) + } + defer func() { + err := index.Close() + if err != nil { + t.Fatal(err) + } + }() + + elements := []string{"air", "water", "fire", "earth"} + for j := 0; j < 10000; j++ { + err = index.Index(fmt.Sprintf("%d", j), + map[string]interface{}{"name": elements[j%len(elements)]}) + if err != nil { + t.Fatal(err) + } + } + + query := NewTermQuery("water") + req := NewSearchRequest(query) + + expErr := fmt.Errorf("MEM_LIMIT_EXCEEDED") + f := func(size uint64) error { + if size > 1000 { + return expErr + } + return nil + } + + ctx := context.WithValue(context.Background(), SearchMemCheckCallbackKey, + SearchMemCheckCallbackFn(f)) + _, err = index.SearchInContext(ctx, req) + if err != expErr { + t.Fatalf("Expected: %v, Got: %v", expErr, err) + } +}