diff options
author | Peter Dettman <peter.dettman@bouncycastle.org> | 2013-11-07 13:18:01 +0400 |
---|---|---|
committer | Peter Dettman <peter.dettman@bouncycastle.org> | 2013-11-07 13:18:01 +0400 |
commit | af343d1ae8d119fbf7710cb819cb3798c1ed583a (patch) | |
tree | 5b360598f3c7352c4c81574cbe9f6f3227a8a4d6 /core/src/main/java/org/bouncycastle/crypto/tls/DeferredHash.java | |
parent | 71de9789c2d1172c33eb2ee8e362b95c0914d85a (diff) |
Support tracking of extra hash algorithms in DeferredHash
Simple logic to choose b/w buffering vs multi-hash-update strategies
Diffstat (limited to 'core/src/main/java/org/bouncycastle/crypto/tls/DeferredHash.java')
-rw-r--r-- | core/src/main/java/org/bouncycastle/crypto/tls/DeferredHash.java | 108 |
1 files changed, 75 insertions, 33 deletions
diff --git a/core/src/main/java/org/bouncycastle/crypto/tls/DeferredHash.java b/core/src/main/java/org/bouncycastle/crypto/tls/DeferredHash.java index ea43f22c..9736d2e1 100644 --- a/core/src/main/java/org/bouncycastle/crypto/tls/DeferredHash.java +++ b/core/src/main/java/org/bouncycastle/crypto/tls/DeferredHash.java @@ -12,15 +12,19 @@ import org.bouncycastle.util.Shorts; class DeferredHash implements TlsHandshakeHash { + protected static final int BUFFERING_HASH_LIMIT = 4; + protected TlsContext context; private DigestInputBuffer buf; private Hashtable hashes; + private Short prfHashAlgorithm; DeferredHash() { this.buf = new DigestInputBuffer(); this.hashes = new Hashtable(); + this.prfHashAlgorithm = null; } public void init(TlsContext context) @@ -28,63 +32,77 @@ class DeferredHash this.context = context; } - public TlsHandshakeHash commit() + public TlsHandshakeHash notifyPRFDetermined() { - // Ensure the PRF hash algorithm is being tracked + int prfAlgorithm = context.getSecurityParameters().getPrfAlgorithm(); + if (prfAlgorithm == PRFAlgorithm.tls_prf_legacy) { - int prfAlgorithm = context.getSecurityParameters().getPrfAlgorithm(); - if (prfAlgorithm == PRFAlgorithm.tls_prf_legacy) - { - CombinedHash legacyHash = new CombinedHash(); - legacyHash.init(context); - buf.updateDigest(legacyHash); - return legacyHash.commit(); - } - - short prfHashAlgorithm = TlsUtils.getHashAlgorithmForPRFAlgorithm(prfAlgorithm); - if (!this.hashes.containsKey(Shorts.valueOf(prfHashAlgorithm))) - { - Digest prfHash = TlsUtils.createHash(prfHashAlgorithm); - this.hashes.put(Shorts.valueOf(prfHashAlgorithm), prfHash); - } + CombinedHash legacyHash = new CombinedHash(); + legacyHash.init(context); + buf.updateDigest(legacyHash); + return legacyHash.notifyPRFDetermined(); } - Enumeration e = hashes.elements(); - while (e.hasMoreElements()) + this.prfHashAlgorithm = Shorts.valueOf(TlsUtils.getHashAlgorithmForPRFAlgorithm(prfAlgorithm)); + + checkTrackingHash(prfHashAlgorithm); + + return this; + } + + public void trackHashAlgorithm(short hashAlgorithm) + { + if (buf == null) { - Digest hash = (Digest)e.nextElement(); - buf.updateDigest(hash); + throw new IllegalStateException("Too late to track more hash algorithms"); } - this.buf = null; + checkTrackingHash(Shorts.valueOf(hashAlgorithm)); + } - return this; + public void sealHashAlgorithms() + { + checkStopBuffering(); } - public Digest fork() + public void keepHashAlgorithms(short[] hashAlgorithms) { - int prfAlgorithm = context.getSecurityParameters().getPrfAlgorithm(); - if (prfAlgorithm == PRFAlgorithm.tls_prf_legacy) + Hashtable kept = new Hashtable(); + + Enumeration e = hashes.keys(); + while (e.hasMoreElements()) { - throw new IllegalStateException("Legacy PRF shouldn't be calling this"); + Short key = (Short)e.nextElement(); + short hashAlgorithm = key.shortValue(); + + if (hashAlgorithm == prfHashAlgorithm.shortValue() + || TlsProtocol.arrayContains(hashAlgorithms, hashAlgorithm)) + { + kept.put(key, hashes.get(key)); + } } - short hashAlgorithm = TlsUtils.getHashAlgorithmForPRFAlgorithm(prfAlgorithm); + this.hashes = kept; + checkStopBuffering(); + } + + public Digest fork() + { if (buf != null) { - Digest hash = TlsUtils.createHash(hashAlgorithm); + Digest hash = TlsUtils.createHash(prfHashAlgorithm.shortValue()); buf.updateDigest(hash); return hash; } - Digest hash = (Digest)hashes.get(Shorts.valueOf(hashAlgorithm)); - if (hash == null) + Digest prfHash = (Digest)hashes.get(prfHashAlgorithm); + if (prfHash == null) { - throw new IllegalStateException("Digest not registered"); + throw new IllegalStateException("PRF hash algorithm not tracked"); } - return TlsUtils.cloneHash(hashAlgorithm, hash); + return TlsUtils.cloneHash(prfHashAlgorithm.shortValue(), prfHash); } public String getAlgorithmName() @@ -149,4 +167,28 @@ class DeferredHash hash.reset(); } } + + protected void checkStopBuffering() + { + if (buf != null && hashes.size() <= BUFFERING_HASH_LIMIT) + { + Enumeration e = hashes.elements(); + while (e.hasMoreElements()) + { + Digest hash = (Digest)e.nextElement(); + buf.updateDigest(hash); + } + + this.buf = null; + } + } + + protected void checkTrackingHash(Short hashAlgorithm) + { + if (!hashes.containsKey(hashAlgorithm)) + { + Digest hash = TlsUtils.createHash(hashAlgorithm.shortValue()); + hashes.put(hashAlgorithm, hash); + } + } } |