From cc544fa3d26237626a9a12fd66a8d904bcdde3cc Mon Sep 17 00:00:00 2001 From: Timo Stamm Date: Fri, 10 May 2024 16:58:42 +0200 Subject: [PATCH] Speed up reflect with classes --- packages/protobuf-bench/README.md | 2 +- .../protobuf/src/reflect/reflect-types.ts | 27 +- packages/protobuf/src/reflect/reflect.ts | 585 ++++++++++-------- 3 files changed, 365 insertions(+), 249 deletions(-) diff --git a/packages/protobuf-bench/README.md b/packages/protobuf-bench/README.md index 3fd282b9d..c070ad41e 100644 --- a/packages/protobuf-bench/README.md +++ b/packages/protobuf-bench/README.md @@ -10,5 +10,5 @@ server would usually do. | code generator | bundle size | minified | compressed | |---------------------|------------------------:|-----------------------:|-------------------:| -| protobuf-es | 126,105 b | 65,235 b | 15,901 b | +| protobuf-es | 126,685 b | 66,254 b | 15,996 b | | protobuf-javascript | 394,384 b | 288,654 b | 45,122 b | diff --git a/packages/protobuf/src/reflect/reflect-types.ts b/packages/protobuf/src/reflect/reflect-types.ts index 84e200bd4..cebed3f68 100644 --- a/packages/protobuf/src/reflect/reflect-types.ts +++ b/packages/protobuf/src/reflect/reflect-types.ts @@ -144,7 +144,7 @@ export interface ReflectMessage { */ addListItem( field: Field, - value: NewListItem, + value: ReflectAddListItemValue, ): FieldError | undefined; /** @@ -153,7 +153,7 @@ export interface ReflectMessage { setMapEntry( field: Field, key: MapEntryKey, - value: NewMapEntryValue, + value: ReflectSetMapEntryValue, ): FieldError | undefined; /** @@ -264,12 +264,18 @@ export interface ReflectMap [unsafeLocal]: Record; } +/** + * A ReflectMap key. + */ export type MapEntryKey = string | number | bigint | boolean; type enumVal = number; +/** + * The return type of ReflectMessage.get() + */ // prettier-ignore -type ReflectGetValue = ( +export type ReflectGetValue = ( Field extends { fieldKind: "map" } ? ( Field extends { mapKind: "message" } ? ReflectMap : Field extends { mapKind: "enum" } ? ReflectMap : @@ -283,8 +289,11 @@ type ReflectGetValue = ( never ); +/** + * The type of the "value" argument of ReflectMessage.set() + */ // prettier-ignore -type ReflectSetValue = ( +export type ReflectSetValue = ( Field extends { fieldKind: "map" } ? ReflectMap : Field extends { fieldKind: "list" } ? ReflectList : Field extends { fieldKind: "enum" } ? number : @@ -293,16 +302,22 @@ type ReflectSetValue = ( never ); +/** + * The type of the "value" argument of ReflectMessage.addListItem() + */ // prettier-ignore -type NewListItem = ( +export type ReflectAddListItemValue = ( Field extends { listKind: "scalar"; scalar: infer T } ? ScalarValue : Field extends { listKind: "enum" } ? enumVal : Field extends { listKind: "message" } ? ReflectMessage : never ); +/** + * The type of the "value" argument of ReflectMessage.setMapEntry() + */ // prettier-ignore -type NewMapEntryValue = ( +export type ReflectSetMapEntryValue = ( Field extends { mapKind: "enum" } ? enumVal : Field extends { mapKind: "message" } ? ReflectMessage : Field extends { mapKind: "scalar"; scalar: infer T } ? ScalarValue : diff --git a/packages/protobuf/src/reflect/reflect.ts b/packages/protobuf/src/reflect/reflect.ts index ee160f530..61e7778b4 100644 --- a/packages/protobuf/src/reflect/reflect.ts +++ b/packages/protobuf/src/reflect/reflect.ts @@ -13,14 +13,18 @@ // limitations under the License. import type { DescField, DescMessage, DescOneof } from "../desc-types.js"; -import type { Message, MessageShape } from "../types.js"; +import type { Message, MessageShape, UnknownField } from "../types.js"; import { checkField, checkListItem, checkMapEntry } from "./reflect-check.js"; import { FieldError } from "./error.js"; import type { MapEntryKey, + ReflectAddListItemValue, + ReflectSetMapEntryValue, + ReflectGetValue, ReflectList, ReflectMap, ReflectMessage, + ReflectSetValue, } from "./reflect-types.js"; import { unsafeAddListItem, @@ -60,142 +64,192 @@ export function reflect( disableFieldValueCheck?: boolean; }, ): ReflectMessage { - message ??= create(messageDesc); - const check = opt?.disableFieldValueCheck !== true; - let fieldsByNumber: Map | undefined; - let sortedFields: DescField[] | undefined; - return { + return new ReflectMessageImpl( + messageDesc, message, - [unsafeLocal]: message, - desc: messageDesc, - fields: messageDesc.fields, - oneofs: messageDesc.oneofs, - members: messageDesc.members, + opt?.disableFieldValueCheck !== true, + ); +} - get sortedFields() { - return ( - sortedFields ?? - (sortedFields = messageDesc.fields - .concat() - .sort((a, b) => a.number - b.number)) - ); - }, +class ReflectMessageImpl implements ReflectMessage { + readonly desc: DescMessage; + readonly fields: readonly DescField[]; + get sortedFields() { + return ( + this._sortedFields ?? + (this._sortedFields = this.desc.fields + .concat() + .sort((a, b) => a.number - b.number)) + ); + } + readonly members: readonly (DescField | DescOneof)[]; + readonly message: Message; + readonly oneofs: readonly DescOneof[]; + readonly [unsafeLocal]: Message; + private readonly check: boolean; + private _fieldsByNumber: Map | undefined; + private _sortedFields: DescField[] | undefined; - findNumber(number) { - if (!fieldsByNumber) { - fieldsByNumber = new Map( - messageDesc.fields.map((f) => [f.number, f]), - ); - } - return fieldsByNumber.get(number); - }, + constructor( + messageDesc: Desc, + message?: MessageShape, + // TODO either remove this option, or support it in reflect-list and reflect-map as well + check = true, + ) { + this.check = check; + this.desc = messageDesc; + this.message = this[unsafeLocal] = message ?? create(messageDesc); + this.fields = messageDesc.fields; + this.oneofs = messageDesc.oneofs; + this.members = messageDesc.members; + } - oneofCase(oneof) { - assertOwn(message, oneof); - return unsafeOneofCase(message, oneof); - }, + findNumber(number: number): DescField | undefined { + if (!this._fieldsByNumber) { + this._fieldsByNumber = new Map( + this.desc.fields.map((f) => [f.number, f]), + ); + } + return this._fieldsByNumber.get(number); + } - isSet(field) { - assertOwn(message, field); - return unsafeIsSet(message, field); - }, + oneofCase(oneof: DescOneof): DescField | undefined { + assertOwn(this.message, oneof); + return unsafeOneofCase(this.message, oneof); + } - clear(field: DescField) { - assertOwn(message, field); - unsafeClear(message, field); - }, + isSet(field: DescField): boolean { + assertOwn(this.message, field); + return unsafeIsSet(this.message, field); + } - get(field) { - assertOwn(message, field); - let value = unsafeGet(message, field); - switch (field.fieldKind) { - case "list": - return reflectList(field, value as unknown[]); - case "map": - // TODO fix types - // eslint-disable-next-line @typescript-eslint/no-unsafe-return - return reflectMap(field, value as Record) as any; // eslint-disable-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-return - case "message": - if ( - value !== undefined && - !field.oneof && - isWrapperDesc(field.message) - ) { - value = { - $typeName: field.message.typeName, - value: longToReflect(field.message.fields[0], value), - } satisfies Message & { value: unknown }; - } - return reflect(field.message, value as Message); - case "scalar": - return value === undefined - ? scalarZeroValue(field.scalar, LongType.BIGINT) - : longToReflect(field, value); - case "enum": - return value ?? field.enum.values[0].number; - } - }, + clear(field: DescField): void { + assertOwn(this.message, field); + unsafeClear(this.message, field); + } - set(field, value) { - assertOwn(message, field); - if (check) { - const err = checkField(field, value); - if (err) { - return err; + get(field: Field): ReflectGetValue { + assertOwn(this.message, field); + let value = unsafeGet(this.message, field); + switch (field.fieldKind) { + case "list": + return new ReflectListImpl( + field, + value as unknown[], + this.check, + ) as ReflectGetValue; + case "map": + // TODO fix types + // eslint-disable-next-line @typescript-eslint/no-unsafe-return + return reflectMap( + field, + value as Record, + this.check, + ) as any; // eslint-disable-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-return + case "message": + if ( + value !== undefined && + !field.oneof && + isWrapperDesc(field.message) + ) { + value = { + $typeName: field.message.typeName, + value: longToReflect(field.message.fields[0], value), + } satisfies Message & { value: unknown }; } + return new ReflectMessageImpl( + field.message, + value as Message | undefined, + this.check, + ) as ReflectGetValue; + case "scalar": + return ( + value === undefined + ? scalarZeroValue(field.scalar, LongType.BIGINT) + : longToReflect(field, value) + ) as ReflectGetValue; + case "enum": + return (value ?? field.enum.values[0].number) as ReflectGetValue; + } + } + + set( + field: Field, + value: ReflectSetValue, + ): FieldError | undefined { + assertOwn(this.message, field); + if (this.check) { + const err = checkField(field, value); + if (err) { + return err; } - let local: unknown = value; - if (isReflectMap(value) || isReflectList(value)) { - local = value[unsafeLocal]; - } else if (isReflectMessage(value)) { - const msg = value.message; - local = !field.oneof && isWrapper(msg) ? msg.value : msg; - } else { - local = longToLocal(field, value); - } - unsafeSet(message, field, local); - return undefined; - }, + } + let local: unknown = value; + if (isReflectMap(value) || isReflectList(value)) { + local = value[unsafeLocal]; + } else if (isReflectMessage(value)) { + const msg = value.message; + local = !field.oneof && isWrapper(msg) ? msg.value : msg; + } else { + local = longToLocal(field, value); + } + unsafeSet(this.message, field, local); + return undefined; + } - addListItem(field, value) { - assertOwn(message, field); - assertKind(field, "list"); - if (check) { - if (checkListItem(field, 0, value)) { - const arr = unsafeGet(message, field) as unknown[]; - return checkListItem(field, arr.length, value); - } + addListItem< + Field extends DescField & { + fieldKind: "list"; + }, + >( + field: Field, + value: ReflectAddListItemValue, + ): FieldError | undefined { + assertOwn(this.message, field); + assertKind(field, "list"); + if (this.check) { + if (checkListItem(field, 0, value)) { + const arr = unsafeGet(this.message, field) as unknown[]; + return checkListItem(field, arr.length, value); } - unsafeAddListItem(message, field, listItemToLocal(field, value)); - return undefined; - }, + } + unsafeAddListItem(this.message, field, listItemToLocal(field, value)); + return undefined; + } - setMapEntry(field, key, value) { - assertOwn(message, field); - assertKind(field, "map"); - if (check) { - const err = checkMapEntry(field, key, value); - if (err) { - return err; - } + setMapEntry< + Field extends DescField & { + fieldKind: "map"; + }, + >( + field: Field, + key: MapEntryKey, + value: ReflectSetMapEntryValue, + ): FieldError | undefined { + assertOwn(this.message, field); + assertKind(field, "map"); + if (this.check) { + const err = checkMapEntry(field, key, value); + if (err) { + return err; } - unsafeSetMapEntry( - message, - field, - mapKeyToLocal(key), - mapValueToLocal(field, value), - ); - return undefined; - }, + } + unsafeSetMapEntry( + this.message, + field, + mapKeyToLocal(key), + mapValueToLocal(field, value), + ); + return undefined; + } - getUnknown() { - return message.$unknown; - }, + getUnknown(): UnknownField[] | undefined { + return this.message.$unknown; + } - setUnknown(value) { - message.$unknown = value; - }, - }; + setUnknown(value: UnknownField[]): void { + this.message.$unknown = value; + } } function assertKind(field: DescField, kind: DescField["fieldKind"]) { @@ -224,64 +278,88 @@ function assertOwn(owner: Message, member: DescField | DescOneof) { export function reflectList( field: DescField & { fieldKind: "list" }, unsafeInput?: unknown[], + check = true, ): ReflectList { - const arr = unsafeInput ?? []; - return { - [unsafeLocal]: arr, - field() { - return field; - }, - get size() { - return arr.length; - }, - get(index) { - const item = arr[index]; - return item === undefined - ? undefined - : (listItemToReflect(field, item) as V); - }, - set(index, item) { - if (index < 0 || index >= arr.length) { - return new FieldError(field, `list item #${index + 1}: out of range`); - } - const err = checkListItem(field, index, item); - if (!err) { - arr[index] = listItemToLocal(field, item); - } - return err; - }, - add(...items) { - let err: FieldError | undefined; - for (let i = 0; i < items.length && !err; i++) { - err = checkListItem(field, arr.length + i, items[i]); + return new ReflectListImpl(field, unsafeInput ?? [], check); +} + +class ReflectListImpl implements ReflectList { + field(): DescField & { fieldKind: "list" } { + return this._field; + } + get size(): number { + return this._arr.length; + } + [unsafeLocal]: unknown[]; + private _arr: unknown[]; + private _field: DescField & { fieldKind: "list" }; + private check: boolean; + + constructor( + field: DescField & { fieldKind: "list" }, + unsafeInput: unknown[], + check: boolean, + ) { + this._field = field; + this._arr = this[unsafeLocal] = unsafeInput; + this.check = check; + } + + get(index: number) { + const item = this._arr[index]; + return item === undefined + ? undefined + : (listItemToReflect(this._field, item, this.check) as V); + } + set(index: number, item: V) { + if (index < 0 || index >= this._arr.length) { + return new FieldError( + this._field, + `list item #${index + 1}: out of range`, + ); + } + if (this.check) { + const err = checkListItem(this._field, index, item); + if (err) { + return err; } - if (!err) { - for (const item of items) { - arr.push(listItemToLocal(field, item)); + } + this._arr[index] = listItemToLocal(this._field, item); + return undefined; + } + add(...items: V[]) { + if (this.check) { + for (let i = 0; i < items.length; i++) { + const err = checkListItem(this._field, this._arr.length + i, items[i]); + if (err) { + return err; } } - return err; - }, - clear() { - arr.splice(0, arr.length); - }, - [Symbol.iterator]() { - return this.values(); - }, - keys() { - return arr.keys(); - }, - *values() { - for (const item of arr) { - yield listItemToReflect(field, item) as V; - } - }, - *entries() { - for (let i = 0; i < arr.length; i++) { - yield [i, listItemToReflect(field, arr[i]) as V]; - } - }, - }; + } + for (const item of items) { + this._arr.push(listItemToLocal(this._field, item)); + } + return undefined; + } + clear() { + this._arr.splice(0, this._arr.length); + } + [Symbol.iterator]() { + return this.values(); + } + keys() { + return this._arr.keys(); + } + *values() { + for (const item of this._arr) { + yield listItemToReflect(this._field, item, this.check) as V; + } + } + *entries(): IterableIterator<[number, V]> { + for (let i = 0; i < this._arr.length; i++) { + yield [i, listItemToReflect(this._field, this._arr[i], this.check) as V]; + } + } } /** @@ -290,73 +368,94 @@ export function reflectList( export function reflectMap( field: DescField & { fieldKind: "map" }, unsafeInput?: Record, + check = true, ): ReflectMap { - const obj = unsafeInput ?? {}; - return { - [unsafeLocal]: obj, - field() { - return field; - }, - set(key, value) { - const err = checkMapEntry(field, key, value); - if (!err) { - obj[mapKeyToLocal(key)] = mapValueToLocal(field, value); - } - return err; - }, - delete(key) { - const k = mapKeyToLocal(key); - const has = Object.prototype.hasOwnProperty.call(obj, k); - if (has) { - delete obj[k]; - } - return has; - }, - clear() { - for (const key of Object.keys(obj)) { - delete obj[key]; - } - }, - get(key) { - let val = obj[mapKeyToLocal(key)]; - if (val !== undefined) { - val = mapValueToReflect(field, val); - } - return val as V | undefined; - }, - has(key) { - return Object.prototype.hasOwnProperty.call(obj, mapKeyToLocal(key)); - }, - *keys() { - for (const objKey of Object.keys(obj)) { - yield mapKeyToReflect(objKey, field.mapKey) as K; - } - }, - *entries() { - for (const objEntry of Object.entries(obj)) { - yield [ - mapKeyToReflect(objEntry[0], field.mapKey) as K, - mapValueToReflect(field, objEntry[1]) as V, - ]; - } - }, - [Symbol.iterator]() { - return this.entries(); - }, - get size() { - return Object.keys(obj).length; - }, - *values() { - for (const val of Object.values(obj)) { - yield mapValueToReflect(field, val) as V; - } - }, - forEach(callbackfn, thisArg) { - for (const mapEntry of this.entries()) { - callbackfn.call(thisArg, mapEntry[1], mapEntry[0], this); + return new ReflectMapImpl(field, unsafeInput, check); +} + +class ReflectMapImpl implements ReflectMap { + private readonly check: boolean; + private readonly _field: DescField & { fieldKind: "map" }; + [unsafeLocal]: Record; + private readonly obj: Record; + + constructor( + field: DescField & { fieldKind: "map" }, + unsafeInput?: Record, + check = true, + ) { + this.obj = this[unsafeLocal] = unsafeInput ?? {}; + this.check = check; + this._field = field; + } + field() { + return this._field; + } + set(key: K, value: V) { + if (this.check) { + const err = checkMapEntry(this._field, key, value); + if (err) { + return err; } - }, - }; + } + this.obj[mapKeyToLocal(key)] = mapValueToLocal(this._field, value); + return undefined; + } + delete(key: K) { + const k = mapKeyToLocal(key); + const has = Object.prototype.hasOwnProperty.call(this.obj, k); + if (has) { + delete this.obj[k]; + } + return has; + } + clear() { + for (const key of Object.keys(this.obj)) { + delete this.obj[key]; + } + } + get(key: K) { + let val = this.obj[mapKeyToLocal(key)]; + if (val !== undefined) { + val = mapValueToReflect(this._field, val, this.check); + } + return val as V | undefined; + } + has(key: K) { + return Object.prototype.hasOwnProperty.call(this.obj, mapKeyToLocal(key)); + } + *keys() { + for (const objKey of Object.keys(this.obj)) { + yield mapKeyToReflect(objKey, this._field.mapKey) as K; + } + } + *entries(): IterableIterator<[K, V]> { + for (const objEntry of Object.entries(this.obj)) { + yield [ + mapKeyToReflect(objEntry[0], this._field.mapKey) as K, + mapValueToReflect(this._field, objEntry[1], this.check) as V, + ]; + } + } + [Symbol.iterator]() { + return this.entries(); + } + get size() { + return Object.keys(this.obj).length; + } + *values() { + for (const val of Object.values(this.obj)) { + yield mapValueToReflect(this._field, val, this.check) as V; + } + } + forEach( + callbackfn: (value: V, key: K, map: ReadonlyMap) => void, + thisArg?: unknown, + ) { + for (const mapEntry of this.entries()) { + callbackfn.call(thisArg, mapEntry[1], mapEntry[0], this); + } + } } function listItemToLocal( @@ -372,9 +471,10 @@ function listItemToLocal( function listItemToReflect( field: DescField & { fieldKind: "list" }, value: unknown, + check: boolean, ): unknown { if (field.listKind == "message") { - return reflect(field.message, value as Message); + return new ReflectMessageImpl(field.message, value as Message, check); } return longToReflect(field, value); } @@ -392,9 +492,10 @@ function mapValueToLocal( function mapValueToReflect( field: DescField & { fieldKind: "map" }, value: unknown, + check: boolean, ): unknown { if (field.mapKind == "message") { - return reflect(field.message, value as Message); + return new ReflectMessageImpl(field.message, value as Message, check); } return value; }