diff --git a/geo/geo_dist.go b/geo/geo_dist.go index 859e6a7b..d3ae0ed9 100644 --- a/geo/geo_dist.go +++ b/geo/geo_dist.go @@ -15,6 +15,7 @@ package geo import ( + "fmt" "math" "strconv" "strings" @@ -67,9 +68,23 @@ func ParseDistance(d string) (float64, error) { return parsedNum, nil } +// ParseDistanceUnit attempts to parse a distance unit and return the +// multiplier for converting this to meters. If the unit cannot be parsed +// then 0 and the error message is returned. +func ParseDistanceUnit(u string) (float64, error) { + for _, unit := range distanceUnits { + for _, unitSuffix := range unit.suffixes { + if u == unitSuffix { + return unit.conv, nil + } + } + } + return 0, fmt.Errorf("unknown distance unit: %s", u) +} + // Haversin computes the distance between two points. // This implemenation uses the sloppy math implemenations which trade off -// accuracy for performance. +// accuracy for performance. The distance returned is in kilometers. func Haversin(lon1, lat1, lon2, lat2 float64) float64 { x1 := lat1 * degreesToRadian x2 := lat2 * degreesToRadian diff --git a/geo/geo_dist_test.go b/geo/geo_dist_test.go index 8067da8a..5c8abff4 100644 --- a/geo/geo_dist_test.go +++ b/geo/geo_dist_test.go @@ -15,6 +15,7 @@ package geo import ( + "fmt" "math" "reflect" "strconv" @@ -46,6 +47,30 @@ func TestParseDistance(t *testing.T) { } } +func TestParseDistanceUnit(t *testing.T) { + tests := []struct { + dist string + want float64 + wantErr error + }{ + {"mi", 1609.344, nil}, + {"m", 1, nil}, + {"km", 1000, nil}, + {"", 0, fmt.Errorf("unknown distance unit: ")}, + {"kam", 0, fmt.Errorf("unknown distance unit: kam")}, + } + + for _, test := range tests { + got, err := ParseDistanceUnit(test.dist) + if !reflect.DeepEqual(err, test.wantErr) { + t.Errorf("expected err: %v, got %v for %s", test.wantErr, err, test.dist) + } + if got != test.want { + t.Errorf("expected distance %f got %f for %s", test.want, got, test.dist) + } + } +} + func TestHaversinDistance(t *testing.T) { earthRadiusKMs := 6378.137 halfCircle := earthRadiusKMs * math.Pi diff --git a/geo/sloppy.go b/geo/sloppy.go index ad7306c6..a0f5a366 100644 --- a/geo/sloppy.go +++ b/geo/sloppy.go @@ -137,7 +137,7 @@ func init() { } // earthDiameter returns an estimation of the earth's diameter at the specified -// latitude +// latitude in kilometers func earthDiameter(lat float64) float64 { index := math.Mod(math.Abs(lat)*radiusIndexer+0.5, float64(len(earthDiameterPerLatitude))) if math.IsNaN(index) { diff --git a/search/sort.go b/search/sort.go index 751eec9d..18d51114 100644 --- a/search/sort.go +++ b/search/sort.go @@ -62,12 +62,22 @@ func ParseSearchSortObj(input map[string]interface{}) (SearchSort, error) { if !foundLocation { return nil, fmt.Errorf("unable to parse geo_distance location") } - return &SortGeoDistance{ - Field: field, - Desc: descending, - lon: lon, - lat: lat, - }, nil + rvd := &SortGeoDistance{ + Field: field, + Desc: descending, + lon: lon, + lat: lat, + unitMult: 1.0, + } + if distUnit, ok := input["unit"].(string); ok { + var err error + rvd.unitMult, err = geo.ParseDistanceUnit(distUnit) + if err != nil { + return nil, err + } + rvd.Unit = distUnit + } + return rvd, nil case "field": field, ok := input["field"].(string) if !ok { @@ -546,11 +556,13 @@ var maxDistance = string(numeric.MustNewPrefixCodedInt64(math.MaxInt64, 0)) // Field is the name of the field // Descending reverse the sort order (default false) type SortGeoDistance struct { - Field string - Desc bool - values []string - lon float64 - lat float64 + Field string + Desc bool + Unit string + values []string + lon float64 + lat float64 + unitMult float64 } // UpdateVisitor notifies this sort field that in this document @@ -581,7 +593,13 @@ func (s *SortGeoDistance) Value(i *DocumentMatch) string { docLat := geo.MortonUnhashLat(uint64(i64)) dist := geo.Haversin(s.lon, s.lat, docLon, docLat) - return string(numeric.MustNewPrefixCodedInt64(int64(dist), 0)) + // dist is returned in km, so convert to m + dist *= 1000 + if s.unitMult != 0 { + dist /= s.unitMult + } + distInt64 := numeric.Float64ToInt64(dist) + return string(numeric.MustNewPrefixCodedInt64(distInt64, 0)) } // Descending determines the order of the sort @@ -628,6 +646,9 @@ func (s *SortGeoDistance) MarshalJSON() ([]byte, error) { "lat": s.lat, }, } + if s.Unit != "" { + sfm["unit"] = s.Unit + } if s.Desc { sfm["desc"] = true }