Skip to content

Commit

Permalink
vm: add default limit to SI serialization context
Browse files Browse the repository at this point in the history
Follow the notion of neo-project/neo#2948.

Signed-off-by: Anna Shaleva <[email protected]>
  • Loading branch information
AnnaShaleva committed Nov 22, 2023
1 parent 8aeb30e commit 6433bc1
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 14 deletions.
9 changes: 5 additions & 4 deletions pkg/vm/stackitem/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func ToJSON(item Item) ([]byte, error) {
// It doesn't contain any pointer and uses less memory than `[]byte`.
type sliceNoPointer struct {
start, end int
itemsCount int
}

func toJSON(data []byte, seen map[Item]sliceNoPointer, item Item) ([]byte, error) {
Expand Down Expand Up @@ -105,7 +106,7 @@ func toJSON(data []byte, seen map[Item]sliceNoPointer, item Item) ([]byte, error
}
}
data = append(data, ']')
seen[item] = sliceNoPointer{start, len(data)}
seen[item] = sliceNoPointer{start: start, end: len(data)}
case *Map:
data = append(data, '{')
for i := range it.value {
Expand All @@ -126,7 +127,7 @@ func toJSON(data []byte, seen map[Item]sliceNoPointer, item Item) ([]byte, error
}
}
data = append(data, '}')
seen[item] = sliceNoPointer{start, len(data)}
seen[item] = sliceNoPointer{start: start, end: len(data)}
case *BigInteger:
if it.Big().CmpAbs(big.NewInt(MaxAllowedInteger)) == 1 {
return nil, fmt.Errorf("%w (MaxAllowedInteger)", ErrInvalidValue)
Expand Down Expand Up @@ -420,15 +421,15 @@ func toJSONWithTypes(data []byte, item Item, seen map[Item]sliceNoPointer) ([]by
data = append(data, '}')

if isBuffer {
seen[item] = sliceNoPointer{start, len(data)}
seen[item] = sliceNoPointer{start: start, end: len(data)}
}
} else {
if len(data)+2 > MaxSize { // also take care of '}'
return nil, errTooBigSize
}
data = append(data, ']', '}')

seen[item] = sliceNoPointer{start, len(data)}
seen[item] = sliceNoPointer{start: start, end: len(data)}
}
return data, nil
}
Expand Down
36 changes: 32 additions & 4 deletions pkg/vm/stackitem/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ import (
// (including itself).
const MaxDeserialized = 2048

// MaxSerialized is the maximum number one serialized item can contain
// (including itself).
const MaxSerialized = MaxDeserialized

// typicalNumOfItems is the number of items covering most serialization needs.
// It's a hint used for map creation, so it does not limit anything, it's just
// a microoptimization to avoid excessive reallocations. Most of the serialized
Expand All @@ -33,6 +37,7 @@ type SerializationContext struct {
uv [9]byte
data []byte
allowInvalid bool
limit int
seen map[Item]sliceNoPointer
}

Expand All @@ -45,10 +50,20 @@ type deserContext struct {

// Serialize encodes the given Item into a byte slice.
func Serialize(item Item) ([]byte, error) {
return SerializeLimited(item, MaxSerialized)
}

// SerializeLimited encodes the given Item into a byte slice using custom
// limit to restrict the maximum serialized number of elements.
func SerializeLimited(item Item, limit int) ([]byte, error) {
sc := SerializationContext{
allowInvalid: false,
limit: MaxSerialized,
seen: make(map[Item]sliceNoPointer, typicalNumOfItems),
}
if limit > 0 {
sc.limit = limit
}
err := sc.serialize(item)
if err != nil {
return nil, err
Expand Down Expand Up @@ -76,6 +91,7 @@ func EncodeBinary(item Item, w *io.BinWriter) {
func EncodeBinaryProtected(item Item, w *io.BinWriter) {
sc := SerializationContext{
allowInvalid: true,
limit: MaxSerialized,
seen: make(map[Item]sliceNoPointer, typicalNumOfItems),
}
err := sc.serialize(item)
Expand All @@ -88,28 +104,32 @@ func EncodeBinaryProtected(item Item, w *io.BinWriter) {

func (w *SerializationContext) writeArray(item Item, arr []Item, start int) error {
w.seen[item] = sliceNoPointer{}
limit := w.limit
w.appendVarUint(uint64(len(arr)))
for i := range arr {
if err := w.serialize(arr[i]); err != nil {
return err
}
}
w.seen[item] = sliceNoPointer{start, len(w.data)}
w.seen[item] = sliceNoPointer{start, len(w.data), limit - w.limit + 1} // number of items including the array itself.
return nil
}

// NewSerializationContext returns reusable stack item serialization context.
func NewSerializationContext() *SerializationContext {
return &SerializationContext{
seen: make(map[Item]sliceNoPointer, typicalNumOfItems),
limit: MaxSerialized,
seen: make(map[Item]sliceNoPointer, typicalNumOfItems),
}
}

// Serialize returns flat slice of bytes with the given item. The process can be protected
// from bad elements if appropriate flag is given (otherwise an error is returned on
// encountering any of them). The buffer returned is only valid until the call to Serialize.
// The number of serialized items is restricted with MaxSerialized.
func (w *SerializationContext) Serialize(item Item, protected bool) ([]byte, error) {
w.allowInvalid = protected
w.limit = MaxSerialized
if w.data != nil {
w.data = w.data[:0]
}
Expand All @@ -135,10 +155,17 @@ func (w *SerializationContext) serialize(item Item) error {
if len(w.data)+v.end-v.start > MaxSize {
return ErrTooBig
}
w.limit -= v.itemsCount
if w.limit < 0 {
return errTooBigElements
}
w.data = append(w.data, w.data[v.start:v.end]...)
return nil
}

w.limit--
if w.limit < 0 {
return errTooBigElements
}
start := len(w.data)
switch t := item.(type) {
case *ByteArray:
Expand Down Expand Up @@ -188,6 +215,7 @@ func (w *SerializationContext) serialize(item Item) error {
}
case *Map:
w.seen[item] = sliceNoPointer{}
limit := w.limit

elems := t.value
w.data = append(w.data, byte(MapT))
Expand All @@ -200,7 +228,7 @@ func (w *SerializationContext) serialize(item Item) error {
return err
}
}
w.seen[item] = sliceNoPointer{start, len(w.data)}
w.seen[item] = sliceNoPointer{start, len(w.data), limit - w.limit + 1} // number of items including Map itself.
case Null:
w.data = append(w.data, byte(AnyT))
case nil:
Expand Down
84 changes: 78 additions & 6 deletions pkg/vm/stackitem/serialization_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package stackitem

import (
"strconv"
"testing"

"github.com/nspcc-dev/neo-go/pkg/io"
Expand All @@ -23,7 +24,19 @@ func TestSerializationMaxErr(t *testing.T) {
}

func testSerialize(t *testing.T, expectedErr error, item Item) {
data, err := Serialize(item)
testSerializeLimited(t, expectedErr, item, -1)
}

func testSerializeLimited(t *testing.T, expectedErr error, item Item, limit int) {
var (
data []byte
err error
)
if limit > 0 {
data, err = SerializeLimited(item, limit)
} else {
data, err = Serialize(item)
}
if expectedErr != nil {
require.ErrorIs(t, err, expectedErr)
return
Expand Down Expand Up @@ -58,7 +71,9 @@ func TestSerialize(t *testing.T) {
testSerialize(t, nil, newItem(items))

items = append(items, zeroByteArray)
data, err := Serialize(newItem(items))
_, err := Serialize(newItem(items))
require.ErrorIs(t, err, errTooBigElements)
data, err := SerializeLimited(newItem(items), MaxSerialized+1) // a tiny hack to check deserialization error.
require.NoError(t, err)
_, err = Deserialize(data)
require.ErrorIs(t, err, ErrTooBig)
Expand Down Expand Up @@ -165,13 +180,70 @@ func TestSerialize(t *testing.T) {
for i := 0; i <= MaxDeserialized; i++ {
m.Add(Make(i), zeroByteArray)
}
data, err := Serialize(m)
_, err := Serialize(m)
require.ErrorIs(t, err, errTooBigElements)
data, err := SerializeLimited(m, (MaxSerialized+1)*2+1) // a tiny hack to check deserialization error.
require.NoError(t, err)
_, err = Deserialize(data)
require.ErrorIs(t, err, ErrTooBig)
})
}

func TestSerializeLimited(t *testing.T) {
const customLimit = 10

smallArray := make([]Item, customLimit-1)
for i := range smallArray {
smallArray[i] = NewBool(true)
}
bigArray := make([]Item, customLimit)
for i := range bigArray {
bigArray[i] = NewBool(true)
}
t.Run("array", func(t *testing.T) {
testSerializeLimited(t, nil, NewArray(smallArray), customLimit)
testSerializeLimited(t, errTooBigElements, NewArray(bigArray), customLimit)
})
t.Run("struct", func(t *testing.T) {
testSerializeLimited(t, nil, NewStruct(smallArray), customLimit)
testSerializeLimited(t, errTooBigElements, NewStruct(bigArray), customLimit)
})
t.Run("map", func(t *testing.T) {
smallMap := make([]MapElement, (customLimit-1)/2)
for i := range smallMap {
smallMap[i] = MapElement{
Key: NewByteArray([]byte(strconv.Itoa(i))),
Value: NewBool(true),
}
}
bigMap := make([]MapElement, customLimit/2)
for i := range bigMap {
bigMap[i] = MapElement{
Key: NewByteArray([]byte("key")),
Value: NewBool(true),
}
}
testSerializeLimited(t, nil, NewMapWithValue(smallMap), customLimit)
testSerializeLimited(t, errTooBigElements, NewMapWithValue(bigMap), customLimit)
})
t.Run("seen", func(t *testing.T) {
t.Run("OK", func(t *testing.T) {
tinyArray := NewArray(make([]Item, (customLimit-3)/2)) // 1 for outer array, 1+1 for two inner arrays and the rest are for arrays' elements.
for i := range tinyArray.value {
tinyArray.value[i] = NewBool(true)
}
testSerializeLimited(t, nil, NewArray([]Item{tinyArray, tinyArray}), customLimit)
})
t.Run("big", func(t *testing.T) {
tinyArray := NewArray(make([]Item, (customLimit-2)/2)) // should break on the second array serialisation.
for i := range tinyArray.value {
tinyArray.value[i] = NewBool(true)
}
testSerializeLimited(t, errTooBigElements, NewArray([]Item{tinyArray, tinyArray}), customLimit)
})
})
}

func TestEmptyDeserialization(t *testing.T) {
empty := []byte{}
_, err := Deserialize(empty)
Expand Down Expand Up @@ -202,7 +274,7 @@ func TestDeserializeTooManyElements(t *testing.T) {
require.NoError(t, err)

item = Make([]Item{item})
data, err = Serialize(item)
data, err = SerializeLimited(item, MaxSerialized+1) // tiny hack to avoid serialization error.
require.NoError(t, err)
_, err = Deserialize(data)
require.ErrorIs(t, err, ErrTooBig)
Expand All @@ -214,14 +286,14 @@ func TestDeserializeLimited(t *testing.T) {
for i := 0; i < customLimit-1; i++ { // 1 for zero inner element.
item = Make([]Item{item})
}
data, err := Serialize(item)
data, err := SerializeLimited(item, customLimit) // tiny hack to avoid serialization error.
require.NoError(t, err)
actual, err := DeserializeLimited(data, customLimit)
require.NoError(t, err)
require.Equal(t, item, actual)

item = Make([]Item{item})
data, err = Serialize(item)
data, err = SerializeLimited(item, customLimit+1) // tiny hack to avoid serialization error.
require.NoError(t, err)
_, err = DeserializeLimited(data, customLimit)
require.Error(t, err)
Expand Down

0 comments on commit 6433bc1

Please sign in to comment.