diff mbox series

[PULL,14/17] crypto/builtin: Split and simplify AES_encrypt_cbc

Message ID 20200910100623.1088965-15-berrange@redhat.com
State New
Headers show
Series [PULL,01/17] tests: fix output message formatting for crypto benchmarks | expand

Commit Message

Daniel P. Berrangé Sept. 10, 2020, 10:06 a.m. UTC
From: Richard Henderson <richard.henderson@linaro.org>

Split into encrypt/decrypt functions, dropping the "enc" argument.
Now that the function is private to this file, we know that "len"
is a multiple of AES_BLOCK_SIZE.  So drop the odd block size code.

Name the functions do_aes_*crypt_cbc to match the *_ecb functions.
Reorder and re-type the arguments to match as well.

Signed-off-by: Richard Henderson <richard.henderson@linaro.org>
Signed-off-by: Daniel P. Berrangé <berrange@redhat.com>
---
 crypto/cipher-builtin.c.inc | 99 ++++++++++++++++---------------------
 1 file changed, 43 insertions(+), 56 deletions(-)
diff mbox series

Patch

diff --git a/crypto/cipher-builtin.c.inc b/crypto/cipher-builtin.c.inc
index 61baad265a..b1fe3b08c3 100644
--- a/crypto/cipher-builtin.c.inc
+++ b/crypto/cipher-builtin.c.inc
@@ -104,61 +104,50 @@  static void do_aes_decrypt_ecb(const void *vctx,
     }
 }
 
-static void AES_cbc_encrypt(const unsigned char *in, unsigned char *out,
-                            const unsigned long length, const AES_KEY *key,
-                            unsigned char *ivec, const int enc)
+static void do_aes_encrypt_cbc(const AES_KEY *key,
+                               size_t len,
+                               uint8_t *out,
+                               const uint8_t *in,
+                               uint8_t *ivec)
 {
-    unsigned long n;
-    unsigned long len = length;
-    unsigned char tmp[AES_BLOCK_SIZE];
-
-    assert(in && out && key && ivec);
-
-    if (enc) {
-        while (len >= AES_BLOCK_SIZE) {
-            for (n = 0; n < AES_BLOCK_SIZE; ++n) {
-                tmp[n] = in[n] ^ ivec[n];
-            }
-            AES_encrypt(tmp, out, key);
-            memcpy(ivec, out, AES_BLOCK_SIZE);
-            len -= AES_BLOCK_SIZE;
-            in += AES_BLOCK_SIZE;
-            out += AES_BLOCK_SIZE;
-        }
-        if (len) {
-            for (n = 0; n < len; ++n) {
-                tmp[n] = in[n] ^ ivec[n];
-            }
-            for (n = len; n < AES_BLOCK_SIZE; ++n) {
-                tmp[n] = ivec[n];
-            }
-            AES_encrypt(tmp, tmp, key);
-            memcpy(out, tmp, AES_BLOCK_SIZE);
-            memcpy(ivec, tmp, AES_BLOCK_SIZE);
-        }
-    } else {
-        while (len >= AES_BLOCK_SIZE) {
-            memcpy(tmp, in, AES_BLOCK_SIZE);
-            AES_decrypt(in, out, key);
-            for (n = 0; n < AES_BLOCK_SIZE; ++n) {
-                out[n] ^= ivec[n];
-            }
-            memcpy(ivec, tmp, AES_BLOCK_SIZE);
-            len -= AES_BLOCK_SIZE;
-            in += AES_BLOCK_SIZE;
-            out += AES_BLOCK_SIZE;
-        }
-        if (len) {
-            memcpy(tmp, in, AES_BLOCK_SIZE);
-            AES_decrypt(tmp, tmp, key);
-            for (n = 0; n < len; ++n) {
-                out[n] = tmp[n] ^ ivec[n];
-            }
-            memcpy(ivec, tmp, AES_BLOCK_SIZE);
+    uint8_t tmp[AES_BLOCK_SIZE];
+    size_t n;
+
+    /* We have already verified that len % AES_BLOCK_SIZE == 0. */
+    while (len) {
+        for (n = 0; n < AES_BLOCK_SIZE; ++n) {
+            tmp[n] = in[n] ^ ivec[n];
         }
+        AES_encrypt(tmp, out, key);
+        memcpy(ivec, out, AES_BLOCK_SIZE);
+        len -= AES_BLOCK_SIZE;
+        in += AES_BLOCK_SIZE;
+        out += AES_BLOCK_SIZE;
     }
 }
 
+static void do_aes_decrypt_cbc(const AES_KEY *key,
+                               size_t len,
+                               uint8_t *out,
+                               const uint8_t *in,
+                               uint8_t *ivec)
+{
+    uint8_t tmp[AES_BLOCK_SIZE];
+    size_t n;
+
+    /* We have already verified that len % AES_BLOCK_SIZE == 0. */
+    while (len) {
+        memcpy(tmp, in, AES_BLOCK_SIZE);
+        AES_decrypt(in, out, key);
+        for (n = 0; n < AES_BLOCK_SIZE; ++n) {
+            out[n] ^= ivec[n];
+        }
+        memcpy(ivec, tmp, AES_BLOCK_SIZE);
+        len -= AES_BLOCK_SIZE;
+        in += AES_BLOCK_SIZE;
+        out += AES_BLOCK_SIZE;
+    }
+}
 
 static int qcrypto_cipher_encrypt_aes(QCryptoCipher *cipher,
                                       const void *in,
@@ -174,9 +163,8 @@  static int qcrypto_cipher_encrypt_aes(QCryptoCipher *cipher,
         do_aes_encrypt_ecb(&ctxt->state.aes.key, len, out, in);
         break;
     case QCRYPTO_CIPHER_MODE_CBC:
-        AES_cbc_encrypt(in, out, len,
-                        &ctxt->state.aes.key.enc,
-                        ctxt->state.aes.iv, 1);
+        do_aes_encrypt_cbc(&ctxt->state.aes.key.enc, len, out, in,
+                           ctxt->state.aes.iv);
         break;
     case QCRYPTO_CIPHER_MODE_XTS:
         xts_encrypt(&ctxt->state.aes.key,
@@ -208,9 +196,8 @@  static int qcrypto_cipher_decrypt_aes(QCryptoCipher *cipher,
         do_aes_decrypt_ecb(&ctxt->state.aes.key, len, out, in);
         break;
     case QCRYPTO_CIPHER_MODE_CBC:
-        AES_cbc_encrypt(in, out, len,
-                        &ctxt->state.aes.key.dec,
-                        ctxt->state.aes.iv, 0);
+        do_aes_decrypt_cbc(&ctxt->state.aes.key.dec, len, out, in,
+                           ctxt->state.aes.iv);
         break;
     case QCRYPTO_CIPHER_MODE_XTS:
         xts_decrypt(&ctxt->state.aes.key,