diff --git a/benchmarks/src/main/scala/fs2/data/benchmarks/MsgPackItemSerializerBenchmarks.scala b/benchmarks/src/main/scala/fs2/data/benchmarks/MsgPackItemSerializerBenchmarks.scala new file mode 100644 index 00000000..caf49e34 --- /dev/null +++ b/benchmarks/src/main/scala/fs2/data/benchmarks/MsgPackItemSerializerBenchmarks.scala @@ -0,0 +1,71 @@ +/* + * Copyright 2024 fs2-data Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2 +package data.benchmarks + +import java.util.concurrent.TimeUnit +import org.openjdk.jmh.annotations._ + +import cats.effect.SyncIO + +import scodec.bits._ +import fs2._ + +@OutputTimeUnit(TimeUnit.MICROSECONDS) +@BenchmarkMode(Array(Mode.AverageTime)) +@State(org.openjdk.jmh.annotations.Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 3, time = 2) +@Measurement(iterations = 10, time = 2) +class MsgPackItemSerializerBenchmarks { + val msgpackItems: List[fs2.data.msgpack.low.MsgpackItem] = { + val bytes = + fs2.io + .readClassLoaderResource[SyncIO]("twitter_msgpack.txt", 4096) + .through(fs2.text.utf8.decode) + .compile + .string + .map(ByteVector.fromHex(_).get) + .unsafeRunSync() + + Stream + .chunk(Chunk.byteVector(bytes)) + .through(fs2.data.msgpack.low.items[SyncIO]) + .compile + .toList + .unsafeRunSync() + } + + + @Benchmark + def serialize() = + Stream + .emits(msgpackItems) + .through(fs2.data.msgpack.low.toNonValidatedBinary[SyncIO]) + .compile + .drain + .unsafeRunSync() + + @Benchmark + def withValidation() = + Stream + .emits(msgpackItems) + .through(fs2.data.msgpack.low.toBinary[SyncIO]) + .compile + .drain + .unsafeRunSync() +} diff --git a/msgpack/src/main/scala/fs2/data/msgpack/exceptions.scala b/msgpack/src/main/scala/fs2/data/msgpack/exceptions.scala new file mode 100644 index 00000000..d4940a68 --- /dev/null +++ b/msgpack/src/main/scala/fs2/data/msgpack/exceptions.scala @@ -0,0 +1,32 @@ +/* + * Copyright 2024 fs2-data Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2 +package data +package msgpack + +abstract class MsgpackException(msg: String, cause: Throwable = null) extends Exception(msg, cause) + +case class MsgpackMalformedItemException(msg: String, position: Option[Long] = None, inner: Throwable = null) + extends MsgpackException(position.fold(msg)(pos => s"at position $pos"), inner) + +case class MsgpackUnexpectedEndOfStreamException(position: Option[Long] = None, inner: Throwable = null) + extends MsgpackException( + position.fold("Unexpected end of stream")(pos => s"Unexpected end of stream starting at position $pos"), + inner) + +case class MsgpackMalformedByteStreamException(msg: String, inner: Throwable = null) + extends MsgpackException(msg, inner) diff --git a/msgpack/src/main/scala/fs2/data/msgpack/low/internal/FormatParsers.scala b/msgpack/src/main/scala/fs2/data/msgpack/low/internal/FormatParsers.scala index b568a6f7..896cf121 100644 --- a/msgpack/src/main/scala/fs2/data/msgpack/low/internal/FormatParsers.scala +++ b/msgpack/src/main/scala/fs2/data/msgpack/low/internal/FormatParsers.scala @@ -34,14 +34,14 @@ private[internal] object FormatParsers { def parseArray[F[_]](length: Int, ctx: ParserContext[F])(implicit F: RaiseThrowable[F]): Pull[F, MsgpackItem, ParserContext[F]] = { requireBytes(length, ctx).map { res => - res.accumulate(v => MsgpackItem.Array(v.toInt(false, ByteOrdering.BigEndian))) + res.accumulate(v => MsgpackItem.Array(v.toLong(false))) } } def parseMap[F[_]](length: Int, ctx: ParserContext[F])(implicit F: RaiseThrowable[F]): Pull[F, MsgpackItem, ParserContext[F]] = { requireBytes(length, ctx).map { res => - res.accumulate(v => MsgpackItem.Map(v.toInt(false, ByteOrdering.BigEndian))) + res.accumulate(v => MsgpackItem.Map(v.toLong(false))) } } @@ -63,7 +63,7 @@ private[internal] object FormatParsers { res <- requireBytes(8, res.toContext) seconds = res.result.toLong(false) } yield res.toContext.prepend(MsgpackItem.Timestamp96(nanosec, seconds)) - case _ => Pull.raiseError(new MsgpackParsingException(s"Invalid timestamp length: ${length}")) + case _ => Pull.raiseError(MsgpackMalformedByteStreamException(s"Invalid timestamp length: ${length}")) } } diff --git a/msgpack/src/main/scala/fs2/data/msgpack/low/internal/Helpers.scala b/msgpack/src/main/scala/fs2/data/msgpack/low/internal/Helpers.scala index 12a884cb..881c81ce 100644 --- a/msgpack/src/main/scala/fs2/data/msgpack/low/internal/Helpers.scala +++ b/msgpack/src/main/scala/fs2/data/msgpack/low/internal/Helpers.scala @@ -23,7 +23,6 @@ package internal import scodec.bits.ByteVector private[internal] object Helpers { - case class MsgpackParsingException(str: String) extends Exception /** @param chunk Current chunk * @param idx Index of the current [[Byte]] in `chunk` @@ -67,7 +66,7 @@ private[internal] object Helpers { // Inbounds chunk access is guaranteed by `ensureChunk` Pull.pure(ctx.next.toResult(ctx.chunk(ctx.idx))) } { - Pull.raiseError(new MsgpackParsingException("Unexpected end of input")) + Pull.raiseError(MsgpackUnexpectedEndOfStreamException()) } } @@ -93,7 +92,7 @@ private[internal] object Helpers { go(count - available, ParserContext(chunk, slice.size, rest, acc), newBytes) } } { - Pull.raiseError(new MsgpackParsingException("Unexpected end of input")) + Pull.raiseError(MsgpackUnexpectedEndOfStreamException()) } } diff --git a/msgpack/src/main/scala/fs2/data/msgpack/low/internal/ItemParser.scala b/msgpack/src/main/scala/fs2/data/msgpack/low/internal/ItemParser.scala index 4536d936..f411a48a 100644 --- a/msgpack/src/main/scala/fs2/data/msgpack/low/internal/ItemParser.scala +++ b/msgpack/src/main/scala/fs2/data/msgpack/low/internal/ItemParser.scala @@ -37,7 +37,7 @@ private[low] object ItemParser { ((byte & 0xff): @switch) match { case Headers.Nil => Pull.pure(ctx.prepend(MsgpackItem.Nil)) - case Headers.NeverUsed => Pull.raiseError(new MsgpackParsingException("Reserved value 0xc1 used")) + case Headers.NeverUsed => Pull.raiseError(MsgpackMalformedByteStreamException("Reserved value 0xc1 used")) case Headers.False => Pull.pure(ctx.prepend(MsgpackItem.False)) case Headers.True => Pull.pure(ctx.prepend(MsgpackItem.True)) case Headers.Bin8 => parseBin(1, ctx) @@ -77,13 +77,13 @@ private[low] object ItemParser { // fixmap else if ((byte & 0xf0) == 0x80) { val length = byte & 0x0f // 0x8f- 0x80 - Pull.pure(ctx.prepend(MsgpackItem.Map(length))) + Pull.pure(ctx.prepend(MsgpackItem.Map(length.toLong))) } // fixarray else if ((byte & 0xf0) == 0x90) { val length = byte & 0x0f // 0x9f- 0x90 - Pull.pure(ctx.prepend(MsgpackItem.Array(length))) + Pull.pure(ctx.prepend(MsgpackItem.Array(length.toLong))) } // fixstr @@ -98,7 +98,7 @@ private[low] object ItemParser { else if ((byte & 0xe0) == 0xe0) { Pull.pure(ctx.prepend(MsgpackItem.SignedInt(ByteVector(byte)))) } else { - Pull.raiseError(new MsgpackParsingException(s"Invalid type ${byte}")) + Pull.raiseError(MsgpackMalformedByteStreamException(s"Invalid type ${byte}")) } } } diff --git a/msgpack/src/main/scala/fs2/data/msgpack/low/internal/ItemSerializer.scala b/msgpack/src/main/scala/fs2/data/msgpack/low/internal/ItemSerializer.scala new file mode 100644 index 00000000..8cd09b8b --- /dev/null +++ b/msgpack/src/main/scala/fs2/data/msgpack/low/internal/ItemSerializer.scala @@ -0,0 +1,247 @@ +/* + * Copyright 2024 fs2-data Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2 +package data +package msgpack +package low +package internal + +import scodec.bits._ + +private[low] object ItemSerializer { + private final val positiveIntMask = hex"7f" + private final val negativeIntMask = hex"e0" + + private final val mapMask = 0x80 + private final val arrayMask = 0x90 + private final val strMask = 0xa0 + + private case class SerializationContext[F[_]](out: Out[F], + chunk: Chunk[MsgpackItem], + idx: Int, + rest: Stream[F, MsgpackItem]) + + /** Buffers [[Chunk]] into 4KiB segments before calling [[Pull.output]]. + * + * @param contents buffered [[Chunk]] + */ + private class Out[F[_]](contents: Chunk[Byte]) { + private val limit = 4096 + + /** Pushes `bv` into the buffer and emits the buffer if it reaches the limit. + */ + @inline + def push(bv: ByteVector): Pull[F, Byte, Out[F]] = + if (contents.size >= limit) + Pull.output(contents).as(new Out(Chunk.byteVector(bv))) + else + Pull.done.as(new Out(contents ++ Chunk.byteVector(bv))) + + /** Splices `bv` into segments and pushes them into the buffer while emitting the buffer at the same time so + * that it never exceeds the limit during the operation. + * + * Use this instead of [[Out.push]] when `bv` may significantly exceed 4KiB. + */ + def pushBuffered(bv: ByteVector): Pull[F, Byte, Out[F]] = { + @inline + def go(chunk: Chunk[Byte], rest: ByteVector): Pull[F, Byte, Out[F]] = + if (rest.isEmpty) + Pull.done.as(new Out(chunk)) + else + Pull.output(chunk) >> go(Chunk.byteVector(rest.take(limit.toLong)), rest.drop(limit.toLong)) + + if (bv.isEmpty) + this.push(bv) + else if (contents.size >= limit) + Pull.output(contents) >> go(Chunk.byteVector(bv.take(limit.toLong)), bv.drop(limit.toLong)) + else + go(contents ++ Chunk.byteVector(bv.take(limit.toLong - contents.size)), bv.drop(limit.toLong - contents.size)) + } + + /** Outputs the whole buffer. */ + @inline + def flush = Pull.output(contents) + } + + @inline + private def step[F[_]: RaiseThrowable](o: Out[F], item: MsgpackItem): Pull[F, Byte, Out[F]] = item match { + case MsgpackItem.UnsignedInt(bytes) => + val bs = bytes.dropWhile(_ == 0) + if (bs.size <= 1) + o.push(ByteVector(Headers.Uint8) ++ bs.padLeft(1)) + else if (bs.size <= 2) + o.push(ByteVector(Headers.Uint16) ++ bs.padLeft(2)) + else if (bs.size <= 4) + o.push(ByteVector(Headers.Uint32) ++ bs.padLeft(4)) + else if (bs.size <= 8) + o.push(ByteVector(Headers.Uint64) ++ bs.padLeft(8)) + else + Pull.raiseError(MsgpackMalformedItemException("Unsigned int exceeds 64 bits")) + + case MsgpackItem.SignedInt(bytes) => + val bs = bytes.dropWhile(_ == 0) + if (bs.size <= 1) + // positive fixint or negative fixint + if ((bs & positiveIntMask) == bs || (bs & negativeIntMask) == negativeIntMask) + o.push(bs.padLeft(1)) + else + o.push(ByteVector(Headers.Int8) ++ bs.padLeft(1)) + else if (bs.size <= 2) + o.push(ByteVector(Headers.Int16) ++ bs.padLeft(2)) + else if (bs.size <= 4) + o.push(ByteVector(Headers.Int32) ++ bs.padLeft(4)) + else if (bs.size <= 8) + o.push(ByteVector(Headers.Int64) ++ bs.padLeft(8)) + else + Pull.raiseError(MsgpackMalformedItemException("Signed int exceeds 64 bits")) + + case MsgpackItem.Float32(float) => + o.push(ByteVector(Headers.Float32) ++ ByteVector.fromInt(java.lang.Float.floatToIntBits(float))) + + case MsgpackItem.Float64(double) => + o.push(ByteVector(Headers.Float64) ++ ByteVector.fromLong(java.lang.Double.doubleToLongBits(double))) + + case MsgpackItem.Str(bytes) => + if (bytes.size <= 31) { + o.push(ByteVector.fromByte((strMask | bytes.size).toByte) ++ bytes) + } else if (bytes.size <= (1 << 8) - 1) { + val size = ByteVector.fromByte(bytes.size.toByte) + o.push(ByteVector(Headers.Str8) ++ size ++ bytes) + } else if (bytes.size <= (1 << 16) - 1) { + val size = ByteVector.fromShort(bytes.size.toShort) + o.push(ByteVector(Headers.Str16) ++ size ++ bytes) + } else if (bytes.size <= (1L << 32) - 1) { + val size = ByteVector.fromInt(bytes.size.toInt) + /* Max length of str32 (incl. type and length info) is 2^32 + 4 bytes + * which is more than Chunk can handle at once + */ + o.pushBuffered(ByteVector(Headers.Str32) ++ size ++ bytes) + } else { + Pull.raiseError(MsgpackMalformedItemException("String exceeds (2^32)-1 bytes")) + } + + case MsgpackItem.Bin(bytes) => + if (bytes.size <= (1 << 8) - 1) { + val size = ByteVector.fromByte(bytes.size.toByte) + o.push(ByteVector(Headers.Bin8) ++ size ++ bytes) + } else if (bytes.size <= (1 << 16) - 1) { + val size = ByteVector.fromShort(bytes.size.toShort) + o.push(ByteVector(Headers.Bin16) ++ size ++ bytes) + } else if (bytes.size <= (1L << 32) - 1) { + val size = ByteVector.fromInt(bytes.size.toInt) + /* Max length of str32 (incl. type and length info) is 2^32 + 4 bytes + * which is more than Chunk can handle at once + */ + o.pushBuffered(ByteVector(Headers.Bin32) ++ size ++ bytes) + } else { + Pull.raiseError(MsgpackMalformedItemException("Binary data exceeds (2^32)-1 bytes")) + } + + case MsgpackItem.Array(size) => + if (size <= 15) { + o.push(ByteVector.fromByte((arrayMask | size).toByte)) + } else if (size <= (1L << 16) - 1) { + val s = ByteVector.fromShort(size.toShort) + o.push(ByteVector(Headers.Array16) ++ s) + } else if (size <= (1L << 32) - 1) { + val s = ByteVector.fromLong(size, 4) + o.push(ByteVector(Headers.Array32) ++ s) + } else { + Pull.raiseError(MsgpackMalformedItemException("Array size exceeds (2^32)-1")) + } + + case MsgpackItem.Map(size) => + if (size <= 15) { + o.push(ByteVector.fromByte((mapMask | size).toByte)) + } else if (size <= (1L << 16) - 1) { + val s = ByteVector.fromShort(size.toShort) + o.push(ByteVector(Headers.Map16) ++ s) + } else if (size <= (1L << 32) - 1) { + val s = ByteVector.fromLong(size, 4) + o.push(ByteVector(Headers.Map32) ++ s) + } else { + Pull.raiseError(MsgpackMalformedItemException("Map size exceeds (2^32)-1 pairs")) + } + + case MsgpackItem.Extension(tpe, bytes) => + val bs = bytes.dropWhile(_ == 0) + if (bs.size <= 1) { + o.push((ByteVector(Headers.FixExt1) :+ tpe) ++ bs.padLeft(1)) + } else if (bs.size <= 2) { + o.push((ByteVector(Headers.FixExt2) :+ tpe) ++ bs.padLeft(2)) + } else if (bs.size <= 4) { + o.push((ByteVector(Headers.FixExt4) :+ tpe) ++ bs.padLeft(4)) + } else if (bs.size <= 8) { + o.push((ByteVector(Headers.FixExt8) :+ tpe) ++ bs.padLeft(8)) + } else if (bs.size <= 16) { + o.push((ByteVector(Headers.FixExt16) :+ tpe) ++ bs.padLeft(16)) + } else if (bs.size <= (1 << 8) - 1) { + val size = ByteVector.fromByte(bs.size.toByte) + o.push((ByteVector(Headers.Ext8) ++ size :+ tpe) ++ bs) + } else if (bs.size <= (1 << 16) - 1) { + val size = ByteVector.fromShort(bs.size.toShort) + o.push((ByteVector(Headers.Ext16) ++ size :+ tpe) ++ bs) + } else { + val size = ByteVector.fromInt(bs.size.toInt) + /* Max length of ext32 (incl. type and length info) is 2^32 + 5 bytes + * which is more than Chunk can handle at once. + */ + o.pushBuffered((ByteVector(Headers.Ext32) ++ size :+ tpe) ++ bs) + } + + case MsgpackItem.Timestamp32(seconds) => + o.push((ByteVector(Headers.FixExt4) :+ Headers.Timestamp.toByte) ++ ByteVector.fromInt(seconds)) + + case MsgpackItem.Timestamp64(combined) => + o.push((ByteVector(Headers.FixExt8) :+ Headers.Timestamp.toByte) ++ ByteVector.fromLong(combined)) + + case MsgpackItem.Timestamp96(nanoseconds, seconds) => + val ns = ByteVector.fromInt(nanoseconds) + val s = ByteVector.fromLong(seconds) + o.push((ByteVector(Headers.Ext8) :+ 12 :+ Headers.Timestamp.toByte) ++ ns ++ s) + + case MsgpackItem.Nil => + o.push(ByteVector(Headers.Nil)) + + case MsgpackItem.False => + o.push(ByteVector(Headers.False)) + + case MsgpackItem.True => + o.push(ByteVector(Headers.True)) + } + + private def stepChunk[F[_]: RaiseThrowable](ctx: SerializationContext[F]): Pull[F, Byte, SerializationContext[F]] = + if (ctx.idx >= ctx.chunk.size) + Pull.done.as(ctx) + else + step(ctx.out, ctx.chunk(ctx.idx)).flatMap { out => + stepChunk(SerializationContext(out, ctx.chunk, ctx.idx + 1, ctx.rest)) + } + + def pipe[F[_]: RaiseThrowable]: Pipe[F, MsgpackItem, Byte] = { stream => + def go(out: Out[F], rest: Stream[F, MsgpackItem]): Pull[F, Byte, Unit] = + rest.pull.uncons.flatMap { + case None => out.flush + case Some((chunk, rest)) => + stepChunk(SerializationContext(out, chunk, 0, rest)).flatMap { case SerializationContext(out, _, _, rest) => + go(out, rest) + } + } + + go(new Out(Chunk.empty), stream).stream + } +} diff --git a/msgpack/src/main/scala/fs2/data/msgpack/low/internal/ItemValidator.scala b/msgpack/src/main/scala/fs2/data/msgpack/low/internal/ItemValidator.scala new file mode 100644 index 00000000..59cf0f2a --- /dev/null +++ b/msgpack/src/main/scala/fs2/data/msgpack/low/internal/ItemValidator.scala @@ -0,0 +1,158 @@ +/* + * Copyright 2024 fs2-data Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2 +package data +package msgpack +package low +package internal + +private[low] object ItemValidator { + + case class Expect(n: Long, from: Long) { + def dec = Expect(n - 1, from) + } + + private val PullNone = Pull.pure(None) + + type ValidationContext = (Chunk[MsgpackItem], Int, Long, List[Expect]) + + def pipe[F[_]](implicit F: RaiseThrowable[F]): Pipe[F, MsgpackItem, MsgpackItem] = { in => + def step1(chunk: Chunk[MsgpackItem], idx: Int, position: Long): Pull[F, MsgpackItem, Option[Expect]] = + chunk(idx) match { + case MsgpackItem.UnsignedInt(bytes) => + if (bytes.size > 8) + Pull.raiseError(MsgpackMalformedItemException("Unsigned int exceeds 64 bits", Some(position))) + else PullNone + + case MsgpackItem.SignedInt(bytes) => + if (bytes.size > 8) + Pull.raiseError(MsgpackMalformedItemException("Signed int exceeds 64 bits", Some(position))) + else PullNone + + case MsgpackItem.Float32(_) => + PullNone + + case MsgpackItem.Float64(_) => + PullNone + + case MsgpackItem.Str(bytes) => + if (bytes.size > (1L << 32) - 1) + Pull.raiseError(MsgpackMalformedItemException("String exceeds (2^32)-1 bytes", Some(position))) + else + PullNone + + case MsgpackItem.Bin(bytes) => + if (bytes.size > (1L << 32) - 1) + Pull.raiseError(MsgpackMalformedItemException("Bin exceeds (2^32)-1 bytes", Some(position))) + else + PullNone + + case MsgpackItem.Array(size) => + if (size < 0) + Pull.raiseError(MsgpackMalformedItemException(s"Array has a negative size ${size}", Some(position))) + else if (size >= (1L << 32)) + Pull.raiseError(MsgpackMalformedItemException(s"Array size exceeds (2^32)-1", Some(position))) + else if (size == 0) + PullNone + else + Pull.pure(Some(Expect(size, position))) + + case MsgpackItem.Map(size) => + if (size < 0) + Pull.raiseError(MsgpackMalformedItemException(s"Map has a negative size ${size}", Some(position))) + else if (size >= (1L << 32)) + Pull.raiseError(MsgpackMalformedItemException(s"Map size exceeds (2^32)-1", Some(position))) + else if (size == 0) + PullNone + else + Pull.pure(Some(Expect(size * 2, position))) + + case MsgpackItem.Extension(_, bytes) => + if (bytes.size > (1L << 32) - 1) + Pull.raiseError(MsgpackMalformedItemException("Extension data exceeds (2^32)-1 bytes", Some(position))) + else + PullNone + + case _: MsgpackItem.Timestamp32 => + PullNone + + case item: MsgpackItem.Timestamp64 => + if (item.nanoseconds > 999999999) + Pull.raiseError( + MsgpackMalformedItemException("Timestamp64 nanoseconds is larger than '999999999'", Some(position))) + else + PullNone + + case MsgpackItem.Timestamp96(nanoseconds, _) => + if (nanoseconds > 999999999) + Pull.raiseError( + MsgpackMalformedItemException("Timestamp96 nanoseconds is larger than '999999999'", Some(position))) + else + PullNone + + case MsgpackItem.Nil => + PullNone + + case MsgpackItem.True => + PullNone + + case MsgpackItem.False => + PullNone + } + + def stepChunk(chunk: Chunk[MsgpackItem], + idx: Int, + stream: Stream[F, MsgpackItem], + position: Long, + state: List[Expect]): Pull[F, MsgpackItem, ValidationContext] = { + if (idx >= chunk.size) + Pull.output(chunk).as((Chunk.empty, 0, position, state)) + else + step1(chunk, idx, position).flatMap { el => + val stateNew: List[Expect] = + if (state.isEmpty) + state + else if (state.head.n == 1) + state.tail + else + state.head.dec :: state.tail + + val prepended = el match { + case Some(x) => x :: stateNew + case None => stateNew + } + + stepChunk(chunk, idx + 1, stream, position + 1, prepended) + } + } + + def go(stream: Stream[F, MsgpackItem], idx: Int, position: Long, state: List[Expect]): Pull[F, MsgpackItem, Unit] = + stream.pull.uncons.flatMap { + case Some((chunk, stream)) => + stepChunk(chunk, idx, stream, position, state).flatMap { case (_, idx, position, state) => + go(stream, idx, position, state) + } + case None => + if (state.isEmpty) + Pull.done + else + Pull.raiseError(MsgpackUnexpectedEndOfStreamException(Some(state.head.from))) + } + + go(in, 0, 0, List.empty).stream + } +} diff --git a/msgpack/src/main/scala/fs2/data/msgpack/low/model.scala b/msgpack/src/main/scala/fs2/data/msgpack/low/model.scala index 675a45b7..e191d7a0 100644 --- a/msgpack/src/main/scala/fs2/data/msgpack/low/model.scala +++ b/msgpack/src/main/scala/fs2/data/msgpack/low/model.scala @@ -34,8 +34,8 @@ object MsgpackItem { case class Str(bytes: ByteVector) extends MsgpackItem case class Bin(bytes: ByteVector) extends MsgpackItem - case class Array(size: Int) extends MsgpackItem - case class Map(size: Int) extends MsgpackItem + case class Array(size: Long) extends MsgpackItem + case class Map(size: Long) extends MsgpackItem case class Extension(tpe: Byte, bytes: ByteVector) extends MsgpackItem diff --git a/msgpack/src/main/scala/fs2/data/msgpack/low/package.scala b/msgpack/src/main/scala/fs2/data/msgpack/low/package.scala index 118e19df..46525f79 100644 --- a/msgpack/src/main/scala/fs2/data/msgpack/low/package.scala +++ b/msgpack/src/main/scala/fs2/data/msgpack/low/package.scala @@ -18,11 +18,33 @@ package fs2 package data package msgpack -import low.internal.ItemParser +import low.internal.{ItemParser, ItemSerializer, ItemValidator} /** A low-level representation of the MessagePack format. */ package object low { + + /** Transforms a stream of [[scala.Byte]]s into a stream of [[MsgpackItem]]s. + */ def items[F[_]](implicit F: RaiseThrowable[F]): Pipe[F, Byte, MsgpackItem] = ItemParser.pipe[F] + + /** Transforms a stream of [[MsgpackItem]]s into a stream of [[scala.Byte]]s. + * + * Will fail with an error if the stream is malformed. + */ + def toBinary[F[_]: RaiseThrowable]: Pipe[F, MsgpackItem, Byte] = + _.through(ItemValidator.pipe).through(ItemSerializer.pipe) + + /** Transforms a stream of [[MsgpackItem]]s into a stream of [[scala.Byte]]s. + * + * Will not validate the input stream and can potentially produce malformed data. Consider using [[toBinary]]. + */ + def toNonValidatedBinary[F[_]: RaiseThrowable]: Pipe[F, MsgpackItem, Byte] = + ItemSerializer.pipe + + /** Validates a stream of [[MsgpackItem]]s, fails when the stream is malformed. + */ + def validate[F[_]](implicit F: RaiseThrowable[F]): Pipe[F, MsgpackItem, MsgpackItem] = + ItemValidator.pipe[F] } diff --git a/msgpack/src/test/scala/fs2/data/msgpack/SerializerSpec.scala b/msgpack/src/test/scala/fs2/data/msgpack/SerializerSpec.scala new file mode 100644 index 00000000..79a889b5 --- /dev/null +++ b/msgpack/src/test/scala/fs2/data/msgpack/SerializerSpec.scala @@ -0,0 +1,190 @@ +/* + * Copyright 2024 fs2-data Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2 +package data +package msgpack + +import cats.effect._ +import scodec.bits._ +import weaver._ + +import java.nio.charset.StandardCharsets +import low.MsgpackItem + +object SerializerSpec extends SimpleIOSuite { + test("MessagePack item serializer should correctly serialize all formats") { + val cases: List[(List[MsgpackItem], ByteVector)] = List( + // nil, false, true + (List(MsgpackItem.Nil, MsgpackItem.False, MsgpackItem.True), hex"c0c2c3"), + + // positive fixint + (List(MsgpackItem.SignedInt(hex"7b")), hex"7b"), + // negative fixint + (List(MsgpackItem.SignedInt(hex"e6")), hex"e6"), + + // uint 8, uint 16, uint 32, uint 64 + (List(MsgpackItem.UnsignedInt(hex"ab")), hex"ccab"), + (List(MsgpackItem.UnsignedInt(hex"abcd")), hex"cdabcd"), + (List(MsgpackItem.UnsignedInt(hex"abcdef01")), hex"ceabcdef01"), + (List(MsgpackItem.UnsignedInt(hex"abcdef0123456789")), hex"cfabcdef0123456789"), + + // int 8, int 16, int 32, int 64 + (List(MsgpackItem.SignedInt(hex"80")), hex"d080"), + (List(MsgpackItem.SignedInt(hex"80ab")), hex"d180ab"), + (List(MsgpackItem.SignedInt(hex"80abcdef")), hex"d280abcdef"), + (List(MsgpackItem.SignedInt(hex"80abcddef0123456")), hex"d380abcddef0123456"), + + // float 32, float 64 + (List(MsgpackItem.Float32(0.125F)), hex"ca3e000000"), + (List(MsgpackItem.Float64(0.125)), hex"cb3fc0000000000000"), + + // fixstr + (List(MsgpackItem.Str(ByteVector("abc".getBytes(StandardCharsets.UTF_8)))), hex"a3616263"), + + // str 8 + (List(MsgpackItem.Str(ByteVector("abcd".repeat(8).getBytes(StandardCharsets.UTF_8)))), + hex"d920" ++ ByteVector("abcd".repeat(8).getBytes(StandardCharsets.UTF_8))), + + // str 16 + (List(MsgpackItem.Str(ByteVector("a".repeat(Math.pow(2, 8).toInt).getBytes(StandardCharsets.UTF_8)))), + hex"da0100" ++ ByteVector("a".repeat(Math.pow(2, 8).toInt).getBytes(StandardCharsets.UTF_8))), + + // str 32 + (List(MsgpackItem.Str(ByteVector("a".repeat(Math.pow(2, 16).toInt).getBytes(StandardCharsets.UTF_8)))), + hex"db00010000" ++ ByteVector("a".repeat(Math.pow(2, 16).toInt).getBytes(StandardCharsets.UTF_8))), + + // bin 8 + (List(MsgpackItem.Bin(ByteVector("abcd".repeat(8).getBytes(StandardCharsets.UTF_8)))), + hex"c420" ++ ByteVector("abcd".repeat(8).getBytes(StandardCharsets.UTF_8))), + + // bin 16 + (List(MsgpackItem.Bin(ByteVector("a".repeat(Math.pow(2, 8).toInt).getBytes(StandardCharsets.UTF_8)))), + hex"c50100" ++ ByteVector("a".repeat(Math.pow(2, 8).toInt).getBytes(StandardCharsets.UTF_8))), + + // bin 32 + (List(MsgpackItem.Bin(ByteVector("a".repeat(Math.pow(2, 16).toInt).getBytes(StandardCharsets.UTF_8)))), + hex"c600010000" ++ ByteVector("a".repeat(Math.pow(2, 16).toInt).getBytes(StandardCharsets.UTF_8))), + + // fixarray + (List(MsgpackItem.Array(0)), hex"90"), + (List(MsgpackItem.Array(1)), hex"91"), + // array 16 + (List(MsgpackItem.Array(16)), hex"dc0010"), + // array 32 + (List(MsgpackItem.Array(Math.pow(2, 16).toLong)), hex"dd00010000"), + + // fixmap + (List(MsgpackItem.Map(0)), hex"80"), + (List(MsgpackItem.Map(1)), hex"81"), + // map 16 + (List(MsgpackItem.Map(16)), hex"de0010"), + // map 32 + (List(MsgpackItem.Map(Math.pow(2, 16).toLong)), hex"df00010000"), + + // fixext 1 + (List(MsgpackItem.Extension(0x54.toByte, hex"ab")), hex"d454ab"), + // fixext 2 + (List(MsgpackItem.Extension(0x54.toByte, hex"abcd")), hex"d554abcd"), + // fixext 4 + (List(MsgpackItem.Extension(0x54.toByte, hex"abcdef01")), hex"d654abcdef01"), + // fixext 8 + (List(MsgpackItem.Extension(0x54.toByte, hex"abcdef0123456789")), hex"d754abcdef0123456789"), + // fixext 8 + (List(MsgpackItem.Extension(0x54.toByte, hex"abcdef0123456789abcdef0123456789")), + hex"d854abcdef0123456789abcdef0123456789"), + + // ext 8 + (List(MsgpackItem.Extension(0x54, ByteVector.fill(17)(0xab))), hex"c71154" ++ ByteVector.fill(17)(0xab)), + + // ext 16 + (List(MsgpackItem.Extension(0x54, ByteVector.fill(Math.pow(2, 8).toLong)(0xab))), + hex"c8010054" ++ ByteVector.fill(Math.pow(2, 8).toLong)(0xab)), + + // ext 32 + (List(MsgpackItem.Extension(0x54, ByteVector.fill(Math.pow(2, 16).toLong)(0xab))), + hex"c90001000054" ++ ByteVector.fill(Math.pow(2, 16).toLong)(0xab)), + + // timestamp 32 + (List(MsgpackItem.Timestamp32(0x0123abcd)), hex"d6ff0123abcd"), + + // timestamp 64 + (List(MsgpackItem.Timestamp64(0x0123456789abcdefL)), hex"d7ff0123456789abcdef"), + + // timestamp 96 + (List(MsgpackItem.Timestamp96(0x0123abcd, 0x0123456789abcdefL)), hex"c70cff0123abcd0123456789abcdef") + ) + + Stream + .emits(cases) + .evalMap { case (source, serialized) => + Stream + .emits(source) + .through(low.toNonValidatedBinary) + .compile + .fold(ByteVector.empty)(_ :+ _) + .map(expect.same(_, serialized)) + + } + .compile + .foldMonoid + } + + test("MessagePack item serializer should be fixpoint for a subset of ByteVector") { + /* The parser mapping ByteVector to MsgpackItem can be seen as a not injective morphism, that is, there + * are many ByteVectors that will map to the same MsgpackItem. Because of this, we cannot possibly guarantee that + * `serialize(parse(bs))` is fixpoint for an arbitrary `bs`. However, currently implemented serializer *is* + * injective (if we exclude the Timestamp format family as it can be represented with Extension types) and so, we + * can guarantee `serialize(parse(bs)) == bs` if `bs` is a member of a subset of ByteVector that is emitted by a + * serializer. + * + * In other words, the following code will be true for any `bs` if `serialize` is injective and we ignore the + * Timestamp type family: + * {{{ + * val first = serialize(parse(bs)) + * val second = serialize(parse(first)) + * first == second + * }}} + * + * This test makes sure that the above holds. + */ + + val cases = List( + hex"918FA46461746582A662756666657282A474797065A6427566666572A4646174619401234567A474797065CCFFA35F6964B8363663316233363661333137353434376163346335343165A5696E64657800A467756964D92438666665653537302D353938312D346630362D623635382D653435383163363064373539A86973416374697665C3A762616C616E6365CB40A946956A97C84CA361676516A8657965436F6C6F72A4626C7565A46E616D65AD4D6F72746F6E204C6974746C65A761646472657373D9313933372044656172626F726E20436F7572742C204861726C656967682C204D6173736163687573657474732C2033353936AA72656769737465726564BA323032332D30382D32395431303A34353A3335202D30323A3030A86C61746974756465CB4047551159C49774A96C6F6E676974756465CBC065F94A771C970FA47461677397A54C6F72656DA3657374A86465736572756E74A54C6F72656DA46E697369A76C61626F726973A86465736572756E74A7667269656E64739382A2696400A46E616D65B04865726E616E64657A204C6172736F6E82A2696401A46E616D65AF4D616E6E696E672053617267656E7482A2696402A46E616D65AF536176616E6E6168204E65776D616E" + ) + + def round(data: ByteVector) = + Stream + .chunk(Chunk.byteVector(data)) + .through(low.items[IO]) + .through(low.toNonValidatedBinary) + .fold(ByteVector.empty)(_ :+ _) + + val out = for { + data <- Stream.emits(cases) + pre <- round(data) + processed <- round(pre) + } yield { + if (processed == pre) + success + else + failure(s"Serializer should be fixpoint for ${pre} but it emitted ${processed}") + } + + out.compile.foldMonoid + + } +} diff --git a/msgpack/src/test/scala/fs2/data/msgpack/ValidationSpec.scala b/msgpack/src/test/scala/fs2/data/msgpack/ValidationSpec.scala new file mode 100644 index 00000000..5edd065a --- /dev/null +++ b/msgpack/src/test/scala/fs2/data/msgpack/ValidationSpec.scala @@ -0,0 +1,115 @@ +/* + * Copyright 2024 fs2-data Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2 +package data +package msgpack + +import cats.effect._ +import low.MsgpackItem +import scodec.bits.ByteVector +import weaver._ +import scodec.bits._ +import cats.implicits._ + +object ValidationSpec extends SimpleIOSuite { + def validation1[F[_]: Sync](cases: (MsgpackItem, Throwable)*): F[Expectations] = + Stream + .emits(cases) + .evalMap { case (lhs, rhs) => + Stream + .emit(lhs) + .through(low.toBinary[F]) + .compile + .drain + .map(_ => failure(s"Expected error for item ${lhs}")) + .handleError(expect.same(_, rhs)) + } + .compile + .foldMonoid + + def validation[F[_]: Sync](cases: (List[MsgpackItem], Throwable)*): F[Expectations] = + Stream + .emits(cases) + .evalMap { case (lhs, rhs) => + Stream + .emits(lhs) + .through(low.toBinary[F]) + .compile + .drain + .map(_ => failure(s"Expected error for item ${lhs}")) + .handleError(expect.same(_, rhs)) + } + .compile + .foldMonoid + + test("should raise if integer values exceed 64 bits") { + validation1( + MsgpackItem.UnsignedInt(hex"10000000000000000") -> + MsgpackMalformedItemException("Unsigned int exceeds 64 bits", Some(0)), + MsgpackItem.SignedInt(hex"10000000000000000") -> + MsgpackMalformedItemException("Signed int exceeds 64 bits", Some(0)) + ) + } + + test("should raise if string or binary values exceed 2^32 - 1 bytes") { + validation1( + MsgpackItem.Str(ByteVector.empty.padLeft(Math.pow(2, 32).toLong)) -> + MsgpackMalformedItemException("String exceeds (2^32)-1 bytes", Some(0)), + MsgpackItem.Bin(ByteVector.empty.padLeft(Math.pow(2, 32).toLong)) -> + MsgpackMalformedItemException("Bin exceeds (2^32)-1 bytes", Some(0)) + ) + } + + test("should raise on unexpected end of input") { + validation( + List(MsgpackItem.Array(2), MsgpackItem.True) -> + MsgpackUnexpectedEndOfStreamException(Some(0)), + List(MsgpackItem.Array(2), MsgpackItem.Array(1), MsgpackItem.True) -> + MsgpackUnexpectedEndOfStreamException(Some(0)), + List(MsgpackItem.Array(1), MsgpackItem.Array(1)) -> + MsgpackUnexpectedEndOfStreamException(Some(1)), + List(MsgpackItem.Array(0), MsgpackItem.Array(1)) -> + MsgpackUnexpectedEndOfStreamException(Some(1)), + List(MsgpackItem.Map(1), MsgpackItem.True) -> + MsgpackUnexpectedEndOfStreamException(Some(0)), + List(MsgpackItem.Map(1), MsgpackItem.Map(1), MsgpackItem.True, MsgpackItem.True) -> + MsgpackUnexpectedEndOfStreamException(Some(0)), + List(MsgpackItem.Map(2), MsgpackItem.True, MsgpackItem.Map(1)) -> + MsgpackUnexpectedEndOfStreamException(Some(2)), + List(MsgpackItem.Map(2), MsgpackItem.True, MsgpackItem.Map(1)) -> + MsgpackUnexpectedEndOfStreamException(Some(2)), + List(MsgpackItem.Map(0), MsgpackItem.Map(1)) -> + MsgpackUnexpectedEndOfStreamException(Some(1)) + ) + } + + test("should raise if extension data exceeds 2^32 - 1 bytes") { + validation1( + MsgpackItem.Extension(0x54, ByteVector.empty.padLeft(Math.pow(2, 32).toLong)) -> + MsgpackMalformedItemException("Extension data exceeds (2^32)-1 bytes", Some(0)) + ) + } + + test("should raise if nanoseconds fields exceed 999999999") { + validation1( + MsgpackItem.Timestamp64(0xee6b280000000000L) -> + MsgpackMalformedItemException("Timestamp64 nanoseconds is larger than '999999999'", Some(0)), + MsgpackItem.Timestamp96(1000000000, 0) -> + MsgpackMalformedItemException("Timestamp96 nanoseconds is larger than '999999999'", Some(0)) + ) + } +}