added support for disjunction search
This commit is contained in:
parent
5de10307d8
commit
8e71daa4e3
|
@ -0,0 +1,37 @@
|
|||
// Copyright (c) 2014 Couchbase, Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
|
||||
// except in compliance with the License. You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the
|
||||
// License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
||||
// either express or implied. See the License for the specific language governing permissions
|
||||
// and limitations under the License.
|
||||
package search
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/couchbaselabs/bleve/index"
|
||||
)
|
||||
|
||||
type TermDisjunctionQuery struct {
|
||||
Terms []Query `json:"terms"`
|
||||
BoostVal float64 `json:"boost"`
|
||||
Explain bool `json:"explain"`
|
||||
Min float64 `json:"min"`
|
||||
}
|
||||
|
||||
func (q *TermDisjunctionQuery) Boost() float64 {
|
||||
return q.BoostVal
|
||||
}
|
||||
|
||||
func (q *TermDisjunctionQuery) Searcher(index index.Index) (Searcher, error) {
|
||||
return NewTermDisjunctionSearcher(index, q)
|
||||
}
|
||||
|
||||
func (q *TermDisjunctionQuery) Validate() error {
|
||||
if int(q.Min) > len(q.Terms) {
|
||||
return fmt.Errorf("Minimum clauses in disjunction exceeds total number of clauses")
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
// Copyright (c) 2014 Couchbase, Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
|
||||
// except in compliance with the License. You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the
|
||||
// License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
||||
// either express or implied. See the License for the specific language governing permissions
|
||||
// and limitations under the License.
|
||||
package search
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type TermDisjunctionQueryScorer struct {
|
||||
explain bool
|
||||
}
|
||||
|
||||
func NewTermDisjunctionQueryScorer(explain bool) *TermDisjunctionQueryScorer {
|
||||
return &TermDisjunctionQueryScorer{
|
||||
explain: explain,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TermDisjunctionQueryScorer) Score(constituents []*DocumentMatch, countMatch, countTotal int) *DocumentMatch {
|
||||
rv := DocumentMatch{
|
||||
ID: constituents[0].ID,
|
||||
}
|
||||
|
||||
var sum float64
|
||||
var childrenExplanations []*Explanation
|
||||
if s.explain {
|
||||
childrenExplanations = make([]*Explanation, len(constituents))
|
||||
}
|
||||
|
||||
locations := []FieldTermLocationMap{}
|
||||
for i, docMatch := range constituents {
|
||||
sum += docMatch.Score
|
||||
if s.explain {
|
||||
childrenExplanations[i] = docMatch.Expl
|
||||
}
|
||||
if docMatch.Locations != nil {
|
||||
locations = append(locations, docMatch.Locations)
|
||||
}
|
||||
}
|
||||
|
||||
var rawExpl *Explanation
|
||||
if s.explain {
|
||||
rawExpl = &Explanation{Value: sum, Message: "sum of:", Children: childrenExplanations}
|
||||
}
|
||||
|
||||
coord := float64(countMatch) / float64(countTotal)
|
||||
rv.Score = sum * coord
|
||||
if s.explain {
|
||||
ce := make([]*Explanation, 2)
|
||||
ce[0] = rawExpl
|
||||
ce[1] = &Explanation{Value: coord, Message: fmt.Sprintf("coord(%d/%d)", countMatch, countTotal)}
|
||||
rv.Expl = &Explanation{Value: rv.Score, Message: "product of:", Children: ce}
|
||||
}
|
||||
|
||||
if len(locations) == 1 {
|
||||
rv.Locations = locations[0]
|
||||
} else if len(locations) > 1 {
|
||||
rv.Locations = mergeLocations(locations)
|
||||
}
|
||||
|
||||
return &rv
|
||||
}
|
|
@ -0,0 +1,174 @@
|
|||
// Copyright (c) 2014 Couchbase, Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
|
||||
// except in compliance with the License. You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the
|
||||
// License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
||||
// either express or implied. See the License for the specific language governing permissions
|
||||
// and limitations under the License.
|
||||
package search
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sort"
|
||||
|
||||
"github.com/couchbaselabs/bleve/index"
|
||||
)
|
||||
|
||||
type TermDisjunctionSearcher struct {
|
||||
index index.Index
|
||||
searchers OrderedSearcherList
|
||||
queryNorm float64
|
||||
currs []*DocumentMatch
|
||||
currentId string
|
||||
scorer *TermDisjunctionQueryScorer
|
||||
min float64
|
||||
}
|
||||
|
||||
func NewTermDisjunctionSearcher(index index.Index, query *TermDisjunctionQuery) (*TermDisjunctionSearcher, error) {
|
||||
// build the downstream searchres
|
||||
searchers := make(OrderedSearcherList, len(query.Terms))
|
||||
for i, termQuery := range query.Terms {
|
||||
searcher, err := termQuery.Searcher(index)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
searchers[i] = searcher
|
||||
}
|
||||
// sort the searchers
|
||||
sort.Sort(sort.Reverse(searchers))
|
||||
// build our searcher
|
||||
rv := TermDisjunctionSearcher{
|
||||
index: index,
|
||||
searchers: searchers,
|
||||
currs: make([]*DocumentMatch, len(searchers)),
|
||||
scorer: NewTermDisjunctionQueryScorer(query.Explain),
|
||||
min: query.Min,
|
||||
}
|
||||
rv.computeQueryNorm()
|
||||
err := rv.initSearchers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &rv, nil
|
||||
}
|
||||
|
||||
func (s *TermDisjunctionSearcher) computeQueryNorm() {
|
||||
// first calculate sum of squared weights
|
||||
sumOfSquaredWeights := 0.0
|
||||
for _, termSearcher := range s.searchers {
|
||||
sumOfSquaredWeights += termSearcher.Weight()
|
||||
}
|
||||
// now compute query norm from this
|
||||
s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights)
|
||||
// finally tell all the downsteam searchers the norm
|
||||
for _, termSearcher := range s.searchers {
|
||||
termSearcher.SetQueryNorm(s.queryNorm)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TermDisjunctionSearcher) initSearchers() error {
|
||||
var err error
|
||||
// get all searchers pointing at their first match
|
||||
for i, termSearcher := range s.searchers {
|
||||
s.currs[i], err = termSearcher.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.currentId = s.nextSmallestId()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TermDisjunctionSearcher) nextSmallestId() string {
|
||||
rv := ""
|
||||
for _, curr := range s.currs {
|
||||
if curr != nil && (curr.ID < rv || rv == "") {
|
||||
rv = curr.ID
|
||||
}
|
||||
}
|
||||
return rv
|
||||
}
|
||||
|
||||
func (s *TermDisjunctionSearcher) Weight() float64 {
|
||||
var rv float64
|
||||
for _, searcher := range s.searchers {
|
||||
rv += searcher.Weight()
|
||||
}
|
||||
return rv
|
||||
}
|
||||
|
||||
func (s *TermDisjunctionSearcher) SetQueryNorm(qnorm float64) {
|
||||
for _, searcher := range s.searchers {
|
||||
searcher.SetQueryNorm(qnorm)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TermDisjunctionSearcher) Next() (*DocumentMatch, error) {
|
||||
var err error
|
||||
var rv *DocumentMatch
|
||||
matching := make([]*DocumentMatch, 0)
|
||||
|
||||
found := false
|
||||
for !found && s.currentId != "" {
|
||||
for _, curr := range s.currs {
|
||||
if curr != nil && curr.ID == s.currentId {
|
||||
matching = append(matching, curr)
|
||||
}
|
||||
}
|
||||
|
||||
if len(matching) >= int(s.min) {
|
||||
found = true
|
||||
// score this match
|
||||
rv = s.scorer.Score(matching, len(matching), len(s.searchers))
|
||||
}
|
||||
|
||||
// reset matching
|
||||
matching = make([]*DocumentMatch, 0)
|
||||
// invoke next on all the matching searchers
|
||||
for i, curr := range s.currs {
|
||||
if curr != nil && curr.ID == s.currentId {
|
||||
searcher := s.searchers[i]
|
||||
s.currs[i], err = searcher.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
s.currentId = s.nextSmallestId()
|
||||
}
|
||||
return rv, nil
|
||||
}
|
||||
|
||||
func (s *TermDisjunctionSearcher) Advance(ID string) (*DocumentMatch, error) {
|
||||
|
||||
// get all searchers pointing at their first match
|
||||
var err error
|
||||
for i, termSearcher := range s.searchers {
|
||||
s.currs[i], err = termSearcher.Advance(ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
s.currentId = s.nextSmallestId()
|
||||
|
||||
return s.Next()
|
||||
}
|
||||
|
||||
func (s *TermDisjunctionSearcher) Count() uint64 {
|
||||
// for now return a worst case
|
||||
var sum uint64 = 0
|
||||
for _, searcher := range s.searchers {
|
||||
sum += searcher.Count()
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
func (s *TermDisjunctionSearcher) Close() {
|
||||
for _, searcher := range s.searchers {
|
||||
searcher.Close()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
// Copyright (c) 2014 Couchbase, Inc.
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
|
||||
// except in compliance with the License. You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the
|
||||
// License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
||||
// either express or implied. See the License for the specific language governing permissions
|
||||
// and limitations under the License.
|
||||
package search
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/couchbaselabs/bleve/index"
|
||||
)
|
||||
|
||||
func TestTermDisjunctionSearch(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
index index.Index
|
||||
query Query
|
||||
results []*DocumentMatch
|
||||
}{
|
||||
{
|
||||
index: twoDocIndex,
|
||||
query: &TermDisjunctionQuery{
|
||||
Terms: []Query{
|
||||
&TermQuery{
|
||||
Term: "marty",
|
||||
Field: "name",
|
||||
BoostVal: 1.0,
|
||||
Explain: true,
|
||||
},
|
||||
&TermQuery{
|
||||
Term: "dustin",
|
||||
Field: "name",
|
||||
BoostVal: 1.0,
|
||||
Explain: true,
|
||||
},
|
||||
},
|
||||
Explain: true,
|
||||
Min: 0,
|
||||
},
|
||||
results: []*DocumentMatch{
|
||||
&DocumentMatch{
|
||||
ID: "1",
|
||||
Score: 0.6775110856165737,
|
||||
},
|
||||
&DocumentMatch{
|
||||
ID: "3",
|
||||
Score: 0.6775110856165737,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for testIndex, test := range tests {
|
||||
searcher, err := test.query.Searcher(test.index)
|
||||
defer searcher.Close()
|
||||
|
||||
next, err := searcher.Next()
|
||||
i := 0
|
||||
for err == nil && next != nil {
|
||||
if i < len(test.results) {
|
||||
if next.ID != test.results[i].ID {
|
||||
t.Errorf("expected result %d to have id %s got %s for test %d", i, test.results[i].ID, next.ID, testIndex)
|
||||
}
|
||||
if next.Score != test.results[i].Score {
|
||||
t.Errorf("expected result %d to have score %v got %v for test %d", i, test.results[i].Score, next.Score, testIndex)
|
||||
t.Logf("scoring explanation: %s", next.Expl)
|
||||
}
|
||||
}
|
||||
next, err = searcher.Next()
|
||||
i++
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("error iterating searcher: %v for test %d", err, testIndex)
|
||||
}
|
||||
if len(test.results) != i {
|
||||
t.Errorf("expected %d results got %d for test %d", len(test.results), i, testIndex)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisjunctionAdvance(t *testing.T) {
|
||||
query := &TermDisjunctionQuery{
|
||||
Terms: []Query{
|
||||
&TermQuery{
|
||||
Term: "marty",
|
||||
Field: "name",
|
||||
BoostVal: 1.0,
|
||||
Explain: true,
|
||||
},
|
||||
&TermQuery{
|
||||
Term: "dustin",
|
||||
Field: "name",
|
||||
BoostVal: 1.0,
|
||||
Explain: true,
|
||||
},
|
||||
},
|
||||
Explain: true,
|
||||
Min: 0,
|
||||
}
|
||||
|
||||
searcher, err := query.Searcher(twoDocIndex)
|
||||
match, err := searcher.Advance("3")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if match == nil {
|
||||
t.Errorf("expected 3, got nil")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue