310 lines
7.9 KiB
Kotlin
310 lines
7.9 KiB
Kotlin
/*
|
|
* Copyright 2023 dorkbox, llc
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
package dorkbox.dns.dns.records
|
|
|
|
import dorkbox.dns.dns.DnsInput
|
|
import dorkbox.dns.dns.DnsOutput
|
|
import dorkbox.dns.dns.constants.DnsOpCode
|
|
import dorkbox.dns.dns.constants.DnsResponseCode
|
|
import dorkbox.dns.dns.constants.DnsSection
|
|
import dorkbox.dns.dns.constants.Flags
|
|
import dorkbox.os.OS.LINE_SEPARATOR
|
|
import kotlin.random.Random
|
|
|
|
/**
|
|
* A DNS message header
|
|
*
|
|
* @author Brian Wellington
|
|
* @see DnsMessage
|
|
*/
|
|
class Header : Cloneable {
|
|
internal var id = 0
|
|
|
|
var flagsByte = 0
|
|
private set
|
|
|
|
private var counts = IntArray(4)
|
|
|
|
/**
|
|
* Create a new empty header with a random message id
|
|
*/
|
|
constructor() {
|
|
init()
|
|
}
|
|
|
|
private fun init() {
|
|
counts = IntArray(4)
|
|
flagsByte = 0
|
|
id = -1
|
|
}
|
|
|
|
/**
|
|
* Creates a new Header from its DNS wire format representation
|
|
*
|
|
* @param b A byte array containing the DNS Header.
|
|
*/
|
|
constructor(b: ByteArray?) : this(DnsInput(b!!)) {}
|
|
|
|
/**
|
|
* Parses a Header from a stream containing DNS wire format.
|
|
*/
|
|
internal constructor(`in`: DnsInput) : this(`in`.readU16()) {
|
|
flagsByte = `in`.readU16()
|
|
for (i in counts.indices) {
|
|
counts[i] = `in`.readU16()
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Create a new empty header.
|
|
*
|
|
* @param id The message id
|
|
*/
|
|
constructor(id: Int) {
|
|
init()
|
|
iD = id
|
|
}
|
|
|
|
fun toWire(): ByteArray {
|
|
val out = DnsOutput()
|
|
toWire(out)
|
|
return out.toByteArray()
|
|
}
|
|
|
|
fun toWire(out: DnsOutput) {
|
|
out.writeU16(iD)
|
|
out.writeU16(flagsByte)
|
|
for (i in counts.indices) {
|
|
out.writeU16(counts[i])
|
|
}
|
|
}
|
|
/**
|
|
* Retrieves the message ID
|
|
*/
|
|
/**
|
|
* Sets the message ID
|
|
*/
|
|
var iD: Int
|
|
get() {
|
|
if (id >= 0) {
|
|
return id
|
|
}
|
|
synchronized(this) {
|
|
if (id < 0) {
|
|
id = random.get().nextInt(0xffff)
|
|
}
|
|
return id
|
|
}
|
|
}
|
|
set(id) {
|
|
require(!(id < 0 || id > 0xffff)) { "DNS message ID $id is out of range" }
|
|
this.id = id
|
|
}
|
|
|
|
/**
|
|
* Sets a flag to the supplied value
|
|
*
|
|
* @see Flags
|
|
*/
|
|
fun setFlag(flag: Flags) {
|
|
checkFlag(flag)
|
|
flagsByte = setFlag(flagsByte, flag, true)
|
|
}
|
|
|
|
/**
|
|
* Sets a flag to the supplied value
|
|
*
|
|
* @see Flags
|
|
*/
|
|
fun unsetFlag(flag: Flags) {
|
|
checkFlag(flag)
|
|
flagsByte = setFlag(flagsByte, flag, false)
|
|
}
|
|
|
|
fun getFlags(): BooleanArray {
|
|
val array = BooleanArray(16)
|
|
for (i in array.indices) {
|
|
if (Flags.isFlag(i)) {
|
|
array[i] = getFlag(i)
|
|
}
|
|
}
|
|
return array
|
|
}
|
|
|
|
/**
|
|
* Retrieves a flag
|
|
*
|
|
* @see Flags
|
|
*/
|
|
fun getFlag(flag: Flags): Boolean {
|
|
// bit s are indexed from left to right
|
|
return flagsByte and (1 shl (15 - flag.value())) != 0
|
|
}
|
|
|
|
/**
|
|
* Retrieves a flag.
|
|
*
|
|
* @param flagValue ALWAYS checked before using, so additional checks are not necessary
|
|
* @see Flags
|
|
*/
|
|
private fun getFlag(flagValue: Int): Boolean {
|
|
// bits are indexed from left to right
|
|
return flagsByte and (1 shl (15 - flagValue)) != 0
|
|
}
|
|
|
|
fun setCount(field: Int, value: Int) {
|
|
require(!(value < 0 || value > 0xFFFF)) { "DNS section count $value is out of range" }
|
|
counts[field] = value
|
|
}
|
|
|
|
fun incCount(field: Int) {
|
|
check(counts[field] != 0xFFFF) { "DNS section count cannot " + "be incremented" }
|
|
counts[field]++
|
|
}
|
|
|
|
fun decCount(field: Int) {
|
|
check(counts[field] != 0) { "DNS section count cannot " + "be decremented" }
|
|
counts[field]--
|
|
}
|
|
|
|
/* Creates a new Header identical to the current one */
|
|
public override fun clone(): Any {
|
|
val h = Header()
|
|
h.id = id
|
|
h.flagsByte = flagsByte
|
|
System.arraycopy(counts, 0, h.counts, 0, counts.size)
|
|
return h
|
|
}
|
|
|
|
/**
|
|
* Converts the header into a String
|
|
*/
|
|
override fun toString(): String {
|
|
return toStringWithRcode(rcode)
|
|
}
|
|
/**
|
|
* Retrieves the message's rcode
|
|
*
|
|
* @see DnsResponseCode
|
|
*/
|
|
/**
|
|
* Sets the message's rcode
|
|
*
|
|
* @see DnsResponseCode
|
|
*/
|
|
var rcode: Int
|
|
get() = flagsByte and 0xF
|
|
set(value) {
|
|
require(!(value < 0 || value > 0xF)) { "DNS DnsResponseCode $value is out of range" }
|
|
flagsByte = flagsByte and 0xF.inv()
|
|
flagsByte = flagsByte or value
|
|
}
|
|
|
|
fun toStringWithRcode(newrcode: Int): String {
|
|
val sb = StringBuilder()
|
|
sb.append(";; ->>HEADER<<- ")
|
|
sb.append("opcode: " + DnsOpCode.string(opcode))
|
|
sb.append(", status: " + DnsResponseCode.string(newrcode))
|
|
sb.append(", id: " + iD)
|
|
sb.append(LINE_SEPARATOR)
|
|
sb.append(";; flags: ").append(printFlags())
|
|
sb.append("; ")
|
|
for (i in 0..3) {
|
|
sb.append(DnsSection.string(i)).append(": ").append(getCount(i)).append(" ")
|
|
}
|
|
return sb.toString()
|
|
}
|
|
/**
|
|
* Retrieves the mesasge's opcode
|
|
*
|
|
* @see DnsOpCode
|
|
*/
|
|
/**
|
|
* Sets the message's opcode
|
|
*
|
|
* @see DnsOpCode
|
|
*/
|
|
var opcode: Int
|
|
get() = flagsByte shr 11 and 0xF
|
|
set(value) {
|
|
require(!(value < 0 || value > 0xF)) { "DNS DnsOpCode " + value + "is out of range" }
|
|
flagsByte = flagsByte and 0x87FF
|
|
flagsByte = flagsByte or (value shl 11)
|
|
}
|
|
|
|
/**
|
|
* Retrieves the record count for the given section
|
|
*
|
|
* @see DnsSection
|
|
*/
|
|
fun getCount(field: Int): Int {
|
|
return counts[field]
|
|
}
|
|
|
|
/**
|
|
* Converts the header's flags into a String
|
|
*/
|
|
fun printFlags(): String {
|
|
val sb = StringBuilder()
|
|
for (i in 0..15) {
|
|
if (Flags.isFlag(i) && getFlag(i)) {
|
|
val flag = Flags.Companion.toFlag(i)
|
|
sb.append(flag.string())
|
|
sb.append(" ")
|
|
}
|
|
}
|
|
return sb.toString()
|
|
}
|
|
|
|
companion object {
|
|
private val random: ThreadLocal<Random> = object : ThreadLocal<Random>() {
|
|
override fun initialValue(): Random {
|
|
return Random(System.nanoTime())
|
|
}
|
|
}
|
|
|
|
/**
|
|
* The length of a DNS Header in wire format.
|
|
*/
|
|
const val LENGTH = 12
|
|
private fun checkFlag(flag: Int) {
|
|
require(Flags.isFlag(flag)) { "invalid flag bit $flag" }
|
|
}
|
|
|
|
private fun checkFlag(flag: Flags) {
|
|
require(validFlag(flag)) { "invalid flag bit $flag" }
|
|
}
|
|
|
|
private fun validFlag(flag: Flags?): Boolean {
|
|
return flag != null && flag.value() >= 0 && flag.value() <= 0xF && Flags.isFlag(
|
|
flag.value().toInt()
|
|
)
|
|
}
|
|
|
|
fun setFlag(flags: Int, flag: Flags, value: Boolean): Int {
|
|
var flags = flags
|
|
checkFlag(flag)
|
|
|
|
// bits are indexed from left to right
|
|
return if (value) {
|
|
(1 shl 15 - flag.value()).let { flags = flags or it; flags }
|
|
} else {
|
|
(1 shl 15 - flag.value()).inv().let { flags = flags and it; flags }
|
|
}
|
|
}
|
|
}
|
|
}
|