Skip to content

Commit

Permalink
Merge pull request #4165 from facebook/cspeed_cmov
Browse files Browse the repository at this point in the history
Improve compression speed on small blocks
  • Loading branch information
Cyan4973 authored Oct 11, 2024
2 parents 7ba4309 + 8e5823b commit 8c38bda
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 72 deletions.
19 changes: 19 additions & 0 deletions lib/compress/zstd_compress_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,25 @@ MEM_STATIC int ZSTD_cParam_withinBounds(ZSTD_cParameter cParam, int value)
return 1;
}

/* ZSTD_selectAddr:
* @return index >= lowLimit ? candidate : backup,
* tries to force branchless codegen. */
MEM_STATIC const BYTE*
ZSTD_selectAddr(U32 index, U32 lowLimit, const BYTE* candidate, const BYTE* backup)
{
#if defined(__GNUC__) && defined(__x86_64__)
__asm__ (
"cmp %1, %2\n"
"cmova %3, %0\n"
: "+r"(candidate)
: "r"(index), "r"(lowLimit), "r"(backup)
);
return candidate;
#else
return index >= lowLimit ? candidate : backup;
#endif
}

/* ZSTD_noCompressBlock() :
* Writes uncompressed block to dst buffer from given src.
* Returns the size of the block */
Expand Down
39 changes: 25 additions & 14 deletions lib/compress/zstd_double_fast.c
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,17 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic(
U32 idxl1; /* the long match index for ip1 */

const BYTE* matchl0; /* the long match for ip */
const BYTE* matchl0_safe; /* matchl0 or safe address */
const BYTE* matchs0; /* the short match for ip */
const BYTE* matchl1; /* the long match for ip1 */
const BYTE* matchs0_safe; /* matchs0 or safe address */

const BYTE* ip = istart; /* the current position */
const BYTE* ip1; /* the next position */
/* Array of ~random data, should have low probability of matching data
* we load from here instead of from tables, if matchl0/matchl1 are
* invalid indices. Used to avoid unpredictable branches. */
const BYTE dummy[] = {0x12,0x34,0x56,0x78,0x9a,0xbc,0xde,0xf0,0xe2,0xb4};

DEBUGLOG(5, "ZSTD_compressBlock_doubleFast_noDict_generic");

Expand Down Expand Up @@ -191,24 +197,29 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic(

hl1 = ZSTD_hashPtr(ip1, hBitsL, 8);

if (idxl0 > prefixLowestIndex) {
/* check prefix long match */
if (MEM_read64(matchl0) == MEM_read64(ip)) {
mLength = ZSTD_count(ip+8, matchl0+8, iend) + 8;
offset = (U32)(ip-matchl0);
while (((ip>anchor) & (matchl0>prefixLowest)) && (ip[-1] == matchl0[-1])) { ip--; matchl0--; mLength++; } /* catch up */
goto _match_found;
}
/* idxl0 > prefixLowestIndex is a (somewhat) unpredictable branch.
* However expression below complies into conditional move. Since
* match is unlikely and we only *branch* on idxl0 > prefixLowestIndex
* if there is a match, all branches become predictable. */
matchl0_safe = ZSTD_selectAddr(idxl0, prefixLowestIndex, matchl0, &dummy[0]);

/* check prefix long match */
if (MEM_read64(matchl0_safe) == MEM_read64(ip) && matchl0_safe == matchl0) {
mLength = ZSTD_count(ip+8, matchl0+8, iend) + 8;
offset = (U32)(ip-matchl0);
while (((ip>anchor) & (matchl0>prefixLowest)) && (ip[-1] == matchl0[-1])) { ip--; matchl0--; mLength++; } /* catch up */
goto _match_found;
}

idxl1 = hashLong[hl1];
matchl1 = base + idxl1;

if (idxs0 > prefixLowestIndex) {
/* check prefix short match */
if (MEM_read32(matchs0) == MEM_read32(ip)) {
goto _search_next_long;
}
/* Same optimization as matchl0 above */
matchs0_safe = ZSTD_selectAddr(idxs0, prefixLowestIndex, matchs0, &dummy[0]);

/* check prefix short match */
if(MEM_read32(matchs0_safe) == MEM_read32(ip) && matchs0_safe == matchs0) {
goto _search_next_long;
}

if (ip1 >= nextStep) {
Expand Down Expand Up @@ -651,7 +662,7 @@ size_t ZSTD_compressBlock_doubleFast_extDict_generic(
size_t mLength;
hashSmall[hSmall] = hashLong[hLong] = curr; /* update hash table */

if (((ZSTD_index_overlap_check(prefixStartIndex, repIndex))
if (((ZSTD_index_overlap_check(prefixStartIndex, repIndex))
& (offset_1 <= curr+1 - dictStartIndex)) /* note: we are searching at curr+1 */
&& (MEM_read32(repMatch) == MEM_read32(ip+1)) ) {
const BYTE* repMatchEnd = repIndex < prefixStartIndex ? dictEnd : iend;
Expand Down
134 changes: 76 additions & 58 deletions lib/compress/zstd_fast.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void ZSTD_fillHashTableForCDict(ZSTD_matchState_t* ms,
size_t const hashAndTag = ZSTD_hashPtr(ip + p, hBits, mls);
if (hashTable[hashAndTag >> ZSTD_SHORT_CACHE_TAG_BITS] == 0) { /* not yet filled */
ZSTD_writeTaggedIndex(hashTable, hashAndTag, curr + p);
} } } }
} } } }
}

static
Expand Down Expand Up @@ -97,6 +97,50 @@ void ZSTD_fillHashTable(ZSTD_matchState_t* ms,
}


typedef int (*ZSTD_match4Found) (const BYTE* currentPtr, const BYTE* matchAddress, U32 matchIdx, U32 idxLowLimit);

static int
ZSTD_match4Found_cmov(const BYTE* currentPtr, const BYTE* matchAddress, U32 matchIdx, U32 idxLowLimit)
{
/* Array of ~random data, should have low probability of matching data.
* Load from here if the index is invalid.
* Used to avoid unpredictable branches. */
static const BYTE dummy[] = {0x12,0x34,0x56,0x78};

/* currentIdx >= lowLimit is a (somewhat) unpredictable branch.
* However expression below compiles into conditional move.
*/
const BYTE* mvalAddr = ZSTD_selectAddr(matchIdx, idxLowLimit, matchAddress, dummy);
/* Note: this used to be written as : return test1 && test2;
* Unfortunately, once inlined, these tests become branches,
* in which case it becomes critical that they are executed in the right order (test1 then test2).
* So we have to write these tests in a specific manner to ensure their ordering.
*/
if (MEM_read32(currentPtr) != MEM_read32(mvalAddr)) return 0;
/* force ordering of these tests, which matters once the function is inlined, as they become branches */
#if defined(__GNUC__)
__asm__("");
#endif
return matchIdx >= idxLowLimit;
}

static int
ZSTD_match4Found_branch(const BYTE* currentPtr, const BYTE* matchAddress, U32 matchIdx, U32 idxLowLimit)
{
/* using a branch instead of a cmov,
* because it's faster in scenarios where matchIdx >= idxLowLimit is generally true,
* aka almost all candidates are within range */
U32 mval;
if (matchIdx >= idxLowLimit) {
mval = MEM_read32(matchAddress);
} else {
mval = MEM_read32(currentPtr) ^ 1; /* guaranteed to not match. */
}

return (MEM_read32(currentPtr) == mval);
}


/**
* If you squint hard enough (and ignore repcodes), the search operation at any
* given position is broken into 4 stages:
Expand Down Expand Up @@ -148,13 +192,12 @@ ZSTD_ALLOW_POINTER_OVERFLOW_ATTR
size_t ZSTD_compressBlock_fast_noDict_generic(
ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM],
void const* src, size_t srcSize,
U32 const mls, U32 const hasStep)
U32 const mls, int useCmov)
{
const ZSTD_compressionParameters* const cParams = &ms->cParams;
U32* const hashTable = ms->hashTable;
U32 const hlog = cParams->hashLog;
/* support stepSize of 0 */
size_t const stepSize = hasStep ? (cParams->targetLength + !(cParams->targetLength) + 1) : 2;
size_t const stepSize = cParams->targetLength + !(cParams->targetLength) + 1; /* min 2 */
const BYTE* const base = ms->window.base;
const BYTE* const istart = (const BYTE*)src;
const U32 endIndex = (U32)((size_t)(istart - base) + srcSize);
Expand All @@ -176,8 +219,7 @@ size_t ZSTD_compressBlock_fast_noDict_generic(

size_t hash0; /* hash for ip0 */
size_t hash1; /* hash for ip1 */
U32 idx; /* match idx for ip0 */
U32 mval; /* src value at match idx */
U32 matchIdx; /* match idx for ip0 */

U32 offcode;
const BYTE* match0;
Expand All @@ -190,6 +232,7 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
size_t step;
const BYTE* nextStep;
const size_t kStepIncr = (1 << (kSearchStrength - 1));
const ZSTD_match4Found matchFound = useCmov ? ZSTD_match4Found_cmov : ZSTD_match4Found_branch;

DEBUGLOG(5, "ZSTD_compressBlock_fast_generic");
ip0 += (ip0 == prefixStart);
Expand Down Expand Up @@ -218,7 +261,7 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
hash0 = ZSTD_hashPtr(ip0, hlog, mls);
hash1 = ZSTD_hashPtr(ip1, hlog, mls);

idx = hashTable[hash0];
matchIdx = hashTable[hash0];

do {
/* load repcode match for ip[2]*/
Expand All @@ -238,35 +281,25 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
offcode = REPCODE1_TO_OFFBASE;
mLength += 4;

/* First write next hash table entry; we've already calculated it.
* This write is known to be safe because the ip1 is before the
/* Write next hash table entry: it's already calculated.
* This write is known to be safe because ip1 is before the
* repcode (ip2). */
hashTable[hash1] = (U32)(ip1 - base);

goto _match;
}

/* load match for ip[0] */
if (idx >= prefixStartIndex) {
mval = MEM_read32(base + idx);
} else {
mval = MEM_read32(ip0) ^ 1; /* guaranteed to not match. */
}

/* check match at ip[0] */
if (MEM_read32(ip0) == mval) {
/* found a match! */

/* First write next hash table entry; we've already calculated it.
* This write is known to be safe because the ip1 == ip0 + 1, so
* we know we will resume searching after ip1 */
if (matchFound(ip0, base + matchIdx, matchIdx, prefixStartIndex)) {
/* Write next hash table entry (it's already calculated).
* This write is known to be safe because the ip1 == ip0 + 1,
* so searching will resume after ip1 */
hashTable[hash1] = (U32)(ip1 - base);

goto _offset;
}

/* lookup ip[1] */
idx = hashTable[hash1];
matchIdx = hashTable[hash1];

/* hash ip[2] */
hash0 = hash1;
Expand All @@ -281,36 +314,19 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
current0 = (U32)(ip0 - base);
hashTable[hash0] = current0;

/* load match for ip[0] */
if (idx >= prefixStartIndex) {
mval = MEM_read32(base + idx);
} else {
mval = MEM_read32(ip0) ^ 1; /* guaranteed to not match. */
}

/* check match at ip[0] */
if (MEM_read32(ip0) == mval) {
/* found a match! */

/* first write next hash table entry; we've already calculated it */
if (matchFound(ip0, base + matchIdx, matchIdx, prefixStartIndex)) {
/* Write next hash table entry, since it's already calculated */
if (step <= 4) {
/* We need to avoid writing an index into the hash table >= the
* position at which we will pick up our searching after we've
* taken this match.
*
* The minimum possible match has length 4, so the earliest ip0
* can be after we take this match will be the current ip0 + 4.
* ip1 is ip0 + step - 1. If ip1 is >= ip0 + 4, we can't safely
* write this position.
*/
/* Avoid writing an index if it's >= position where search will resume.
* The minimum possible match has length 4, so search can resume at ip0 + 4.
*/
hashTable[hash1] = (U32)(ip1 - base);
}

goto _offset;
}

/* lookup ip[1] */
idx = hashTable[hash1];
matchIdx = hashTable[hash1];

/* hash ip[2] */
hash0 = hash1;
Expand All @@ -332,7 +348,7 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
} while (ip3 < ilimit);

_cleanup:
/* Note that there are probably still a couple positions we could search.
/* Note that there are probably still a couple positions one could search.
* However, it seems to be a meaningful performance hit to try to search
* them. So let's not. */

Expand Down Expand Up @@ -361,7 +377,7 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
_offset: /* Requires: ip0, idx */

/* Compute the offset code. */
match0 = base + idx;
match0 = base + matchIdx;
rep_offset2 = rep_offset1;
rep_offset1 = (U32)(ip0-match0);
offcode = OFFSET_TO_OFFBASE(rep_offset1);
Expand Down Expand Up @@ -406,12 +422,12 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
goto _start;
}

#define ZSTD_GEN_FAST_FN(dictMode, mls, step) \
static size_t ZSTD_compressBlock_fast_##dictMode##_##mls##_##step( \
#define ZSTD_GEN_FAST_FN(dictMode, mml, cmov) \
static size_t ZSTD_compressBlock_fast_##dictMode##_##mml##_##cmov( \
ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], \
void const* src, size_t srcSize) \
{ \
return ZSTD_compressBlock_fast_##dictMode##_generic(ms, seqStore, rep, src, srcSize, mls, step); \
return ZSTD_compressBlock_fast_##dictMode##_generic(ms, seqStore, rep, src, srcSize, mml, cmov); \
}

ZSTD_GEN_FAST_FN(noDict, 4, 1)
Expand All @@ -428,10 +444,12 @@ size_t ZSTD_compressBlock_fast(
ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM],
void const* src, size_t srcSize)
{
U32 const mls = ms->cParams.minMatch;
U32 const mml = ms->cParams.minMatch;
/* use cmov when "candidate in range" branch is likely unpredictable */
int const useCmov = ms->cParams.windowLog < 19;
assert(ms->dictMatchState == NULL);
if (ms->cParams.targetLength > 1) {
switch(mls)
if (useCmov) {
switch(mml)
{
default: /* includes case 3 */
case 4 :
Expand All @@ -444,7 +462,8 @@ size_t ZSTD_compressBlock_fast(
return ZSTD_compressBlock_fast_noDict_7_1(ms, seqStore, rep, src, srcSize);
}
} else {
switch(mls)
/* use a branch instead */
switch(mml)
{
default: /* includes case 3 */
case 4 :
Expand All @@ -456,7 +475,6 @@ size_t ZSTD_compressBlock_fast(
case 7 :
return ZSTD_compressBlock_fast_noDict_7_0(ms, seqStore, rep, src, srcSize);
}

}
}

Expand Down Expand Up @@ -546,7 +564,7 @@ size_t ZSTD_compressBlock_fast_dictMatchState_generic(
size_t const dictHashAndTag1 = ZSTD_hashPtr(ip1, dictHBits, mls);
hashTable[hash0] = curr; /* update hash table */

if ((ZSTD_index_overlap_check(prefixStartIndex, repIndex))
if ((ZSTD_index_overlap_check(prefixStartIndex, repIndex))
&& (MEM_read32(repMatch) == MEM_read32(ip0 + 1))) {
const BYTE* const repMatchEnd = repIndex < prefixStartIndex ? dictEnd : iend;
mLength = ZSTD_count_2segments(ip0 + 1 + 4, repMatch + 4, iend, repMatchEnd, prefixStart) + 4;
Expand Down

0 comments on commit 8c38bda

Please sign in to comment.