diff --git a/query_boolean.go b/query_boolean.go index fdd5ccce..5a0fffef 100644 --- a/query_boolean.go +++ b/query_boolean.go @@ -39,32 +39,40 @@ func (q *BooleanQuery) SetBoost(b float64) *BooleanQuery { func (q *BooleanQuery) Searcher(i *indexImpl, explain bool) (search.Searcher, error) { - var err error - var mustSearcher search.Searcher + var mustSearcher *search.ConjunctionSearcher if q.Must != nil { - mustSearcher, err = q.Must.Searcher(i, explain) + ms, err := q.Must.Searcher(i, explain) if err != nil { return nil, err } + if ms != nil { + mustSearcher = ms.(*search.ConjunctionSearcher) + } } - var shouldSearcher search.Searcher + var shouldSearcher *search.DisjunctionSearcher if q.Should != nil { - shouldSearcher, err = q.Should.Searcher(i, explain) + ss, err := q.Should.Searcher(i, explain) if err != nil { return nil, err } + if ss != nil { + shouldSearcher = ss.(*search.DisjunctionSearcher) + } } - var mustNotSearcher search.Searcher + var mustNotSearcher *search.DisjunctionSearcher if q.MustNot != nil { - mustNotSearcher, err = q.MustNot.Searcher(i, explain) + mns, err := q.MustNot.Searcher(i, explain) if err != nil { return nil, err } + if mns != nil { + mustNotSearcher = mns.(*search.DisjunctionSearcher) + } } - return search.NewBooleanSearcher(i.i, mustSearcher.(*search.ConjunctionSearcher), shouldSearcher.(*search.DisjunctionSearcher), mustNotSearcher.(*search.DisjunctionSearcher), explain) + return search.NewBooleanSearcher(i.i, mustSearcher, shouldSearcher, mustNotSearcher, explain) } func (q *BooleanQuery) Validate() error { diff --git a/query_disjunction.go b/query_disjunction.go index 0506aa42..aa38147e 100644 --- a/query_disjunction.go +++ b/query_disjunction.go @@ -50,7 +50,7 @@ func (q *DisjunctionQuery) SetMin(m float64) *DisjunctionQuery { return q } -func (q *DisjunctionQuery) Searcher(i *indexImpl, explain bool) (*search.DisjunctionSearcher, error) { +func (q *DisjunctionQuery) Searcher(i *indexImpl, explain bool) (search.Searcher, error) { searchers := make([]search.Searcher, len(q.Disjuncts)) for in, disjunct := range q.Disjuncts { var err error