Fixed issues surrounding capacity and strings/byte arrays. (we should not be using capacity, require() should manage if we need to grow capacity or not

This commit is contained in:
nathan 2019-06-13 01:42:09 +02:00
parent 0b8c9080c6
commit 7b74a730ba

View File

@ -62,6 +62,9 @@ import io.netty.buffer.Unpooled;
* Modified from KRYO to use ByteBuf. * Modified from KRYO to use ByteBuf.
*/ */
public class ByteBufOutput extends Output { public class ByteBufOutput extends Output {
// NOTE: capacity IS NOT USED!
private ByteBuf byteBuf; private ByteBuf byteBuf;
private int initialReaderIndex = 0; private int initialReaderIndex = 0;
private int initialWriterIndex = 0; private int initialWriterIndex = 0;
@ -84,7 +87,6 @@ public class ByteBufOutput extends Output {
* maxBufferSize and an exception is thrown. Can be -1 for no maximum. */ * maxBufferSize and an exception is thrown. Can be -1 for no maximum. */
public ByteBufOutput (int bufferSize, int maxBufferSize) { public ByteBufOutput (int bufferSize, int maxBufferSize) {
if (maxBufferSize < -1) throw new IllegalArgumentException("maxBufferSize cannot be < -1: " + maxBufferSize); if (maxBufferSize < -1) throw new IllegalArgumentException("maxBufferSize cannot be < -1: " + maxBufferSize);
this.capacity = bufferSize;
this.maxCapacity = maxBufferSize == -1 ? Util.maxArraySize : maxBufferSize; this.maxCapacity = maxBufferSize == -1 ? Util.maxArraySize : maxBufferSize;
byteBuf = Unpooled.buffer(bufferSize); byteBuf = Unpooled.buffer(bufferSize);
} }
@ -166,7 +168,6 @@ public class ByteBufOutput extends Output {
this.byteBuf = buffer; this.byteBuf = buffer;
this.maxCapacity = maxBufferSize == -1 ? Util.maxArraySize : maxBufferSize; this.maxCapacity = maxBufferSize == -1 ? Util.maxArraySize : maxBufferSize;
capacity = buffer.capacity();
position = initialWriterIndex; position = initialWriterIndex;
total = 0; total = 0;
outputStream = null; outputStream = null;
@ -194,8 +195,16 @@ public class ByteBufOutput extends Output {
byteBuf.setIndex(initialReaderIndex, initialWriterIndex); byteBuf.setIndex(initialReaderIndex, initialWriterIndex);
} }
/**
* Ensures the buffer is large enough to read the specified number of bytes.
* @return true if the buffer has been resized.
*/
protected boolean require (int required) throws KryoException { protected boolean require (int required) throws KryoException {
int origCode = byteBuf.ensureWritable(1, true); if (byteBuf.isWritable(required)) {
return false;
}
int origCode = byteBuf.ensureWritable(required, true);
if (origCode == 0) { if (origCode == 0) {
// 0 if the buffer has enough writable bytes, and its capacity is unchanged. // 0 if the buffer has enough writable bytes, and its capacity is unchanged.
@ -203,12 +212,10 @@ public class ByteBufOutput extends Output {
} }
else if (origCode == 2) { else if (origCode == 2) {
// 2 if the buffer has enough writable bytes, and its capacity has been increased. // 2 if the buffer has enough writable bytes, and its capacity has been increased.
capacity = byteBuf.capacity();
return true; return true;
} }
else if (origCode == 3) { else if (origCode == 3) {
// 3 if the buffer does not have enough bytes, but its capacity has been increased to its maximum. // 3 if the buffer does not have enough bytes, but its capacity has been increased to its maximum.
capacity = byteBuf.capacity();
return true; return true;
} }
else { else {
@ -216,8 +223,8 @@ public class ByteBufOutput extends Output {
flush(); flush();
} }
// only got here because we were unable to resize the buffer! // only got here because we were unable to resize the buffer! So we flushed it first to try again!
origCode = byteBuf.ensureWritable(1, true); origCode = byteBuf.ensureWritable(required, true);
if (origCode == 0) { if (origCode == 0) {
// 0 if the buffer has enough writable bytes, and its capacity is unchanged. // 0 if the buffer has enough writable bytes, and its capacity is unchanged.
@ -228,12 +235,10 @@ public class ByteBufOutput extends Output {
} }
else if (origCode == 2) { else if (origCode == 2) {
// 2 if the buffer has enough writable bytes, and its capacity has been increased. // 2 if the buffer has enough writable bytes, and its capacity has been increased.
capacity = byteBuf.capacity();
return true; return true;
} }
else if (origCode == 3) { else if (origCode == 3) {
// 3 if the buffer does not have enough bytes, but its capacity has been increased to its maximum. // 3 if the buffer does not have enough bytes, but its capacity has been increased to its maximum.
capacity = byteBuf.capacity();
return true; return true;
} }
else { else {
@ -268,7 +273,7 @@ public class ByteBufOutput extends Output {
} }
public void write (int value) throws KryoException { public void write (int value) throws KryoException {
if (position == capacity) require(1); require(1);
byteBuf.writeByte((byte)value); byteBuf.writeByte((byte)value);
position++; position++;
} }
@ -285,13 +290,13 @@ public class ByteBufOutput extends Output {
// byte: // byte:
public void writeByte (byte value) throws KryoException { public void writeByte (byte value) throws KryoException {
if (position == capacity) require(1); require(1);
byteBuf.writeByte(value); byteBuf.writeByte(value);
position++; position++;
} }
public void writeByte (int value) throws KryoException { public void writeByte (int value) throws KryoException {
if (position == capacity) require(1); require(1);
byteBuf.writeByte((byte)value); byteBuf.writeByte((byte)value);
position++; position++;
} }
@ -303,16 +308,10 @@ public class ByteBufOutput extends Output {
public void writeBytes (byte[] bytes, int offset, int count) throws KryoException { public void writeBytes (byte[] bytes, int offset, int count) throws KryoException {
if (bytes == null) throw new IllegalArgumentException("bytes cannot be null."); if (bytes == null) throw new IllegalArgumentException("bytes cannot be null.");
int copyCount = Math.min(capacity - position, count);
while (true) { require(count);
byteBuf.writeBytes(bytes, offset, copyCount); byteBuf.writeBytes(bytes, offset, count);
position += copyCount; position += count;
count -= copyCount;
if (count == 0) return;
offset += copyCount;
copyCount = Math.min(capacity, count);
require(copyCount);
}
} }
// int: // int:
@ -330,7 +329,7 @@ public class ByteBufOutput extends Output {
public int writeVarInt (int value, boolean optimizePositive) throws KryoException { public int writeVarInt (int value, boolean optimizePositive) throws KryoException {
if (!optimizePositive) value = (value << 1) ^ (value >> 31); if (!optimizePositive) value = (value << 1) ^ (value >> 31);
if (value >>> 7 == 0) { if (value >>> 7 == 0) {
if (position == capacity) require(1); require(1);
position++; position++;
byteBuf.writeByte((byte)value); byteBuf.writeByte((byte)value);
return 1; return 1;
@ -376,7 +375,7 @@ public class ByteBufOutput extends Output {
if (!optimizePositive) value = (value << 1) ^ (value >> 31); if (!optimizePositive) value = (value << 1) ^ (value >> 31);
int first = (value & 0x3F) | (flag ? 0x80 : 0); // Mask first 6 bits, bit 8 is the flag. int first = (value & 0x3F) | (flag ? 0x80 : 0); // Mask first 6 bits, bit 8 is the flag.
if (value >>> 6 == 0) { if (value >>> 6 == 0) {
if (position == capacity) require(1); require(1);
byteBuf.writeByte((byte)first); byteBuf.writeByte((byte)first);
position++; position++;
return 1; return 1;
@ -437,7 +436,7 @@ public class ByteBufOutput extends Output {
public int writeVarLong (long value, boolean optimizePositive) throws KryoException { public int writeVarLong (long value, boolean optimizePositive) throws KryoException {
if (!optimizePositive) value = (value << 1) ^ (value >> 63); if (!optimizePositive) value = (value << 1) ^ (value >> 63);
if (value >>> 7 == 0) { if (value >>> 7 == 0) {
if (position == capacity) require(1); require(1);
position++; position++;
byteBuf.writeByte((byte)value); byteBuf.writeByte((byte)value);
return 1; return 1;
@ -584,7 +583,7 @@ public class ByteBufOutput extends Output {
// boolean: // boolean:
public void writeBoolean (boolean value) throws KryoException { public void writeBoolean (boolean value) throws KryoException {
if (position == capacity) require(1); require(1);
byteBuf.writeByte((byte)(value ? 1 : 0)); byteBuf.writeByte((byte)(value ? 1 : 0));
position++; position++;
} }
@ -601,24 +600,33 @@ public class ByteBufOutput extends Output {
writeByte(1 | 0x80); // 1 means empty string, bit 8 means UTF8. writeByte(1 | 0x80); // 1 means empty string, bit 8 means UTF8.
return; return;
} }
// Detect ASCII.
outer: require(charCount); // must be able to write this number of chars
if (charCount > 1 && charCount <= 32) {
for (int i = 0; i < charCount; i++) // Detect ASCII, we only do this for small strings
if (value.charAt(i) > 127) break outer; boolean isAscii = charCount <= 32;
if (capacity - position < charCount)
writeAscii_slow(value, charCount); if (isAscii) {
else { for (int i = 0; i < charCount; i++) {
for (int i = 0, n = value.length(); i < n; ++i) if (value.charAt(i) > 127) {
byteBuf.writeByte((byte)value.charAt(i)); isAscii = false;
position += charCount; break; // not ascii
}
} }
byteBuf.setByte(position - 1, (byte)(byteBuf.getByte(position - 1) | 0x80));
return;
} }
writeVarIntFlag(true, charCount + 1, true);
int charIndex = 0; if (isAscii) {
if (capacity - position >= charCount) { // this is ascii
for (int i = 0, n = value.length(); i < n; ++i) {
byteBuf.writeByte((byte)value.charAt(i));
}
position += charCount;
byteBuf.setByte(position - 1, (byte)(byteBuf.getByte(position - 1) | 0x80));
} else {
writeVarIntFlag(true, charCount + 1, true);
int charIndex = 0;
// Try to write 7 bit chars. // Try to write 7 bit chars.
ByteBuf byteBuf = this.byteBuf; ByteBuf byteBuf = this.byteBuf;
while (true) { while (true) {
@ -632,8 +640,9 @@ public class ByteBufOutput extends Output {
} }
} }
position = byteBuf.writerIndex(); position = byteBuf.writerIndex();
if (charIndex < charCount) writeUtf8_slow(value, charCount, charIndex);
} }
if (charIndex < charCount) writeUtf8_slow(value, charCount, charIndex);
} }
public void writeAscii (String value) throws KryoException { public void writeAscii (String value) throws KryoException {
@ -646,168 +655,132 @@ public class ByteBufOutput extends Output {
writeByte(1 | 0x80); // 1 means empty string, bit 8 means UTF8. writeByte(1 | 0x80); // 1 means empty string, bit 8 means UTF8.
return; return;
} }
if (capacity - position < charCount)
writeAscii_slow(value, charCount); require(charCount); // must be able to write this number of chars
else {
ByteBuf byteBuf = this.byteBuf; ByteBuf byteBuf = this.byteBuf;
for (int i = 0, n = value.length(); i < n; ++i) for (int i = 0, n = value.length(); i < n; ++i) {
byteBuf.writeByte((byte)value.charAt(i)); byteBuf.writeByte((byte)value.charAt(i));
position += charCount;
} }
position += charCount;
byteBuf.setByte(position - 1, (byte)(byteBuf.getByte(position - 1) | 0x80)); // Bit 8 means end of ASCII. byteBuf.setByte(position - 1, (byte)(byteBuf.getByte(position - 1) | 0x80)); // Bit 8 means end of ASCII.
} }
private void writeUtf8_slow (String value, int charCount, int charIndex) { private void writeUtf8_slow (String value, int charCount, int charIndex) {
for (; charIndex < charCount; charIndex++) { for (; charIndex < charCount; charIndex++) {
if (position == capacity) require(Math.min(capacity, charCount - charIndex));
position++;
int c = value.charAt(charIndex); int c = value.charAt(charIndex);
if (c <= 0x007F) if (c <= 0x007F) {
byteBuf.writeByte((byte)c); writeByte((byte)c);
}
else if (c > 0x07FF) { else if (c > 0x07FF) {
require(3);
byteBuf.writeByte((byte)(0xE0 | c >> 12 & 0x0F)); byteBuf.writeByte((byte)(0xE0 | c >> 12 & 0x0F));
require(2);
position += 2;
byteBuf.writeByte((byte)(0x80 | c >> 6 & 0x3F)); byteBuf.writeByte((byte)(0x80 | c >> 6 & 0x3F));
byteBuf.writeByte((byte)(0x80 | c & 0x3F)); byteBuf.writeByte((byte)(0x80 | c & 0x3F));
position += 3;
} else { } else {
require(2);
byteBuf.writeByte((byte)(0xC0 | c >> 6 & 0x1F)); byteBuf.writeByte((byte)(0xC0 | c >> 6 & 0x1F));
if (position == capacity) require(1);
position++;
byteBuf.writeByte((byte)(0x80 | c & 0x3F)); byteBuf.writeByte((byte)(0x80 | c & 0x3F));
position += 2;
} }
} }
} }
private void writeAscii_slow (String value, int charCount) throws KryoException {
ByteBuf buffer = this.byteBuf;
int charIndex = 0;
int charsToWrite = Math.min(charCount, capacity - position);
while (charIndex < charCount) {
byte[] tmp = new byte[charCount];
value.getBytes(charIndex, charIndex + charsToWrite, tmp, 0);
buffer.writeBytes(tmp, 0, charsToWrite);
charIndex += charsToWrite;
position += charsToWrite;
charsToWrite = Math.min(charCount - charIndex, capacity);
if (require(charsToWrite)) buffer = this.byteBuf;
}
}
// Primitive arrays: // Primitive arrays:
public void writeInts (int[] array, int offset, int count) throws KryoException { public void writeInts (int[] array, int offset, int count) throws KryoException {
if (capacity >= count << 2 && require(count << 2)) { require(count << 2);
ByteBuf byteBuf = this.byteBuf;
for (int n = offset + count; offset < n; offset++) { ByteBuf byteBuf = this.byteBuf;
int value = array[offset]; for (int n = offset + count; offset < n; offset++) {
byteBuf.writeByte((byte)value); int value = array[offset];
byteBuf.writeByte((byte)(value >> 8)); byteBuf.writeByte((byte)value);
byteBuf.writeByte((byte)(value >> 16)); byteBuf.writeByte((byte)(value >> 8));
byteBuf.writeByte((byte)(value >> 24)); byteBuf.writeByte((byte)(value >> 16));
} byteBuf.writeByte((byte)(value >> 24));
position = byteBuf.writerIndex();
} else {
for (int n = offset + count; offset < n; offset++)
writeInt(array[offset]);
} }
position = byteBuf.writerIndex();
} }
public void writeLongs (long[] array, int offset, int count) throws KryoException { public void writeLongs (long[] array, int offset, int count) throws KryoException {
if (capacity >= count << 3 && require(count << 3)) { require(count << 3);
ByteBuf byteBuf = this.byteBuf;
for (int n = offset + count; offset < n; offset++) { ByteBuf byteBuf = this.byteBuf;
long value = array[offset]; for (int n = offset + count; offset < n; offset++) {
byteBuf.writeByte((byte)value); long value = array[offset];
byteBuf.writeByte((byte)(value >>> 8)); byteBuf.writeByte((byte)value);
byteBuf.writeByte((byte)(value >>> 16)); byteBuf.writeByte((byte)(value >>> 8));
byteBuf.writeByte((byte)(value >>> 24)); byteBuf.writeByte((byte)(value >>> 16));
byteBuf.writeByte((byte)(value >>> 32)); byteBuf.writeByte((byte)(value >>> 24));
byteBuf.writeByte((byte)(value >>> 40)); byteBuf.writeByte((byte)(value >>> 32));
byteBuf.writeByte((byte)(value >>> 48)); byteBuf.writeByte((byte)(value >>> 40));
byteBuf.writeByte((byte)(value >>> 56)); byteBuf.writeByte((byte)(value >>> 48));
} byteBuf.writeByte((byte)(value >>> 56));
position = byteBuf.writerIndex();
} else {
for (int n = offset + count; offset < n; offset++)
writeLong(array[offset]);
} }
position = byteBuf.writerIndex();
} }
public void writeFloats (float[] array, int offset, int count) throws KryoException { public void writeFloats (float[] array, int offset, int count) throws KryoException {
if (capacity >= count << 2 && require(count << 2)) { require(count << 2);
ByteBuf byteBuf = this.byteBuf;
for (int n = offset + count; offset < n; offset++) { ByteBuf byteBuf = this.byteBuf;
int value = Float.floatToIntBits(array[offset]); for (int n = offset + count; offset < n; offset++) {
byteBuf.writeByte((byte)value); int value = Float.floatToIntBits(array[offset]);
byteBuf.writeByte((byte)(value >> 8)); byteBuf.writeByte((byte)value);
byteBuf.writeByte((byte)(value >> 16)); byteBuf.writeByte((byte)(value >> 8));
byteBuf.writeByte((byte)(value >> 24)); byteBuf.writeByte((byte)(value >> 16));
} byteBuf.writeByte((byte)(value >> 24));
position = byteBuf.writerIndex();
} else {
for (int n = offset + count; offset < n; offset++)
writeFloat(array[offset]);
} }
position = byteBuf.writerIndex();
} }
public void writeDoubles (double[] array, int offset, int count) throws KryoException { public void writeDoubles (double[] array, int offset, int count) throws KryoException {
if (capacity >= count << 3 && require(count << 3)) { require(count << 3);
ByteBuf byteBuf = this.byteBuf;
for (int n = offset + count; offset < n; offset++) { ByteBuf byteBuf = this.byteBuf;
long value = Double.doubleToLongBits(array[offset]); for (int n = offset + count; offset < n; offset++) {
byteBuf.writeByte((byte)value); long value = Double.doubleToLongBits(array[offset]);
byteBuf.writeByte((byte)(value >>> 8)); byteBuf.writeByte((byte)value);
byteBuf.writeByte((byte)(value >>> 16)); byteBuf.writeByte((byte)(value >>> 8));
byteBuf.writeByte((byte)(value >>> 24)); byteBuf.writeByte((byte)(value >>> 16));
byteBuf.writeByte((byte)(value >>> 32)); byteBuf.writeByte((byte)(value >>> 24));
byteBuf.writeByte((byte)(value >>> 40)); byteBuf.writeByte((byte)(value >>> 32));
byteBuf.writeByte((byte)(value >>> 48)); byteBuf.writeByte((byte)(value >>> 40));
byteBuf.writeByte((byte)(value >>> 56)); byteBuf.writeByte((byte)(value >>> 48));
} byteBuf.writeByte((byte)(value >>> 56));
position = byteBuf.writerIndex();
} else {
for (int n = offset + count; offset < n; offset++)
writeDouble(array[offset]);
} }
position = byteBuf.writerIndex();
} }
public void writeShorts (short[] array, int offset, int count) throws KryoException { public void writeShorts (short[] array, int offset, int count) throws KryoException {
if (capacity >= count << 1 && require(count << 1)) { require(count << 1);
for (int n = offset + count; offset < n; offset++) {
int value = array[offset]; for (int n = offset + count; offset < n; offset++) {
byteBuf.writeByte((byte)value); int value = array[offset];
byteBuf.writeByte((byte)(value >>> 8)); byteBuf.writeByte((byte)value);
} byteBuf.writeByte((byte)(value >>> 8));
position = byteBuf.writerIndex();
} else {
for (int n = offset + count; offset < n; offset++)
writeShort(array[offset]);
} }
position = byteBuf.writerIndex();
} }
public void writeChars (char[] array, int offset, int count) throws KryoException { public void writeChars (char[] array, int offset, int count) throws KryoException {
if (capacity >= count << 1 && require(count << 1)) { require(count << 1);
for (int n = offset + count; offset < n; offset++) {
int value = array[offset]; for (int n = offset + count; offset < n; offset++) {
byteBuf.writeByte((byte)value); int value = array[offset];
byteBuf.writeByte((byte)(value >>> 8)); byteBuf.writeByte((byte)value);
} byteBuf.writeByte((byte)(value >>> 8));
position = byteBuf.writerIndex();
} else {
for (int n = offset + count; offset < n; offset++)
writeChar(array[offset]);
} }
position = byteBuf.writerIndex();
} }
public void writeBooleans (boolean[] array, int offset, int count) throws KryoException { public void writeBooleans (boolean[] array, int offset, int count) throws KryoException {
if (capacity >= count && require(count)) { require(count);
for (int n = offset + count; offset < n; offset++)
byteBuf.writeByte(array[offset] ? (byte)1 : 0); for (int n = offset + count; offset < n; offset++)
position = byteBuf.writerIndex(); byteBuf.writeByte(array[offset] ? (byte)1 : 0);
} else { position = byteBuf.writerIndex();
for (int n = offset + count; offset < n; offset++)
writeBoolean(array[offset]);
}
} }
} }