diff --git a/index_impl.go b/index_impl.go index 1036aef2..68777f07 100644 --- a/index_impl.go +++ b/index_impl.go @@ -50,9 +50,11 @@ const storePath = "store" var mappingInternalKey = []byte("_mapping") -const SearchMemCheckCallbackKey = "_search_mem_callback_key" +const SearchQueryStartCallbackKey = "_search_query_start_callback_key" +const SearchQueryEndCallbackKey = "_search_query_end_callback_key" -type SearchMemCheckCallbackFn func(size uint64) error +type SearchQueryStartCallbackFn func(size uint64) error +type SearchQueryEndCallbackFn func(size uint64) error func indexStorePath(path string) string { return path + string(os.PathSeparator) + storePath @@ -483,15 +485,24 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr collector.SetFacetsBuilder(facetsBuilder) } - if memCb := ctx.Value(SearchMemCheckCallbackKey); memCb != nil { - if memCbFn, ok := memCb.(SearchMemCheckCallbackFn); ok { - err = memCbFn(memNeededForSearch(req, searcher, collector)) + memNeeded := memNeededForSearch(req, searcher, collector) + if cb := ctx.Value(SearchQueryStartCallbackKey); cb != nil { + if cbF, ok := cb.(SearchQueryStartCallbackFn); ok { + err = cbF(memNeeded) } } if err != nil { return nil, err } + if cb := ctx.Value(SearchQueryEndCallbackKey); cb != nil { + if cbF, ok := cb.(SearchQueryEndCallbackFn); ok { + defer func() { + _ = cbF(memNeeded) + }() + } + } + err = collector.Collect(ctx, searcher, indexReader) if err != nil { return nil, err diff --git a/index_test.go b/index_test.go index 57429dcb..69ca61a9 100644 --- a/index_test.go +++ b/index_test.go @@ -1871,7 +1871,7 @@ func BenchmarkScorchSearchOverhead(b *testing.B) { benchmarkSearchOverhead(scorch.Name, b) } -func TestSearchMemCheckCallback(t *testing.T) { +func TestSearchQueryCallback(t *testing.T) { defer func() { err := os.RemoveAll("testidx") if err != nil { @@ -1910,8 +1910,8 @@ func TestSearchMemCheckCallback(t *testing.T) { return nil } - ctx := context.WithValue(context.Background(), SearchMemCheckCallbackKey, - SearchMemCheckCallbackFn(f)) + ctx := context.WithValue(context.Background(), SearchQueryStartCallbackKey, + SearchQueryStartCallbackFn(f)) _, err = index.SearchInContext(ctx, req) if err != expErr { t.Fatalf("Expected: %v, Got: %v", expErr, err)