From 1040cf14f586b7e5e469d3bf958cdd3952b93881 Mon Sep 17 00:00:00 2001 From: nathan Date: Fri, 25 Sep 2020 19:54:49 +0200 Subject: [PATCH] Crypto code polish. wrapped crypto stuff in try/catch --- .../network/connection/CryptoManagement.kt | 127 +++++++++--------- src/dorkbox/network/connection/EndPoint.kt | 2 +- 2 files changed, 67 insertions(+), 62 deletions(-) diff --git a/src/dorkbox/network/connection/CryptoManagement.kt b/src/dorkbox/network/connection/CryptoManagement.kt index afbd349c..9018c6b1 100644 --- a/src/dorkbox/network/connection/CryptoManagement.kt +++ b/src/dorkbox/network/connection/CryptoManagement.kt @@ -16,7 +16,6 @@ package dorkbox.network.connection import dorkbox.netUtil.IP -import dorkbox.network.Configuration import dorkbox.network.handshake.ClientConnectionInfo import dorkbox.network.serialization.AeronInput import dorkbox.network.serialization.AeronOutput @@ -30,9 +29,9 @@ import java.net.InetAddress import java.security.KeyFactory import java.security.KeyPairGenerator import java.security.MessageDigest -import java.security.PrivateKey -import java.security.PublicKey import java.security.SecureRandom +import java.security.interfaces.XECPrivateKey +import java.security.interfaces.XECPublicKey import java.security.spec.NamedParameterSpec import java.security.spec.XECPrivateKeySpec import java.security.spec.XECPublicKeySpec @@ -45,34 +44,38 @@ import javax.crypto.spec.SecretKeySpec /** * Management for all of the crypto stuff used */ -internal class CryptoManagement(val logger: KLogger, private val settingsStore: SettingsStore, type: Class<*>, config: Configuration) { +internal class CryptoManagement(val logger: KLogger, + private val settingsStore: SettingsStore, + type: Class<*>, + private val enableRemoteSignatureValidation: Boolean) { private val X25519 = "X25519" private val X25519KeySpec = NamedParameterSpec(X25519) - private val keyFactory = KeyFactory.getInstance(X25519) // key size is 32 bytes + + private val keyFactory = KeyFactory.getInstance(X25519) // key size is 32 bytes (256 bits) private val keyAgreement = KeyAgreement.getInstance("XDH") + private val aesCipher = Cipher.getInstance("AES/GCM/PKCS5Padding") private val hash = MessageDigest.getInstance("SHA-256"); companion object { const val curve25519 = "curve25519" - const val AES_KEY_SIZE = 256 - const val GCM_IV_LENGTH = 12 - const val GCM_TAG_LENGTH = 16 + const val GCM_IV_LENGTH_BYTES = 12 + const val GCM_TAG_LENGTH_BITS = 128 } - val privateKey: PrivateKey - val publicKey: PublicKey + val privateKey: XECPrivateKey + val publicKey: XECPublicKey + + val privateKeyBytes: ByteArray val publicKeyBytes: ByteArray - val secureRandom = SecureRandom(settingsStore.getSalt()) + private val secureRandom = SecureRandom(settingsStore.getSalt()) - private val iv = ByteArray(GCM_IV_LENGTH) + private val iv = ByteArray(GCM_IV_LENGTH_BYTES) val cryptOutput = AeronOutput() val cryptInput = AeronInput() - private val enableRemoteSignatureValidation = config.enableRemoteSignatureValidation - init { if (!enableRemoteSignatureValidation) { logger.warn("WARNING: Disabling remote key validation is a security risk!!") @@ -109,8 +112,10 @@ internal class CryptoManagement(val logger: KLogger, private val settingsStore: logger.info("ECC public key: ${Sys.bytesToHex(publicKeyBytes)}") - this.publicKey = keyFactory.generatePublic(XECPublicKeySpec(X25519KeySpec, BigInteger(publicKeyBytes))) - this.privateKey = keyFactory.generatePrivate(XECPrivateKeySpec(X25519KeySpec, privateKeyBytes)) + this.publicKey = keyFactory.generatePublic(XECPublicKeySpec(X25519KeySpec, BigInteger(publicKeyBytes))) as XECPublicKey + this.privateKey = keyFactory.generatePrivate(XECPrivateKeySpec(X25519KeySpec, privateKeyBytes)) as XECPrivateKey + + this.privateKeyBytes = privateKeyBytes!! this.publicKeyBytes = publicKeyBytes!! } @@ -189,58 +194,58 @@ internal class CryptoManagement(val logger: KLogger, private val settingsStore: connectionStreamId: Int, kryoRegDetails: ByteArray): ByteArray { - val secretKeySpec = generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, publicKeyBytes) - secureRandom.nextBytes(iv) - val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH * 8, iv) - aesCipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, gcmParameterSpec) + try { + val secretKeySpec = generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, publicKeyBytes) + secureRandom.nextBytes(iv) - // now create the byte array that holds all our data - cryptOutput.reset() - cryptOutput.writeInt(connectionSessionId) - cryptOutput.writeInt(connectionStreamId) - cryptOutput.writeInt(publicationPort) - cryptOutput.writeInt(subscriptionPort) - cryptOutput.writeInt(kryoRegDetails.size) - cryptOutput.writeBytes(kryoRegDetails) + val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH_BITS, iv) + aesCipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, gcmParameterSpec) - return iv + aesCipher.doFinal(cryptOutput.toBytes()) + // now create the byte array that holds all our data + cryptOutput.reset() + cryptOutput.writeInt(connectionSessionId) + cryptOutput.writeInt(connectionStreamId) + cryptOutput.writeInt(publicationPort) + cryptOutput.writeInt(subscriptionPort) + cryptOutput.writeInt(kryoRegDetails.size) + cryptOutput.writeBytes(kryoRegDetails) + + return iv + aesCipher.doFinal(cryptOutput.toBytes()) + } catch (e: Exception) { + logger.error("Error during AES encrypt", e) + return ByteArray(0) + } } // NOTE: ALWAYS CALLED ON THE SAME THREAD! (from the client, mutually exclusive calls to encrypt) - fun decrypt(registrationData: ByteArray?, serverPublicKeyBytes: ByteArray?): ClientConnectionInfo? { - if (registrationData == null || serverPublicKeyBytes == null) { + fun decrypt(registrationData: ByteArray, serverPublicKeyBytes: ByteArray): ClientConnectionInfo? { + try { + val secretKeySpec = generateAesKey(serverPublicKeyBytes, publicKeyBytes, serverPublicKeyBytes) + + // now decrypt the data + val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH_BITS, registrationData, 0, GCM_IV_LENGTH_BYTES) + aesCipher.init(Cipher.DECRYPT_MODE, secretKeySpec, gcmParameterSpec) + + cryptInput.buffer = aesCipher.doFinal(registrationData, GCM_IV_LENGTH_BYTES, registrationData.size - GCM_IV_LENGTH_BYTES) + + val sessionId = cryptInput.readInt() + val streamId = cryptInput.readInt() + val publicationPort = cryptInput.readInt() + val subscriptionPort = cryptInput.readInt() + val regDetailsSize = cryptInput.readInt() + val regDetails = cryptInput.readBytes(regDetailsSize) + + // now read data off + return ClientConnectionInfo(sessionId = sessionId, + streamId = streamId, + publicationPort = publicationPort, + subscriptionPort = subscriptionPort, + publicKey = serverPublicKeyBytes, + kryoRegistrationDetails = regDetails) + } catch (e: Exception) { + logger.error("Error during AES decrypt!", e) return null } - - val secretKeySpec = generateAesKey(serverPublicKeyBytes, publicKeyBytes, serverPublicKeyBytes) - - // now read the encrypted data - registrationData.copyInto(destination = iv, endIndex = GCM_IV_LENGTH) - - val secretBytes = ByteArray(registrationData.size - GCM_IV_LENGTH) - registrationData.copyInto(destination = secretBytes, startIndex = GCM_IV_LENGTH) - - - // now decrypt the data - val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH * 8, iv) - aesCipher.init(Cipher.DECRYPT_MODE, secretKeySpec, gcmParameterSpec) - - cryptInput.buffer = aesCipher.doFinal(secretBytes) - - val sessionId = cryptInput.readInt() - val streamId = cryptInput.readInt() - val publicationPort = cryptInput.readInt() - val subscriptionPort = cryptInput.readInt() - val regDetailsSize = cryptInput.readInt() - val regDetails = cryptInput.readBytes(regDetailsSize) - - // now read data off - return ClientConnectionInfo(sessionId = sessionId, - streamId = streamId, - publicationPort = publicationPort, - subscriptionPort = subscriptionPort, - publicKey = serverPublicKeyBytes, - kryoRegistrationDetails = regDetails) } override fun hashCode(): Int { diff --git a/src/dorkbox/network/connection/EndPoint.kt b/src/dorkbox/network/connection/EndPoint.kt index 68a2cbe5..0803fc8a 100644 --- a/src/dorkbox/network/connection/EndPoint.kt +++ b/src/dorkbox/network/connection/EndPoint.kt @@ -130,7 +130,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A // we have to be able to specify the property store settingsStore = createSettingsStore(logger) - crypto = CryptoManagement(logger, settingsStore, type, config) + crypto = CryptoManagement(logger, settingsStore, type, config.enableRemoteSignatureValidation) // Only starts the media driver if we are NOT already running! try {