src/share/classes/com/sun/crypto/provider/GHASH.java

Print this page
rev 10474 : com.sun.crypto.provider.GHASH performance fix

@@ -26,13 +26,11 @@
  * (C) Copyright IBM Corp. 2013
  */
 
 package com.sun.crypto.provider;
 
-import java.util.Arrays;
-import java.security.*;
-import static com.sun.crypto.provider.AESConstants.AES_BLOCK_SIZE;
+import java.security.ProviderException;
 
 /**
  * This class represents the GHASH function defined in NIST 800-38D
  * under section 6.4. It needs to be constructed w/ a hash subkey, i.e.
  * block H. Given input of 128-bit blocks, it will process and output

@@ -42,66 +40,94 @@
  *
  * @since 1.8
  */
 final class GHASH {
 
-    private static final byte P128 = (byte) 0xe1; //reduction polynomial
+    private static long getLong(byte[] buffer, int offset) {
+        long result = 0;
+        int end = offset + 8;
+        for (int i = offset; i < end; ++i) {
+            result = (result << 8) + (buffer[i] & 0xFF);
+        }
+        return result;
+    }
 
-    private static boolean getBit(byte[] b, int pos) {
-        int p = pos / 8;
-        pos %= 8;
-        int i = (b[p] >>> (7 - pos)) & 1;
-        return i != 0;
-    }
-
-    private static void shift(byte[] b) {
-        byte temp, temp2;
-        temp2 = 0;
-        for (int i = 0; i < b.length; i++) {
-            temp = (byte) ((b[i] & 0x01) << 7);
-            b[i] = (byte) ((b[i] & 0xff) >>> 1);
-            b[i] = (byte) (b[i] | temp2);
-            temp2 = temp;
+    private static void putLong(byte[] buffer, int offset, long value) {
+        int end = offset + 8;
+        for (int i = end - 1; i >= offset; --i) {
+            buffer[i] = (byte) value;
+            value >>= 8;
         }
     }
 
-    // Given block X and Y, returns the muliplication of X * Y
-    private static byte[] blockMult(byte[] x, byte[] y) {
-        if (x.length != AES_BLOCK_SIZE || y.length != AES_BLOCK_SIZE) {
-            throw new RuntimeException("illegal input sizes");
-        }
-        byte[] z = new byte[AES_BLOCK_SIZE];
-        byte[] v = y.clone();
-        // calculate Z1-Z127 and V1-V127
-        for (int i = 0; i < 127; i++) {
+    private static final int AES_BLOCK_SIZE = 16;
+    
+    // Multiplies state0, state1 by V0, V1.
+    private void blockMult(long V0, long V1) {
+        long Z0 = 0;
+        long Z1 = 0;
+        long X; 
+        
+        // Separate loops for processing state0 and state1.
+        X = state0;
+        for (int i = 0; i < 64; i++) {
             // Zi+1 = Zi if bit i of x is 0
-            if (getBit(x, i)) {
-                for (int n = 0; n < z.length; n++) {
-                    z[n] ^= v[n];
-                }
+            long mask = X >> 63;
+            Z0 ^= V0 & mask;
+            Z1 ^= V1 & mask;
+            
+            // Save mask for conditional reduction below.
+            mask = (V1 << 63) >> 63;
+            
+            // V = rightshift(V)
+            long carry = V0 & 1;
+            V0 = V0 >>> 1;
+            V1 = (V1 >>> 1) | (carry << 63);
+
+            // Conditional reduction modulo P128. 
+            V0 ^= 0xe100000000000000L & mask;
+            X <<= 1;
             }
-            boolean lastBitOfV = getBit(v, 127);
-            shift(v);
-            if (lastBitOfV) v[0] ^= P128;
+        
+        X = state1;
+        for (int i = 64; i < 127; i++) {
+            // Zi+1 = Zi if bit i of x is 0
+            long mask = X >> 63;
+            Z0 ^= V0 & mask;
+            Z1 ^= V1 & mask;
+            
+            // Save mask for conditional reduction below.
+            mask = (V1 << 63) >> 63;
+            
+            // V = rightshift(V)
+            long carry = V0 & 1;
+            V0 = V0 >>> 1;
+            V1 = (V1 >>> 1) | (carry << 63);
+
+            // Conditional reduction.
+            V0 ^= 0xe100000000000000L & mask;
+            X <<= 1;
         }
+        
         // calculate Z128
-        if (getBit(x, 127)) {
-            for (int n = 0; n < z.length; n++) {
-                z[n] ^= v[n];
-            }
-        }
-        return z;
+        long mask = X >> 63;
+        Z0 ^= V0 & mask;
+        Z1 ^= V1 & mask;
+        
+        // Save result.
+        state0 = Z0;
+        state1 = Z1;
     }
 
     // hash subkey H; should not change after the object has been constructed
-    private final byte[] subkeyH;
+    private final long subkeyH0, subkeyH1;
 
     // buffer for storing hash
-    private byte[] state;
+    private long state0, state1;
 
     // variables for save/restore calls
-    private byte[] stateSave = null;
+    private long stateSave0, stateSave1;
 
     /**
      * Initializes the cipher in the specified mode with the given key
      * and iv.
      *

@@ -112,45 +138,47 @@
      */
     GHASH(byte[] subkeyH) throws ProviderException {
         if ((subkeyH == null) || subkeyH.length != AES_BLOCK_SIZE) {
             throw new ProviderException("Internal error");
         }
-        this.subkeyH = subkeyH;
-        this.state = new byte[AES_BLOCK_SIZE];
+        this.subkeyH0 = getLong(subkeyH, 0);
+        this.subkeyH1 = getLong(subkeyH, 8);
     }
 
     /**
      * Resets the GHASH object to its original state, i.e. blank w/
      * the same subkey H. Used after digest() is called and to re-use
      * this object for different data w/ the same H.
      */
     void reset() {
-        Arrays.fill(state, (byte) 0);
+        state0 = 0;
+        state1 = 0;
     }
 
     /**
      * Save the current snapshot of this GHASH object.
      */
     void save() {
-        stateSave = state.clone();
+        stateSave0 = state0;
+        stateSave1 = state1;
     }
 
     /**
      * Restores this object using the saved snapshot.
      */
     void restore() {
-        state = stateSave;
+        state0 = stateSave0;
+        state1 = stateSave1;
     }
 
     private void processBlock(byte[] data, int ofs) {
         if (data.length - ofs < AES_BLOCK_SIZE) {
             throw new RuntimeException("need complete block");
         }
-        for (int n = 0; n < state.length; n++) {
-            state[n] ^= data[ofs + n];
-        }
-        state = blockMult(state, subkeyH);
+        state0 ^= getLong(data, ofs);
+        state1 ^= getLong(data, ofs + 8);
+        blockMult(subkeyH0, subkeyH1);
     }
 
     void update(byte[] in) {
         update(in, 0, in.length);
     }

@@ -167,12 +195,12 @@
             processBlock(in, i);
         }
     }
 
     byte[] digest() {
-        try {
-            return state.clone();
-        } finally {
+        byte[] result = new byte[AES_BLOCK_SIZE];
+        putLong(result, 0, state0);
+        putLong(result, 8, state1);
             reset();
-        }
+        return result;
     }
 }