Browse Source

Merge pull request #11625 from jakesmith/hpcc-20435

HPCC-20435 Allow 16/24/32 bit aes keys and fix 0 length issue.

Reviewed-by: Gavin Halliday <ghalliday@hpccsystems.com>
Gavin Halliday 6 years ago
parent
commit
237655d517

+ 35 - 12
system/security/cryptohelper/ske.cpp

@@ -31,22 +31,42 @@
 #include "pke.hpp"
 #include "ske.hpp"
 
+static const EVP_CIPHER *getAesCipher(size32_t keyLen)
+{
+    switch (keyLen)
+    {
+    case 128/8:
+        return EVP_aes_128_cbc();
+    case 192/8:
+        return EVP_aes_192_cbc();
+    case 256/8:
+        return EVP_aes_256_cbc();
+    default:
+        throw makeStringException(0, "Invalid AES key size, must be 128, 192 or 256 bit");
+    }
+}
+
 namespace cryptohelper
 {
 
-size32_t aesKeyEncrypt(MemoryBuffer &out, size32_t inSz, const void *inBytes, const char key[aesKeySize], const char iv[aesBlockSize])
+// NB: static random IV used for AES, custom IV can be supplied to cryptohelper::aes* routines
+static const char staticAesIV[16+1] = "j5P2Sz&DnW'FOW^{";
+
+size32_t aesEncrypt(MemoryBuffer &out, size32_t inSz, const void *inBytes, size32_t keyLen, const char *key, const char iv[aesBlockSize])
 {
+    if (0 == inSz)
+        return 0;
     OwnedEVPCipherCtx ctx(EVP_CIPHER_CTX_new());
     if (!ctx)
         throw makeEVPException(0, "Failed EVP_CIPHER_CTX_new");
-
     /* Initialise the encryption operation. IMPORTANT - ensure you use a key
      * and IV size appropriate for your cipher
      * In this example we are using 256 bit AES (i.e. a 256 bit key). The
      * IV size for *most* modes is the same as the block size. For AES this
      * is 128 bits
      * */
-    if (1 != EVP_EncryptInit_ex(ctx, EVP_aes_256_cbc(), nullptr, (const unsigned char *)key, (const unsigned char *)iv))
+    if (!iv) iv = staticAesIV;
+    if (1 != EVP_EncryptInit_ex(ctx, getAesCipher(keyLen), nullptr, (const unsigned char *)key, (const unsigned char *)iv))
         throw makeEVPException(0, "Failed EVP_EncryptInit_ex");
 
     /* Provide the message to be encrypted, and obtain the encrypted output.
@@ -72,8 +92,10 @@ size32_t aesKeyEncrypt(MemoryBuffer &out, size32_t inSz, const void *inBytes, co
     return (size32_t)ciphertext_len;
 }
 
-size32_t aesKeyDecrypt(MemoryBuffer &out, size32_t inSz, const void *inBytes, const char *key, const char *iv)
+size32_t aesDecrypt(MemoryBuffer &out, size32_t inSz, const void *inBytes, size32_t keyLen, const char *key, const char *iv)
 {
+    if (0 == inSz)
+        return 0;
     OwnedEVPCipherCtx ctx(EVP_CIPHER_CTX_new());
     if (!ctx)
         throw makeEVPException(0, "Failed EVP_CIPHER_CTX_new");
@@ -90,7 +112,8 @@ size32_t aesKeyDecrypt(MemoryBuffer &out, size32_t inSz, const void *inBytes, co
      * IV size for *most* modes is the same as the block size. For AES this
      * is 128 bits
      * */
-    if (1 != EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), nullptr, (const unsigned char *)key, (const unsigned char *)iv))
+    if (!iv) iv = staticAesIV;
+    if (1 != EVP_DecryptInit_ex(ctx, getAesCipher(keyLen), nullptr, (const unsigned char *)key, (const unsigned char *)iv))
         throw makeEVPException(0, "Failed EVP_DecryptInit_ex");
 
     /* Provide the message to be decrypted, and obtain the plaintext output.
@@ -115,19 +138,19 @@ size32_t aesKeyDecrypt(MemoryBuffer &out, size32_t inSz, const void *inBytes, co
 size32_t aesEncryptWithRSAEncryptedKey(MemoryBuffer &out, size32_t inSz, const void *inBytes, const CLoadedKey &publicKey)
 {
     // create random AES key and IV
-    char randomAesKey[aesKeySize];
+    char randomAesKey[aesMaxKeySize];
     char randomIV[aesBlockSize];
-    fillRandomData(aesKeySize, randomAesKey);
+    fillRandomData(aesMaxKeySize, randomAesKey);
     fillRandomData(aesBlockSize, randomIV);
 
     size32_t startSz = out.length();
     DelayedSizeMarker mark(out);
-    publicKeyEncrypt(out, aesKeySize, randomAesKey, publicKey);
+    publicKeyEncrypt(out, aesMaxKeySize, randomAesKey, publicKey);
     mark.write();
     out.append(aesBlockSize, randomIV);
 
     DelayedSizeMarker aesSz(out);
-    aesKeyEncrypt(out, inSz, inBytes, randomAesKey, randomIV);
+    aesEncrypt(out, inSz, inBytes, aesMaxKeySize, randomAesKey, randomIV);
     aesSz.write();
     return out.length()-startSz;
 }
@@ -137,12 +160,12 @@ size32_t aesDecryptWithRSAEncryptedKey(MemoryBuffer &out, size32_t inSz, const v
     MemoryBuffer in;
     in.setBuffer(inSz, (void *)inBytes, false);
     // read encrypted AES key
-    char randomAesKey[aesKeySize];
+    char randomAesKey[aesMaxKeySize];
     size32_t encryptedAESKeySz;
     in.read(encryptedAESKeySz);
     MemoryBuffer aesKey;
     size32_t decryptedAesKeySz = privateKeyDecrypt(aesKey, encryptedAESKeySz, in.readDirect(encryptedAESKeySz), privateKey);
-    if (decryptedAesKeySz != aesKeySize)
+    if (decryptedAesKeySz != aesMaxKeySize)
         throw makeStringException(0, "aesDecryptWithRSAEncryptedKey - invalid input");
 
     unsigned iVPos = in.getPos(); // read directly further down
@@ -151,7 +174,7 @@ size32_t aesDecryptWithRSAEncryptedKey(MemoryBuffer &out, size32_t inSz, const v
     size32_t aesEncryptedSz;
     in.read(aesEncryptedSz);
 
-    return aesKeyDecrypt(out, aesEncryptedSz, in.readDirect(aesEncryptedSz), (const char *)aesKey.bytes(), (const char *)in.bytes()+iVPos);
+    return aesDecrypt(out, aesEncryptedSz, in.readDirect(aesEncryptedSz), aesMaxKeySize, (const char *)aesKey.bytes(), (const char *)in.bytes()+iVPos);
 }
 
 

+ 6 - 3
system/security/cryptohelper/ske.hpp

@@ -35,12 +35,15 @@ namespace cryptohelper
 
 #if defined(_USE_OPENSSL) && !defined(_WIN32)
 
-const unsigned aesKeySize = 256/8; // 256 bits
+const unsigned aesMaxKeySize = 256/8; // 256 bits
 const unsigned aesBlockSize = 128/8; // 128 bits
 
-CRYPTOHELPER_API size32_t aesKeyEncrypt(MemoryBuffer &out, size32_t inSz, const void *inBytes, const char key[aesKeySize], const char iv[aesBlockSize]);
-CRYPTOHELPER_API size32_t aesKeyDecrypt(MemoryBuffer &out, size32_t inSz, const void *inBytes, const char key[aesKeySize], const char iv[aesBlockSize]);
+// for AES, keyLen must be 16, 24, or 32 Bytes
 
+CRYPTOHELPER_API size32_t aesEncrypt(MemoryBuffer &out, size32_t inSz, const void *inBytes, size32_t keyLen, const char *key, const char iv[aesBlockSize] = nullptr);
+CRYPTOHELPER_API size32_t aesDecrypt(MemoryBuffer &out, size32_t inSz, const void *inBytes, size32_t keyLen, const char *key, const char iv[aesBlockSize] = nullptr);
+
+class CLoadedKey;
 // aesEncryptWithRSAEncryptedKey serializes encrypted data along with an RSA encrypted key in the format { RSA-encrypted-AES-key, aes-IV, AES-encrypted-data }
 CRYPTOHELPER_API size32_t aesEncryptWithRSAEncryptedKey(MemoryBuffer &out, size32_t inSz, const void *inBytes, const CLoadedKey &publicKey);
 // aesDecryptWithRSAEncryptedKey deserializes data created by aesEncryptWithRSAEncryptedKey

+ 43 - 31
testing/unittests/cryptotests.cpp

@@ -380,50 +380,62 @@ protected:
             // create random data
             MemoryBuffer messageMb, encryptedMessageMb, decryptedMessageMb;
 
-            char aesKey[aesKeySize];
+            char aesKey[aesMaxKeySize];
             char aesIV[aesBlockSize];
-            fillRandomData(aesKeySize, aesKey);
+            fillRandomData(aesMaxKeySize, aesKey);
             fillRandomData(aesBlockSize, aesIV);
 
             fillRandomData(1024*100, messageMb);
-            printf("aesEncryptDecryptTests with %u bytes\n", messageMb.length());
-            aesKeyEncrypt(encryptedMessageMb, messageMb.length(), messageMb.bytes(), aesKey, aesIV);
-            aesKeyDecrypt(decryptedMessageMb, encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesKey, aesIV);
+            printf("aesEncryptDecryptTests with %u bytes with 256bit aes key\n", messageMb.length());
+            aesEncrypt(encryptedMessageMb, messageMb.length(), messageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
+            aesDecrypt(decryptedMessageMb, encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
+            ASSERT(messageMb.length() == decryptedMessageMb.length());
+            ASSERT(0 == memcmp(messageMb.bytes(), decryptedMessageMb.bytes(), messageMb.length()));
+
+            printf("aesEncryptDecryptTests with %u bytes with 192bit aes key\n", messageMb.length());
+            aesEncrypt(encryptedMessageMb.clear(), messageMb.length(), messageMb.bytes(), 192/8, aesKey, aesIV);
+            aesDecrypt(decryptedMessageMb.clear(), encryptedMessageMb.length(), encryptedMessageMb.bytes(), 192/8, aesKey, aesIV);
+            ASSERT(messageMb.length() == decryptedMessageMb.length());
+            ASSERT(0 == memcmp(messageMb.bytes(), decryptedMessageMb.bytes(), messageMb.length()));
+
+            printf("aesEncryptDecryptTests with %u bytes with 128bit aes key\n", messageMb.length());
+            aesEncrypt(encryptedMessageMb.clear(), messageMb.length(), messageMb.bytes(), 128/8, aesKey, aesIV);
+            aesDecrypt(decryptedMessageMb.clear(), encryptedMessageMb.length(), encryptedMessageMb.bytes(), 128/8, aesKey, aesIV);
             ASSERT(messageMb.length() == decryptedMessageMb.length());
             ASSERT(0 == memcmp(messageMb.bytes(), decryptedMessageMb.bytes(), messageMb.length()));
 
             messageMb.clear(); // 0 length test
             printf("aesEncryptDecryptTests with %u bytes\n", messageMb.length());
-            aesKeyEncrypt(encryptedMessageMb.clear(), messageMb.length(), messageMb.bytes(), aesKey, aesIV);
-            aesKeyDecrypt(decryptedMessageMb.clear(), encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesKey, aesIV);
+            aesEncrypt(encryptedMessageMb.clear(), messageMb.length(), messageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
+            aesDecrypt(decryptedMessageMb.clear(), encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
             ASSERT(messageMb.length() == decryptedMessageMb.length());
             ASSERT(0 == memcmp(messageMb.bytes(), decryptedMessageMb.bytes(), messageMb.length()));
 
             fillRandomData(1, messageMb.clear()); // 1 byte test
             printf("aesEncryptDecryptTests with %u bytes\n", messageMb.length());
-            aesKeyEncrypt(encryptedMessageMb.clear(), messageMb.length(), messageMb.bytes(), aesKey, aesIV);
-            aesKeyDecrypt(decryptedMessageMb.clear(), encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesKey, aesIV);
+            aesEncrypt(encryptedMessageMb.clear(), messageMb.length(), messageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
+            aesDecrypt(decryptedMessageMb.clear(), encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
             ASSERT(messageMb.length() == decryptedMessageMb.length());
             ASSERT(0 == memcmp(messageMb.bytes(), decryptedMessageMb.bytes(), messageMb.length()));
 
             fillRandomData(cryptohelper::aesBlockSize-1, messageMb.clear()); // aesBlockSize-1 test
             printf("aesEncryptDecryptTests with %u bytes\n", messageMb.length());
-            aesKeyEncrypt(encryptedMessageMb.clear(), messageMb.length(), messageMb.bytes(), aesKey, aesIV);
-            aesKeyDecrypt(decryptedMessageMb.clear(), encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesKey, aesIV);
+            aesEncrypt(encryptedMessageMb.clear(), messageMb.length(), messageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
+            aesDecrypt(decryptedMessageMb.clear(), encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
             ASSERT(messageMb.length() == decryptedMessageMb.length());
             ASSERT(0 == memcmp(messageMb.bytes(), decryptedMessageMb.bytes(), messageMb.length()));
 
             fillRandomData(cryptohelper::aesBlockSize, messageMb.clear()); // aesBlockSize test
             printf("aesEncryptDecryptTests with %u bytes\n", messageMb.length());
-            aesKeyEncrypt(encryptedMessageMb.clear(), messageMb.length(), messageMb.bytes(), aesKey, aesIV);
-            aesKeyDecrypt(decryptedMessageMb.clear(), encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesKey, aesIV);
+            aesEncrypt(encryptedMessageMb.clear(), messageMb.length(), messageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
+            aesDecrypt(decryptedMessageMb.clear(), encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
             ASSERT(messageMb.length() == decryptedMessageMb.length());
             ASSERT(0 == memcmp(messageMb.bytes(), decryptedMessageMb.bytes(), messageMb.length()));
 
             fillRandomData(cryptohelper::aesBlockSize+1, messageMb.clear()); // aesBlockSize+1 test
             printf("aesEncryptDecryptTests with %u bytes\n", messageMb.length());
-            aesKeyEncrypt(encryptedMessageMb.clear(), messageMb.length(), messageMb.bytes(), aesKey, aesIV);
-            aesKeyDecrypt(decryptedMessageMb.clear(), encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesKey, aesIV);
+            aesEncrypt(encryptedMessageMb.clear(), messageMb.length(), messageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
+            aesDecrypt(decryptedMessageMb.clear(), encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
             ASSERT(messageMb.length() == decryptedMessageMb.length());
             ASSERT(0 == memcmp(messageMb.bytes(), decryptedMessageMb.bytes(), messageMb.length()));
         }
@@ -450,9 +462,9 @@ protected:
             MemoryBuffer messageMb;
             fillRandomData(1024*100, messageMb);
 
-            char aesKey[aesKeySize];
+            char aesKey[aesMaxKeySize];
             char aesIV[aesBlockSize];
-            fillRandomData(aesKeySize, aesKey);
+            fillRandomData(aesMaxKeySize, aesKey);
             fillRandomData(aesBlockSize, aesIV);
 
             Owned<CLoadedKey> publicKey = loadPublicKeyFromMemory(pubKey, nullptr);
@@ -486,13 +498,13 @@ protected:
         class CAsyncfor : public CAsyncFor
         {
             MemoryBuffer messageMb;
-            char aesKey[aesKeySize];
+            char aesKey[aesMaxKeySize];
             char aesIV[aesBlockSize];
         public:
             CAsyncfor()
             {
                 // create random key
-                fillRandomData(aesKeySize, aesKey);
+                fillRandomData(aesMaxKeySize, aesKey);
                 fillRandomData(aesBlockSize, aesIV);
                 // create random data
                 fillRandomData(1024*100, messageMb);
@@ -500,10 +512,10 @@ protected:
             void Do(unsigned idx)
             {
                 MemoryBuffer encryptedMessageMb;
-                aesKeyEncrypt(encryptedMessageMb, messageMb.length(), messageMb.bytes(), aesKey, aesIV);
+                aesEncrypt(encryptedMessageMb, messageMb.length(), messageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
 
                 MemoryBuffer decryptedMessageMb;
-                aesKeyDecrypt(decryptedMessageMb, encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesKey, aesIV);
+                aesDecrypt(decryptedMessageMb, encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
 
                 ASSERT(messageMb.length() == decryptedMessageMb.length());
                 ASSERT(0 == memcmp(messageMb.bytes(), decryptedMessageMb.bytes(), messageMb.length()));
@@ -532,10 +544,10 @@ public:
     void aesCompareJlibVsCryptoHelper()
     {
         MemoryBuffer messageMb, encryptedMessageMb, decryptedMessageMb;
-        char aesKey[aesKeySize];
+        char aesKey[aesMaxKeySize];
         char aesIV[aesBlockSize];
         // create random key
-        fillRandomData(aesKeySize, aesKey);
+        fillRandomData(aesMaxKeySize, aesKey);
         fillRandomData(aesBlockSize, aesIV);
 
         // create random data
@@ -544,12 +556,12 @@ public:
         encryptedMessageMb.ensureCapacity(dataSz+aesBlockSize);
 
         CCycleTimer timer;
-        aesKeyEncrypt(encryptedMessageMb, messageMb.length(), messageMb.bytes(), aesKey, aesIV);
+        cryptohelper::aesEncrypt(encryptedMessageMb, messageMb.length(), messageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
         printf("OPENSSL AES %u MB encrypt time: %u ms\n", dataSz/0x100000, timer.elapsedMs());
 
         decryptedMessageMb.ensureCapacity(encryptedMessageMb.length()+aesBlockSize);
         timer.reset();
-        aesKeyDecrypt(decryptedMessageMb, encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesKey, aesIV);
+        cryptohelper::aesDecrypt(decryptedMessageMb, encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
         printf("OPENSSL AES %u MB decrypt time: %u ms\n", dataSz/0x100000, timer.elapsedMs());
 
         ASSERT(messageMb.length() == decryptedMessageMb.length());
@@ -557,12 +569,12 @@ public:
 
         encryptedMessageMb.clear();
         timer.reset();
-        aesEncrypt(aesKey, aesKeySize, messageMb.bytes(), messageMb.length(), encryptedMessageMb);
+        ::aesEncrypt(aesKey, aesMaxKeySize, messageMb.bytes(), messageMb.length(), encryptedMessageMb);
         printf("JLIB    AES %u MB encrypt time: %u ms\n", dataSz/0x100000, timer.elapsedMs());
 
         decryptedMessageMb.clear();
         timer.reset();
-        aesDecrypt(aesKey, aesKeySize, encryptedMessageMb.bytes(), encryptedMessageMb.length(), decryptedMessageMb);
+        ::aesDecrypt(aesKey, aesMaxKeySize, encryptedMessageMb.bytes(), encryptedMessageMb.length(), decryptedMessageMb);
         printf("JLIB    AES %u MB decrypt time: %u ms\n", dataSz/0x100000, timer.elapsedMs());
 
         ASSERT(messageMb.length() == decryptedMessageMb.length());
@@ -572,10 +584,10 @@ public:
     void aesSpeedTest()
     {
         MemoryBuffer messageMb;
-        char aesKey[aesKeySize];
+        char aesKey[aesMaxKeySize];
         char aesIV[aesBlockSize];
         // create random key
-        fillRandomData(aesKeySize, aesKey);
+        fillRandomData(aesMaxKeySize, aesKey);
         fillRandomData(aesBlockSize, aesIV);
 
         // create random data
@@ -584,12 +596,12 @@ public:
         MemoryBuffer encryptedMessageMb;
         encryptedMessageMb.ensureCapacity(dataSz+aesBlockSize);
         CCycleTimer timer;
-        aesKeyEncrypt(encryptedMessageMb, messageMb.length(), messageMb.bytes(), aesKey, aesIV);
+        aesEncrypt(encryptedMessageMb, messageMb.length(), messageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
         printf("AES %u MB encrypt time: %u ms\n", dataSz/0x100000, timer.elapsedMs());
         MemoryBuffer decryptedMessageMb;
         decryptedMessageMb.ensureCapacity(encryptedMessageMb.length()+aesBlockSize);
         timer.reset();
-        aesKeyDecrypt(decryptedMessageMb, encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesKey, aesIV);
+        aesDecrypt(decryptedMessageMb, encryptedMessageMb.length(), encryptedMessageMb.bytes(), aesMaxKeySize, aesKey, aesIV);
         printf("AES %u MB decrypt time: %u ms\n", dataSz/0x100000, timer.elapsedMs());
     }