package types import ( "bytes" "database/sql/driver" "fmt" "net" "strconv" ) type ( // Subnet is used to parse a subnet parameter. Subnet net.IPNet // IP is used to parse an IP parameter. IP net.IP // IPVersion represents the two IP versions currently in use. IPVersion int ) // UnmarshalJSON parses a value into a subnet. // // It is also checked if the provided IP matches the network address // of the subnet. func (s *Subnet) UnmarshalJSON(in []byte) error { in = bytes.Trim(in, `"`) ip, ipnet, err := net.ParseCIDR(string(in)) if err != nil { return fmt.Errorf("not a valid subnet: %#v", err) } if !ipnet.IP.Equal(ip) { return fmt.Errorf("provided IP '%s' is not network address '%s' of declared subnet", ip.String(), ipnet.IP.String()) } *s = Subnet(*ipnet) return nil } // String returns the string representation of the subnet. // // The subnet is returned as the subnet address and prefix separated by `/` // as defined in RFC 4632 and RFC 4291. func (s *Subnet) String() string { return (*net.IPNet)(s).String() } func (i IP) Is4() bool { if a := net.IP(i).To4(); a == nil { return false } return true } func (i IP) Is6() bool { return !i.Is4() } // Value implements the database Value interface. // // This function is needed so that a subnet can be inserted into // the database without much casting. func (s Subnet) Value() (driver.Value, error) { return s.String(), nil } func (i *IP) UnmarshalJSON(in []byte) error { in = bytes.Trim(in, `"`) ip := net.ParseIP(string(in)) if ip == nil { return fmt.Errorf("not a valid ip") } *i = IP(ip) return nil } // UnmarshalJSON parses the incoming version from json. func (v *IPVersion) UnmarshalJSON(in []byte) error { raw, err := strconv.Atoi(string(in)) if err != nil { return fmt.Errorf("can't parse ip version: %#v", err) } if raw != 4 && raw != 6 { return fmt.Errorf("only version 4 and 6 are supported") } *v = IPVersion(raw) return nil }