diff --git a/core/src/main/scala/com/avsystem/commons/serialization/GenCodec.scala b/core/src/main/scala/com/avsystem/commons/serialization/GenCodec.scala index 123ca6728..b35f06563 100644 --- a/core/src/main/scala/com/avsystem/commons/serialization/GenCodec.scala +++ b/core/src/main/scala/com/avsystem/commons/serialization/GenCodec.scala @@ -312,11 +312,13 @@ object GenCodec extends RecursiveAutoCodecs with TupleGenCodecs { } trait SizedCodec[T] extends GenCodec[T] { - def size(value: T): Int + def size(value: T): Int = size(value, Opt.Empty) + + def size(value: T, output: Opt[SequentialOutput]): Int protected final def declareSizeFor(output: SequentialOutput, value: T): Unit = if (output.sizePolicy != SizePolicy.Ignored) { - output.declareSize(size(value)) + output.declareSize(size(value, output.opt)) } } @@ -336,8 +338,8 @@ object GenCodec extends RecursiveAutoCodecs with TupleGenCodecs { object OOOFieldsObjectCodec { // this was introduced so that transparent wrapper cases are possible in flat sealed hierarchies final class Transformed[A, B](val wrapped: OOOFieldsObjectCodec[B], onWrite: A => B, onRead: B => A) extends OOOFieldsObjectCodec[A] { - def size(value: A): Int = - wrapped.size(onWrite(value)) + def size(value: A, output: Opt[SequentialOutput]): Int = + wrapped.size(onWrite(value), output) def readObject(input: ObjectInput, outOfOrderFields: FieldValues): A = onRead(wrapped.readObject(input, outOfOrderFields)) diff --git a/core/src/main/scala/com/avsystem/commons/serialization/IgnoreTransientDefaultMarker.scala b/core/src/main/scala/com/avsystem/commons/serialization/IgnoreTransientDefaultMarker.scala new file mode 100644 index 000000000..a80b30808 --- /dev/null +++ b/core/src/main/scala/com/avsystem/commons/serialization/IgnoreTransientDefaultMarker.scala @@ -0,0 +1,15 @@ +package com.avsystem.commons +package serialization + +/** + * Instructs [[GenCodec]] to ignore the [[transientDefault]] annotation when serializing a case class. + * This ensures that even if a field's value is the same as its default, it will be included in the serialized + * representation. Deserialization behavior remains unchanged. If a field is missing from the input, the default + * value will be used as usual. + * + * This marker can be helpful when using the same model class in multiple contexts with different serialization + * formats that have conflicting requirements for handling default values. + * + * @see [[CustomMarkersOutputWrapper]] for an easy way to add markers to existing [[Output]] implementations + */ +object IgnoreTransientDefaultMarker extends CustomEventMarker[Unit] diff --git a/core/src/main/scala/com/avsystem/commons/serialization/cbor/CborOptimizedCodecs.scala b/core/src/main/scala/com/avsystem/commons/serialization/cbor/CborOptimizedCodecs.scala index d46a06454..a7e168778 100644 --- a/core/src/main/scala/com/avsystem/commons/serialization/cbor/CborOptimizedCodecs.scala +++ b/core/src/main/scala/com/avsystem/commons/serialization/cbor/CborOptimizedCodecs.scala @@ -87,7 +87,7 @@ class OOOFieldCborRawKeysCodec[T](stdObjectCodec: OOOFieldsObjectCodec[T], keyCo stdObjectCodec.writeFields(output, value) } - def size(value: T): Int = stdObjectCodec.size(value) + def size(value: T, output: Opt[SequentialOutput]): Int = stdObjectCodec.size(value, output) def nullable: Boolean = stdObjectCodec.nullable } diff --git a/core/src/main/scala/com/avsystem/commons/serialization/customMarkerWrappers.scala b/core/src/main/scala/com/avsystem/commons/serialization/customMarkerWrappers.scala new file mode 100644 index 000000000..ff3bcc2f6 --- /dev/null +++ b/core/src/main/scala/com/avsystem/commons/serialization/customMarkerWrappers.scala @@ -0,0 +1,107 @@ +package com.avsystem.commons +package serialization + +trait AcceptsAdditionalCustomMarkers extends AcceptsCustomEvents { + + protected def markers: Set[CustomEventMarker[?]] + + override def customEvent[T](marker: CustomEventMarker[T], event: T): Boolean = + markers(marker) || super.customEvent(marker, event) +} + +/** + * [[Input]] implementation that adds additional markers [[CustomEventMarker]] to the provided [[Input]] instance + */ +final class CustomMarkersInputWrapper private( + override protected val wrapped: Input, + override protected val markers: Set[CustomEventMarker[?]], +) extends InputWrapper with AcceptsAdditionalCustomMarkers { + + override def readList(): ListInput = + new CustomMarkersInputWrapper.AdjustedListInput(super.readList(), markers) + + override def readObject(): ObjectInput = + new CustomMarkersInputWrapper.AdjustedObjectInput(super.readObject(), markers) +} +object CustomMarkersInputWrapper { + def apply(input: Input, markers: CustomEventMarker[?]*): CustomMarkersInputWrapper = + CustomMarkersInputWrapper(input, markers.toSet) + + def apply(input: Input, markers: Set[CustomEventMarker[?]]): CustomMarkersInputWrapper = + new CustomMarkersInputWrapper(input, markers) + + private final class AdjustedListInput( + override protected val wrapped: ListInput, + override protected val markers: Set[CustomEventMarker[?]], + ) extends ListInputWrapper with AcceptsAdditionalCustomMarkers { + override def nextElement(): Input = new CustomMarkersInputWrapper(super.nextElement(), markers) + } + + private final class AdjustedFieldInput( + override protected val wrapped: FieldInput, + override protected val markers: Set[CustomEventMarker[?]], + ) extends FieldInputWrapper with AcceptsAdditionalCustomMarkers { + + override def readList(): ListInput = new AdjustedListInput(super.readList(), markers) + override def readObject(): ObjectInput = new AdjustedObjectInput(super.readObject(), markers) + } + + private final class AdjustedObjectInput( + override protected val wrapped: ObjectInput, + override protected val markers: Set[CustomEventMarker[?]], + ) extends ObjectInputWrapper with AcceptsAdditionalCustomMarkers { + + override def nextField(): FieldInput = new AdjustedFieldInput(super.nextField(), markers) + override def peekField(name: String): Opt[FieldInput] = + super.peekField(name).map(new AdjustedFieldInput(_, markers)) + } +} + +/** + * [[Output]] implementation that adds additional markers [[CustomEventMarker]] to the provided [[Output]] instance + */ +final class CustomMarkersOutputWrapper private( + override protected val wrapped: Output, + override protected val markers: Set[CustomEventMarker[?]], +) extends OutputWrapper with AcceptsAdditionalCustomMarkers { + + override def writeSimple(): SimpleOutput = + new CustomMarkersOutputWrapper.AdjustedSimpleOutput(super.writeSimple(), markers) + + override def writeList(): ListOutput = + new CustomMarkersOutputWrapper.AdjustedListOutput(super.writeList(), markers) + + override def writeObject(): ObjectOutput = + new CustomMarkersOutputWrapper.AdjustedObjectOutput(super.writeObject(), markers) +} + +object CustomMarkersOutputWrapper { + def apply(output: Output, markers: CustomEventMarker[?]*): CustomMarkersOutputWrapper = + CustomMarkersOutputWrapper(output, markers.toSet) + + def apply(output: Output, markers: Set[CustomEventMarker[?]]): CustomMarkersOutputWrapper = + new CustomMarkersOutputWrapper(output, markers) + + private final class AdjustedSimpleOutput( + override protected val wrapped: SimpleOutput, + override protected val markers: Set[CustomEventMarker[?]], + ) extends SimpleOutputWrapper with AcceptsAdditionalCustomMarkers + + private final class AdjustedListOutput( + override protected val wrapped: ListOutput, + override protected val markers: Set[CustomEventMarker[?]], + ) extends ListOutputWrapper with AcceptsAdditionalCustomMarkers { + + override def writeElement(): Output = + new CustomMarkersOutputWrapper(super.writeElement(), markers) + } + + private final class AdjustedObjectOutput( + override protected val wrapped: ObjectOutput, + override protected val markers: Set[CustomEventMarker[?]], + ) extends ObjectOutputWrapper with AcceptsAdditionalCustomMarkers { + + override def writeField(key: String): Output = + new CustomMarkersOutputWrapper(super.writeField(key), markers) + } +} diff --git a/core/src/main/scala/com/avsystem/commons/serialization/macroCodecs.scala b/core/src/main/scala/com/avsystem/commons/serialization/macroCodecs.scala index 9475bf5cf..77ced836b 100644 --- a/core/src/main/scala/com/avsystem/commons/serialization/macroCodecs.scala +++ b/core/src/main/scala/com/avsystem/commons/serialization/macroCodecs.scala @@ -12,7 +12,7 @@ class SingletonCodec[T <: Singleton]( ) extends ErrorReportingCodec[T] with OOOFieldsObjectCodec[T] { final def nullable = true final def readObject(input: ObjectInput, outOfOrderFields: FieldValues): T = singletonValue - def size(value: T): Int = 0 + def size(value: T, output: Opt[SequentialOutput]): Int = 0 def writeFields(output: ObjectOutput, value: T): Unit = () } @@ -109,7 +109,7 @@ abstract class ProductCodec[T <: Product]( nullable: Boolean, fieldNames: Array[String] ) extends ApplyUnapplyCodec[T](typeRepr, nullable, fieldNames) { - def size(value: T): Int = value.productArity + def size(value: T, output: Opt[SequentialOutput]): Int = value.productArity final def writeFields(output: ObjectOutput, value: T): Unit = { val size = value.productArity diff --git a/core/src/test/scala/com/avsystem/commons/serialization/IgnoreTransientDefaultMarkerTest.scala b/core/src/test/scala/com/avsystem/commons/serialization/IgnoreTransientDefaultMarkerTest.scala new file mode 100644 index 000000000..b2f6b9aef --- /dev/null +++ b/core/src/test/scala/com/avsystem/commons/serialization/IgnoreTransientDefaultMarkerTest.scala @@ -0,0 +1,75 @@ +package com.avsystem.commons +package serialization + +import com.avsystem.commons.serialization.CodecTestData.HasDefaults + +object IgnoreTransientDefaultMarkerTest { + final case class NestedHasDefaults( + @transientDefault flag: Boolean = false, + obj: HasDefaults, + list: Seq[HasDefaults], + @transientDefault defaultObj: HasDefaults = HasDefaults(), + ) + object NestedHasDefaults extends HasGenCodec[NestedHasDefaults] + + final case class HasOptParam( + @transientDefault flag: Boolean = false, + @optionalParam str: Opt[String] = Opt.Empty, + ) + object HasOptParam extends HasGenCodec[HasOptParam] +} + +class IgnoreTransientDefaultMarkerTest extends AbstractCodecTest { + import IgnoreTransientDefaultMarkerTest.* + + override type Raw = Any + + def writeToOutput(write: Output => Unit): Any = { + var result: Any = null + write(CustomMarkersOutputWrapper(new SimpleValueOutput(v => result = v), IgnoreTransientDefaultMarker)) + result + } + + def createInput(raw: Any): Input = + CustomMarkersInputWrapper(new SimpleValueInput(raw), IgnoreTransientDefaultMarker) + + test("write case class with default values") { + testWrite(HasDefaults(str = "lol"), Map("str" -> "lol", "int" -> 42)) + testWrite(HasDefaults(43, "lol"), Map("int" -> 43, "str" -> "lol")) + testWrite(HasDefaults(str = null), Map("str" -> null, "int" -> 42)) + testWrite(HasDefaults(str = "dafuq"), Map("str" -> "dafuq", "int" -> 42)) + } + + //noinspection RedundantDefaultArgument + test("read case class with default values") { + testRead(Map("str" -> "lol", "int" -> 42), HasDefaults(str = "lol", int = 42)) + testRead(Map("str" -> "lol"), HasDefaults(str = "lol", int = 42)) + testRead(Map("int" -> 43, "str" -> "lol"), HasDefaults(int = 43, str = "lol")) + testRead(Map("str" -> null, "int" -> 42), HasDefaults(str = null, int = 42)) + testRead(Map("str" -> null), HasDefaults(str = null, int = 42)) + testRead(Map(), HasDefaults(str = "dafuq", int = 42)) + } + + test("write case class with opt values") { + testWrite(HasOptParam(str = "lol".opt), Map("flag" -> false, "str" -> "lol")) + testWrite(HasOptParam(), Map("flag" -> false)) + } + + //noinspection RedundantDefaultArgument + test("write nested case class with default values") { + testWrite( + value = NestedHasDefaults( + flag = false, + obj = HasDefaults(str = "lol"), + list = Seq(HasDefaults(int = 43)), + defaultObj = HasDefaults(), + ), + expectedRepr = Map( + "flag" -> false, + "defaultObj" -> Map[String, Any]("str" -> "kek", "int" -> 42), + "obj" -> Map[String, Any]("str" -> "lol", "int" -> 42), + "list" -> List(Map[String, Any]("str" -> "kek", "int" -> 43)), + ), + ) + } +} diff --git a/core/src/test/scala/com/avsystem/commons/serialization/ObjectSizeTest.scala b/core/src/test/scala/com/avsystem/commons/serialization/ObjectSizeTest.scala index 62faf1ac4..35fd3b16c 100644 --- a/core/src/test/scala/com/avsystem/commons/serialization/ObjectSizeTest.scala +++ b/core/src/test/scala/com/avsystem/commons/serialization/ObjectSizeTest.scala @@ -3,7 +3,7 @@ package serialization import org.scalatest.funsuite.AnyFunSuite -case class RecordWithDefaults( +final case class RecordWithDefaults( @transientDefault a: String = "", b: Int = 42 ) { @@ -11,7 +11,7 @@ case class RecordWithDefaults( } object RecordWithDefaults extends HasApplyUnapplyCodec[RecordWithDefaults] -class CustomRecordWithDefaults(val a: String, val b: Int) +final class CustomRecordWithDefaults(val a: String, val b: Int) object CustomRecordWithDefaults extends HasApplyUnapplyCodec[CustomRecordWithDefaults] { def apply(@transientDefault a: String = "", b: Int = 42): CustomRecordWithDefaults = new CustomRecordWithDefaults(a, b) @@ -19,19 +19,61 @@ object CustomRecordWithDefaults extends HasApplyUnapplyCodec[CustomRecordWithDef Opt((crwd.a, crwd.b)) } -class CustomWrapper(val a: String) +final class CustomWrapper(val a: String) object CustomWrapper extends HasApplyUnapplyCodec[CustomWrapper] { def apply(@transientDefault a: String = ""): CustomWrapper = new CustomWrapper(a) def unapply(cw: CustomWrapper): Opt[String] = Opt(cw.a) } +final case class RecordWithOpts( + @optionalParam abc: Opt[String] = Opt.Empty, + @transientDefault flag: Opt[Boolean] = Opt.Empty, + b: Int = 42, +) +object RecordWithOpts extends HasApplyUnapplyCodec[RecordWithOpts] + +final case class SingleFieldRecordWithOpts(@optionalParam abc: Opt[String] = Opt.Empty) +object SingleFieldRecordWithOpts extends HasApplyUnapplyCodec[SingleFieldRecordWithOpts] + +final case class SingleFieldRecordWithTD(@transientDefault abc: String = "abc") +object SingleFieldRecordWithTD extends HasApplyUnapplyCodec[SingleFieldRecordWithTD] + class ObjectSizeTest extends AnyFunSuite { test("computing object size") { assert(RecordWithDefaults.codec.size(RecordWithDefaults()) == 2) assert(RecordWithDefaults.codec.size(RecordWithDefaults("fuu")) == 3) + assert(RecordWithOpts.codec.size(RecordWithOpts("abc".opt)) == 2) + assert(RecordWithOpts.codec.size(RecordWithOpts("abc".opt, true.opt)) == 3) + assert(RecordWithOpts.codec.size(RecordWithOpts()) == 1) + assert(SingleFieldRecordWithOpts.codec.size(SingleFieldRecordWithOpts()) == 0) + assert(SingleFieldRecordWithOpts.codec.size(SingleFieldRecordWithOpts("abc".opt)) == 1) + assert(SingleFieldRecordWithTD.codec.size(SingleFieldRecordWithTD()) == 0) + assert(SingleFieldRecordWithTD.codec.size(SingleFieldRecordWithTD("haha")) == 1) assert(CustomRecordWithDefaults.codec.size(CustomRecordWithDefaults()) == 1) assert(CustomRecordWithDefaults.codec.size(CustomRecordWithDefaults("fuu")) == 2) assert(CustomWrapper.codec.size(CustomWrapper()) == 0) assert(CustomWrapper.codec.size(CustomWrapper("fuu")) == 1) } + + test("computing object size with custom output") { + val defaultIgnoringOutput = new SequentialOutput { + override def customEvent[T](marker: CustomEventMarker[T], event: T): Boolean = + marker match { + case IgnoreTransientDefaultMarker => true + case _ => super.customEvent(marker, event) + } + override def finish(): Unit = () + } + assert(RecordWithDefaults.codec.size(RecordWithDefaults(), defaultIgnoringOutput.opt) == 3) + assert(RecordWithDefaults.codec.size(RecordWithDefaults("fuu"), defaultIgnoringOutput.opt) == 3) + assert(RecordWithOpts.codec.size(RecordWithOpts("abc".opt), defaultIgnoringOutput.opt) == 3) + assert(RecordWithOpts.codec.size(RecordWithOpts("abc".opt, true.opt), defaultIgnoringOutput.opt) == 3) + assert(RecordWithOpts.codec.size(RecordWithOpts(), defaultIgnoringOutput.opt) == 2) + assert(SingleFieldRecordWithOpts.codec.size(SingleFieldRecordWithOpts(), defaultIgnoringOutput.opt) == 0) // @optionalParam field should NOT be counted + assert(SingleFieldRecordWithOpts.codec.size(SingleFieldRecordWithOpts("abc".opt), defaultIgnoringOutput.opt) == 1) + assert(SingleFieldRecordWithTD.codec.size(SingleFieldRecordWithTD(), defaultIgnoringOutput.opt) == 1) // @transientDefault field should be counted + assert(SingleFieldRecordWithTD.codec.size(SingleFieldRecordWithTD("haha"), defaultIgnoringOutput.opt) == 1) + assert(CustomRecordWithDefaults.codec.size(CustomRecordWithDefaults(), defaultIgnoringOutput.opt) == 2) + assert(CustomRecordWithDefaults.codec.size(CustomRecordWithDefaults("fuu"), defaultIgnoringOutput.opt) == 2) + } } diff --git a/core/src/test/scala/com/avsystem/commons/serialization/cbor/CborInputOutputTest.scala b/core/src/test/scala/com/avsystem/commons/serialization/cbor/CborInputOutputTest.scala index fa77ef529..96798f1fb 100644 --- a/core/src/test/scala/com/avsystem/commons/serialization/cbor/CborInputOutputTest.scala +++ b/core/src/test/scala/com/avsystem/commons/serialization/cbor/CborInputOutputTest.scala @@ -10,7 +10,7 @@ import org.scalatest.funsuite.AnyFunSuite import java.io.{ByteArrayOutputStream, DataOutputStream} -case class Record( +final case class Record( b: Boolean, i: Int, l: List[String], @@ -19,7 +19,7 @@ case class Record( ) object Record extends HasGenCodec[Record] -case class CustomKeysRecord( +final case class CustomKeysRecord( @cborKey(1) first: Int, @cborKey(true) second: Boolean, @cborKey(Vector(1, 2, 3)) third: String, @@ -28,6 +28,18 @@ case class CustomKeysRecord( ) object CustomKeysRecord extends HasCborCodec[CustomKeysRecord] +final case class CustomKeysRecordWithDefaults( + @transientDefault @cborKey(1) first: Int = 0, + @cborKey(true) second: Boolean, +) +object CustomKeysRecordWithDefaults extends HasCborCodec[CustomKeysRecordWithDefaults] + +final case class CustomKeysRecordWithNoDefaults( + @cborKey(1) first: Int = 0, + @cborKey(true) second: Boolean, +) +object CustomKeysRecordWithNoDefaults extends HasCborCodec[CustomKeysRecordWithNoDefaults] + @cborDiscriminator(0) sealed trait GenericSealedTrait[+T] object GenericSealedTrait extends HasPolyCborCodec[GenericSealedTrait] { @@ -61,14 +73,22 @@ class CborInputOutputTest extends AnyFunSuite { keyCodec: CborKeyCodec = CborKeyCodec.Default )(implicit pos: Position): Unit = test(s"${pos.lineNumber}: $value") { - val baos = new ByteArrayOutputStream - val output = new CborOutput(new DataOutputStream(baos), keyCodec, SizePolicy.Optional) - GenCodec.write[T](output, value) - val bytes = baos.toByteArray - assert(Bytes(bytes).toString == binary) - assert(RawCbor(bytes).readAs[T](keyCodec) == value) + assertRoundtrip(value, binary, keyCodec) } + private def assertRoundtrip[T: GenCodec]( + value: T, + binary: String, + keyCodec: CborKeyCodec = CborKeyCodec.Default + )(implicit pos: Position): Unit = { + val baos = new ByteArrayOutputStream + val output = new CborOutput(new DataOutputStream(baos), keyCodec, SizePolicy.Optional) + GenCodec.write[T](output, value) + val bytes = baos.toByteArray + assert(Bytes(bytes).toString == binary) + assert(RawCbor(bytes).readAs[T](keyCodec) == value) + } + // binary representation from cbor.me roundtrip(null, "F6") @@ -213,6 +233,24 @@ class CborInputOutputTest extends AnyFunSuite { """{"first":42,"second":true,"third":"foo","strMap":{"foo":1},"intMap":{"1":"foo"}}""") } + test("writing with IgnoreTransientDefaultMarker to CBOR output") { + val baos = new ByteArrayOutputStream + val output = CustomMarkersOutputWrapper( + new CborOutput(new DataOutputStream(baos), keyCodec, SizePolicy.Optional), + IgnoreTransientDefaultMarker, + ) + val value = CustomKeysRecordWithDefaults(first = 0, second = true) + GenCodec.write(output, value) + val bytes = Bytes(baos.toByteArray) + + val expectedRawValue = "A20100F5F5" + assert(bytes.toString == expectedRawValue) + assert(RawCbor(bytes.bytes).readAs[CustomKeysRecordWithDefaults](keyCodec) == value) + + // should be the same as model with @transientDefault and serialization ignoring it + assertRoundtrip(CustomKeysRecordWithNoDefaults(first = 0, second = true), expectedRawValue) + } + test("chunked text string") { assert(CborInput.readRawCbor[String](RawCbor.fromHex("7F626162626162626162FF")) == "ababab") } diff --git a/macros/src/main/scala/com/avsystem/commons/macros/serialization/GenCodecMacros.scala b/macros/src/main/scala/com/avsystem/commons/macros/serialization/GenCodecMacros.scala index 3ae6fc7a3..36a92e5bb 100644 --- a/macros/src/main/scala/com/avsystem/commons/macros/serialization/GenCodecMacros.scala +++ b/macros/src/main/scala/com/avsystem/commons/macros/serialization/GenCodecMacros.scala @@ -10,6 +10,8 @@ class GenCodecMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) with import c.universe._ + private def IgnoreTransientDefaultMarkerObj: Tree = q"$SerializationPkg.IgnoreTransientDefaultMarker" + override def allowOptionalParams: Boolean = true def mkTupleCodec[T: WeakTypeTag](elementCodecs: Tree*): Tree = instrument { @@ -120,7 +122,7 @@ class GenCodecMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) with q""" new $SerializationPkg.SingletonCodec[$tpe](${tpe.toString}, $safeSingleValue) { ..${generated.map({ case (sym, depTpe) => generatedDepDeclaration(sym, depTpe) })} - override def size(value: $tpe): $IntCls = ${generated.size} + override def size(value: $tpe, output: $OptCls[$SerializationPkg.SequentialOutput]): $IntCls = ${generated.size} override def writeFields(output: $SerializationPkg.ObjectOutput, value: $tpe): $UnitCls = { ..${generated.map({ case (sym, _) => generatedWrite(sym) })} } @@ -172,16 +174,49 @@ class GenCodecMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) with } } - def writeField(p: ApplyParam, value: Tree): Tree = { - val transientValue = - if (isTransientDefault(p)) Some(p.defaultValue) - else p.optionLike.map(ol => q"${ol.reference(Nil)}.none") - + def doWriteField(p: ApplyParam, value: Tree, transientValue: Option[Tree]): Tree = { val writeArgs = q"output" :: q"${p.idx}" :: value :: transientValue.toList val writeTargs = if (isOptimizedPrimitive(p)) Nil else List(p.valueType) q"writeField[..$writeTargs](..$writeArgs)" } + def writeFieldNoTransientDefault(p: ApplyParam, value: Tree): Tree = { + val transientValue = p.optionLike.map(ol => q"${ol.reference(Nil)}.none") + doWriteField(p, value, transientValue) + } + + def writeFieldTransientDefaultPossible(p: ApplyParam, value: Tree): Tree = + if (isTransientDefault(p)) doWriteField(p, value, Some(p.defaultValue)) + else writeFieldNoTransientDefault(p, value) + + def writeField(p: ApplyParam, value: Tree, ignoreTransientDefault: Tree): Tree = + if (isTransientDefault(p)) // optimize code to avoid calling 'output.customEvent' when param does not have @transientDefault + q""" + if($ignoreTransientDefault) ${writeFieldNoTransientDefault(p, value)} + else ${writeFieldTransientDefaultPossible(p, value)} + """ + else + writeFieldNoTransientDefault(p, value) + + def ignoreTransientDefaultCheck: Tree = + q"output.customEvent($IgnoreTransientDefaultMarkerObj, ())" + + // when params size is 1 + def writeSingle(p: ApplyParam, value: Tree): Tree = + writeField(p, value, ignoreTransientDefaultCheck) + + // when params size is greater than 1 + def writeMultiple(value: ApplyParam => Tree): Tree = + // optimize code to avoid calling 'output.customEvent' when there no params with @transientDefault + // extracted to `val` to avoid calling 'output.customEvent' multiple times + if (anyParamHasTransientDefault) + q""" + val ignoreTransientDefault = $ignoreTransientDefaultCheck + ..${params.map(p => writeField(p, value(p), q"ignoreTransientDefault"))} + """ + else + q"..${params.map(p => writeFieldNoTransientDefault(p, value(p)))}" + def writeFields: Tree = params match { case Nil => if (canUseFields) @@ -194,57 +229,90 @@ class GenCodecMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) with """ case List(p: ApplyParam) => if (canUseFields) - writeField(p, q"value.${p.sym.name}") + q"${writeSingle(p, q"value.${p.sym.name}")}" else q""" val unapplyRes = $companion.$unapply[..${dtpe.typeArgs}](value) - if(unapplyRes.isEmpty) unapplyFailed else ${writeField(p, q"unapplyRes.get")} + if(unapplyRes.isEmpty) unapplyFailed + else ${writeSingle(p, q"unapplyRes.get")} """ case _ => if (canUseFields) - q"..${params.map(p => writeField(p, q"value.${p.sym.name}"))}" + q"${writeMultiple(p => q"value.${p.sym.name}")}" else q""" val unapplyRes = $companion.$unapply[..${dtpe.typeArgs}](value) - if(unapplyRes.isEmpty) unapplyFailed else { + if(unapplyRes.isEmpty) unapplyFailed + else { val t = unapplyRes.get - ..${params.map(p => writeField(p, q"t.${tupleGet(p.idx)}"))} + ${writeMultiple(p => q"t.${tupleGet(p.idx)}")} } """ } + def anyParamHasTransientDefault: Boolean = + params.exists(isTransientDefault) + + def isOptionLike(p: ApplyParam): Boolean = + p.optionLike.nonEmpty + def mayBeTransient(p: ApplyParam): Boolean = - p.optionLike.nonEmpty || isTransientDefault(p) + isOptionLike(p) || isTransientDefault(p) def transientValue(p: ApplyParam): Tree = p.optionLike match { case Some(optionLike) => q"${optionLike.reference(Nil)}.none" case None => p.defaultValue } - def countTransientFields: Tree = + // assumes usage in SizedCodec.size(value, output) method implementation + def countTransientFields: Tree = { + def checkIgnoreTransientDefaultMarker: Tree = + q"output.isDefined && output.get.customEvent($IgnoreTransientDefaultMarkerObj, ())" + + def doCount(paramsToCount: List[ApplyParam], accessor: ApplyParam => Tree): Tree = + paramsToCount.foldLeft[Tree](q"0") { + (acc, p) => q"$acc + (if(${accessor(p)} == ${transientValue(p)}) 1 else 0)" + } + + def countOnlyOptionLike(accessor: ApplyParam => Tree): Tree = + doCount(params.filter(isOptionLike), accessor) + + def countTransient(accessor: ApplyParam => Tree): Tree = + doCount(params.filter(mayBeTransient), accessor) + + def countMultipleParams(accessor: ApplyParam => Tree): Tree = + if (anyParamHasTransientDefault) + q"if($checkIgnoreTransientDefaultMarker) ${countOnlyOptionLike(accessor)} else ${countTransient(accessor)}" + else + countTransient(accessor) + + def countSingleParam(param: ApplyParam, value: Tree): Tree = + if (isTransientDefault(param)) + q"if(!$checkIgnoreTransientDefaultMarker && $value == ${transientValue(param)}) 1 else 0" + else + q"if($value == ${transientValue(param)}) 1 else 0" + if (canUseFields) - params.filter(mayBeTransient).foldLeft[Tree](q"0") { - (acc, p) => q"$acc + (if(value.${p.sym.name} == ${transientValue(p)}) 1 else 0)" + countMultipleParams(p => q"value.${p.sym.name}") + else if (!params.exists(mayBeTransient)) + q"0" + else + params match { + case List(p: ApplyParam) => + q""" + val unapplyRes = $companion.$unapply[..${dtpe.typeArgs}](value) + if(unapplyRes.isEmpty) unapplyFailed else { ${countSingleParam(p, q"unapplyRes.get")} } + """ + case _ => + q""" + val unapplyRes = $companion.$unapply[..${dtpe.typeArgs}](value) + if(unapplyRes.isEmpty) unapplyFailed else { + val t = unapplyRes.get + ${countMultipleParams(p => q"t.${tupleGet(p.idx)}")} + } + """ } - else if (!params.exists(mayBeTransient)) q"0" - else params match { - case List(p: ApplyParam) => - q""" - val unapplyRes = $companion.$unapply[..${dtpe.typeArgs}](value) - if(unapplyRes.isEmpty) unapplyFailed else if(unapplyRes.get == ${transientValue(p)}) 1 else 0 - """ - case _ => - val res = params.filter(mayBeTransient).foldLeft[Tree](q"0") { - (acc, p) => q"$acc + (if(t.${tupleGet(p.idx)} == ${transientValue(p)}) 1 else 0)" - } - q""" - val unapplyRes = $companion.$unapply[..${dtpe.typeArgs}](value) - if(unapplyRes.isEmpty) unapplyFailed else { - val t = unapplyRes.get - $res - } - """ - } + } if (isTransparent(dtpe.typeSymbol)) params match { case List(p: ApplyParam) => @@ -292,8 +360,8 @@ class GenCodecMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) with def sizeMethod: List[Tree] = if (useProductCodec) Nil else { val res = q""" - def size(value: $dtpe): $IntCls = - ${params.size} + ${generated.size} - $countTransientFields + def size(value: $dtpe, output: $OptCls[$SerializationPkg.SequentialOutput]): $IntCls = + ${params.size} + ${generated.size} - { $countTransientFields } """ List(res) }