diff --git a/cmd/pkiadm/serial.go b/cmd/pkiadm/serial.go new file mode 100644 index 0000000..ea1ad9a --- /dev/null +++ b/cmd/pkiadm/serial.go @@ -0,0 +1,100 @@ +package main + +import ( + "fmt" + "math" + "os" + "text/tabwriter" + + "github.com/gibheer/pkiadm" + "github.com/pkg/errors" + flag "github.com/spf13/pflag" +) + +func createSerial(args []string, client *pkiadm.Client) error { + fs := flag.NewFlagSet("create-private", flag.ExitOnError) + fs.Usage = func() { + fmt.Printf("Usage of %s:\n", "pkiadm create-private") + fmt.Println(`Create a new serial producer for certificate generation. New IDs will be generated by random in the defined limits.`) + fs.PrintDefaults() + } + ser := pkiadm.Serial{} + fs.StringVar(&ser.ID, "id", "", "set the unique id for the new serial") + fs.Int64Var(&ser.Min, "min", 0, "set the minimum id") + fs.Int64Var(&ser.Max, "max", math.MaxInt64, "set the maximum id") + fs.Parse(args) + + if err := client.CreateSerial(ser); err != nil { + return errors.Wrap(err, "could not create serial") + } + + return nil +} +func setSerial(args []string, client *pkiadm.Client) error { + fs := flag.NewFlagSet("set-private", flag.ExitOnError) + ser := pkiadm.Serial{} + fs.StringVar(&ser.ID, "id", "", "set the unique id for the serial to change") + fs.Int64Var(&ser.Min, "min", 0, "set the minimum id") + fs.Int64Var(&ser.Max, "max", math.MaxInt64, "set the maximum id") + fs.Parse(args) + + fieldList := []string{} + for _, field := range []string{"type", "bits"} { + flag := fs.Lookup(field) + if flag.Changed { + fieldList = append(fieldList, field) + } + } + + if err := client.SetSerial(ser, fieldList); err != nil { + return err + } + return nil +} +func deleteSerial(args []string, client *pkiadm.Client) error { + fs := flag.NewFlagSet("delete-private", flag.ExitOnError) + var id = fs.String("id", "", "set the id of the serial to delete") + fs.Parse(args) + + if err := client.DeleteSerial(*id); err != nil { + return err + } + return nil +} +func listSerial(args []string, client *pkiadm.Client) error { + fs := flag.NewFlagSet("list-private", flag.ExitOnError) + fs.Parse(args) + + sers, err := client.ListSerial() + if err != nil { + return err + } + + if len(sers) == 0 { + return nil + } + out := tabwriter.NewWriter(os.Stdout, 2, 2, 1, ' ', tabwriter.AlignRight) + fmt.Fprintf(out, "%s\t%s\t%s\t\n", "id", "min", "max") + for _, ser := range sers { + fmt.Fprintf(out, "%s\t%d\t%d\t\n", ser.ID, ser.Min, ser.Max) + } + out.Flush() + + return nil +} +func showSerial(args []string, client *pkiadm.Client) error { + fs := flag.NewFlagSet("show-private", flag.ExitOnError) + var id = fs.String("id", "", "set the id of the serial to show") + fs.Parse(args) + + ser, err := client.ShowSerial(*id) + if err != nil { + return err + } + out := tabwriter.NewWriter(os.Stdout, 2, 2, 1, ' ', tabwriter.AlignRight) + fmt.Fprintf(out, "ID:\t%s\t\n", ser.ID) + fmt.Fprintf(out, "min:\t%d\t\n", ser.Min) + fmt.Fprintf(out, "max:\t%d\t\n", ser.Max) + out.Flush() + return nil +} diff --git a/cmd/pkiadmd/serial.go b/cmd/pkiadmd/serial.go index ff3750b..296f985 100644 --- a/cmd/pkiadmd/serial.go +++ b/cmd/pkiadmd/serial.go @@ -2,6 +2,7 @@ package main import ( "crypto/rand" + "fmt" "math/big" "github.com/gibheer/pkiadm" @@ -50,9 +51,102 @@ func (s *Serial) DependsOn() []pkiadm.ResourceName { return []pkiadm.ResourceNam // Generate generates a new serial number and stores it to avoid double // assigning. func (s *Serial) Generate() (*big.Int, error) { - val, err := rand.Int(rand.Reader, big.NewInt(s.Max-s.Min)) - if err != nil { - return big.NewInt(-1), err + for { + val, err := rand.Int(rand.Reader, big.NewInt(s.Max-s.Min)) + if err != nil { + return big.NewInt(-1), err + } + if _, found := s.UsedIDs[val.Int64()]; !found { + s.UsedIDs[val.Int64()] = true + return big.NewInt(val.Int64() + s.Min), nil + } } - return big.NewInt(val.Int64() + s.Min), nil +} + +func (s *Server) CreateSerial(inSer pkiadm.Serial, res *pkiadm.Result) error { + s.lock() + defer s.unlock() + + ser, err := NewSerial(inSer.ID, inSer.Min, inSer.Max) + if err != nil { + res.SetError(err, "Could not create new serial '%s'", inSer.ID) + return nil + } + if err := s.storage.AddSerial(ser); err != nil { + res.SetError(err, "Could not add serial '%s'", inSer.ID) + return nil + } + return s.store(res) +} +func (s *Server) SetSerial(changeset pkiadm.SerialChange, res *pkiadm.Result) error { + s.lock() + defer s.unlock() + + ser, err := s.storage.GetSerial(pkiadm.ResourceName{ID: changeset.Serial.ID, Type: pkiadm.RTSerial}) + if err != nil { + res.SetError(err, "Could not find serial '%s'", changeset.Serial.ID) + return nil + } + + for _, field := range changeset.FieldList { + switch field { + case "min": + ser.Min = changeset.Serial.Min + case "max": + ser.Max = changeset.Serial.Max + default: + res.SetError(fmt.Errorf("unknown field"), "unknown field '%s'", field) + return nil + } + } + if err := s.storage.Update(pkiadm.ResourceName{ID: ser.ID, Type: pkiadm.RTSerial}); err != nil { + res.SetError(err, "Could not update serial '%s'", changeset.Serial.ID) + return nil + } + return s.store(res) +} +func (s *Server) DeleteSerial(inSer pkiadm.ResourceName, res *pkiadm.Result) error { + s.lock() + defer s.unlock() + + ser, err := s.storage.GetSerial(pkiadm.ResourceName{ID: inSer.ID, Type: pkiadm.RTSerial}) + if err != nil { + res.SetError(err, "Could not find serial '%s'", inSer.ID) + return nil + } + + if err := s.storage.Remove(ser); err != nil { + res.SetError(err, "Could not remove serial '%s'", ser.ID) + return nil + } + return s.store(res) +} +func (s *Server) ShowSerial(inSer pkiadm.ResourceName, res *pkiadm.ResultSerial) error { + s.lock() + defer s.unlock() + + ser, err := s.storage.GetSerial(pkiadm.ResourceName{ID: inSer.ID, Type: pkiadm.RTSerial}) + if err != nil { + res.Result.SetError(err, "Could not find serial '%s'", inSer.ID) + return nil + } + res.Serials = []pkiadm.Serial{pkiadm.Serial{ + ID: ser.ID, + Min: ser.Min, + Max: ser.Max, + }} + return nil +} +func (s *Server) ListSerial(filter pkiadm.Filter, res *pkiadm.ResultSerial) error { + s.lock() + defer s.unlock() + + for _, ser := range s.storage.Serials { + res.Serials = append(res.Serials, pkiadm.Serial{ + ID: ser.ID, + Min: ser.Min, + Max: ser.Max, + }) + } + return nil } diff --git a/serial.go b/serial.go new file mode 100644 index 0000000..1bd1450 --- /dev/null +++ b/serial.go @@ -0,0 +1,56 @@ +package pkiadm + +type ( + Serial struct { + ID string + Min int64 + Max int64 + } + + SerialChange struct { + Serial Serial + FieldList []string + } + + ResultSerial struct { + Result Result + Serials []Serial + } +) + +// CreateSerial sends a RPC request to create a new private key. +func (c *Client) CreateSerial(ser Serial) error { + return c.exec("CreateSerial", ser) +} +func (c *Client) SetSerial(ser Serial, fieldList []string) error { + changeset := SerialChange{ser, fieldList} + return c.exec("SetSerial", changeset) +} +func (c *Client) DeleteSerial(id string) error { + ser := ResourceName{ID: id, Type: RTSerial} + return c.exec("DeleteSerial", ser) +} +func (c *Client) ListSerial() ([]Serial, error) { + result := &ResultSerial{} + if err := c.query("ListSerial", Filter{}, result); err != nil { + return []Serial{}, err + } + if result.Result.HasError { + return []Serial{}, result.Result.Error + } + return result.Serials, nil +} +func (c *Client) ShowSerial(id string) (Serial, error) { + ser := ResourceName{ID: id, Type: RTSerial} + result := &ResultSerial{} + if err := c.query("ShowSerial", ser, result); err != nil { + return Serial{}, err + } + if result.Result.HasError { + return Serial{}, result.Result.Error + } + for _, privateKey := range result.Serials { + return privateKey, nil + } + return Serial{}, nil +}