100 lines
2.5 KiB
Go
100 lines
2.5 KiB
Go
|
//+build !go1.15
|
||
|
|
||
|
package jwt
|
||
|
|
||
|
import (
|
||
|
"crypto/ecdsa"
|
||
|
"crypto/rand"
|
||
|
"math/big"
|
||
|
"math/bits"
|
||
|
)
|
||
|
|
||
|
// Implements the Sign method from SigningMethod
|
||
|
// For this signing method, key must be an ecdsa.PrivateKey struct
|
||
|
func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string, error) {
|
||
|
// Get the key
|
||
|
var ecdsaKey *ecdsa.PrivateKey
|
||
|
switch k := key.(type) {
|
||
|
case *ecdsa.PrivateKey:
|
||
|
ecdsaKey = k
|
||
|
default:
|
||
|
return "", ErrInvalidKeyType
|
||
|
}
|
||
|
|
||
|
// Create the hasher
|
||
|
if !m.Hash.Available() {
|
||
|
return "", ErrHashUnavailable
|
||
|
}
|
||
|
|
||
|
hasher := m.Hash.New()
|
||
|
hasher.Write([]byte(signingString))
|
||
|
|
||
|
// Sign the string and return r, s
|
||
|
if r, s, err := ecdsa.Sign(rand.Reader, ecdsaKey, hasher.Sum(nil)); err == nil {
|
||
|
curveBits := ecdsaKey.Curve.Params().BitSize
|
||
|
|
||
|
if m.CurveBits != curveBits {
|
||
|
return "", ErrInvalidKey
|
||
|
}
|
||
|
|
||
|
keyBytes := curveBits / 8
|
||
|
if curveBits%8 > 0 {
|
||
|
keyBytes += 1
|
||
|
}
|
||
|
|
||
|
// We serialize the outputs (r and s) into big-endian byte arrays
|
||
|
// padded with zeros on the left to make sure the sizes work out.
|
||
|
// Output must be 2*keyBytes long.
|
||
|
out := make([]byte, 2*keyBytes)
|
||
|
fillBytesInt(r, out[0:keyBytes]) // r is assigned to the first half of output.
|
||
|
fillBytesInt(s, out[keyBytes:]) // s is assigned to the second half of output.
|
||
|
|
||
|
return EncodeSegment(out), nil
|
||
|
} else {
|
||
|
return "", err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func fillBytesInt(x *big.Int, buf []byte) []byte {
|
||
|
// Clear whole buffer. (This gets optimized into a memclr.)
|
||
|
for i := range buf {
|
||
|
buf[i] = 0
|
||
|
}
|
||
|
|
||
|
// This code is deeply inspired by go's own implementation but rewritten.
|
||
|
|
||
|
// Although this function is called bits it returns words
|
||
|
words := x.Bits()
|
||
|
|
||
|
// Words are uints as per the definition of bits.Word and thus there are usually (64) /8 bytes per word
|
||
|
bytesPerWord := bits.UintSize / 8
|
||
|
|
||
|
// If our buffer is longer than the expected number of words start mid-way
|
||
|
pos := len(buf) - len(words)*bytesPerWord
|
||
|
|
||
|
// Now iterate across the words (backwards)
|
||
|
for i := range words {
|
||
|
// Grab the last word (Which is the biggest number)
|
||
|
word := words[len(words)-1-i]
|
||
|
|
||
|
// Now for each byte in the word
|
||
|
// [abcd...] we want buf[0] = a, buf[1] = b ...
|
||
|
|
||
|
for j := bytesPerWord; j > 0; j-- {
|
||
|
d := byte(word)
|
||
|
// if our position is less than 0 then panic
|
||
|
if pos+j-1 >= 0 {
|
||
|
// set the value of the byte to the byte
|
||
|
buf[pos+j-1] = d
|
||
|
} else if d != 0 {
|
||
|
panic("math/big: buffer too small to fit value") // have to use the same panic string for complete compatibility
|
||
|
}
|
||
|
// shift the word 8 bits and reloop.
|
||
|
word >>= 8
|
||
|
}
|
||
|
pos += bytesPerWord
|
||
|
}
|
||
|
|
||
|
return buf
|
||
|
}
|