343 lines
13 KiB
Kotlin
343 lines
13 KiB
Kotlin
/*
|
|
* Copyright 2023 dorkbox, llc
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
package dorkbox.network.connection
|
|
|
|
import dorkbox.bytes.Hash
|
|
import dorkbox.hex.toHexString
|
|
import dorkbox.network.handshake.ClientConnectionInfo
|
|
import dorkbox.network.serialization.AeronInput
|
|
import dorkbox.network.serialization.AeronOutput
|
|
import dorkbox.network.serialization.SettingsStore
|
|
import dorkbox.util.entropy.Entropy
|
|
import org.slf4j.Logger
|
|
import java.math.BigInteger
|
|
import java.net.InetAddress
|
|
import java.security.KeyFactory
|
|
import java.security.KeyPairGenerator
|
|
import java.security.SecureRandom
|
|
import java.security.interfaces.XECPrivateKey
|
|
import java.security.interfaces.XECPublicKey
|
|
import java.security.spec.NamedParameterSpec
|
|
import java.security.spec.XECPrivateKeySpec
|
|
import java.security.spec.XECPublicKeySpec
|
|
import javax.crypto.Cipher
|
|
import javax.crypto.KeyAgreement
|
|
import javax.crypto.spec.GCMParameterSpec
|
|
import javax.crypto.spec.SecretKeySpec
|
|
|
|
|
|
/**
|
|
* Management for all the crypto stuff used
|
|
*/
|
|
internal class CryptoManagement(val logger: Logger,
|
|
private val settingsStore: SettingsStore,
|
|
type: Class<*>,
|
|
private val enableRemoteSignatureValidation: Boolean) {
|
|
companion object {
|
|
private val X25519 = "X25519"
|
|
const val curve25519 = "curve25519"
|
|
|
|
const val GCM_IV_LENGTH_BYTES = 12 // 12 bytes for a 96-bit IV
|
|
const val GCM_TAG_LENGTH_BITS = 128
|
|
const val AES_ALGORITHM = "AES/GCM/NoPadding"
|
|
|
|
val NOCRYPT = SecretKeySpec(ByteArray(1), "NOCRYPT")
|
|
val secureRandom = SecureRandom()
|
|
}
|
|
|
|
|
|
private val X25519KeySpec = NamedParameterSpec(X25519)
|
|
|
|
private val keyFactory = KeyFactory.getInstance(X25519) // key size is 32 bytes (256 bits)
|
|
private val keyAgreement = KeyAgreement.getInstance("XDH")
|
|
|
|
private val aesCipher = Cipher.getInstance(AES_ALGORITHM)
|
|
|
|
|
|
|
|
val privateKey: XECPrivateKey
|
|
val publicKey: XECPublicKey
|
|
|
|
// These are both 32 bytes long (256 bits)
|
|
val privateKeyBytes: ByteArray
|
|
val publicKeyBytes: ByteArray
|
|
|
|
private val iv = ByteArray(GCM_IV_LENGTH_BYTES)
|
|
val cryptOutput = AeronOutput()
|
|
val cryptInput = AeronInput()
|
|
|
|
init {
|
|
if (!enableRemoteSignatureValidation) {
|
|
logger.warn("WARNING: Disabling remote key validation is a security risk!!")
|
|
}
|
|
|
|
secureRandom.setSeed(settingsStore.salt)
|
|
|
|
// initialize the private/public keys used for negotiating ECC handshakes
|
|
// these are ONLY used for IP connections. LOCAL connections do not need a handshake!
|
|
val privateKeyBytes: ByteArray
|
|
val publicKeyBytes: ByteArray
|
|
|
|
if (settingsStore.validKeys()) {
|
|
privateKeyBytes = settingsStore.privateKey
|
|
publicKeyBytes = settingsStore.publicKey
|
|
} else {
|
|
try {
|
|
// seed our RNG based off of this and create our ECC keys
|
|
val seedBytes = Entropy["There are no ECC keys for the ${type.simpleName} yet"]
|
|
logger.debug("Now generating ECC ($curve25519) keys. Please wait!")
|
|
|
|
secureRandom.nextBytes(seedBytes)
|
|
|
|
// see: https://stackoverflow.com/questions/60303561/how-do-i-pass-a-44-bytes-x25519-public-key-created-by-openssl-to-cryptokit-which
|
|
val (xdhPublic, xdhPrivate) = createKeyPair(secureRandom)
|
|
|
|
publicKeyBytes = xdhPublic.u.toByteArray()
|
|
privateKeyBytes = xdhPrivate.scalar
|
|
|
|
// save to properties file
|
|
settingsStore.privateKey = privateKeyBytes
|
|
settingsStore.publicKey = publicKeyBytes
|
|
} catch (e: Exception) {
|
|
val message = "Unable to initialize/generate ECC keys. FORCED SHUTDOWN."
|
|
logger.error(message, e)
|
|
throw SecurityException(message, e)
|
|
}
|
|
}
|
|
|
|
logger.info("ECC public key: ${publicKeyBytes.toHexString()}")
|
|
|
|
this.publicKey = keyFactory.generatePublic(XECPublicKeySpec(X25519KeySpec, BigInteger(publicKeyBytes))) as XECPublicKey
|
|
this.privateKey = keyFactory.generatePrivate(XECPrivateKeySpec(X25519KeySpec, privateKeyBytes)) as XECPrivateKey
|
|
|
|
this.privateKeyBytes = privateKeyBytes
|
|
this.publicKeyBytes = publicKeyBytes
|
|
}
|
|
|
|
private fun createKeyPair(secureRandom: SecureRandom): Pair<XECPublicKeySpec, XECPrivateKeySpec> {
|
|
val kpg: KeyPairGenerator = KeyPairGenerator.getInstance("XDH")
|
|
kpg.initialize(NamedParameterSpec.X25519, secureRandom)
|
|
val keyPair = kpg.generateKeyPair()
|
|
|
|
// get the ACTUAL key spec, so we can get the ACTUAL key byte arrays
|
|
val xdhPublic: XECPublicKeySpec = keyFactory.getKeySpec(keyPair.public, XECPublicKeySpec::class.java)
|
|
val xdhPrivate: XECPrivateKeySpec = keyFactory.getKeySpec(keyPair.private, XECPrivateKeySpec::class.java)
|
|
|
|
return Pair(xdhPublic, xdhPrivate)
|
|
}
|
|
|
|
/**
|
|
* If the key does not match AND we have disabled remote key validation -- key validation is REQUIRED!
|
|
*
|
|
* @return true if all is OK (the remote address public key matches the one saved or we disabled remote key validation.)
|
|
* false if we should abort
|
|
*/
|
|
internal fun validateRemoteAddress(remoteAddress: InetAddress, remoteAddressString: String, publicKey: ByteArray?): PublicKeyValidationState {
|
|
if (publicKey == null) {
|
|
logger.error("Error validating public key for ${remoteAddressString}! It was null (and should not have been)")
|
|
return PublicKeyValidationState.INVALID
|
|
}
|
|
|
|
try {
|
|
val savedPublicKey = settingsStore.getRegisteredServerKey(remoteAddress)
|
|
if (savedPublicKey != null) {
|
|
// COMPARE!
|
|
if (!publicKey.contentEquals(savedPublicKey)) {
|
|
return if (enableRemoteSignatureValidation) {
|
|
// keys do not match, abort!
|
|
logger.error("The public key for remote connection $remoteAddressString does not match. Denying connection attempt")
|
|
PublicKeyValidationState.INVALID
|
|
}
|
|
else {
|
|
logger.warn("The public key for remote connection $remoteAddressString does not match. Permitting connection attempt.")
|
|
PublicKeyValidationState.TAMPERED
|
|
}
|
|
}
|
|
}
|
|
} catch (e: SecurityException) {
|
|
// keys do not match, abort!
|
|
logger.error("Error validating public key for $remoteAddressString!", e)
|
|
return PublicKeyValidationState.INVALID
|
|
}
|
|
|
|
return PublicKeyValidationState.VALID
|
|
}
|
|
|
|
private fun makeInfo(serverPublicKeyBytes: ByteArray, secretKey: SecretKeySpec): ClientConnectionInfo {
|
|
val sessionIdPub = cryptInput.readInt()
|
|
val sessionIdSub = cryptInput.readInt()
|
|
val streamIdPub = cryptInput.readInt()
|
|
val streamIdSub = cryptInput.readInt()
|
|
val regDetailsSize = cryptInput.readInt()
|
|
val sessionTimeout = cryptInput.readLong()
|
|
val regDetails = cryptInput.readBytes(regDetailsSize)
|
|
|
|
// now save data off
|
|
return ClientConnectionInfo(
|
|
sessionIdPub = sessionIdPub,
|
|
sessionIdSub = sessionIdSub,
|
|
streamIdPub = streamIdPub,
|
|
streamIdSub = streamIdSub,
|
|
publicKey = serverPublicKeyBytes,
|
|
sessionTimeout = sessionTimeout,
|
|
kryoRegistrationDetails = regDetails,
|
|
secretKey = secretKey)
|
|
}
|
|
|
|
// NOTE: ALWAYS CALLED ON THE SAME THREAD! (from the server, mutually exclusive calls to decrypt)
|
|
fun nocrypt(
|
|
sessionIdPub: Int,
|
|
sessionIdSub: Int,
|
|
streamIdPub: Int,
|
|
streamIdSub: Int,
|
|
sessionTimeout: Long,
|
|
kryoRegDetails: ByteArray
|
|
): ByteArray {
|
|
|
|
return try {
|
|
// now create the byte array that holds all our data
|
|
cryptOutput.reset()
|
|
cryptOutput.writeInt(sessionIdPub)
|
|
cryptOutput.writeInt(sessionIdSub)
|
|
cryptOutput.writeInt(streamIdPub)
|
|
cryptOutput.writeInt(streamIdSub)
|
|
cryptOutput.writeInt(kryoRegDetails.size)
|
|
cryptOutput.writeLong(sessionTimeout)
|
|
cryptOutput.writeBytes(kryoRegDetails)
|
|
|
|
cryptOutput.toBytes()
|
|
} catch (e: Exception) {
|
|
logger.error("Error during AES encrypt", e)
|
|
ByteArray(0)
|
|
}
|
|
}
|
|
|
|
// NOTE: ALWAYS CALLED ON THE SAME THREAD! (from the client, mutually exclusive calls to encrypt)
|
|
fun nocrypt(registrationData: ByteArray, serverPublicKeyBytes: ByteArray): ClientConnectionInfo? {
|
|
return try {
|
|
// The message was intended for this client. Try to parse it as one of the available message types.
|
|
// this message is NOT-ENCRYPTED!
|
|
cryptInput.buffer = registrationData
|
|
|
|
makeInfo(serverPublicKeyBytes, NOCRYPT)
|
|
} catch (e: Exception) {
|
|
logger.error("Error during IPC decrypt!", e)
|
|
null
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Generate the AES key based on ECDH
|
|
*/
|
|
internal fun generateAesKey(remotePublicKeyBytes: ByteArray, bytesA: ByteArray, bytesB: ByteArray): SecretKeySpec {
|
|
val clientPublicKey = keyFactory.generatePublic(XECPublicKeySpec(X25519KeySpec, BigInteger(remotePublicKeyBytes)))
|
|
keyAgreement.init(privateKey)
|
|
keyAgreement.doPhase(clientPublicKey, true)
|
|
val sharedSecret = keyAgreement.generateSecret()
|
|
|
|
// Derive a key from the shared secret and both public keys
|
|
val hash = Hash.sha256
|
|
hash.reset()
|
|
hash.update(sharedSecret)
|
|
hash.update(bytesA)
|
|
hash.update(bytesB)
|
|
|
|
return SecretKeySpec(hash.digest(), "AES")
|
|
}
|
|
|
|
// NOTE: ALWAYS CALLED ON THE SAME THREAD! (from the server, mutually exclusive calls to decrypt)
|
|
fun encrypt(
|
|
cryptoSecretKey: SecretKeySpec,
|
|
sessionIdPub: Int,
|
|
sessionIdSub: Int,
|
|
streamIdPub: Int,
|
|
streamIdSub: Int,
|
|
sessionTimeout: Long,
|
|
kryoRegDetails: ByteArray
|
|
): ByteArray {
|
|
|
|
try {
|
|
secureRandom.nextBytes(iv)
|
|
|
|
val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH_BITS, iv)
|
|
aesCipher.init(Cipher.ENCRYPT_MODE, cryptoSecretKey, gcmParameterSpec)
|
|
|
|
// now create the byte array that holds all our data
|
|
cryptOutput.reset()
|
|
cryptOutput.writeInt(sessionIdPub)
|
|
cryptOutput.writeInt(sessionIdSub)
|
|
cryptOutput.writeInt(streamIdPub)
|
|
cryptOutput.writeInt(streamIdSub)
|
|
cryptOutput.writeInt(kryoRegDetails.size)
|
|
cryptOutput.writeLong(sessionTimeout)
|
|
cryptOutput.writeBytes(kryoRegDetails)
|
|
|
|
return iv + aesCipher.doFinal(cryptOutput.toBytes())
|
|
} catch (e: Exception) {
|
|
logger.error("Error during AES encrypt", e)
|
|
return ByteArray(0)
|
|
}
|
|
}
|
|
|
|
// NOTE: ALWAYS CALLED ON THE SAME THREAD! (from the client, mutually exclusive calls to encrypt)
|
|
fun decrypt(registrationData: ByteArray, serverPublicKeyBytes: ByteArray): ClientConnectionInfo? {
|
|
return try {
|
|
val secretKeySpec = generateAesKey(serverPublicKeyBytes, publicKeyBytes, serverPublicKeyBytes)
|
|
|
|
// now decrypt the data
|
|
val gcmParameterSpec = GCMParameterSpec(GCM_TAG_LENGTH_BITS, registrationData, 0, GCM_IV_LENGTH_BYTES)
|
|
aesCipher.init(Cipher.DECRYPT_MODE, secretKeySpec, gcmParameterSpec)
|
|
|
|
cryptInput.buffer = aesCipher.doFinal(registrationData, GCM_IV_LENGTH_BYTES, registrationData.size - GCM_IV_LENGTH_BYTES)
|
|
|
|
makeInfo(serverPublicKeyBytes, secretKeySpec)
|
|
|
|
} catch (e: Exception) {
|
|
logger.error("Error during AES decrypt!", e)
|
|
null
|
|
}
|
|
}
|
|
|
|
override fun hashCode(): Int {
|
|
val prime = 31
|
|
var result = 1
|
|
result = prime * result + publicKeyBytes.hashCode()
|
|
return result
|
|
}
|
|
|
|
override fun equals(other: Any?): Boolean {
|
|
if (this === other) {
|
|
return true
|
|
}
|
|
if (other == null) {
|
|
return false
|
|
}
|
|
if (javaClass != other.javaClass) {
|
|
return false
|
|
}
|
|
|
|
val other1 = other as CryptoManagement
|
|
if (!privateKey.encoded!!.contentEquals(other1.privateKey.encoded)) {
|
|
return false
|
|
}
|
|
if (!publicKeyBytes.contentEquals(other1.publicKeyBytes)) {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
}
|