// zinflate.cpp - written and placed in the public domain by Wei Dai

// This is a complete reimplementation of the DEFLATE decompression algorithm.
// It should not be affected by any security vulnerabilities in the zlib 
// compression library. In particular it is not affected by the double free bug
// (http://www.kb.cert.org/vuls/id/368819).

#include "pch.h"
#include "zinflate.h"


struct CodeLessThan
	inline bool operator()(CryptoPP::HuffmanDecoder::code_t lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs)
		{return lhs < rhs.code;}
	// needed for MSVC .NET 2005
	inline bool operator()(const CryptoPP::HuffmanDecoder::CodeInfo &lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs)
		{return lhs.code < rhs.code;}

inline bool LowFirstBitReader::FillBuffer(unsigned int length)
	while (m_bitsBuffered < length)
		byte b;
		if (!m_store.Get(b))
			return false;
		m_buffer |= (unsigned long)b << m_bitsBuffered;
		m_bitsBuffered += 8;
	assert(m_bitsBuffered <= sizeof(unsigned long)*8);
	return true;

inline unsigned long LowFirstBitReader::PeekBits(unsigned int length)
	bool result = FillBuffer(length);
	return m_buffer & (((unsigned long)1 << length) - 1);

inline void LowFirstBitReader::SkipBits(unsigned int length)
	assert(m_bitsBuffered >= length);
	m_buffer >>= length;
	m_bitsBuffered -= length;

inline unsigned long LowFirstBitReader::GetBits(unsigned int length)
	unsigned long result = PeekBits(length);
	return result;

inline HuffmanDecoder::code_t HuffmanDecoder::NormalizeCode(HuffmanDecoder::code_t code, unsigned int codeBits)
	return code << (MAX_CODE_BITS - codeBits);

void HuffmanDecoder::Initialize(const unsigned int *codeBits, unsigned int nCodes)
	// the Huffman codes are represented in 3 ways in this code:
	// 1. most significant code bit (i.e. top of code tree) in the least significant bit position
	// 2. most significant code bit (i.e. top of code tree) in the most significant bit position
	// 3. most significant code bit (i.e. top of code tree) in n-th least significant bit position,
	//    where n is the maximum code length for this code tree
	// (1) is the way the codes come in from the deflate stream
	// (2) is used to sort codes so they can be binary searched
	// (3) is used in this function to compute codes from code lengths
	// a code in representation (2) is called "normalized" here
	// The BitReverse() function is used to convert between (1) and (2)
	// The NormalizeCode() function is used to convert from (3) to (2)

	if (nCodes == 0)
		throw Err("null code");

	m_maxCodeBits = *std::max_element(codeBits, codeBits+nCodes);

	if (m_maxCodeBits > MAX_CODE_BITS)
		throw Err("code length exceeds maximum");

	if (m_maxCodeBits == 0)
		throw Err("null code");

	// count number of codes of each length
	SecBlockWithHint<unsigned int, 15+1> blCount(m_maxCodeBits+1);
	std::fill(blCount.begin(), blCount.end(), 0);
	unsigned int i;
	for (i=0; i<nCodes; i++)

	// compute the starting code of each length
	code_t code = 0;
	SecBlockWithHint<code_t, 15+1> nextCode(m_maxCodeBits+1);
	nextCode[1] = 0;
	for (i=2; i<=m_maxCodeBits; i++)
		// compute this while checking for overflow: code = (code + blCount[i-1]) << 1
		if (code > code + blCount[i-1])
			throw Err("codes oversubscribed");
		code += blCount[i-1];
		if (code > (code << 1))
			throw Err("codes oversubscribed");
		code <<= 1;
		nextCode[i] = code;

	if (code > (1 << m_maxCodeBits) - blCount[m_maxCodeBits])
		throw Err("codes oversubscribed");
	else if (m_maxCodeBits != 1 && code < (1 << m_maxCodeBits) - blCount[m_maxCodeBits])
		throw Err("codes incomplete");

	// compute a vector of <code, length, value> triples sorted by code
	m_codeToValue.resize(nCodes - blCount[0]);
	unsigned int j=0;
	for (i=0; i<nCodes; i++) 
		unsigned int len = codeBits[i];
		if (len != 0)
			code = NormalizeCode(nextCode[len]++, len);
			m_codeToValue[j].code = code;
			m_codeToValue[j].len = len;
			m_codeToValue[j].value = i;
	std::sort(m_codeToValue.begin(), m_codeToValue.end());

	// initialize the decoding cache
	m_cacheBits = STDMIN(9U, m_maxCodeBits);
	m_cacheMask = (1 << m_cacheBits) - 1;
	m_normalizedCacheMask = NormalizeCode(m_cacheMask, m_cacheBits);
	assert(m_normalizedCacheMask == BitReverse(m_cacheMask));

	if (m_cache.size() != size_t(1) << m_cacheBits)
		m_cache.resize(1 << m_cacheBits);

	for (i=0; i<m_cache.size(); i++)
		m_cache[i].type = 0;

