0
0
Fork 0

added support for disjunction search

This commit is contained in:
Marty Schoch 2014-04-25 09:31:28 -06:00
parent 5de10307d8
commit 8e71daa4e3
4 changed files with 392 additions and 0 deletions

View File

@ -0,0 +1,37 @@
// Copyright (c) 2014 Couchbase, Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
// except in compliance with the License. You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the
// License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
// either express or implied. See the License for the specific language governing permissions
// and limitations under the License.
package search
import (
"fmt"
"github.com/couchbaselabs/bleve/index"
)
type TermDisjunctionQuery struct {
Terms []Query `json:"terms"`
BoostVal float64 `json:"boost"`
Explain bool `json:"explain"`
Min float64 `json:"min"`
}
func (q *TermDisjunctionQuery) Boost() float64 {
return q.BoostVal
}
func (q *TermDisjunctionQuery) Searcher(index index.Index) (Searcher, error) {
return NewTermDisjunctionSearcher(index, q)
}
func (q *TermDisjunctionQuery) Validate() error {
if int(q.Min) > len(q.Terms) {
return fmt.Errorf("Minimum clauses in disjunction exceeds total number of clauses")
}
return nil
}

View File

@ -0,0 +1,68 @@
// Copyright (c) 2014 Couchbase, Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
// except in compliance with the License. You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the
// License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
// either express or implied. See the License for the specific language governing permissions
// and limitations under the License.
package search
import (
"fmt"
)
type TermDisjunctionQueryScorer struct {
explain bool
}
func NewTermDisjunctionQueryScorer(explain bool) *TermDisjunctionQueryScorer {
return &TermDisjunctionQueryScorer{
explain: explain,
}
}
func (s *TermDisjunctionQueryScorer) Score(constituents []*DocumentMatch, countMatch, countTotal int) *DocumentMatch {
rv := DocumentMatch{
ID: constituents[0].ID,
}
var sum float64
var childrenExplanations []*Explanation
if s.explain {
childrenExplanations = make([]*Explanation, len(constituents))
}
locations := []FieldTermLocationMap{}
for i, docMatch := range constituents {
sum += docMatch.Score
if s.explain {
childrenExplanations[i] = docMatch.Expl
}
if docMatch.Locations != nil {
locations = append(locations, docMatch.Locations)
}
}
var rawExpl *Explanation
if s.explain {
rawExpl = &Explanation{Value: sum, Message: "sum of:", Children: childrenExplanations}
}
coord := float64(countMatch) / float64(countTotal)
rv.Score = sum * coord
if s.explain {
ce := make([]*Explanation, 2)
ce[0] = rawExpl
ce[1] = &Explanation{Value: coord, Message: fmt.Sprintf("coord(%d/%d)", countMatch, countTotal)}
rv.Expl = &Explanation{Value: rv.Score, Message: "product of:", Children: ce}
}
if len(locations) == 1 {
rv.Locations = locations[0]
} else if len(locations) > 1 {
rv.Locations = mergeLocations(locations)
}
return &rv
}

View File

