Crypto code polish. wrapped crypto stuff in try/catch

This commit is contained in:
nathan 2020-09-25 19:54:49 +02:00
parent 2b2e185c4a
commit 1040cf14f5
2 changed files with 67 additions and 62 deletions

View File

@ -16,7 +16,6 @@
package dorkbox.network.connection package dorkbox.network.connection
import dorkbox.netUtil.IP import dorkbox.netUtil.IP
import dorkbox.network.Configuration
import dorkbox.network.handshake.ClientConnectionInfo import dorkbox.network.handshake.ClientConnectionInfo
import dorkbox.network.serialization.AeronInput import dorkbox.network.serialization.AeronInput
import dorkbox.network.serialization.AeronOutput import dorkbox.network.serialization.AeronOutput
@ -30,9 +29,9 @@ import java.net.InetAddress
import java.security.KeyFactory import java.security.KeyFactory
import java.security.KeyPairGenerator import java.security.KeyPairGenerator
import java.security.MessageDigest import java.security.MessageDigest
import java.security.PrivateKey
import java.security.PublicKey
import java.security.SecureRandom import java.security.SecureRandom
import java.security.interfaces.XECPrivateKey
import java.security.interfaces.XECPublicKey
import java.security.spec.NamedParameterSpec import java.security.spec.NamedParameterSpec
import java.security.spec.XECPrivateKeySpec import java.security.spec.XECPrivateKeySpec
import java.security.spec.XECPublicKeySpec import java.security.spec.XECPublicKeySpec
@ -45,34 +44,38 @@ import javax.crypto.spec.SecretKeySpec
/** /**
* Management for all of the crypto stuff used * Management for all of the crypto stuff used
*/ */
internal class CryptoManagement(val logger: KLogger, private val settingsStore: SettingsStore, type: Class<*>, config: Configuration) { internal class CryptoManagement(val logger: KLogger,
private val settingsStore: SettingsStore,
type: Class<*>,
private val enableRemoteSignatureValidation: Boolean) {
private val X25519 = "X25519" private val X25519 = "X25519"
private val X25519KeySpec = NamedParameterSpec(X25519) private val X25519KeySpec = NamedParameterSpec(X25519)
private val keyFactory = KeyFactory.getInstance(X25519) // key size is 32 bytes
private val keyFactory = KeyFactory.getInstance(X25519) // key size is 32 bytes (256 bits)
private val keyAgreement = KeyAgreement.getInstance("XDH") private val keyAgreement = KeyAgreement.getInstance("XDH")
private val aesCipher = Cipher.getInstance("AES/GCM/PKCS5Padding") private val aesCipher = Cipher.getInstance("AES/GCM/PKCS5Padding")
private val hash = MessageDigest.getInstance("SHA-256"); private val hash = MessageDigest.getInstance("SHA-256");
companion object { companion object {
const val curve25519 = "curve25519" const val curve25519 = "curve25519"
const val AES_KEY_SIZE = 256 const val GCM_IV_LENGTH_BYTES = 12
const val GCM_IV_LENGTH = 12 const val GCM_TAG_LENGTH_BITS = 128
const val GCM_TAG_LENGTH = 16
} }
val privateKey: PrivateKey val privateKey: XECPrivateKey
val publicKey: PublicKey val publicKey: XECPublicKey
val privateKeyBytes: ByteArray
val publicKeyBytes: ByteArray val publicKeyBytes: ByteArray
val secureRandom = SecureRandom(settingsStore.getSalt()) private val secureRandom = SecureRandom(settingsStore.getSalt())
private val iv = ByteArray(GCM_IV_LENGTH) private val iv = ByteArray(GCM_IV_LENGTH_BYTES)
val cryptOutput = AeronOutput() val cryptOutput = AeronOutput()
val cryptInput = AeronInput() val cryptInput = AeronInput()
private val enableRemoteSignatureValidation = config.enableRemoteSignatureValidation
init { init {
if (!enableRemoteSignatureValidation) { if (!enableRemoteSignatureValidation) {
logger.warn("WARNING: Disabling remote key validation is a security risk!!") logger.warn("WARNING: Disabling remote key validation is a security risk!!")
@ -109,8 +112,10 @@ internal class CryptoManagement(val logger: KLogger, private val settingsStore:
logger.info("ECC public key: ${Sys.bytesToHex(publicKeyBytes)}") logger.info("ECC public key: ${Sys.bytesToHex(publicKeyBytes)}")
this.publicKey = keyFactory.generatePublic(XECPublicKeySpec(X25519KeySpec, BigInteger(publicKeyBytes))) this.publicKey = keyFactory.generatePublic(XECPublicKeySpec(X25519KeySpec, BigInteger(publicKeyBytes))) as XECPublicKey
this.privateKey = keyFactory.generatePrivate(XECPrivateKeySpec(X25519KeySpec, privateKeyBytes)) this.privateKey = keyFactory.generatePrivate(XECPrivateKeySpec(X25519KeySpec, privateKeyBytes)) as XECPrivateKey
this.privateKeyBytes = privateKeyBytes!!
this.publicKeyBytes = publicKeyBytes!! this.publicKeyBytes = publicKeyBytes!!
} }
@ -189,58 +194,58 @@ internal class CryptoManagement(val logger: KLogger, private val settingsStore:
connectionStreamId: Int, connectionStreamId: Int,
kryoRegDetails: ByteArray): ByteArray { kryoRegDetails: ByteArray): ByteArray {
val secretKeySpec = generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, publicKeyBytes) try {
secureRandom.nextBytes(iv) val secretKeySpec = generateAesKey(clientPublicKeyBytes, clientPublicKeyBytes, publicKeyBytes)
val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH * 8, iv) secureRandom.nextBytes(iv)
aesCipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, gcmParameterSpec)
// now create the byte array that holds all our data val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH_BITS, iv)
cryptOutput.reset() aesCipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, gcmParameterSpec)
cryptOutput.writeInt(connectionSessionId)
cryptOutput.writeInt(connectionStreamId)
cryptOutput.writeInt(publicationPort)
cryptOutput.writeInt(subscriptionPort)
cryptOutput.writeInt(kryoRegDetails.size)
cryptOutput.writeBytes(kryoRegDetails)
return iv + aesCipher.doFinal(cryptOutput.toBytes()) // now create the byte array that holds all our data
cryptOutput.reset()
cryptOutput.writeInt(connectionSessionId)
cryptOutput.writeInt(connectionStreamId)
cryptOutput.writeInt(publicationPort)
cryptOutput.writeInt(subscriptionPort)
cryptOutput.writeInt(kryoRegDetails.size)
cryptOutput.writeBytes(kryoRegDetails)
return iv + aesCipher.doFinal(cryptOutput.toBytes())
} catch (e: Exception) {
logger.error("Error during AES encrypt", e)
return ByteArray(0)
}
} }
// NOTE: ALWAYS CALLED ON THE SAME THREAD! (from the client, mutually exclusive calls to encrypt) // NOTE: ALWAYS CALLED ON THE SAME THREAD! (from the client, mutually exclusive calls to encrypt)
fun decrypt(registrationData: ByteArray?, serverPublicKeyBytes: ByteArray?): ClientConnectionInfo? { fun decrypt(registrationData: ByteArray, serverPublicKeyBytes: ByteArray): ClientConnectionInfo? {
if (registrationData == null || serverPublicKeyBytes == null) { try {
val secretKeySpec = generateAesKey(serverPublicKeyBytes, publicKeyBytes, serverPublicKeyBytes)
// now decrypt the data
val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH_BITS, registrationData, 0, GCM_IV_LENGTH_BYTES)
aesCipher.init(Cipher.DECRYPT_MODE, secretKeySpec, gcmParameterSpec)
cryptInput.buffer = aesCipher.doFinal(registrationData, GCM_IV_LENGTH_BYTES, registrationData.size - GCM_IV_LENGTH_BYTES)
val sessionId = cryptInput.readInt()
val streamId = cryptInput.readInt()
val publicationPort = cryptInput.readInt()
val subscriptionPort = cryptInput.readInt()
val regDetailsSize = cryptInput.readInt()
val regDetails = cryptInput.readBytes(regDetailsSize)
// now read data off
return ClientConnectionInfo(sessionId = sessionId,
streamId = streamId,
publicationPort = publicationPort,
subscriptionPort = subscriptionPort,
publicKey = serverPublicKeyBytes,
kryoRegistrationDetails = regDetails)
} catch (e: Exception) {
logger.error("Error during AES decrypt!", e)
return null return null
} }
val secretKeySpec = generateAesKey(serverPublicKeyBytes, publicKeyBytes, serverPublicKeyBytes)
// now read the encrypted data
registrationData.copyInto(destination = iv, endIndex = GCM_IV_LENGTH)
val secretBytes = ByteArray(registrationData.size - GCM_IV_LENGTH)
registrationData.copyInto(destination = secretBytes, startIndex = GCM_IV_LENGTH)
// now decrypt the data
val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH * 8, iv)
aesCipher.init(Cipher.DECRYPT_MODE, secretKeySpec, gcmParameterSpec)
cryptInput.buffer = aesCipher.doFinal(secretBytes)
val sessionId = cryptInput.readInt()
val streamId = cryptInput.readInt()
val publicationPort = cryptInput.readInt()
val subscriptionPort = cryptInput.readInt()
val regDetailsSize = cryptInput.readInt()
val regDetails = cryptInput.readBytes(regDetailsSize)
// now read data off
return ClientConnectionInfo(sessionId = sessionId,
streamId = streamId,
publicationPort = publicationPort,
subscriptionPort = subscriptionPort,
publicKey = serverPublicKeyBytes,
kryoRegistrationDetails = regDetails)
} }
override fun hashCode(): Int { override fun hashCode(): Int {

View File

@ -130,7 +130,7 @@ internal constructor(val type: Class<*>, internal val config: Configuration) : A
// we have to be able to specify the property store // we have to be able to specify the property store
settingsStore = createSettingsStore(logger) settingsStore = createSettingsStore(logger)
crypto = CryptoManagement(logger, settingsStore, type, config) crypto = CryptoManagement(logger, settingsStore, type, config.enableRemoteSignatureValidation)
// Only starts the media driver if we are NOT already running! // Only starts the media driver if we are NOT already running!
try { try {