0
0
Fork 0

reformat code with gofmt

Yes, I know that this breaks the history search, but it had to be done
sooner or later. I also adjusted my editor to follow the guidelines more
closely.
This commit is contained in:
Gibheer 2015-03-25 20:43:18 +01:00
parent bb41ff218a
commit ba5a59931e
5 changed files with 677 additions and 590 deletions

View File

@ -2,118 +2,122 @@
package main package main
import ( import (
"fmt" "flag"
"flag" "fmt"
"os" "os"
) )
type ( type (
Command struct { Command struct {
Use string // command name (used for matching) Use string // command name (used for matching)
Short string // a short description to display Short string // a short description to display
Long string // a long help text Long string // a long help text
Example string // an example string Example string // an example string
Run func(*Command, []string) // the command to run Run func(*Command, []string) // the command to run
flagSet *flag.FlagSet // internal flagset with all flags flagSet *flag.FlagSet // internal flagset with all flags
commands []*Command // the list of subcommands commands []*Command // the list of subcommands
} }
) )
// This function adds a new sub command. // This function adds a new sub command.
func (c *Command) AddCommand(cmds... *Command) { func (c *Command) AddCommand(cmds ...*Command) {
res := c.commands res := c.commands
for _, cmd := range cmds { for _, cmd := range cmds {
res = append(res, cmd) res = append(res, cmd)
} }
c.commands = res c.commands = res
} }
// Evaluate the arguments and call either the subcommand or parse it as flags. // Evaluate the arguments and call either the subcommand or parse it as flags.
func (c *Command) eval(args []string) error { func (c *Command) eval(args []string) error {
var name string = "" var name string = ""
var rest []string = []string{} var rest []string = []string{}
if len(args) > 0 { if len(args) > 0 {
name = args[0] name = args[0]
} }
if len(args) > 1 { if len(args) > 1 {
rest = args[1:] rest = args[1:]
} }
if name == "help" { if name == "help" {
c.Help(rest) c.Help(rest)
return nil return nil
} }
for _, cmd := range c.commands { for _, cmd := range c.commands {
if cmd.Use == name { if cmd.Use == name {
return cmd.eval(rest) return cmd.eval(rest)
} }
} }
if err := c.Flags().Parse(args); err != nil { return err } if err := c.Flags().Parse(args); err != nil {
if c.Run != nil { return err
c.Run(c, rest) }
} else { if c.Run != nil {
c.Help(rest) c.Run(c, rest)
} } else {
return nil c.Help(rest)
}
return nil
} }
// Execute the command. It will fetch os.Args[1:] itself. // Execute the command. It will fetch os.Args[1:] itself.
func (c *Command) Execute() error { func (c *Command) Execute() error {
return c.eval(os.Args[1:]) return c.eval(os.Args[1:])
} }
// Return the flagset currently in use. // Return the flagset currently in use.
func (c *Command) Flags() *flag.FlagSet { func (c *Command) Flags() *flag.FlagSet {
if c.flagSet == nil { if c.flagSet == nil {
c.flagSet = flag.NewFlagSet(c.Use, flag.ContinueOnError) c.flagSet = flag.NewFlagSet(c.Use, flag.ContinueOnError)
} }
return c.flagSet return c.flagSet
} }
// Print the help for the current command or a subcommand. // Print the help for the current command or a subcommand.
func (c *Command) Help(args []string) { func (c *Command) Help(args []string) {
if len(args) > 0 { if len(args) > 0 {
for _, cmd := range c.commands { for _, cmd := range c.commands {
if args[0] == cmd.Use { if args[0] == cmd.Use {
cmd.Help([]string{}) cmd.Help([]string{})
return return
} }
} }
} }
if c.Long != "" { fmt.Println(c.Long, "\n") } if c.Long != "" {
c.Usage() fmt.Println(c.Long, "\n")
}
c.Usage()
} }
// Print the usage information. // Print the usage information.
func (c *Command) Usage() { func (c *Command) Usage() {
usage := "" usage := ""
if c.Use != "" { if c.Use != "" {
usage = usage + " " + c.Use usage = usage + " " + c.Use
} }
if len(c.commands) > 0 { if len(c.commands) > 0 {
usage = usage + " command" usage = usage + " command"
} }
if c.flagSet != nil { if c.flagSet != nil {
usage = usage + " [flags]" usage = usage + " [flags]"
} }
fmt.Printf("Usage: %s%s\n", os.Args[0], usage) fmt.Printf("Usage: %s%s\n", os.Args[0], usage)
if len(c.commands) > 0 { if len(c.commands) > 0 {
fmt.Printf("\nwhere command is one of:\n") fmt.Printf("\nwhere command is one of:\n")
for _, cmd := range c.commands { for _, cmd := range c.commands {
fmt.Printf("\t%s\t\t%s\n", cmd.Use, cmd.Short) fmt.Printf("\t%s\t\t%s\n", cmd.Use, cmd.Short)
} }
} }
if c.flagSet != nil { if c.flagSet != nil {
fmt.Printf("\nwhere flags is any of:\n") fmt.Printf("\nwhere flags is any of:\n")
c.Flags().SetOutput(os.Stdout) c.Flags().SetOutput(os.Stdout)
c.Flags().PrintDefaults() c.Flags().PrintDefaults()
} }
if c.Example != "" { if c.Example != "" {
fmt.Println("\nexample:") fmt.Println("\nexample:")
fmt.Printf("\t%s\n", c.Example) fmt.Printf("\t%s\n", c.Example)
} }
} }

764
flags.go
View File

