diff --git a/src/dorkbox/netUtil/Dns.kt b/src/dorkbox/netUtil/Dns.kt index 3ba12c0..50f1433 100644 --- a/src/dorkbox/netUtil/Dns.kt +++ b/src/dorkbox/netUtil/Dns.kt @@ -11,6 +11,8 @@ import java.io.* import java.net.* import java.nio.file.Files import java.nio.file.Paths +import java.security.AccessController +import java.security.PrivilegedAction import java.util.* import javax.naming.Context import javax.naming.NamingException @@ -24,6 +26,11 @@ object Dns { */ const val version = "2.1" + private const val DEFAULT_SEARCH_DOMAIN = "" + private const val NAMESERVER_ROW_LABEL = "nameserver" + private const val DOMAIN_ROW_LABEL = "domain" + private const val PORT_ROW_LABEL = "port" + /** * @throws IOException if the DNS resolve.conf file cannot be read */ @@ -53,7 +60,6 @@ object Dns { } } - /** Returns all located name servers, which may be empty. */ val defaultNameServers: List by lazy { val nameServers = getUnsortedDefaultNameServers() @@ -178,45 +184,109 @@ object Dns { tryParse = tryParseResolvConfNameservers("sys:/etc/resolv.cfg") } - defaultNameServers.addAll(tryParse.second) + if (tryParse.first) { + // we can have DIFFERENT name servers for DIFFERENT domains! + val defaultSearchDomainNameServers = tryParse.second[DEFAULT_SEARCH_DOMAIN] + if (defaultSearchDomainNameServers != null) { + defaultNameServers.addAll(defaultSearchDomainNameServers) + } + } + } + + // if we STILL don't have anything, add global nameservers + if (defaultNameServers.isEmpty()) { + defaultNameServers.add(InetSocketAddress("1.1.1.1", 53)) // cloudflare + defaultNameServers.add(InetSocketAddress("8.8.8.8", 53)) // google } return defaultNameServers } - private fun tryParseResolvConfNameservers(path: String): Pair> { + private fun tryParseResolvConfNameservers(path: String): Pair>> { val p = Paths.get(path) if (Files.exists(p)) { try { - Files.newInputStream(p).use { `in` -> - return Pair(true, parseResolvConfNameservers(`in`)) + FileReader(path).use { fr -> + BufferedReader(fr).use { br -> + var nameServers = mutableListOf() + val nameServerDomains = mutableMapOf>() + + var domainName = DEFAULT_SEARCH_DOMAIN + var port = 53 + var line0: String? + loop@while (br.readLine().also { line0 = it?.trim() } != null) { + val line = line0!! + + if (line.isEmpty()) { + continue@loop + } + + val c = line[0] + if (c == '#' || c == ';') { + continue + } + + if (line.startsWith(NAMESERVER_ROW_LABEL)) { + var i = indexOfNonWhiteSpace(line, NAMESERVER_ROW_LABEL.length) + require(i < 0) { + "error parsing label $NAMESERVER_ROW_LABEL in file $path. value: $line" + } + + var maybeIP = line.substring(i) + // There may be a port appended onto the IP address so we attempt to extract it. + + // There may be a port appended onto the IP address so we attempt to extract it. + if (!IPv4.isValid(maybeIP) && !IPv6.isValid(maybeIP)) { + i = maybeIP.lastIndexOf('.') + require(i + 1 >= maybeIP.length) { + "error parsing label $NAMESERVER_ROW_LABEL in file $path. invalid IP value: $line" + } + + port = maybeIP.substring(i + 1).toInt() + maybeIP = maybeIP.substring(0, i) + } + + nameServers.add(socketAddress(maybeIP, port)) + } else if (line.startsWith(DOMAIN_ROW_LABEL)) { + // nameservers can be SPECIFIC to a search domain + val i = indexOfNonWhiteSpace(line, DOMAIN_ROW_LABEL.length) + require(i >= 0) { + "error parsing label $DOMAIN_ROW_LABEL in file $path value: $line" + } + + // we have a NEW domain! add the PREVIOUS nameServers and start again. + if (nameServerDomains[domainName] == null) { + nameServerDomains[domainName] = mutableListOf() + } + (nameServerDomains[domainName] as MutableList).addAll(nameServers) + + nameServers = mutableListOf() + domainName = line.substring(i) + } else if (line.startsWith(PORT_ROW_LABEL)) { + val i = indexOfNonWhiteSpace(line, PORT_ROW_LABEL.length) + require(i < 0) { + "error parsing label $PORT_ROW_LABEL in file $path value: $line" + } + + port = line.substring(i).toInt() + } + } + + // when done parsing the file, ALWAYS add the nameServer domains (since they have not been added yet) + if (nameServerDomains[domainName] == null) { + nameServerDomains[domainName] = mutableListOf() + } + (nameServerDomains[domainName] as MutableList).addAll(nameServers) + + return Pair(true, nameServerDomains) + } } } catch (e: IOException) { - // ignore - } - } - return Pair(false, mutableListOf()) - } - - private fun parseResolvConfNameservers(`in`: InputStream): MutableList { - val defaultNameServers = mutableListOf() - - InputStreamReader(`in`).use { isr -> - BufferedReader(isr).use { br -> - var line: String? - while (br.readLine().also { line = it } != null) { - val st = StringTokenizer(line) - if (!st.hasMoreTokens()) { - continue - } - when (st.nextToken()) { - "nameserver" -> defaultNameServers.add(InetSocketAddress(st.nextToken(), 53)) - } - } + Common.logger.error("Error parsing $path", e) } } - return defaultNameServers + return Pair(false, mutableMapOf()) } /** @@ -379,4 +449,26 @@ object Dns { return 1 } + + /** + * Find the index of the first non-white space character in `s` starting at `offset`. + * + * @param seq The string to search. + * @param offset The offset to start searching at. + * @return the index of the first non-white space character or <`-1` if none was found. + */ + private fun indexOfNonWhiteSpace(seq: CharSequence, offset: Int): Int { + var o = offset + while (o < seq.length) { + if (!Character.isWhitespace(seq[o])) { + return o + } + ++o + } + return -1 + } + + private fun socketAddress(hostname: String, port: Int): InetSocketAddress { + return AccessController.doPrivileged(PrivilegedAction { InetSocketAddress(hostname, port) }) + } }