From dc03a61130e3d3f5650e5188e9b25c1c89642cf4 Mon Sep 17 00:00:00 2001 From: Bas Westerbaan Date: Mon, 19 Aug 2024 11:49:59 +0200 Subject: [PATCH] Add ML-KEM decapsulation key check. Described in section 7.3 of FIPS 203. The check is only required if the private key is from an untrusted source. We do not distinguish between a trusted and untrusted source in the current API, so we'll perform the check every time we unmarshal the private key. --- kem/kem.go | 3 +++ kem/kyber/templates/pkg.templ.go | 28 ++++++++++++++++++++++++++++ kem/mlkem/mlkem1024/kyber.go | 19 ++++++++++++++++--- kem/mlkem/mlkem512/kyber.go | 19 ++++++++++++++++--- kem/mlkem/mlkem768/kyber.go | 19 ++++++++++++++++--- 5 files changed, 79 insertions(+), 9 deletions(-) diff --git a/kem/kem.go b/kem/kem.go index 6ab0aa3ba..a2f6a2aae 100644 --- a/kem/kem.go +++ b/kem/kem.go @@ -113,6 +113,9 @@ var ( // ErrPubKey is the error used if the provided public key is invalid. ErrPubKey = errors.New("invalid public key") + // ErrPrivKey is the error used if the provided private key is invalid. + ErrPrivKey = errors.New("invalid private key") + // ErrCipherText is the error used if the provided ciphertext is invalid. ErrCipherText = errors.New("invalid ciphertext") ) diff --git a/kem/kyber/templates/pkg.templ.go b/kem/kyber/templates/pkg.templ.go index 44a311405..2c9925c41 100644 --- a/kem/kyber/templates/pkg.templ.go +++ b/kem/kyber/templates/pkg.templ.go @@ -256,19 +256,41 @@ func (sk *PrivateKey) Pack(buf []byte) { // Unpacks sk from buf. // // Panics if buf is not of size PrivateKeySize. +{{ if .NIST -}} +// +// Returns an error if buf is not of size PrivateKeySize, or private key +// doesn't pass the ML-KEM decapsulation key check. +func (sk *PrivateKey) Unpack(buf []byte) error { + if len(buf) != PrivateKeySize { + return kem.ErrPrivKeySize + } +{{- else -}} func (sk *PrivateKey) Unpack(buf []byte) { if len(buf) != PrivateKeySize { panic("buf must be of length PrivateKeySize") } +{{- end }} sk.sk = new(cpapke.PrivateKey) sk.sk.Unpack(buf[:cpapke.PrivateKeySize]) buf = buf[cpapke.PrivateKeySize:] sk.pk = new(cpapke.PublicKey) sk.pk.Unpack(buf[:cpapke.PublicKeySize]) +{{ if .NIST -}} + var hpk [32]byte + h := sha3.New256() + h.Write(buf[:cpapke.PublicKeySize]) + h.Read(hpk[:]) +{{ end -}} buf = buf[cpapke.PublicKeySize:] copy(sk.hpk[:], buf[:32]) copy(sk.z[:], buf[32:]) +{{ if .NIST -}} + if !bytes.Equal(hpk[:], sk.hpk[:]) { + return kem.ErrPrivKey + } + return nil +{{ end -}} } // Packs pk to buf. @@ -463,6 +485,12 @@ func (*scheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) { return nil, kem.ErrPrivKeySize } var ret PrivateKey + {{ if .NIST -}} + if err := ret.Unpack(buf); err != nil { + return nil, err + } + {{- else -}} ret.Unpack(buf) + {{- end }} return &ret, nil } diff --git a/kem/mlkem/mlkem1024/kyber.go b/kem/mlkem/mlkem1024/kyber.go index a1e73b892..c548faa2a 100644 --- a/kem/mlkem/mlkem1024/kyber.go +++ b/kem/mlkem/mlkem1024/kyber.go @@ -203,9 +203,12 @@ func (sk *PrivateKey) Pack(buf []byte) { // Unpacks sk from buf. // // Panics if buf is not of size PrivateKeySize. -func (sk *PrivateKey) Unpack(buf []byte) { +// +// Returns an error if buf is not of size PrivateKeySize, or private key +// doesn't pass the ML-KEM decapsulation key check. +func (sk *PrivateKey) Unpack(buf []byte) error { if len(buf) != PrivateKeySize { - panic("buf must be of length PrivateKeySize") + return kem.ErrPrivKeySize } sk.sk = new(cpapke.PrivateKey) @@ -213,9 +216,17 @@ func (sk *PrivateKey) Unpack(buf []byte) { buf = buf[cpapke.PrivateKeySize:] sk.pk = new(cpapke.PublicKey) sk.pk.Unpack(buf[:cpapke.PublicKeySize]) + var hpk [32]byte + h := sha3.New256() + h.Write(buf[:cpapke.PublicKeySize]) + h.Read(hpk[:]) buf = buf[cpapke.PublicKeySize:] copy(sk.hpk[:], buf[:32]) copy(sk.z[:], buf[32:]) + if !bytes.Equal(hpk[:], sk.hpk[:]) { + return kem.ErrPrivKey + } + return nil } // Packs pk to buf. @@ -389,6 +400,8 @@ func (*scheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) { return nil, kem.ErrPrivKeySize } var ret PrivateKey - ret.Unpack(buf) + if err := ret.Unpack(buf); err != nil { + return nil, err + } return &ret, nil } diff --git a/kem/mlkem/mlkem512/kyber.go b/kem/mlkem/mlkem512/kyber.go index 8bb2ce194..85f07aba7 100644 --- a/kem/mlkem/mlkem512/kyber.go +++ b/kem/mlkem/mlkem512/kyber.go @@ -203,9 +203,12 @@ func (sk *PrivateKey) Pack(buf []byte) { // Unpacks sk from buf. // // Panics if buf is not of size PrivateKeySize. -func (sk *PrivateKey) Unpack(buf []byte) { +// +// Returns an error if buf is not of size PrivateKeySize, or private key +// doesn't pass the ML-KEM decapsulation key check. +func (sk *PrivateKey) Unpack(buf []byte) error { if len(buf) != PrivateKeySize { - panic("buf must be of length PrivateKeySize") + return kem.ErrPrivKeySize } sk.sk = new(cpapke.PrivateKey) @@ -213,9 +216,17 @@ func (sk *PrivateKey) Unpack(buf []byte) { buf = buf[cpapke.PrivateKeySize:] sk.pk = new(cpapke.PublicKey) sk.pk.Unpack(buf[:cpapke.PublicKeySize]) + var hpk [32]byte + h := sha3.New256() + h.Write(buf[:cpapke.PublicKeySize]) + h.Read(hpk[:]) buf = buf[cpapke.PublicKeySize:] copy(sk.hpk[:], buf[:32]) copy(sk.z[:], buf[32:]) + if !bytes.Equal(hpk[:], sk.hpk[:]) { + return kem.ErrPrivKey + } + return nil } // Packs pk to buf. @@ -389,6 +400,8 @@ func (*scheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) { return nil, kem.ErrPrivKeySize } var ret PrivateKey - ret.Unpack(buf) + if err := ret.Unpack(buf); err != nil { + return nil, err + } return &ret, nil } diff --git a/kem/mlkem/mlkem768/kyber.go b/kem/mlkem/mlkem768/kyber.go index 61271cd86..afa483156 100644 --- a/kem/mlkem/mlkem768/kyber.go +++ b/kem/mlkem/mlkem768/kyber.go @@ -203,9 +203,12 @@ func (sk *PrivateKey) Pack(buf []byte) { // Unpacks sk from buf. // // Panics if buf is not of size PrivateKeySize. -func (sk *PrivateKey) Unpack(buf []byte) { +// +// Returns an error if buf is not of size PrivateKeySize, or private key +// doesn't pass the ML-KEM decapsulation key check. +func (sk *PrivateKey) Unpack(buf []byte) error { if len(buf) != PrivateKeySize { - panic("buf must be of length PrivateKeySize") + return kem.ErrPrivKeySize } sk.sk = new(cpapke.PrivateKey) @@ -213,9 +216,17 @@ func (sk *PrivateKey) Unpack(buf []byte) { buf = buf[cpapke.PrivateKeySize:] sk.pk = new(cpapke.PublicKey) sk.pk.Unpack(buf[:cpapke.PublicKeySize]) + var hpk [32]byte + h := sha3.New256() + h.Write(buf[:cpapke.PublicKeySize]) + h.Read(hpk[:]) buf = buf[cpapke.PublicKeySize:] copy(sk.hpk[:], buf[:32]) copy(sk.z[:], buf[32:]) + if !bytes.Equal(hpk[:], sk.hpk[:]) { + return kem.ErrPrivKey + } + return nil } // Packs pk to buf. @@ -389,6 +400,8 @@ func (*scheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) { return nil, kem.ErrPrivKeySize } var ret PrivateKey - ret.Unpack(buf) + if err := ret.Unpack(buf); err != nil { + return nil, err + } return &ret, nil }