@ -4,205 +4,205 @@ package main
// often used by multiple functions. // often used by multiple functions.
import ( import (
"crypto/elliptic" "crypto/elliptic"
"encoding/base64" "encoding/base64"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"net" "net"
"os" "os"
"reflect" "reflect"
"strings" "strings"
"time" "time"
"github.com/gibheer/pki" "github.com/gibheer/pki"
) )
const ( const (
RsaLowerLength = 2048 RsaLowerLength = 2048
RsaUpperLength = 16384 RsaUpperLength = 16384
) )
var ( var (
EcdsaCurves = []int{224, 256, 384, 521} EcdsaCurves = []int{224, 256, 384, 521}
) )
type ( type (
// holds all certificate related flags, which need parsing afterwards // holds all certificate related flags, which need parsing afterwards
certiticateRequestRawFlags struct { certiticateRequestRawFlags struct {
manual struct { manual struct {
serialNumber string // the serial number for the cert serialNumber string // the serial number for the cert
commonName string // the common name used in the cert commonName string // the common name used in the cert
dnsNames string // all alternative names in the certificate (comma separated list) dnsNames string // all alternative names in the certificate (comma separated list)
ipAddresses string // all IP addresses in the certificate (comma separated list) ipAddresses string // all IP addresses in the certificate (comma separated list)
emailAddresses string // alternative email addresses emailAddresses string // alternative email addresses
} }
automatic struct { automatic struct {
Country string // the country names which should end up in the cert (comma separated list) Country string // the country names which should end up in the cert (comma separated list)
Organization string // the organization names (comma separated list) Organization string // the organization names (comma separated list)
OrganizationalUnit string // the organizational units (comma separated list) OrganizationalUnit string // the organizational units (comma separated list)
Locality string // the city or locality (comma separated list) Locality string // the city or locality (comma separated list)
Province string // the province name (comma separated list) Province string // the province name (comma separated list)
StreetAddress string // the street addresses of the organization (comma separated list) StreetAddress string // the street addresses of the organization (comma separated list)
PostalCode string // the postal codes of the locality PostalCode string // the postal codes of the locality
} }
} }
// a container go gather all incoming flags for further processing // a container go gather all incoming flags for further processing
paramContainer struct { paramContainer struct {
outputPath string // path to output whatever is generated outputPath string // path to output whatever is generated
inputPath string // path to an input resource inputPath string // path to an input resource
cryptType string // type of something (private key) cryptType string // type of something (private key)
length int // the length of something (private key) length int // the length of something (private key)
privateKeyPath string // path to the private key privateKeyPath string // path to the private key
publicKeyPath string // path to the public key publicKeyPath string // path to the public key
signRequestPath string // path to the certificate sign request signRequestPath string // path to the certificate sign request
certificateFlags certiticateRequestRawFlags // container for certificate related flags certificateFlags certiticateRequestRawFlags // container for certificate related flags
signature string // a base64 encoded signature signature string // a base64 encoded signature
certGeneration certGenerationRaw certGeneration certGenerationRaw
} }
privateKeyGenerationFlags struct { privateKeyGenerationFlags struct {
Type string // type of the private key (rsa, ecdsa) Type string // type of the private key (rsa, ecdsa)
Curve elliptic.Curve // curve for ecdsa Curve elliptic.Curve // curve for ecdsa
Size int // bitsize for rsa Size int // bitsize for rsa
} }
certGenerationRaw struct { certGenerationRaw struct {
serial int64 serial int64
notBefore string notBefore string
notAfter string notAfter string
isCA bool isCA bool
length int length int
} }
flagCheck func()(error) flagCheck func() error
) )
var ( var (
CmdRoot = &Command { CmdRoot = &Command{
Short: "A tool to manage keys and certificates.", Short: "A tool to manage keys and certificates.",
Long: `This tool provides a way to manage private and public keys, create Long: `This tool provides a way to manage private and public keys, create
certificate requests and certificates and sign/verify messages.`, certificate requests and certificates and sign/verify messages.`,
} }
CmdCreatePrivateKey = &Command { CmdCreatePrivateKey = &Command{
Use: "create-private", Use: "create-private",
Short: "create a private key", Short: "create a private key",
Long: "Create an ecdsa or rsa key with this command", Long: "Create an ecdsa or rsa key with this command",
Example: "create-private -type=ecdsa -length=521", Example: "create-private -type=ecdsa -length=521",
Run: create_private_key, Run: create_private_key,
} }
CmdCreatePublicKey = &Command { CmdCreatePublicKey = &Command{
Use: "create-public", Use: "create-public",
Short: "create a public key from a private key", Short: "create a public key from a private key",
Long: "Create a public key derived from a private key.", Long: "Create a public key derived from a private key.",
Example: "create-public -private-key=foo.ecdsa", Example: "create-public -private-key=foo.ecdsa",
Run: create_public_key, Run: create_public_key,
} }
CmdSignInput = &Command { CmdSignInput = &Command{
Use: "sign-input", Use: "sign-input",
Short: "sign a text using a private key", Short: "sign a text using a private key",
Long: "Create a signature using a private key", Long: "Create a signature using a private key",
Example: "sign-input -private-key=foo.ecdsa -input=textfile", Example: "sign-input -private-key=foo.ecdsa -input=textfile",
Run: sign_input, Run: sign_input,
} }
CmdVerifyInput = &Command { CmdVerifyInput = &Command{
Use: "verify-input", Use: "verify-input",
Short: "verify a text using a signature", Short: "verify a text using a signature",
Long: "Verify a text using a signature and a public key.", Long: "Verify a text using a signature and a public key.",
Example: "verify-input -public-key=foo.ecdsa.pub -input=textfile -signature=abc456", Example: "verify-input -public-key=foo.ecdsa.pub -input=textfile -signature=abc456",
Run: verify_input, Run: verify_input,
} }
CmdCreateSignRequest = &Command { CmdCreateSignRequest = &Command{
Use: "create-sign-request", Use: "create-sign-request",
Short: "create a certificate sign request", Short: "create a certificate sign request",
Long: "Create a certificate sign request.", Long: "Create a certificate sign request.",
Example: "create-sign-request -private-key=foo.ecdsa -common-name=foo -serial=1", Example: "create-sign-request -private-key=foo.ecdsa -common-name=foo -serial=1",
Run: create_sign_request, Run: create_sign_request,
} }
CmdCreateCert = &Command { CmdCreateCert = &Command{
Use: "create-cert", Use: "create-cert",
Short: "create a certificate from a sign request", Short: "create a certificate from a sign request",
Long: "Create a certificate based on a certificate sign request.", Long: "Create a certificate based on a certificate sign request.",
Example: "create-cert -private-key=foo.ecdsa -csr-path=foo.csr", Example: "create-cert -private-key=foo.ecdsa -csr-path=foo.csr",
Run: create_cert, Run: create_cert,
} }
// variable to hold the raw arguments before checking // variable to hold the raw arguments before checking
flagContainer *paramContainer flagContainer *paramContainer
// loaded private key // loaded private key
FlagPrivateKey pki.PrivateKey FlagPrivateKey pki.PrivateKey
// loaded public key // loaded public key
FlagPublicKey pki.PublicKey FlagPublicKey pki.PublicKey
// the IO handler for input // the IO handler for input
FlagInput io.ReadCloser FlagInput io.ReadCloser
// the IO handler for output // the IO handler for output
FlagOutput io.WriteCloser FlagOutput io.WriteCloser
// signature from the args // signature from the args
FlagSignature []byte FlagSignature []byte
// private key specific stuff // private key specific stuff
FlagPrivateKeyGeneration privateKeyGenerationFlags FlagPrivateKeyGeneration privateKeyGenerationFlags
// a certificate filled with the parameters // a certificate filled with the parameters
FlagCertificateRequestData *pki.CertificateData FlagCertificateRequestData *pki.CertificateData
// the certificate sign request // the certificate sign request
FlagCertificateSignRequest *pki.CertificateRequest FlagCertificateSignRequest *pki.CertificateRequest
// certificate specific creation stuff // certificate specific creation stuff
FlagCertificateGeneration pki.CertificateOptions FlagCertificateGeneration pki.CertificateOptions
) )
func InitFlags() { func InitFlags() {
flagContainer = &paramContainer{} flagContainer = &paramContainer{}
CmdRoot.AddCommand( CmdRoot.AddCommand(
CmdCreatePrivateKey, CmdCreatePrivateKey,
CmdCreatePublicKey, CmdCreatePublicKey,
CmdSignInput, CmdSignInput,
CmdVerifyInput, CmdVerifyInput,
CmdCreateSignRequest, CmdCreateSignRequest,
CmdCreateCert, CmdCreateCert,
) )
// private-key // private-key
InitFlagOutput(CmdCreatePrivateKey) InitFlagOutput(CmdCreatePrivateKey)
InitFlagPrivateKeyGeneration(CmdCreatePrivateKey) InitFlagPrivateKeyGeneration(CmdCreatePrivateKey)
// public-key // public-key
InitFlagOutput(CmdCreatePublicKey) InitFlagOutput(CmdCreatePublicKey)
InitFlagPrivateKey(CmdCreatePublicKey) InitFlagPrivateKey(CmdCreatePublicKey)
// sign-input // sign-input
InitFlagInput(CmdSignInput) InitFlagInput(CmdSignInput)
InitFlagPrivateKey(CmdSignInput) InitFlagPrivateKey(CmdSignInput)
InitFlagOutput(CmdSignInput) InitFlagOutput(CmdSignInput)
// verify-input // verify-input
InitFlagInput(CmdVerifyInput) InitFlagInput(CmdVerifyInput)
InitFlagPrivateKey(CmdVerifyInput) InitFlagPrivateKey(CmdVerifyInput)
InitFlagOutput(CmdVerifyInput) InitFlagOutput(CmdVerifyInput)
InitFlagSignature(CmdVerifyInput) InitFlagSignature(CmdVerifyInput)
// create-sign-request // create-sign-request
InitFlagPrivateKey(CmdCreateSignRequest) InitFlagPrivateKey(CmdCreateSignRequest)
InitFlagOutput(CmdCreateSignRequest) InitFlagOutput(CmdCreateSignRequest)
InitFlagCertificateFields(CmdCreateSignRequest) InitFlagCertificateFields(CmdCreateSignRequest)
// create-certificate // create-certificate
InitFlagPrivateKey(CmdCreateCert) InitFlagPrivateKey(CmdCreateCert)
InitFlagOutput(CmdCreateCert) InitFlagOutput(CmdCreateCert)
InitFlagCert(CmdCreateCert) InitFlagCert(CmdCreateCert)
InitFlagCSR(CmdCreateCert) InitFlagCSR(CmdCreateCert)
} }
func checkFlags(checks... flagCheck) error { func checkFlags(checks ...flagCheck) error {
for _, check := range checks { for _, check := range checks {
if err := check(); err != nil { if err := check(); err != nil {
return err return err
} }
} }
return nil return nil
} }
//// print a message with the usage part //// print a message with the usage part
@ -213,281 +213,315 @@ func checkFlags(checks... flagCheck) error {
// add the private key option to the requested flags // add the private key option to the requested flags
func InitFlagPrivateKey(cmd *Command) { func InitFlagPrivateKey(cmd *Command) {
cmd.Flags().StringVar(&flagContainer.privateKeyPath, "private-key", "", "path to the private key (required)") cmd.Flags().StringVar(&flagContainer.privateKeyPath, "private-key", "", "path to the private key (required)")
} }
// check the private key flag and load the private key // check the private key flag and load the private key
func checkPrivateKey() error { func checkPrivateKey() error {
if flagContainer.privateKeyPath == "" { return fmt.Errorf("No private key given!") } if flagContainer.privateKeyPath == "" {
// check permissions of private key file return fmt.Errorf("No private key given!")
info, err := os.Stat(flagContainer.privateKeyPath) }
if err != nil { return fmt.Errorf("Error reading private key: %s", err) } // check permissions of private key file
if info.Mode().Perm().String()[4:] != "------" { info, err := os.Stat(flagContainer.privateKeyPath)
return fmt.Errorf("private key file modifyable by others!") if err != nil {
} return fmt.Errorf("Error reading private key: %s", err)
}
if info.Mode().Perm().String()[4:] != "------" {
return fmt.Errorf("private key file modifyable by others!")
}
pk, err := ReadPrivateKeyFile(flagContainer.privateKeyPath) pk, err := ReadPrivateKeyFile(flagContainer.privateKeyPath)
if err != nil { return fmt.Errorf("Error reading private key: %s", err) } if err != nil {
FlagPrivateKey = pk return fmt.Errorf("Error reading private key: %s", err)
return nil }
FlagPrivateKey = pk
return nil
} }
// add the public key flag // add the public key flag
func InitFlagPublicKey(cmd *Command) { func InitFlagPublicKey(cmd *Command) {
cmd.Flags().StringVar(&flagContainer.publicKeyPath, "public-key", "", "path to the public key (required)") cmd.Flags().StringVar(&flagContainer.publicKeyPath, "public-key", "", "path to the public key (required)")
} }
// parse public key flag // parse public key flag
func checkPublicKey() error { func checkPublicKey() error {
if flagContainer.publicKeyPath == "" { return fmt.Errorf("No public key given!") } if flagContainer.publicKeyPath == "" {
return fmt.Errorf("No public key given!")
}
pu, err := ReadPublicKeyFile(flagContainer.publicKeyPath) pu, err := ReadPublicKeyFile(flagContainer.publicKeyPath)
if err != nil { return fmt.Errorf("Error reading public key: %s", err) } if err != nil {
FlagPublicKey = pu return fmt.Errorf("Error reading public key: %s", err)
return nil }
FlagPublicKey = pu
return nil
} }
// add flag to load certificate flags // add flag to load certificate flags
func InitFlagCert(cmd *Command) { func InitFlagCert(cmd *Command) {
cmd.Flags().Int64Var(&flagContainer.certGeneration.serial, "serial", 0, "serial number of all certificates") cmd.Flags().Int64Var(&flagContainer.certGeneration.serial, "serial", 0, "serial number of all certificates")
cmd.Flags().BoolVar(&flagContainer.certGeneration.isCA, "ca", false, "check if the resulting certificate is a ca") cmd.Flags().BoolVar(&flagContainer.certGeneration.isCA, "ca", false, "check if the resulting certificate is a ca")
cmd.Flags().IntVar( cmd.Flags().IntVar(
&flagContainer.certGeneration. &flagContainer.certGeneration.
length, length,
"length", "length",
0, 0,
"the number of certificates allowed in the chain between this cert and the end certificate", "the number of certificates allowed in the chain between this cert and the end certificate",
) )
cmd.Flags().StringVar( cmd.Flags().StringVar(
&flagContainer.certGeneration.notBefore, &flagContainer.certGeneration.notBefore,
"not-before", "not-before",
time.Now().Format(time.RFC3339), time.Now().Format(time.RFC3339),
"time before the certificate is not valid in RFC3339 format (default now)", "time before the certificate is not valid in RFC3339 format (default now)",
) )
cmd.Flags().StringVar( cmd.Flags().StringVar(
&flagContainer.certGeneration. &flagContainer.certGeneration.
notAfter, notAfter,
"not-after", "not-after",
time.Now().Add(time.Duration(180 * 24 * time.Hour)).Format(time.RFC3339), time.Now().Add(time.Duration(180*24*time.Hour)).Format(time.RFC3339),
"time after which the certificate is not valid in RFC3339 format (default now + 180 days)", "time after which the certificate is not valid in RFC3339 format (default now + 180 days)",
) )
} }
// parse the certificate data // parse the certificate data
func checkCertFlags() error { func checkCertFlags() error {
FlagCertificateGeneration.IsCA = flagContainer.certGeneration.isCA FlagCertificateGeneration.IsCA = flagContainer.certGeneration.isCA
FlagCertificateGeneration.CALength = flagContainer.certGeneration.length FlagCertificateGeneration.CALength = flagContainer.certGeneration.length
FlagCertificateGeneration.SerialNumber = big.NewInt(flagContainer.certGeneration.serial) FlagCertificateGeneration.SerialNumber = big.NewInt(flagContainer.certGeneration.serial)
var err error var err error
if notbefore := flagContainer.certGeneration.notBefore; notbefore != "" { if notbefore := flagContainer.certGeneration.notBefore; notbefore != "" {
FlagCertificateGeneration.NotBefore, err = parseTimeRFC3339(notbefore) FlagCertificateGeneration.NotBefore, err = parseTimeRFC3339(notbefore)
if err != nil { return err } if err != nil {
} return err
if notafter := flagContainer.certGeneration.notAfter; notafter != "" { }
FlagCertificateGeneration.NotAfter, err = parseTimeRFC3339(notafter) }
if err != nil { return err } if notafter := flagContainer.certGeneration.notAfter; notafter != "" {
} FlagCertificateGeneration.NotAfter, err = parseTimeRFC3339(notafter)
return nil if err != nil {
return err
}
}
return nil
} }
func parseTimeRFC3339(tr string) (time.Time, error) { func parseTimeRFC3339(tr string) (time.Time, error) {
return time.Parse(time.RFC3339, tr) return time.Parse(time.RFC3339, tr)
} }
// add flag to load certificate sign request // add flag to load certificate sign request
func InitFlagCSR(cmd *Command) { func InitFlagCSR(cmd *Command) {
cmd.Flags().StringVar(&flagContainer.signRequestPath, "csr-path", "", "path to the certificate sign request") cmd.Flags().StringVar(&flagContainer.signRequestPath, "csr-path", "", "path to the certificate sign request")
} }
// parse the certificate sign request // parse the certificate sign request
func checkCSR() error { func checkCSR() error {
rest, err := ioutil.ReadFile(flagContainer.signRequestPath) rest, err := ioutil.ReadFile(flagContainer.signRequestPath)
if err != nil { return fmt.Errorf("Error reading certificate sign request: %s", err) } if err != nil {
return fmt.Errorf("Error reading certificate sign request: %s", err)
}
var csr_asn1 []byte var csr_asn1 []byte
var block *pem.Block var block *pem.Block
for len(rest) > 0 { for len(rest) > 0 {
block, rest = pem.Decode(rest) block, rest = pem.Decode(rest)
if block.Type == "CERTIFICATE REQUEST" { if block.Type == "CERTIFICATE REQUEST" {
csr_asn1 = block.Bytes csr_asn1 = block.Bytes
break break
} }
} }
if len(csr_asn1) == 0 { if len(csr_asn1) == 0 {
return fmt.Errorf( return fmt.Errorf(
"No certificate sign request found in %s", "No certificate sign request found in %s",
flagContainer.signRequestPath, flagContainer.signRequestPath,
) )
} }
csr, err := pki.LoadCertificateSignRequest(csr_asn1) csr, err := pki.LoadCertificateSignRequest(csr_asn1)
if err != nil { return fmt.Errorf("Invalid certificate sign request: %s", err) } if err != nil {
FlagCertificateSignRequest = csr return fmt.Errorf("Invalid certificate sign request: %s", err)
return nil }
FlagCertificateSignRequest = csr
return nil
} }
func InitFlagOutput(cmd *Command) { func InitFlagOutput(cmd *Command) {
cmd.Flags().StringVar(&flagContainer.outputPath, "output", "STDOUT", "path to the output or STDOUT") cmd.Flags().StringVar(&flagContainer.outputPath, "output", "STDOUT", "path to the output or STDOUT")
} }
// parse the output parameter and open the file handle // parse the output parameter and open the file handle
func checkOutput() error { func checkOutput() error {
if flagContainer.outputPath == "STDOUT" { if flagContainer.outputPath == "STDOUT" {
FlagOutput = os.Stdout FlagOutput = os.Stdout
return nil return nil
} }
var err error var err error
FlagOutput, err = os.OpenFile( FlagOutput, err = os.OpenFile(
flagContainer.outputPath, flagContainer.outputPath,
os.O_WRONLY | os.O_APPEND | os.O_CREATE, // do not kill users files! os.O_WRONLY|os.O_APPEND|os.O_CREATE, // do not kill users files!
0600, 0600,
) )
if err != nil { return err } if err != nil {
return nil return err
}
return nil
} }
// add the input parameter to load resources from // add the input parameter to load resources from
func InitFlagInput(cmd *Command) { func InitFlagInput(cmd *Command) {
cmd.Flags().StringVar(&flagContainer.inputPath, "input", "STDIN", "path to the input or STDIN") cmd.Flags().StringVar(&flagContainer.inputPath, "input", "STDIN", "path to the input or STDIN")
} }
// parse the input parameter and open the file handle // parse the input parameter and open the file handle
func checkInput() error { func checkInput() error {
if flagContainer.inputPath == "STDIN" { if flagContainer.inputPath == "STDIN" {
FlagInput = os.Stdin FlagInput = os.Stdin
return nil return nil
} }
var err error var err error
FlagInput, err = os.Open(flagContainer.inputPath) FlagInput, err = os.Open(flagContainer.inputPath)
if err != nil { return err } if err != nil {
return nil return err
}
return nil
} }
// This function adds the private key generation flags. // This function adds the private key generation flags.
func InitFlagPrivateKeyGeneration(cmd *Command) { func InitFlagPrivateKeyGeneration(cmd *Command) {
cmd.Flags().StringVar(&flagContainer.cryptType, "type", "ecdsa", "the type of the private key (ecdsa, rsa)") cmd.Flags().StringVar(&flagContainer.cryptType, "type", "ecdsa", "the type of the private key (ecdsa, rsa)")
cmd.Flags().IntVar( cmd.Flags().IntVar(
&flagContainer.length, &flagContainer.length,
"length", 521, "length", 521,
fmt.Sprintf("%d - %d for rsa; one of %v for ecdsa", RsaLowerLength, RsaUpperLength, EcdsaCurves), fmt.Sprintf("%d - %d for rsa; one of %v for ecdsa", RsaLowerLength, RsaUpperLength, EcdsaCurves),
) )
} }
// check the private key generation variables and move them to the work space // check the private key generation variables and move them to the work space
func checkPrivateKeyGeneration() error { func checkPrivateKeyGeneration() error {
pk_type := flagContainer.cryptType pk_type := flagContainer.cryptType
FlagPrivateKeyGeneration.Type = pk_type FlagPrivateKeyGeneration.Type = pk_type
switch pk_type { switch pk_type {
case "ecdsa": case "ecdsa":
switch flagContainer.length { switch flagContainer.length {
case 224: FlagPrivateKeyGeneration.Curve = elliptic.P224() case 224:
case 256: FlagPrivateKeyGeneration.Curve = elliptic.P256() FlagPrivateKeyGeneration.Curve = elliptic.P224()
case 384: FlagPrivateKeyGeneration.Curve = elliptic.P384() case 256:
case 521: FlagPrivateKeyGeneration.Curve = elliptic.P521() FlagPrivateKeyGeneration.Curve = elliptic.P256()
default: return fmt.Errorf("Curve %d unknown!", flagContainer.length) case 384:
} FlagPrivateKeyGeneration.Curve = elliptic.P384()
case "rsa": case 521:
size := flagContainer.length FlagPrivateKeyGeneration.Curve = elliptic.P521()
if RsaLowerLength <= size && size <= RsaUpperLength { default:
FlagPrivateKeyGeneration.Size = size return fmt.Errorf("Curve %d unknown!", flagContainer.length)
} else { }
return fmt.Errorf("Length of %d is not allowed for rsa!", size) case "rsa":
} size := flagContainer.length
default: return fmt.Errorf("Type %s is unknown!", pk_type) if RsaLowerLength <= size && size <= RsaUpperLength {
} FlagPrivateKeyGeneration.Size = size
return nil } else {
return fmt.Errorf("Length of %d is not allowed for rsa!", size)
}
default:
return fmt.Errorf("Type %s is unknown!", pk_type)
}
return nil
} }
// add the signature flag to load a signature from a signing process // add the signature flag to load a signature from a signing process
func InitFlagSignature(cmd *Command) { func InitFlagSignature(cmd *Command) {
cmd.Flags().StringVar(&flagContainer.signature, "signature", "", "the base64 encoded signature to use for verification") cmd.Flags().StringVar(&flagContainer.signature, "signature", "", "the base64 encoded signature to use for verification")
} }
// parse the signature flag // parse the signature flag
func checkSignature() error { func checkSignature() error {
var err error var err error
FlagSignature, err = base64.StdEncoding.DecodeString(flagContainer.signature) FlagSignature, err = base64.StdEncoding.DecodeString(flagContainer.signature)
if err != nil { return err } if err != nil {
return nil return err
}
return nil
} }
// add the certificate fields to the flags // add the certificate fields to the flags
func InitFlagCertificateFields(cmd *Command) { func InitFlagCertificateFields(cmd *Command) {
cmd.Flags().StringVar( cmd.Flags().StringVar(
&flagContainer.certificateFlags.manual.serialNumber, &flagContainer.certificateFlags.manual.serialNumber,
"serial", "1", "unique serial number of the CA"); "serial", "1", "unique serial number of the CA")
cmd.Flags().StringVar( cmd.Flags().StringVar(
&flagContainer.certificateFlags.manual.commonName, &flagContainer.certificateFlags.manual.commonName,
"common-name", "", "common name of the entity to certify"); "common-name", "", "common name of the entity to certify")
cmd.Flags().StringVar( cmd.Flags().StringVar(
&flagContainer.certificateFlags.manual.dnsNames, &flagContainer.certificateFlags.manual.dnsNames,
"dns-names", "", "comma separated list of alternative fqdn entries for the entity"); "dns-names", "", "comma separated list of alternative fqdn entries for the entity")
cmd.Flags().StringVar( cmd.Flags().StringVar(
&flagContainer.certificateFlags.manual.emailAddresses, &flagContainer.certificateFlags.manual.emailAddresses,
"email-address", "", "comma separated list of alternative email entries for the entity"); "email-address", "", "comma separated list of alternative email entries for the entity")
cmd.Flags().StringVar( cmd.Flags().StringVar(
&flagContainer.certificateFlags.manual.ipAddresses, &flagContainer.certificateFlags.manual.ipAddresses,
"ip-address", "", "comma separated list of alternative ip entries for the entity"); "ip-address", "", "comma separated list of alternative ip entries for the entity")
cmd.Flags().StringVar( cmd.Flags().StringVar(
&flagContainer.certificateFlags.automatic.Country, &flagContainer.certificateFlags.automatic.Country,
"country", "", "comma separated list of countries the entitiy resides in"); "country", "", "comma separated list of countries the entitiy resides in")
cmd.Flags().StringVar( cmd.Flags().StringVar(
&flagContainer.certificateFlags.automatic.Organization, &flagContainer.certificateFlags.automatic.Organization,
"organization", "", "comma separated list of organizations the entity belongs to"); "organization", "", "comma separated list of organizations the entity belongs to")
cmd.Flags().StringVar( cmd.Flags().StringVar(
&flagContainer.certificateFlags.automatic.OrganizationalUnit, &flagContainer.certificateFlags.automatic.OrganizationalUnit,
"organization-unit", "", "comma separated list of organization units or departments the entity belongs to"); "organization-unit", "", "comma separated list of organization units or departments the entity belongs to")
cmd.Flags().StringVar( cmd.Flags().StringVar(
&flagContainer.certificateFlags.automatic.Locality, &flagContainer.certificateFlags.automatic.Locality,
"locality", "", "comma separated list of localities or cities the entity resides in"); "locality", "", "comma separated list of localities or cities the entity resides in")
cmd.Flags().StringVar( cmd.Flags().StringVar(
&flagContainer.certificateFlags.automatic.Province, &flagContainer.certificateFlags.automatic.Province,
"province", "", "comma separated list of provinces the entity resides in"); "province", "", "comma separated list of provinces the entity resides in")
cmd.Flags().StringVar( cmd.Flags().StringVar(
&flagContainer.certificateFlags.automatic.StreetAddress, &flagContainer.certificateFlags.automatic.StreetAddress,
"street-address", "", "comma separated list of street addresses the entity resides in"); "street-address", "", "comma separated list of street addresses the entity resides in")
cmd.Flags().StringVar( cmd.Flags().StringVar(
&flagContainer.certificateFlags.automatic.PostalCode, &flagContainer.certificateFlags.automatic.PostalCode,
"postal-code", "", "comma separated list of postal codes of the localities"); "postal-code", "", "comma separated list of postal codes of the localities")
} }
// parse the certificate fields into a raw certificate // parse the certificate fields into a raw certificate
func checkCertificateFields() error { func checkCertificateFields() error {
FlagCertificateRequestData = pki.NewCertificateData() FlagCertificateRequestData = pki.NewCertificateData()
// convert the automatic flags // convert the automatic flags
container_type := reflect.ValueOf(&flagContainer.certificateFlags.automatic).Elem() container_type := reflect.ValueOf(&flagContainer.certificateFlags.automatic).Elem()
cert_data_type := reflect.ValueOf(&FlagCertificateRequestData.Subject).Elem() cert_data_type := reflect.ValueOf(&FlagCertificateRequestData.Subject).Elem()
for _, field := range []string{"Country", "Organization", "OrganizationalUnit", for _, field := range []string{"Country", "Organization", "OrganizationalUnit",
"Locality", "Province", "StreetAddress", "PostalCode"} { "Locality", "Province", "StreetAddress", "PostalCode"} {
field_value := container_type.FieldByName(field).String() field_value := container_type.FieldByName(field).String()
if field_value == "" { continue } if field_value == "" {
target := cert_data_type.FieldByName(field) continue
target.Set(reflect.ValueOf(strings.Split(field_value, ","))) }
} target := cert_data_type.FieldByName(field)
target.Set(reflect.ValueOf(strings.Split(field_value, ",")))
}
// convert the manual flags // convert the manual flags
data := FlagCertificateRequestData data := FlagCertificateRequestData
raw_data := flagContainer.certificateFlags.manual raw_data := flagContainer.certificateFlags.manual
data.Subject.SerialNumber = raw_data.serialNumber data.Subject.SerialNumber = raw_data.serialNumber
data.Subject.CommonName = raw_data.commonName data.Subject.CommonName = raw_data.commonName
if raw_data.dnsNames != "" { if raw_data.dnsNames != "" {
data.DNSNames = strings.Split(raw_data.dnsNames, ",") data.DNSNames = strings.Split(raw_data.dnsNames, ",")
} }
if raw_data.emailAddresses != "" { if raw_data.emailAddresses != "" {
data.EmailAddresses = strings.Split(raw_data.emailAddresses, ",") data.EmailAddresses = strings.Split(raw_data.emailAddresses, ",")
} }
if raw_data.ipAddresses == "" { return nil } if raw_data.ipAddresses == "" {
raw_ips := strings.Split(raw_data.ipAddresses, ",") return nil
data.IPAddresses = make([]net.IP, len(raw_ips)) }
for i, ip := range raw_ips { raw_ips := strings.Split(raw_data.ipAddresses, ",")
data.IPAddresses[i] = net.ParseIP(ip) data.IPAddresses = make([]net.IP, len(raw_ips))
if data.IPAddresses[i] == nil { for i, ip := range raw_ips {
return fmt.Errorf("'%s' is not a valid IP", ip) data.IPAddresses[i] = net.ParseIP(ip)
} if data.IPAddresses[i] == nil {
} return fmt.Errorf("'%s' is not a valid IP", ip)
}
}
return nil return nil
} }

42
io.go
View File

@ -3,39 +3,43 @@ package main
// handle all io and de/encoding of data // handle all io and de/encoding of data
import ( import (
"encoding/pem" "encoding/pem"
"errors" "errors"
"io/ioutil" "io/ioutil"
) )
var ( var (
ErrBlockNotFound = errors.New("block not found") ErrBlockNotFound = errors.New("block not found")
) )
// load a pem section from a file // load a pem section from a file
func readSectionFromFile(path, btype string) ([]byte, error) { func readSectionFromFile(path, btype string) ([]byte, error) {
raw, err := readFile(path) raw, err := readFile(path)
if err != nil { return raw, err } if err != nil {
return raw, err
}
return decodeSection(raw, btype) return decodeSection(raw, btype)
} }
// read a file completely and report possible errors // read a file completely and report possible errors
func readFile(path string) ([]byte, error) { func readFile(path string) ([]byte, error) {
raw, err := ioutil.ReadFile(path) raw, err := ioutil.ReadFile(path)
if err != nil { return EmptyByteArray, err } if err != nil {
return raw, nil return EmptyByteArray, err
}
return raw, nil
} }
// decode a pem encoded file and search for the specified section // decode a pem encoded file and search for the specified section
func decodeSection(data []byte, btype string) ([]byte, error) { func decodeSection(data []byte, btype string) ([]byte, error) {
rest := data rest := data
for len(rest) > 0 { for len(rest) > 0 {
var block *pem.Block var block *pem.Block
block, rest = pem.Decode(rest) block, rest = pem.Decode(rest)
if block.Type == btype { if block.Type == btype {
return block.Bytes, nil return block.Bytes, nil
} }
} }
return EmptyByteArray, ErrBlockNotFound return EmptyByteArray, ErrBlockNotFound
} }

229
main.go
View File

@ -1,149 +1,184 @@
package main package main
import ( import (
"crypto" "crypto"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"github.com/gibheer/pki" "github.com/gibheer/pki"
) )
const ( const (
ErrorProgram int = iota ErrorProgram int = iota
ErrorFlagInput ErrorFlagInput
ErrorInput ErrorInput
) )
var ( var (
EmptyByteArray = make([]byte, 0) EmptyByteArray = make([]byte, 0)
) )
func main() { func main() {
InitFlags() InitFlags()
CmdRoot.Execute() CmdRoot.Execute()
} }
// create a new private key // create a new private key
func create_private_key(cmd *Command, args []string) { func create_private_key(cmd *Command, args []string) {
err := checkFlags(checkOutput, checkPrivateKeyGeneration) err := checkFlags(checkOutput, checkPrivateKeyGeneration)
if err != nil { if err != nil {
crash_with_help(cmd, ErrorFlagInput, "Flags invalid: %s", err) crash_with_help(cmd, ErrorFlagInput, "Flags invalid: %s", err)
} }
var pk pki.Pemmer var pk pki.Pemmer
switch FlagPrivateKeyGeneration.Type { switch FlagPrivateKeyGeneration.Type {
case "ecdsa": pk, err = pki.NewPrivateKeyEcdsa(FlagPrivateKeyGeneration.Curve) case "ecdsa":
case "rsa": pk, err = pki.NewPrivateKeyRsa(FlagPrivateKeyGeneration.Size) pk, err = pki.NewPrivateKeyEcdsa(FlagPrivateKeyGeneration.Curve)
default: crash_with_help(cmd, ErrorInput, "Unknown private key type '%s'", FlagPrivateKeyGeneration.Type) case "rsa":
} pk, err = pki.NewPrivateKeyRsa(FlagPrivateKeyGeneration.Size)
if err != nil { crash_with_help(cmd, ErrorProgram, "Error creating private key: %s", err) } default:
marsh_pem, err := pk.MarshalPem() crash_with_help(cmd, ErrorInput, "Unknown private key type '%s'", FlagPrivateKeyGeneration.Type)
if err != nil { crash_with_help(cmd, ErrorProgram, "Error when marshalling to pem: %s", err) } }
_, err = marsh_pem.WriteTo(FlagOutput) if err != nil {
if err != nil { crash_with_help(cmd, ErrorProgram, "Error when writing output: %s", err) } crash_with_help(cmd, ErrorProgram, "Error creating private key: %s", err)
}
marsh_pem, err := pk.MarshalPem()
if err != nil {
crash_with_help(cmd, ErrorProgram, "Error when marshalling to pem: %s", err)
}
_, err = marsh_pem.WriteTo(FlagOutput)
if err != nil {
crash_with_help(cmd, ErrorProgram, "Error when writing output: %s", err)
}
} }
// create a public key derived from a private key // create a public key derived from a private key
func create_public_key(cmd *Command, args []string) { func create_public_key(cmd *Command, args []string) {
err := checkFlags(checkPrivateKey, checkOutput) err := checkFlags(checkPrivateKey, checkOutput)
if err != nil { if err != nil {
crash_with_help(cmd, ErrorFlagInput, "Flags invalid: %s", err) crash_with_help(cmd, ErrorFlagInput, "Flags invalid: %s", err)
} }
var pub_key pki.Pemmer var pub_key pki.Pemmer
pub_key = FlagPrivateKey.Public() pub_key = FlagPrivateKey.Public()
marsh_pem, err := pub_key.MarshalPem() marsh_pem, err := pub_key.MarshalPem()
if err != nil { crash_with_help(cmd, ErrorProgram, "Error when marshalling to pem: %s", err) } if err != nil {
_, err = marsh_pem.WriteTo(FlagOutput) crash_with_help(cmd, ErrorProgram, "Error when marshalling to pem: %s", err)
if err != nil { crash_with_help(cmd, ErrorProgram, "Error when writing output: %s", err) } }
_, err = marsh_pem.WriteTo(FlagOutput)
if err != nil {
crash_with_help(cmd, ErrorProgram, "Error when writing output: %s", err)
}
} }
// sign a message using he private key // sign a message using he private key
func sign_input(cmd *Command, args []string) { func sign_input(cmd *Command, args []string) {
err := checkFlags(checkPrivateKey, checkInput, checkOutput) err := checkFlags(checkPrivateKey, checkInput, checkOutput)
if err != nil { if err != nil {
crash_with_help(cmd, ErrorFlagInput, "Flags invalid: %s", err) crash_with_help(cmd, ErrorFlagInput, "Flags invalid: %s", err)
} }
message, err := ioutil.ReadAll(FlagInput) message, err := ioutil.ReadAll(FlagInput)
if err != nil { crash_with_help(cmd, ErrorProgram, "Error reading input: %s", err) } if err != nil {
signature, err := FlagPrivateKey.Sign(message, crypto.SHA256) crash_with_help(cmd, ErrorProgram, "Error reading input: %s", err)
if err != nil { crash_with_help(cmd, ErrorProgram, "Could not compute signature: %s", err) } }
_, err = io.WriteString(FlagOutput, base64.StdEncoding.EncodeToString(signature)) signature, err := FlagPrivateKey.Sign(message, crypto.SHA256)
if err != nil { crash_with_help(cmd, ErrorProgram, "Could not write to output: %s", err) } if err != nil {
crash_with_help(cmd, ErrorProgram, "Could not compute signature: %s", err)
}
_, err = io.WriteString(FlagOutput, base64.StdEncoding.EncodeToString(signature))
if err != nil {
crash_with_help(cmd, ErrorProgram, "Could not write to output: %s", err)
}
// if we print to stderr, send a final line break to make the output nice // if we print to stderr, send a final line break to make the output nice
if FlagOutput == os.Stdout { if FlagOutput == os.Stdout {
// we can ignore the result, as either Stdout did work or not // we can ignore the result, as either Stdout did work or not
_, _ = io.WriteString(FlagOutput, "\n") _, _ = io.WriteString(FlagOutput, "\n")
} }
} }
// verify a message using a signature and a public key // verify a message using a signature and a public key
func verify_input(cmd *Command, args []string) { func verify_input(cmd *Command, args []string) {
err := checkFlags(checkPrivateKey, checkInput, checkOutput, checkSignature) err := checkFlags(checkPrivateKey, checkInput, checkOutput, checkSignature)
if err != nil { if err != nil {
crash_with_help(cmd, ErrorFlagInput, "Flags invalid: %s", err) crash_with_help(cmd, ErrorFlagInput, "Flags invalid: %s", err)
} }
signature := FlagSignature signature := FlagSignature
message, err := ioutil.ReadAll(FlagInput) message, err := ioutil.ReadAll(FlagInput)
if err != nil { crash_with_help(cmd, ErrorProgram, "Error reading input: %s", err) } if err != nil {
valid, err := FlagPublicKey.Verify(message, signature, crypto.SHA256) crash_with_help(cmd, ErrorProgram, "Error reading input: %s", err)
if err != nil { crash_with_help(cmd, ErrorProgram, "Could not verify message using signature: %s", err) } }
if valid { valid, err := FlagPublicKey.Verify(message, signature, crypto.SHA256)
fmt.Println("valid") if err != nil {
os.Exit(0) crash_with_help(cmd, ErrorProgram, "Could not verify message using signature: %s", err)
} }
fmt.Println("invalid") if valid {
os.Exit(1) fmt.Println("valid")
os.Exit(0)
}
fmt.Println("invalid")
os.Exit(1)
} }
// create a certificate sign request // create a certificate sign request
func create_sign_request(cmd *Command, args []string) { func create_sign_request(cmd *Command, args []string) {
err := checkFlags(checkPrivateKey, checkOutput, checkCertificateFields) err := checkFlags(checkPrivateKey, checkOutput, checkCertificateFields)
if err != nil { if err != nil {
crash_with_help(cmd, ErrorFlagInput, "Flags invalid: %s", err) crash_with_help(cmd, ErrorFlagInput, "Flags invalid: %s", err)
} }
csr, err := FlagCertificateRequestData.ToCertificateRequest(FlagPrivateKey) csr, err := FlagCertificateRequestData.ToCertificateRequest(FlagPrivateKey)
if err != nil { crash_with_help(cmd, ErrorProgram, "Could not create certificate sign request: %s", err) } if err != nil {
pem_block, err := csr.MarshalPem() crash_with_help(cmd, ErrorProgram, "Could not create certificate sign request: %s", err)
if err != nil { crash_with_help(cmd, ErrorProgram, "Error when marshalling to pem: %s", err) } }
_, err = pem_block.WriteTo(FlagOutput) pem_block, err := csr.MarshalPem()
if err != nil { crash_with_help(cmd, ErrorProgram, "Could not write to output: %s", err) } if err != nil {
crash_with_help(cmd, ErrorProgram, "Error when marshalling to pem: %s", err)
}
_, err = pem_block.WriteTo(FlagOutput)
if err != nil {
crash_with_help(cmd, ErrorProgram, "Could not write to output: %s", err)
}
} }
func create_cert(cmd *Command, args []string) { func create_cert(cmd *Command, args []string) {
err := checkFlags(checkPrivateKey, checkOutput, checkCSR, checkCertFlags) err := checkFlags(checkPrivateKey, checkOutput, checkCSR, checkCertFlags)
if err != nil { if err != nil {
crash_with_help(cmd, ErrorFlagInput, "Flags invalid: %s", err) crash_with_help(cmd, ErrorFlagInput, "Flags invalid: %s", err)
} }
// TODO implement flags for all certificate options // TODO implement flags for all certificate options
cert, err := FlagCertificateSignRequest.ToCertificate( cert, err := FlagCertificateSignRequest.ToCertificate(
FlagPrivateKey, FlagPrivateKey,
FlagCertificateGeneration, FlagCertificateGeneration,
nil, nil,
) )
if err != nil { crash_with_help(cmd, ErrorProgram, "Error generating certificate: %s", err) } if err != nil {
pem_block, err := cert.MarshalPem() crash_with_help(cmd, ErrorProgram, "Error generating certificate: %s", err)
if err != nil { crash_with_help(cmd, ErrorProgram, "Error when marshalling to pem: %s", err) } }
_, err = pem_block.WriteTo(FlagOutput) pem_block, err := cert.MarshalPem()
if err != nil { crash_with_help(cmd, ErrorProgram, "Could not write to output: %s", err) } if err != nil {
crash_with_help(cmd, ErrorProgram, "Error when marshalling to pem: %s", err)
}
_, err = pem_block.WriteTo(FlagOutput)
if err != nil {
crash_with_help(cmd, ErrorProgram, "Could not write to output: %s", err)
}
} }
// crash and provide a helpful message // crash and provide a helpful message
func crash_with_help(cmd *Command, code int, message string, args ...interface{}) { func crash_with_help(cmd *Command, code int, message string, args ...interface{}) {
fmt.Fprintf(os.Stderr, message + "\n", args...) fmt.Fprintf(os.Stderr, message+"\n", args...)
cmd.Usage() cmd.Usage()
os.Exit(code) os.Exit(code)
} }
// return the arguments to the program // return the arguments to the program
func program_args() []string { func program_args() []string {
return os.Args[2:] return os.Args[2:]
} }

View File

@ -1,43 +1,53 @@
package main package main
import ( import (
"errors" "errors"
"github.com/gibheer/pki" "github.com/gibheer/pki"
) )
var ( var (
ErrNoPKFound = errors.New("no private key found") ErrNoPKFound = errors.New("no private key found")
ErrNoPUFound = errors.New("no public key found") ErrNoPUFound = errors.New("no public key found")
ErrUnknownFormat = errors.New("key is an unknown format") ErrUnknownFormat = errors.New("key is an unknown format")
) )
// Read the private key from the path and try to figure out which type of key it // Read the private key from the path and try to figure out which type of key it
// might be. // might be.
func ReadPrivateKeyFile(path string) (pki.PrivateKey, error) { func ReadPrivateKeyFile(path string) (pki.PrivateKey, error) {
raw_pk, err := readSectionFromFile(path, pki.PemLabelEcdsa) raw_pk, err := readSectionFromFile(path, pki.PemLabelEcdsa)
if err == nil { if err == nil {
pk, err := pki.LoadPrivateKeyEcdsa(raw_pk) pk, err := pki.LoadPrivateKeyEcdsa(raw_pk)
if err != nil { return nil, err } if err != nil {
return pk, nil return nil, err
} }
raw_pk, err = readSectionFromFile(path, pki.PemLabelRsa) return pk, nil
if err == nil { }
pk, err := pki.LoadPrivateKeyRsa(raw_pk) raw_pk, err = readSectionFromFile(path, pki.PemLabelRsa)
if err != nil { return nil, err } if err == nil {
return pk, nil pk, err := pki.LoadPrivateKeyRsa(raw_pk)
} if err != nil {
return nil, ErrNoPKFound return nil, err
}
return pk, nil
}
return nil, ErrNoPKFound
} }
// read the public key and try to figure out what kind of key it might be // read the public key and try to figure out what kind of key it might be
func ReadPublicKeyFile(path string) (pki.PublicKey, error) { func ReadPublicKeyFile(path string) (pki.PublicKey, error) {
raw_pu, err := readSectionFromFile(path, pki.PemLabelPublic) raw_pu, err := readSectionFromFile(path, pki.PemLabelPublic)
if err != nil { return nil, ErrNoPUFound } if err != nil {
return nil, ErrNoPUFound
}
var public pki.PublicKey var public pki.PublicKey
public, err = pki.LoadPublicKeyEcdsa(raw_pu) public, err = pki.LoadPublicKeyEcdsa(raw_pu)
if err == nil { return public, nil } if err == nil {
public, err = pki.LoadPublicKeyRsa(raw_pu) return public, nil
if err == nil { return public, nil } }
return nil, ErrUnknownFormat public, err = pki.LoadPublicKeyRsa(raw_pu)
if err == nil {
return public, nil
}
return nil, ErrUnknownFormat
} }