diff --git a/sign_input.go b/sign_input.go index 6a264dc..c1ab9e0 100644 --- a/sign_input.go +++ b/sign_input.go @@ -4,10 +4,12 @@ import ( "crypto" "crypto/rand" "crypto/sha256" + "encoding/base64" "errors" "flag" "fmt" "io" + "io/ioutil" "os" // "crypto/ecdsa" // "crypto/rsa" @@ -16,16 +18,22 @@ import ( type ( SignInputFlags struct { Message string // the message to sign + MessageStream string // the message stream to sign PrivateKeyPath string // path to the private key Output string // a path or stream to output the private key to + Format string // the format of the output private_key crypto.Signer output_stream io.Writer // the output stream for the CSR + input_stream io.Reader // the input stream to read the message from } ) func sign_input() { flags := parse_sign_input_flags() + if flags.Message != "" && flags.MessageStream != "" { + crash_with_help(2, "Only message or message file can be signed!") + } flags.private_key = load_private_key(flags.PrivateKeyPath) output_stream, err := open_output_stream(flags.Output) @@ -35,6 +43,15 @@ func sign_input() { flags.output_stream = output_stream defer output_stream.Close() + if flags.MessageStream != "" { + input_stream, err := open_input_stream(flags.MessageStream) + if err != nil { + crash_with_help(2, fmt.Sprintf("Error when opening stream %s: %s", flags.MessageStream, err)) + } + flags.input_stream = input_stream + defer input_stream.Close() + } + if err := create_signature(flags); err != nil { fmt.Fprintln(os.Stderr, "Error when creating signature", err) os.Exit(3) @@ -45,15 +62,25 @@ func parse_sign_input_flags() SignInputFlags { flags := SignInputFlags{} fs := flag.NewFlagSet("sign-input", flag.ExitOnError) fs.StringVar(&flags.PrivateKeyPath, "private-key", "", "path to the private key file") - fs.StringVar(&flags.Output, "output", "STDOUT", "path where the generated signature should be stored") + fs.StringVar(&flags.Output, "output", "STDOUT", "path where the generated signature should be stored (STDOUT, STDERR, file)") fs.StringVar(&flags.Message, "message", "", "the message to sign") + fs.StringVar(&flags.MessageStream, "message-stream", "STDIN", "the path to the stream to sign (file, STDIN)") + fs.StringVar(&flags.Format, "format", "base64", "the output format (binary, base64)") fs.Parse(os.Args[2:]) return flags } func create_signature(flags SignInputFlags) error { - message := []byte(flags.Message) + var message []byte + var err error + + if flags.MessageStream != "" { + message, err = ioutil.ReadAll(flags.input_stream) + if err != nil { return err } + } else { + message = []byte(flags.Message) + } // compute sha256 of the message hash := sha256.New() length, _ := hash.Write(message) @@ -66,6 +93,11 @@ func create_signature(flags SignInputFlags) error { nil, ) if err != nil { return err } - flags.output_stream.Write(signature) + if flags.Format == "base64" { + flags.output_stream.Write([]byte(base64.StdEncoding.EncodeToString(signature))) + } else { + flags.output_stream.Write(signature) + } + flags.output_stream.Write([]byte("\n")) return nil } diff --git a/verify_signature.go b/verify_signature.go index 76020bb..b661e9a 100644 --- a/verify_signature.go +++ b/verify_signature.go @@ -6,19 +6,28 @@ import ( "crypto/x509" "encoding/asn1" "encoding/pem" + "encoding/base64" "errors" "flag" "fmt" + "io" "io/ioutil" "math/big" "os" + "strings" ) type ( VerifySignatureFlags struct { Message string // the message to sign + MessageStream string // the path to the input stream PublicKeyPath string // path to the private key Signature string // a path or stream to output the private key to + SignatureStream string // read signature from an input stream + Format string // the format of the signature + + message_stream io.Reader // the message stream + signature_stream io.Reader // the signature stream } // struct to load the signature into (which is basically two bigint in byte form) Signature struct { @@ -28,16 +37,40 @@ type ( func verify_signature() { flags := parse_verify_signature_flags() + if flags.SignatureStream == flags.MessageStream && + ( flags.Message == "" && flags.Signature == "") { + crash_with_help(2, "Signature and Message stream can't be the same source!") + } + + // open streams + if flags.Message == "" && flags.MessageStream != "" { + message_stream, err := open_input_stream(flags.MessageStream) + if err != nil { + crash_with_help(2, fmt.Sprintf("Error when opening stream %s: %s", flags.MessageStream, err)) + } + defer message_stream.Close() + flags.message_stream = message_stream + } + if flags.Signature == "" && flags.SignatureStream != "" { + signature_stream, err := open_input_stream(flags.SignatureStream) + if err != nil { + crash_with_help(2, fmt.Sprintf("Error when opening stream %s: %s", flags.SignatureStream, err)) + } + defer signature_stream.Close() + flags.signature_stream = signature_stream + } + public_key, err := load_public_key_ecdsa(flags.PublicKeyPath) if err != nil { crash_with_help(2, fmt.Sprintf("Error when loading public key: %s", err)) } - signature, err := load_signature(flags.Signature) + signature, err := load_signature(flags) if err != nil { crash_with_help(2, fmt.Sprintf("Error when loading the signature: %s", err)) } + message, err := load_message(flags) hash := sha256.New() - hash.Write([]byte(flags.Message)) + hash.Write([]byte(message)) success := ecdsa.Verify(public_key, hash.Sum(nil), signature.R, signature.S) fmt.Println(success) @@ -49,7 +82,10 @@ func parse_verify_signature_flags() VerifySignatureFlags { fs := flag.NewFlagSet("verify-signature", flag.ExitOnError) fs.StringVar(&flags.PublicKeyPath, "public-key", "", "path to the public key file") fs.StringVar(&flags.Signature, "signature", "", "path where the signature file can be found") - fs.StringVar(&flags.Message, "message", "", "the message to be validated") + fs.StringVar(&flags.SignatureStream, "signature-stream", "", "the path to the stream of the signature (file, STDIN)") + fs.StringVar(&flags.Format, "format", "auto", "the input format of the signature (auto, binary, base64)") + fs.StringVar(&flags.Message, "message", "", "the message to validate") + fs.StringVar(&flags.MessageStream, "message-stream", "STDIN", "the path to the stream to validate (file, STDIN)") fs.Parse(os.Args[2:]) return flags @@ -74,15 +110,50 @@ func load_public_key_ecdsa(path string) (*ecdsa.PublicKey, error) { } // parse the signature from asn1 file -func load_signature(path string) (*Signature, error) { - signature_file, err := os.Open(path) - if err != nil { return nil, err } - signature_raw, err := ioutil.ReadAll(signature_file) - if err != nil { return nil, err } - signature_file.Close() +func load_signature(flags VerifySignatureFlags) (*Signature, error) { + var signature_raw []byte + var err error + if flags.Message != "" { + signature_raw = []byte(flags.Message) + } else { + signature_raw, err = ioutil.ReadAll(flags.signature_stream) + if err != nil { return nil, err } + } + switch strings.ToLower(flags.Format) { + case "auto": + sig, err := load_signature_base64(signature_raw) + if err != nil { + sig, err = load_signature_binary(signature_raw) + if err != nil { return nil, err } + return sig, nil + } + return sig, nil + case "base64": return load_signature_base64(signature_raw) + case "binary": return load_signature_binary(signature_raw) + default: return nil, errors.New("Unknown format!") + } +} + +// convert the signature from base64 into a signature +func load_signature_base64(signature_raw []byte) (*Signature, error) { + asn1_sig, err := base64.StdEncoding.DecodeString(string(signature_raw)) + if err != nil { return nil, err } + return load_signature_binary(asn1_sig) +} + +// convert the signature from asn1 into a signature +func load_signature_binary(signature_raw []byte) (*Signature, error) { var signature Signature - _, err = asn1.Unmarshal(signature_raw, &signature) + _, err := asn1.Unmarshal(signature_raw, &signature) if err != nil { return nil, err } return &signature, nil } + +// load the message from a stream or the parameter +func load_message(flags VerifySignatureFlags) (string, error) { + if flags.Message != "" { return flags.Message, nil } + message, err := ioutil.ReadAll(flags.message_stream) + if err != nil { return "", err } + return string(message), nil +}