void HuffmanDecoder::FillCacheEntry(LookupEntry &entry, code_t normalizedCode) const
	normalizedCode &= m_normalizedCacheMask;
	const CodeInfo &codeInfo = *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode, CodeLessThan())-1);
	if (codeInfo.len <= m_cacheBits)
		entry.type = 1;
		entry.value = codeInfo.value;
		entry.len = codeInfo.len;
		entry.begin = &codeInfo;
		const CodeInfo *last = & *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode + ~m_normalizedCacheMask, CodeLessThan())-1);
		if (codeInfo.len == last->len)
			entry.type = 2;
			entry.len = codeInfo.len;
			entry.type = 3;
			entry.end = last+1;

inline unsigned int HuffmanDecoder::Decode(code_t code, /* out */ value_t &value) const
	assert(m_codeToValue.size() > 0);
	LookupEntry &entry = m_cache[code & m_cacheMask];

	code_t normalizedCode;
	if (entry.type != 1)
		normalizedCode = BitReverse(code);

	if (entry.type == 0)
		FillCacheEntry(entry, normalizedCode);

	if (entry.type == 1)
		value = entry.value;
		return entry.len;
		const CodeInfo &codeInfo = (entry.type == 2)
			? entry.begin[(normalizedCode << m_cacheBits) >> (MAX_CODE_BITS - (entry.len - m_cacheBits))]
			: *(std::upper_bound(entry.begin, entry.end, normalizedCode, CodeLessThan())-1);
		value = codeInfo.value;
		return codeInfo.len;

bool HuffmanDecoder::Decode(LowFirstBitReader &reader, value_t &value) const
	unsigned int codeBits = Decode(reader.PeekBuffer(), value);
	if (codeBits > reader.BitsBuffered())
		return false;
	return true;

// *************************************************************

Inflator::Inflator(BufferedTransformation *attachment, bool repeat, int propagation)
	: AutoSignaling<Filter>(propagation)
	, m_state(PRE_STREAM), m_repeat(repeat), m_reader(m_inQueue)

void Inflator::IsolatedInitialize(const NameValuePairs &parameters)
	m_state = PRE_STREAM;
	parameters.GetValue("Repeat", m_repeat);

void Inflator::OutputByte(byte b)
	m_window[m_current++] = b;
	if (m_current == m_window.size())
		ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush);
		m_lastFlush = 0;
		m_current = 0;
		m_wrappedAround = true;

void Inflator::OutputString(const byte *string, size_t length)
	while (length)
		size_t len = UnsignedMin(length, m_window.size() - m_current);
		memcpy(m_window + m_current, string, len);
		m_current += len;
		if (m_current == m_window.size())
			ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush);
			m_lastFlush = 0;
			m_current = 0;
			m_wrappedAround = true;
		string += len;
		length -= len;

void Inflator::OutputPast(unsigned int length, unsigned int distance)
	size_t start;
	if (distance <= m_current)
		start = m_current - distance;
	else if (m_wrappedAround && distance <= m_window.size())
		start = m_current + m_window.size() - distance;
		throw BadBlockErr();

	if (start + length > m_window.size())
		for (; start < m_window.size(); start++, length--)
		start = 0;

	if (start + length > m_current || m_current + length >= m_window.size())
		while (length--)
		memcpy(m_window + m_current, m_window + start, length);
		m_current += length;

size_t Inflator::Put2(const byte *inString, size_t length, int messageEnd, bool blocking)
	if (!blocking)
		throw BlockingInputOnly("Inflator");

	LazyPutter lp(m_inQueue, inString, length);
	ProcessInput(messageEnd != 0);

	if (messageEnd)
		if (!(m_state == PRE_STREAM || m_state == AFTER_END))
			throw UnexpectedEndErr();

	Output(0, NULL, 0, messageEnd, blocking);
	return 0;

