diff options
author | David Hook <dgh@cryptoworkshop.com> | 2013-05-31 11:07:45 +0400 |
---|---|---|
committer | David Hook <dgh@cryptoworkshop.com> | 2013-05-31 11:07:45 +0400 |
commit | 2b976f5364cfdbc37d3086019d93483c983eb80b (patch) | |
tree | cb846af3fd1d43f9c2562a1fb2d06b997ad8f229 /core/src/main/java/org/bouncycastle/crypto/tls/DTLSRecordLayer.java | |
parent | 5f714bd92fbd780d22406f4bc3681be005f6f04a (diff) |
initial reshuffle
Diffstat (limited to 'core/src/main/java/org/bouncycastle/crypto/tls/DTLSRecordLayer.java')
-rw-r--r-- | core/src/main/java/org/bouncycastle/crypto/tls/DTLSRecordLayer.java | 497 |
1 files changed, 497 insertions, 0 deletions
diff --git a/core/src/main/java/org/bouncycastle/crypto/tls/DTLSRecordLayer.java b/core/src/main/java/org/bouncycastle/crypto/tls/DTLSRecordLayer.java new file mode 100644 index 00000000..3fde01aa --- /dev/null +++ b/core/src/main/java/org/bouncycastle/crypto/tls/DTLSRecordLayer.java @@ -0,0 +1,497 @@ +package org.bouncycastle.crypto.tls; + +import java.io.IOException; + +class DTLSRecordLayer + implements DatagramTransport +{ + + private static final int RECORD_HEADER_LENGTH = 13; + private static final int MAX_FRAGMENT_LENGTH = 1 << 14; + private static final long TCP_MSL = 1000L * 60 * 2; + private static final long RETRANSMIT_TIMEOUT = TCP_MSL * 2; + + private final DatagramTransport transport; + private final TlsContext context; + private final TlsPeer peer; + + private final ByteQueue recordQueue = new ByteQueue(); + + private volatile boolean closed = false; + private volatile boolean failed = false; + private volatile ProtocolVersion discoveredPeerVersion = null; + private volatile boolean inHandshake; + private DTLSEpoch currentEpoch, pendingEpoch; + private DTLSEpoch readEpoch, writeEpoch; + + private DTLSHandshakeRetransmit retransmit = null; + private DTLSEpoch retransmitEpoch = null; + private long retransmitExpiry = 0; + + DTLSRecordLayer(DatagramTransport transport, TlsContext context, TlsPeer peer, short contentType) + { + this.transport = transport; + this.context = context; + this.peer = peer; + + this.inHandshake = true; + + this.currentEpoch = new DTLSEpoch(0, new TlsNullCipher(context)); + this.pendingEpoch = null; + this.readEpoch = currentEpoch; + this.writeEpoch = currentEpoch; + } + + ProtocolVersion getDiscoveredPeerVersion() + { + return discoveredPeerVersion; + } + + void initPendingEpoch(TlsCipher pendingCipher) + { + if (pendingEpoch != null) + { + throw new IllegalStateException(); + } + + /* + * TODO "In order to ensure that any given sequence/epoch pair is unique, implementations + * MUST NOT allow the same epoch value to be reused within two times the TCP maximum segment + * lifetime." + */ + + // TODO Check for overflow + this.pendingEpoch = new DTLSEpoch(writeEpoch.getEpoch() + 1, pendingCipher); + } + + void handshakeSuccessful(DTLSHandshakeRetransmit retransmit) + { + if (readEpoch == currentEpoch || writeEpoch == currentEpoch) + { + // TODO + throw new IllegalStateException(); + } + + if (retransmit != null) + { + this.retransmit = retransmit; + this.retransmitEpoch = currentEpoch; + this.retransmitExpiry = System.currentTimeMillis() + RETRANSMIT_TIMEOUT; + } + + this.inHandshake = false; + this.currentEpoch = pendingEpoch; + this.pendingEpoch = null; + } + + void resetWriteEpoch() + { + if (retransmitEpoch != null) + { + this.writeEpoch = retransmitEpoch; + } + else + { + this.writeEpoch = currentEpoch; + } + } + + public int getReceiveLimit() + throws IOException + { + return Math.min(MAX_FRAGMENT_LENGTH, + readEpoch.getCipher().getPlaintextLimit(transport.getReceiveLimit() - RECORD_HEADER_LENGTH)); + } + + public int getSendLimit() + throws IOException + { + return Math.min(MAX_FRAGMENT_LENGTH, + writeEpoch.getCipher().getPlaintextLimit(transport.getSendLimit() - RECORD_HEADER_LENGTH)); + } + + public int receive(byte[] buf, int off, int len, int waitMillis) + throws IOException + { + + byte[] record = null; + + for (; ; ) + { + + int receiveLimit = Math.min(len, getReceiveLimit()) + RECORD_HEADER_LENGTH; + if (record == null || record.length < receiveLimit) + { + record = new byte[receiveLimit]; + } + + try + { + if (retransmit != null && System.currentTimeMillis() > retransmitExpiry) + { + retransmit = null; + retransmitEpoch = null; + } + + int received = receiveRecord(record, 0, receiveLimit, waitMillis); + if (received < 0) + { + return received; + } + if (received < RECORD_HEADER_LENGTH) + { + continue; + } + int length = TlsUtils.readUint16(record, 11); + if (received != (length + RECORD_HEADER_LENGTH)) + { + continue; + } + + short type = TlsUtils.readUint8(record, 0); + + // TODO Support user-specified custom protocols? + switch (type) + { + case ContentType.alert: + case ContentType.application_data: + case ContentType.change_cipher_spec: + case ContentType.handshake: + break; + default: + // TODO Exception? + continue; + } + + int epoch = TlsUtils.readUint16(record, 3); + + DTLSEpoch recordEpoch = null; + if (epoch == readEpoch.getEpoch()) + { + recordEpoch = readEpoch; + } + else if (type == ContentType.handshake && retransmitEpoch != null + && epoch == retransmitEpoch.getEpoch()) + { + recordEpoch = retransmitEpoch; + } + + if (recordEpoch == null) + { + continue; + } + + long seq = TlsUtils.readUint48(record, 5); + if (recordEpoch.getReplayWindow().shouldDiscard(seq)) + { + continue; + } + + ProtocolVersion version = TlsUtils.readVersion(record, 1); + if (discoveredPeerVersion != null && !discoveredPeerVersion.equals(version)) + { + continue; + } + + byte[] plaintext = recordEpoch.getCipher().decodeCiphertext( + getMacSequenceNumber(recordEpoch.getEpoch(), seq), type, record, RECORD_HEADER_LENGTH, + received - RECORD_HEADER_LENGTH); + + recordEpoch.getReplayWindow().reportAuthenticated(seq); + + if (discoveredPeerVersion == null) + { + discoveredPeerVersion = version; + } + + switch (type) + { + case ContentType.alert: + { + + if (plaintext.length == 2) + { + short alertLevel = plaintext[0]; + short alertDescription = plaintext[1]; + + peer.notifyAlertReceived(alertLevel, alertDescription); + + if (alertLevel == AlertLevel.fatal) + { + fail(alertDescription); + throw new TlsFatalAlert(alertDescription); + } + + // TODO Can close_notify be a fatal alert? + if (alertDescription == AlertDescription.close_notify) + { + closeTransport(); + } + } + else + { + // TODO What exception? + } + + continue; + } + case ContentType.application_data: + { + if (inHandshake) + { + // TODO Consider buffering application data for new epoch that arrives + // out-of-order with the Finished message + continue; + } + break; + } + case ContentType.change_cipher_spec: + { + // Implicitly receive change_cipher_spec and change to pending cipher state + + if (plaintext.length != 1 || plaintext[0] != 1) + { + continue; + } + + if (pendingEpoch != null) + { + readEpoch = pendingEpoch; + } + + continue; + } + case ContentType.handshake: + { + if (!inHandshake) + { + if (retransmit != null) + { + retransmit.receivedHandshakeRecord(epoch, plaintext, 0, plaintext.length); + } + + // TODO Consider support for HelloRequest + continue; + } + } + } + + /* + * NOTE: If we receive any non-handshake data in the new epoch implies the peer has + * received our final flight. + */ + if (!inHandshake && retransmit != null) + { + this.retransmit = null; + this.retransmitEpoch = null; + } + + System.arraycopy(plaintext, 0, buf, off, plaintext.length); + return plaintext.length; + } + catch (IOException e) + { + // NOTE: Assume this is a timeout for the moment + throw e; + } + } + } + + public void send(byte[] buf, int off, int len) + throws IOException + { + + short contentType = ContentType.application_data; + + if (this.inHandshake || this.writeEpoch == this.retransmitEpoch) + { + + contentType = ContentType.handshake; + + short handshakeType = TlsUtils.readUint8(buf, off); + if (handshakeType == HandshakeType.finished) + { + + DTLSEpoch nextEpoch = null; + if (this.inHandshake) + { + nextEpoch = pendingEpoch; + } + else if (this.writeEpoch == this.retransmitEpoch) + { + nextEpoch = currentEpoch; + } + + if (nextEpoch == null) + { + // TODO + throw new IllegalStateException(); + } + + // Implicitly send change_cipher_spec and change to pending cipher state + + // TODO Send change_cipher_spec and finished records in single datagram? + byte[] data = new byte[]{1}; + sendRecord(ContentType.change_cipher_spec, data, 0, data.length); + + writeEpoch = nextEpoch; + } + } + + sendRecord(contentType, buf, off, len); + } + + public void close() + throws IOException + { + if (!closed) + { + if (inHandshake) + { + warn(AlertDescription.user_canceled, "User canceled handshake"); + } + closeTransport(); + } + } + + void fail(short alertDescription) + { + if (!closed) + { + try + { + raiseAlert(AlertLevel.fatal, alertDescription, null, null); + } + catch (Exception e) + { + // Ignore + } + + failed = true; + + closeTransport(); + } + } + + void warn(short alertDescription, String message) + throws IOException + { + raiseAlert(AlertLevel.warning, alertDescription, message, null); + } + + private void closeTransport() + { + if (!closed) + { + /* + * RFC 5246 7.2.1. Unless some other fatal alert has been transmitted, each party is + * required to send a close_notify alert before closing the write side of the + * connection. The other party MUST respond with a close_notify alert of its own and + * close down the connection immediately, discarding any pending writes. + */ + + try + { + if (!failed) + { + warn(AlertDescription.close_notify, null); + } + transport.close(); + } + catch (Exception e) + { + // Ignore + } + + closed = true; + } + } + + private void raiseAlert(short alertLevel, short alertDescription, String message, Exception cause) + throws IOException + { + + peer.notifyAlertRaised(alertLevel, alertDescription, message, cause); + + byte[] error = new byte[2]; + error[0] = (byte)alertLevel; + error[1] = (byte)alertDescription; + + sendRecord(ContentType.alert, error, 0, 2); + } + + private int receiveRecord(byte[] buf, int off, int len, int waitMillis) + throws IOException + { + if (recordQueue.size() > 0) + { + int length = 0; + if (recordQueue.size() >= RECORD_HEADER_LENGTH) + { + byte[] lengthBytes = new byte[2]; + recordQueue.read(lengthBytes, 0, 2, 11); + length = TlsUtils.readUint16(lengthBytes, 0); + } + + int received = Math.min(recordQueue.size(), RECORD_HEADER_LENGTH + length); + recordQueue.read(buf, off, received, 0); + recordQueue.removeData(received); + return received; + } + + int received = transport.receive(buf, off, len, waitMillis); + if (received >= RECORD_HEADER_LENGTH) + { + int fragmentLength = TlsUtils.readUint16(buf, off + 11); + int recordLength = RECORD_HEADER_LENGTH + fragmentLength; + if (received > recordLength) + { + recordQueue.addData(buf, off + recordLength, received - recordLength); + received = recordLength; + } + } + + return received; + } + + private void sendRecord(short contentType, byte[] buf, int off, int len) + throws IOException + { + + /* + * RFC 5264 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert, + * or ChangeCipherSpec content types. + */ + if (len < 1 && contentType != ContentType.application_data) + { + throw new TlsFatalAlert(AlertDescription.internal_error); + } + + int recordEpoch = writeEpoch.getEpoch(); + long recordSequenceNumber = writeEpoch.allocateSequenceNumber(); + + byte[] ciphertext = writeEpoch.getCipher().encodePlaintext( + getMacSequenceNumber(recordEpoch, recordSequenceNumber), contentType, buf, off, len); + + if (ciphertext.length > MAX_FRAGMENT_LENGTH) + { + throw new TlsFatalAlert(AlertDescription.internal_error); + } + + byte[] record = new byte[ciphertext.length + RECORD_HEADER_LENGTH]; + TlsUtils.writeUint8(contentType, record, 0); + ProtocolVersion version = discoveredPeerVersion != null ? discoveredPeerVersion : context.getClientVersion(); + TlsUtils.writeVersion(version, record, 1); + TlsUtils.writeUint16(recordEpoch, record, 3); + TlsUtils.writeUint48(recordSequenceNumber, record, 5); + TlsUtils.writeUint16(ciphertext.length, record, 11); + System.arraycopy(ciphertext, 0, record, RECORD_HEADER_LENGTH, ciphertext.length); + + transport.send(record, 0, record.length); + } + + private static long getMacSequenceNumber(int epoch, long sequence_number) + { + return ((long)epoch << 48) | sequence_number; + } +} |