contrib/python-zstandard/zstd/compress/zstd_opt.c
changeset 42937 69de49c4e39c
parent 42070 675775c33ab6
child 43994 de7838053207
--- a/contrib/python-zstandard/zstd/compress/zstd_opt.c	Sun Sep 15 00:07:30 2019 -0400
+++ b/contrib/python-zstandard/zstd/compress/zstd_opt.c	Sun Sep 15 20:04:00 2019 -0700
@@ -64,9 +64,15 @@
 }
 #endif
 
+static int ZSTD_compressedLiterals(optState_t const* const optPtr)
+{
+    return optPtr->literalCompressionMode != ZSTD_lcm_uncompressed;
+}
+
 static void ZSTD_setBasePrices(optState_t* optPtr, int optLevel)
 {
-    optPtr->litSumBasePrice = WEIGHT(optPtr->litSum, optLevel);
+    if (ZSTD_compressedLiterals(optPtr))
+        optPtr->litSumBasePrice = WEIGHT(optPtr->litSum, optLevel);
     optPtr->litLengthSumBasePrice = WEIGHT(optPtr->litLengthSum, optLevel);
     optPtr->matchLengthSumBasePrice = WEIGHT(optPtr->matchLengthSum, optLevel);
     optPtr->offCodeSumBasePrice = WEIGHT(optPtr->offCodeSum, optLevel);
@@ -99,6 +105,7 @@
             const BYTE* const src, size_t const srcSize,
                   int const optLevel)
 {
+    int const compressedLiterals = ZSTD_compressedLiterals(optPtr);
     DEBUGLOG(5, "ZSTD_rescaleFreqs (srcSize=%u)", (unsigned)srcSize);
     optPtr->priceType = zop_dynamic;
 
@@ -113,9 +120,10 @@
             /* huffman table presumed generated by dictionary */
             optPtr->priceType = zop_dynamic;
 
-            assert(optPtr->litFreq != NULL);
-            optPtr->litSum = 0;
-            {   unsigned lit;
+            if (compressedLiterals) {
+                unsigned lit;
+                assert(optPtr->litFreq != NULL);
+                optPtr->litSum = 0;
                 for (lit=0; lit<=MaxLit; lit++) {
                     U32 const scaleLog = 11;   /* scale to 2K */
                     U32 const bitCost = HUF_getNbBits(optPtr->symbolCosts->huf.CTable, lit);
@@ -163,10 +171,11 @@
         } else {  /* not a dictionary */
 
             assert(optPtr->litFreq != NULL);
-            {   unsigned lit = MaxLit;
+            if (compressedLiterals) {
+                unsigned lit = MaxLit;
                 HIST_count_simple(optPtr->litFreq, &lit, src, srcSize);   /* use raw first block to init statistics */
+                optPtr->litSum = ZSTD_downscaleStat(optPtr->litFreq, MaxLit, 1);
             }
-            optPtr->litSum = ZSTD_downscaleStat(optPtr->litFreq, MaxLit, 1);
 
             {   unsigned ll;
                 for (ll=0; ll<=MaxLL; ll++)
@@ -190,7 +199,8 @@
 
     } else {   /* new block : re-use previous statistics, scaled down */
 
-        optPtr->litSum = ZSTD_downscaleStat(optPtr->litFreq, MaxLit, 1);
+        if (compressedLiterals)
+            optPtr->litSum = ZSTD_downscaleStat(optPtr->litFreq, MaxLit, 1);
         optPtr->litLengthSum = ZSTD_downscaleStat(optPtr->litLengthFreq, MaxLL, 0);
         optPtr->matchLengthSum = ZSTD_downscaleStat(optPtr->matchLengthFreq, MaxML, 0);
         optPtr->offCodeSum = ZSTD_downscaleStat(optPtr->offCodeFreq, MaxOff, 0);
@@ -207,6 +217,10 @@
                                 int optLevel)
 {
     if (litLength == 0) return 0;
+
+    if (!ZSTD_compressedLiterals(optPtr))
+        return (litLength << 3) * BITCOST_MULTIPLIER;  /* Uncompressed - 8 bytes per literal. */
+
     if (optPtr->priceType == zop_predef)
         return (litLength*6) * BITCOST_MULTIPLIER;  /* 6 bit per literal - no statistic used */
 
@@ -241,13 +255,13 @@
  * to provide a cost which is directly comparable to a match ending at same position */
 static int ZSTD_litLengthContribution(U32 const litLength, const optState_t* const optPtr, int optLevel)
 {
-    if (optPtr->priceType >= zop_predef) return WEIGHT(litLength, optLevel);
+    if (optPtr->priceType >= zop_predef) return (int)WEIGHT(litLength, optLevel);
 
     /* dynamic statistics */
     {   U32 const llCode = ZSTD_LLcode(litLength);
-        int const contribution = (LL_bits[llCode] * BITCOST_MULTIPLIER)
-                               + WEIGHT(optPtr->litLengthFreq[0], optLevel)   /* note: log2litLengthSum cancel out */
-                               - WEIGHT(optPtr->litLengthFreq[llCode], optLevel);
+        int const contribution = (int)(LL_bits[llCode] * BITCOST_MULTIPLIER)
+                               + (int)WEIGHT(optPtr->litLengthFreq[0], optLevel)   /* note: log2litLengthSum cancel out */
+                               - (int)WEIGHT(optPtr->litLengthFreq[llCode], optLevel);
 #if 1
         return contribution;
 #else
@@ -264,7 +278,7 @@
                                      const optState_t* const optPtr,
                                      int optLevel)
 {
-    int const contribution = ZSTD_rawLiteralsCost(literals, litLength, optPtr, optLevel)
+    int const contribution = (int)ZSTD_rawLiteralsCost(literals, litLength, optPtr, optLevel)
                            + ZSTD_litLengthContribution(litLength, optPtr, optLevel);
     return contribution;
 }
@@ -310,7 +324,8 @@
                              U32 offsetCode, U32 matchLength)
 {
     /* literals */
-    {   U32 u;
+    if (ZSTD_compressedLiterals(optPtr)) {
+        U32 u;
         for (u=0; u < litLength; u++)
             optPtr->litFreq[literals[u]] += ZSTD_LITFREQ_ADD;
         optPtr->litSum += litLength*ZSTD_LITFREQ_ADD;
@@ -357,13 +372,15 @@
 
 /* Update hashTable3 up to ip (excluded)
    Assumption : always within prefix (i.e. not within extDict) */
-static U32 ZSTD_insertAndFindFirstIndexHash3 (ZSTD_matchState_t* ms, const BYTE* const ip)
+static U32 ZSTD_insertAndFindFirstIndexHash3 (ZSTD_matchState_t* ms,
+                                              U32* nextToUpdate3,
+                                              const BYTE* const ip)
 {
     U32* const hashTable3 = ms->hashTable3;
     U32 const hashLog3 = ms->hashLog3;
     const BYTE* const base = ms->window.base;
-    U32 idx = ms->nextToUpdate3;
-    U32 const target = ms->nextToUpdate3 = (U32)(ip - base);
+    U32 idx = *nextToUpdate3;
+    U32 const target = (U32)(ip - base);
     size_t const hash3 = ZSTD_hash3Ptr(ip, hashLog3);
     assert(hashLog3 > 0);
 
@@ -372,6 +389,7 @@
         idx++;
     }
 
+    *nextToUpdate3 = target;
     return hashTable3[hash3];
 }
 
@@ -488,9 +506,11 @@
     }   }
 
     *smallerPtr = *largerPtr = 0;
-    if (bestLength > 384) return MIN(192, (U32)(bestLength - 384));   /* speed optimization */
-    assert(matchEndIdx > current + 8);
-    return matchEndIdx - (current + 8);
+    {   U32 positions = 0;
+        if (bestLength > 384) positions = MIN(192, (U32)(bestLength - 384));   /* speed optimization */
+        assert(matchEndIdx > current + 8);
+        return MAX(positions, matchEndIdx - (current + 8));
+    }
 }
 
 FORCE_INLINE_TEMPLATE
@@ -505,8 +525,13 @@
     DEBUGLOG(6, "ZSTD_updateTree_internal, from %u to %u  (dictMode:%u)",
                 idx, target, dictMode);
 
-    while(idx < target)
-        idx += ZSTD_insertBt1(ms, base+idx, iend, mls, dictMode == ZSTD_extDict);
+    while(idx < target) {
+        U32 const forward = ZSTD_insertBt1(ms, base+idx, iend, mls, dictMode == ZSTD_extDict);
+        assert(idx < (U32)(idx + forward));
+        idx += forward;
+    }
+    assert((size_t)(ip - base) <= (size_t)(U32)(-1));
+    assert((size_t)(iend - base) <= (size_t)(U32)(-1));
     ms->nextToUpdate = target;
 }
 
@@ -516,11 +541,12 @@
 
 FORCE_INLINE_TEMPLATE
 U32 ZSTD_insertBtAndGetAllMatches (
+                    ZSTD_match_t* matches,   /* store result (found matches) in this table (presumed large enough) */
                     ZSTD_matchState_t* ms,
+                    U32* nextToUpdate3,
                     const BYTE* const ip, const BYTE* const iLimit, const ZSTD_dictMode_e dictMode,
-                    U32 rep[ZSTD_REP_NUM],
+                    const U32 rep[ZSTD_REP_NUM],
                     U32 const ll0,   /* tells if associated literal length is 0 or not. This value must be 0 or 1 */
-                    ZSTD_match_t* matches,
                     const U32 lengthToBeat,
                     U32 const mls /* template */)
 {
@@ -541,8 +567,8 @@
     U32 const dictLimit = ms->window.dictLimit;
     const BYTE* const dictEnd = dictBase + dictLimit;
     const BYTE* const prefixStart = base + dictLimit;
-    U32 const btLow = btMask >= current ? 0 : current - btMask;
-    U32 const windowLow = ms->window.lowLimit;
+    U32 const btLow = (btMask >= current) ? 0 : current - btMask;
+    U32 const windowLow = ZSTD_getLowestMatchIndex(ms, current, cParams->windowLog);
     U32 const matchLow = windowLow ? windowLow : 1;
     U32* smallerPtr = bt + 2*(current&btMask);
     U32* largerPtr  = bt + 2*(current&btMask) + 1;
@@ -612,7 +638,7 @@
 
     /* HC3 match finder */
     if ((mls == 3) /*static*/ && (bestLength < mls)) {
-        U32 const matchIndex3 = ZSTD_insertAndFindFirstIndexHash3(ms, ip);
+        U32 const matchIndex3 = ZSTD_insertAndFindFirstIndexHash3(ms, nextToUpdate3, ip);
         if ((matchIndex3 >= matchLow)
           & (current - matchIndex3 < (1<<18)) /*heuristic : longer distance likely too expensive*/ ) {
             size_t mlen;
@@ -638,9 +664,7 @@
                      (ip+mlen == iLimit) ) {  /* best possible length */
                     ms->nextToUpdate = current+1;  /* skip insertion */
                     return 1;
-                }
-            }
-        }
+        }   }   }
         /* no dictMatchState lookup: dicts don't have a populated HC3 table */
     }
 