bool Inflator::IsolatedFlush(bool hardFlush, bool blocking)
	if (!blocking)
		throw BlockingInputOnly("Inflator");

	if (hardFlush)

	return false;

void Inflator::ProcessInput(bool flush)
	while (true)
		switch (m_state)
		case PRE_STREAM:
			if (!flush && m_inQueue.CurrentSize() < MaxPrestreamHeaderSize())
			m_state = WAIT_HEADER;
			m_wrappedAround = false;
			m_current = 0;
			m_lastFlush = 0;
			m_window.New(1 << GetLog2WindowSize());
			// maximum number of bytes before actual compressed data starts
			const size_t MAX_HEADER_SIZE = BitsToBytes(3+5+5+4+19*7+286*15+19*15);
			if (m_inQueue.CurrentSize() < (flush ? 1 : MAX_HEADER_SIZE))
			if (!DecodeBody())
			if (!flush && m_inQueue.CurrentSize() < MaxPoststreamTailSize())
			m_state = m_repeat ? PRE_STREAM : AFTER_END;
			Output(0, NULL, 0, GetAutoSignalPropagation(), true);	// TODO: non-blocking
			if (m_inQueue.IsEmpty())
		case AFTER_END:

void Inflator::DecodeHeader()
	if (!m_reader.FillBuffer(3))
		throw UnexpectedEndErr();
	m_eof = m_reader.GetBits(1) != 0;
	m_blockType = (byte)m_reader.GetBits(2);
	switch (m_blockType)
	case 0:	// stored
		m_reader.SkipBits(m_reader.BitsBuffered() % 8);
		if (!m_reader.FillBuffer(32))
			throw UnexpectedEndErr();
		m_storedLen = (word16)m_reader.GetBits(16);
		word16 nlen = (word16)m_reader.GetBits(16);
		if (nlen != (word16)~m_storedLen)
			throw BadBlockErr();
	case 1:	// fixed codes
		m_nextDecode = LITERAL;
	case 2:	// dynamic codes
		if (!m_reader.FillBuffer(5+5+4))
			throw UnexpectedEndErr();
		unsigned int hlit = m_reader.GetBits(5);
		unsigned int hdist = m_reader.GetBits(5);
		unsigned int hclen = m_reader.GetBits(4);

		FixedSizeSecBlock<unsigned int, 286+32> codeLengths;
		unsigned int i;
		static const unsigned int border[] = {    // Order of the bit length code lengths
			16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
		std::fill(codeLengths.begin(), codeLengths+19, 0);
		for (i=0; i<hclen+4; i++)
			codeLengths[border[i]] = m_reader.GetBits(3);

			HuffmanDecoder codeLengthDecoder(codeLengths, 19);
			for (i = 0; i < hlit+257+hdist+1; )
				unsigned int k, count, repeater;
				bool result = codeLengthDecoder.Decode(m_reader, k);
				if (!result)
					throw UnexpectedEndErr();
				if (k <= 15)
					count = 1;
					repeater = k;
				else switch (k)
				case 16:
					if (!m_reader.FillBuffer(2))
						throw UnexpectedEndErr();
					count = 3 + m_reader.GetBits(2);
					if (i == 0)
						throw BadBlockErr();
					repeater = codeLengths[i-1];
				case 17:
					if (!m_reader.FillBuffer(3))
						throw UnexpectedEndErr();
					count = 3 + m_reader.GetBits(3);
					repeater = 0;
				case 18:
					if (!m_reader.FillBuffer(7))
						throw UnexpectedEndErr();
					count = 11 + m_reader.GetBits(7);
					repeater = 0;
				if (i + count > hlit+257+hdist+1)
					throw BadBlockErr();
				std::fill(codeLengths + i, codeLengths + i + count, repeater);
				i += count;
			m_dynamicLiteralDecoder.Initialize(codeLengths, hlit+257);
			if (hdist == 0 && codeLengths[hlit+257] == 0)
				if (hlit != 0)	// a single zero distance code length means all literals
					throw BadBlockErr();
				m_dynamicDistanceDecoder.Initialize(codeLengths+hlit+257, hdist+1);
			m_nextDecode = LITERAL;
		catch (HuffmanDecoder::Err &)
			throw BadBlockErr();
		throw BadBlockErr();	// reserved block type
	m_state = DECODING_BODY;

bool Inflator::DecodeBody()
	bool blockEnd = false;
	switch (m_blockType)
	case 0:	// stored
		assert(m_reader.BitsBuffered() == 0);
		while (!m_inQueue.IsEmpty() && !blockEnd)
			size_t size;
			const byte *block = m_inQueue.Spy(size);
			size = UnsignedMin(m_storedLen, size);
			OutputString(block, size);
			m_storedLen -= (word16)size;
			if (m_storedLen == 0)
				blockEnd = true;
	case 1:	// fixed codes
	case 2:	// dynamic codes
		static const unsigned int lengthStarts[] = {
			3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31,
			35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258};
		static const unsigned int lengthExtraBits[] = {
			0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,
			3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0};
		static const unsigned int distanceStarts[] = {
			1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193,
			257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145,
			8193, 12289, 16385, 24577};
		static const unsigned int distanceExtraBits[] = {
			0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
			7, 7, 8, 8, 9, 9, 10, 10, 11, 11,
			12, 12, 13, 13};

		const HuffmanDecoder& literalDecoder = GetLiteralDecoder();
		const HuffmanDecoder& distanceDecoder = GetDistanceDecoder();

		switch (m_nextDecode)
		case LITERAL:
			while (true)
				if (!literalDecoder.Decode(m_reader, m_literal))
					m_nextDecode = LITERAL;
				if (m_literal < 256)
				else if (m_literal == 256)	// end of block
					blockEnd = true;
					if (m_literal > 285)
						throw BadBlockErr();
					unsigned int bits;
					bits = lengthExtraBits[m_literal-257];
					if (!m_reader.FillBuffer(bits))
						m_nextDecode = LENGTH_BITS;
					m_literal = m_reader.GetBits(bits) + lengthStarts[m_literal-257];
		case DISTANCE:
					if (!distanceDecoder.Decode(m_reader, m_distance))
						m_nextDecode = DISTANCE;
					bits = distanceExtraBits[m_distance];
					if (!m_reader.FillBuffer(bits))
						m_nextDecode = DISTANCE_BITS;
					m_distance = m_reader.GetBits(bits) + distanceStarts[m_distance];
					OutputPast(m_literal, m_distance);
	if (blockEnd)
		if (m_eof)
			if (m_reader.BitsBuffered())
				// undo too much lookahead
				SecBlockWithHint<byte, 4> buffer(m_reader.BitsBuffered() / 8);
				for (unsigned int i=0; i<buffer.size(); i++)
					buffer[i] = (byte)m_reader.GetBits(8);
				m_inQueue.Unget(buffer, buffer.size());
			m_state = POST_STREAM;
			m_state = WAIT_HEADER;
	return blockEnd;

void Inflator::FlushOutput()
	if (m_state != PRE_STREAM)
		assert(m_current >= m_lastFlush);
		ProcessDecompressedData(m_window + m_lastFlush, m_current - m_lastFlush);
		m_lastFlush = m_current;

struct NewFixedLiteralDecoder
	HuffmanDecoder * operator()() const
		unsigned int codeLengths[288];
		std::fill(codeLengths + 0, codeLengths + 144, 8);
		std::fill(codeLengths + 144, codeLengths + 256, 9);
		std::fill(codeLengths + 256, codeLengths + 280, 7);
		std::fill(codeLengths + 280, codeLengths + 288, 8);
		std::auto_ptr<HuffmanDecoder> pDecoder(new HuffmanDecoder);
		pDecoder->Initialize(codeLengths, 288);
		return pDecoder.release();

struct NewFixedDistanceDecoder
	HuffmanDecoder * operator()() const
		unsigned int codeLengths[32];
		std::fill(codeLengths + 0, codeLengths + 32, 5);
		std::auto_ptr<HuffmanDecoder> pDecoder(new HuffmanDecoder);
		pDecoder->Initialize(codeLengths, 32);
		return pDecoder.release();

const HuffmanDecoder& Inflator::GetLiteralDecoder() const
	return m_blockType == 1 ? Singleton<HuffmanDecoder, NewFixedLiteralDecoder>().Ref() : m_dynamicLiteralDecoder;

const HuffmanDecoder& Inflator::GetDistanceDecoder() const
	return m_blockType == 1 ? Singleton<HuffmanDecoder, NewFixedDistanceDecoder>().Ref() : m_dynamicDistanceDecoder;