@ -0,0 +1,174 @@
// Copyright (c) 2014 Couchbase, Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
// except in compliance with the License. You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the
// License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
// either express or implied. See the License for the specific language governing permissions
// and limitations under the License.
package search
import (
"math"
"sort"
"github.com/couchbaselabs/bleve/index"
)
type TermDisjunctionSearcher struct {
index index.Index
searchers OrderedSearcherList
queryNorm float64
currs []*DocumentMatch
currentId string
scorer *TermDisjunctionQueryScorer
min float64
}
func NewTermDisjunctionSearcher(index index.Index, query *TermDisjunctionQuery) (*TermDisjunctionSearcher, error) {
// build the downstream searchres
searchers := make(OrderedSearcherList, len(query.Terms))
for i, termQuery := range query.Terms {
searcher, err := termQuery.Searcher(index)
if err != nil {
return nil, err
}
searchers[i] = searcher
}
// sort the searchers
sort.Sort(sort.Reverse(searchers))
// build our searcher
rv := TermDisjunctionSearcher{
index: index,
searchers: searchers,
currs: make([]*DocumentMatch, len(searchers)),
scorer: NewTermDisjunctionQueryScorer(query.Explain),
min: query.Min,
}
rv.computeQueryNorm()
err := rv.initSearchers()
if err != nil {
return nil, err
}
return &rv, nil
}
func (s *TermDisjunctionSearcher) computeQueryNorm() {
// first calculate sum of squared weights
sumOfSquaredWeights := 0.0
for _, termSearcher := range s.searchers {
sumOfSquaredWeights += termSearcher.Weight()
}
// now compute query norm from this
s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights)
// finally tell all the downsteam searchers the norm
for _, termSearcher := range s.searchers {
termSearcher.SetQueryNorm(s.queryNorm)
}
}
func (s *TermDisjunctionSearcher) initSearchers() error {
var err error
// get all searchers pointing at their first match
for i, termSearcher := range s.searchers {
s.currs[i], err = termSearcher.Next()
if err != nil {
return err
}
}
s.currentId = s.nextSmallestId()
return nil
}
func (s *TermDisjunctionSearcher) nextSmallestId() string {
rv := ""
for _, curr := range s.currs {
if curr != nil && (curr.ID < rv || rv == "") {
rv = curr.ID
}
}
return rv
}
func (s *TermDisjunctionSearcher) Weight() float64 {
var rv float64
for _, searcher := range s.searchers {
rv += searcher.Weight()
}
return rv
}
func (s *TermDisjunctionSearcher) SetQueryNorm(qnorm float64) {
for _, searcher := range s.searchers {
searcher.SetQueryNorm(qnorm)
}
}
func (s *TermDisjunctionSearcher) Next() (*DocumentMatch, error) {
var err error
var rv *DocumentMatch
matching := make([]*DocumentMatch, 0)
found := false
for !found && s.currentId != "" {
for _, curr := range s.currs {
if curr != nil && curr.ID == s.currentId {
matching = append(matching, curr)
}
}
if len(matching) >= int(s.min) {
found = true
// score this match
rv = s.scorer.Score(matching, len(matching), len(s.searchers))
}
// reset matching
matching = make([]*DocumentMatch, 0)
// invoke next on all the matching searchers
for i, curr := range s.currs {
if curr != nil && curr.ID == s.currentId {
searcher := s.searchers[i]
s.currs[i], err = searcher.Next()
if err != nil {
return nil, err
}
}
}
s.currentId = s.nextSmallestId()
}
return rv, nil
}
func (s *TermDisjunctionSearcher) Advance(ID string) (*DocumentMatch, error) {
// get all searchers pointing at their first match
var err error
for i, termSearcher := range s.searchers {
s.currs[i], err = termSearcher.Advance(ID)
if err != nil {
return nil, err
}
}
s.currentId = s.nextSmallestId()
return s.Next()
}
func (s *TermDisjunctionSearcher) Count() uint64 {
// for now return a worst case
var sum uint64 = 0
for _, searcher := range s.searchers {
sum += searcher.Count()
}
return sum
}
func (s *TermDisjunctionSearcher) Close() {
for _, searcher := range s.searchers {
searcher.Close()
}
}

View File

@ -0,0 +1,113 @@
// Copyright (c) 2014 Couchbase, Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
// except in compliance with the License. You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the
// License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
// either express or implied. See the License for the specific language governing permissions
// and limitations under the License.
package search
import (
"testing"
"github.com/couchbaselabs/bleve/index"
)
func TestTermDisjunctionSearch(t *testing.T) {
tests := []struct {
index index.Index
query Query
results []*DocumentMatch
}{
{
index: twoDocIndex,
query: &TermDisjunctionQuery{
Terms: []Query{
&TermQuery{
Term: "marty",
Field: "name",
BoostVal: 1.0,
Explain: true,
},
&TermQuery{
Term: "dustin",
Field: "name",
BoostVal: 1.0,
Explain: true,
},
},
Explain: true,
Min: 0,
},
results: []*DocumentMatch{
&DocumentMatch{
ID: "1",
Score: 0.6775110856165737,
},
&DocumentMatch{
ID: "3",
Score: 0.6775110856165737,
},
},
},
}
for testIndex, test := range tests {
searcher, err := test.query.Searcher(test.index)
defer searcher.Close()
next, err := searcher.Next()
i := 0
for err == nil && next != nil {
if i < len(test.results) {
if next.ID != test.results[i].ID {
t.Errorf("expected result %d to have id %s got %s for test %d", i, test.results[i].ID, next.ID, testIndex)
}
if next.Score != test.results[i].Score {
t.Errorf("expected result %d to have score %v got %v for test %d", i, test.results[i].Score, next.Score, testIndex)
t.Logf("scoring explanation: %s", next.Expl)
}
}
next, err = searcher.Next()
i++
}
if err != nil {
t.Fatalf("error iterating searcher: %v for test %d", err, testIndex)
}
if len(test.results) != i {
t.Errorf("expected %d results got %d for test %d", len(test.results), i, testIndex)
}
}
}
func TestDisjunctionAdvance(t *testing.T) {
query := &TermDisjunctionQuery{
Terms: []Query{
&TermQuery{
Term: "marty",
Field: "name",
BoostVal: 1.0,
Explain: true,
},
&TermQuery{
Term: "dustin",
Field: "name",
BoostVal: 1.0,
Explain: true,
},
},
Explain: true,
Min: 0,
}
searcher, err := query.Searcher(twoDocIndex)
match, err := searcher.Advance("3")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if match == nil {
t.Errorf("expected 3, got nil")
}
}