@@ -648,19 +672,21 @@
 
     while (nbCompares-- && (matchIndex >= matchLow)) {
         U32* const nextPtr = bt + 2*(matchIndex & btMask);
+        const BYTE* match;
         size_t matchLength = MIN(commonLengthSmaller, commonLengthLarger);   /* guaranteed minimum nb of common bytes */
-        const BYTE* match;
         assert(current > matchIndex);
 
         if ((dictMode == ZSTD_noDict) || (dictMode == ZSTD_dictMatchState) || (matchIndex+matchLength >= dictLimit)) {
             assert(matchIndex+matchLength >= dictLimit);  /* ensure the condition is correct when !extDict */
             match = base + matchIndex;
+            if (matchIndex >= dictLimit) assert(memcmp(match, ip, matchLength) == 0);  /* ensure early section of match is equal as expected */
             matchLength += ZSTD_count(ip+matchLength, match+matchLength, iLimit);
         } else {
             match = dictBase + matchIndex;
+            assert(memcmp(match, ip, matchLength) == 0);  /* ensure early section of match is equal as expected */
             matchLength += ZSTD_count_2segments(ip+matchLength, match+matchLength, iLimit, dictEnd, prefixStart);
             if (matchIndex+matchLength >= dictLimit)
-                match = base + matchIndex;   /* prepare for match[matchLength] */
+                match = base + matchIndex;   /* prepare for match[matchLength] read */
         }
 
         if (matchLength > bestLength) {
@@ -745,10 +771,13 @@
 
 
 FORCE_INLINE_TEMPLATE U32 ZSTD_BtGetAllMatches (
+                        ZSTD_match_t* matches,   /* store result (match found, increasing size) in this table */
                         ZSTD_matchState_t* ms,
+                        U32* nextToUpdate3,
                         const BYTE* ip, const BYTE* const iHighLimit, const ZSTD_dictMode_e dictMode,
-                        U32 rep[ZSTD_REP_NUM], U32 const ll0,
-                        ZSTD_match_t* matches, U32 const lengthToBeat)
+                        const U32 rep[ZSTD_REP_NUM],
+                        U32 const ll0,
+                        U32 const lengthToBeat)
 {
     const ZSTD_compressionParameters* const cParams = &ms->cParams;
     U32 const matchLengthSearch = cParams->minMatch;
@@ -757,12 +786,12 @@
     ZSTD_updateTree_internal(ms, ip, iHighLimit, matchLengthSearch, dictMode);
     switch(matchLengthSearch)
     {
-    case 3 : return ZSTD_insertBtAndGetAllMatches(ms, ip, iHighLimit, dictMode, rep, ll0, matches, lengthToBeat, 3);
+    case 3 : return ZSTD_insertBtAndGetAllMatches(matches, ms, nextToUpdate3, ip, iHighLimit, dictMode, rep, ll0, lengthToBeat, 3);
     default :
-    case 4 : return ZSTD_insertBtAndGetAllMatches(ms, ip, iHighLimit, dictMode, rep, ll0, matches, lengthToBeat, 4);
-    case 5 : return ZSTD_insertBtAndGetAllMatches(ms, ip, iHighLimit, dictMode, rep, ll0, matches, lengthToBeat, 5);
+    case 4 : return ZSTD_insertBtAndGetAllMatches(matches, ms, nextToUpdate3, ip, iHighLimit, dictMode, rep, ll0, lengthToBeat, 4);
+    case 5 : return ZSTD_insertBtAndGetAllMatches(matches, ms, nextToUpdate3, ip, iHighLimit, dictMode, rep, ll0, lengthToBeat, 5);
     case 7 :
-    case 6 : return ZSTD_insertBtAndGetAllMatches(ms, ip, iHighLimit, dictMode, rep, ll0, matches, lengthToBeat, 6);
+    case 6 : return ZSTD_insertBtAndGetAllMatches(matches, ms, nextToUpdate3, ip, iHighLimit, dictMode, rep, ll0, lengthToBeat, 6);
     }
 }
 
@@ -838,6 +867,7 @@
 
     U32 const sufficient_len = MIN(cParams->targetLength, ZSTD_OPT_NUM -1);
     U32 const minMatch = (cParams->minMatch == 3) ? 3 : 4;
+    U32 nextToUpdate3 = ms->nextToUpdate;
 
     ZSTD_optimal_t* const opt = optStatePtr->priceTable;
     ZSTD_match_t* const matches = optStatePtr->matchTable;
@@ -847,7 +877,6 @@
     DEBUGLOG(5, "ZSTD_compressBlock_opt_generic: current=%u, prefix=%u, nextToUpdate=%u",
                 (U32)(ip - base), ms->window.dictLimit, ms->nextToUpdate);
     assert(optLevel <= 2);
-    ms->nextToUpdate3 = ms->nextToUpdate;
     ZSTD_rescaleFreqs(optStatePtr, (const BYTE*)src, srcSize, optLevel);
     ip += (ip==prefixStart);
 
@@ -858,7 +887,7 @@
         /* find first match */
         {   U32 const litlen = (U32)(ip - anchor);
             U32 const ll0 = !litlen;
-            U32 const nbMatches = ZSTD_BtGetAllMatches(ms, ip, iend, dictMode, rep, ll0, matches, minMatch);
+            U32 const nbMatches = ZSTD_BtGetAllMatches(matches, ms, &nextToUpdate3, ip, iend, dictMode, rep, ll0, minMatch);
             if (!nbMatches) { ip++; continue; }
 
             /* initialize opt[0] */
@@ -870,7 +899,7 @@
             /* large match -> immediate encoding */
             {   U32 const maxML = matches[nbMatches-1].len;
                 U32 const maxOffset = matches[nbMatches-1].off;
-                DEBUGLOG(6, "found %u matches of maxLength=%u and maxOffCode=%u at cPos=%u => start new serie",
+                DEBUGLOG(6, "found %u matches of maxLength=%u and maxOffCode=%u at cPos=%u => start new series",
                             nbMatches, maxML, maxOffset, (U32)(ip-prefixStart));
 
                 if (maxML > sufficient_len) {
@@ -955,7 +984,7 @@
                 U32 const litlen = (opt[cur].mlen == 0) ? opt[cur].litlen : 0;
                 U32 const previousPrice = opt[cur].price;
                 U32 const basePrice = previousPrice + ZSTD_litLengthPrice(0, optStatePtr, optLevel);
-                U32 const nbMatches = ZSTD_BtGetAllMatches(ms, inr, iend, dictMode, opt[cur].rep, ll0, matches, minMatch);
+                U32 const nbMatches = ZSTD_BtGetAllMatches(matches, ms, &nextToUpdate3, inr, iend, dictMode, opt[cur].rep, ll0, minMatch);
                 U32 matchNb;
                 if (!nbMatches) {
                     DEBUGLOG(7, "rPos:%u : no match found", cur);
@@ -1079,7 +1108,7 @@
     }   /* while (ip < ilimit) */
 
     /* Return the last literals size */
-    return iend - anchor;
+    return (size_t)(iend - anchor);
 }
 
 
@@ -1108,7 +1137,8 @@
 /* used in 2-pass strategy */
 MEM_STATIC void ZSTD_upscaleStats(optState_t* optPtr)
 {
-    optPtr->litSum = ZSTD_upscaleStat(optPtr->litFreq, MaxLit, 0);
+    if (ZSTD_compressedLiterals(optPtr))
+        optPtr->litSum = ZSTD_upscaleStat(optPtr->litFreq, MaxLit, 0);
     optPtr->litLengthSum = ZSTD_upscaleStat(optPtr->litLengthFreq, MaxLL, 0);
     optPtr->matchLengthSum = ZSTD_upscaleStat(optPtr->matchLengthFreq, MaxML, 0);
     optPtr->offCodeSum = ZSTD_upscaleStat(optPtr->offCodeFreq, MaxOff, 0);
@@ -1117,7 +1147,7 @@
 /* ZSTD_initStats_ultra():
  * make a first compression pass, just to seed stats with more accurate starting values.
  * only works on first block, with no dictionary and no ldm.
- * this function cannot error, hence its constract must be respected.
+ * this function cannot error, hence its contract must be respected.
  */
 static void
 ZSTD_initStats_ultra(ZSTD_matchState_t* ms,
@@ -1142,7 +1172,6 @@
     ms->window.dictLimit += (U32)srcSize;
     ms->window.lowLimit = ms->window.dictLimit;
     ms->nextToUpdate = ms->window.dictLimit;
-    ms->nextToUpdate3 = ms->window.dictLimit;
 
     /* re-inforce weight of collected statistics */
     ZSTD_upscaleStats(&ms->opt);