diff --git a/core/crypto/rsa_common.go b/core/crypto/rsa_common.go index c7e305439a..b084b51f2c 100644 --- a/core/crypto/rsa_common.go +++ b/core/crypto/rsa_common.go @@ -11,10 +11,12 @@ import ( const WeakRsaKeyEnv = "LIBP2P_ALLOW_WEAK_RSA_KEYS" var MinRsaKeyBits = 2048 +var MaxRsaKeyBits = 8192 // ErrRsaKeyTooSmall is returned when trying to generate or parse an RSA key // that's smaller than MinRsaKeyBits bits. In test var ErrRsaKeyTooSmall error +var ErrRsaKeyTooBig error = fmt.Errorf("rsa keys must be <= %d bits", MaxRsaKeyBits) func init() { if _, ok := os.LookupEnv(WeakRsaKeyEnv); ok { diff --git a/core/crypto/rsa_go.go b/core/crypto/rsa_go.go index bfbd987bee..cab9e3812f 100644 --- a/core/crypto/rsa_go.go +++ b/core/crypto/rsa_go.go @@ -31,6 +31,9 @@ func GenerateRSAKeyPair(bits int, src io.Reader) (PrivKey, PubKey, error) { if bits < MinRsaKeyBits { return nil, nil, ErrRsaKeyTooSmall } + if bits > MaxRsaKeyBits { + return nil, nil, ErrRsaKeyTooBig + } priv, err := rsa.GenerateKey(src, bits) if err != nil { return nil, nil, err @@ -124,6 +127,9 @@ func UnmarshalRsaPrivateKey(b []byte) (key PrivKey, err error) { if sk.N.BitLen() < MinRsaKeyBits { return nil, ErrRsaKeyTooSmall } + if sk.N.BitLen() > MaxRsaKeyBits { + return nil, ErrRsaKeyTooBig + } return &RsaPrivateKey{sk: *sk}, nil } @@ -141,6 +147,9 @@ func UnmarshalRsaPublicKey(b []byte) (key PubKey, err error) { if pk.N.BitLen() < MinRsaKeyBits { return nil, ErrRsaKeyTooSmall } + if pk.N.BitLen() > MaxRsaKeyBits { + return nil, ErrRsaKeyTooBig + } return &RsaPublicKey{k: *pk}, nil } diff --git a/core/crypto/rsa_test.go b/core/crypto/rsa_test.go index 69151b86c9..2fced564fb 100644 --- a/core/crypto/rsa_test.go +++ b/core/crypto/rsa_test.go @@ -68,6 +68,44 @@ func TestRSASmallKey(t *testing.T) { } } +func TestRSABigKeyFailsToGenerate(t *testing.T) { + _, _, err := GenerateRSAKeyPair(MaxRsaKeyBits*2, rand.Reader) + if err != ErrRsaKeyTooBig { + t.Fatal("should have refused to create too big RSA key") + } +} + +func TestRSABigKey(t *testing.T) { + // Make the global limit smaller for this test to run faster. + // Note we also change the limit below, but this is different + origSize := MaxRsaKeyBits + MaxRsaKeyBits = 2048 + defer func() { MaxRsaKeyBits = origSize }() // + + MaxRsaKeyBits *= 2 + badPriv, badPub, err := GenerateRSAKeyPair(MaxRsaKeyBits, rand.Reader) + if err != nil { + t.Fatalf("should have succeeded, got: %s", err) + } + pubBytes, err := MarshalPublicKey(badPub) + if err != nil { + t.Fatal(err) + } + privBytes, err := MarshalPrivateKey(badPriv) + if err != nil { + t.Fatal(err) + } + MaxRsaKeyBits /= 2 + _, err = UnmarshalPublicKey(pubBytes) + if err != ErrRsaKeyTooBig { + t.Fatal("should have refused to unmarshal a too big key") + } + _, err = UnmarshalPrivateKey(privBytes) + if err != ErrRsaKeyTooBig { + t.Fatal("should have refused to unmarshal a too big key") + } +} + func TestRSASignZero(t *testing.T) { priv, pub, err := GenerateRSAKeyPair(2048, rand.Reader) if err != nil {