#include "jabber.h"
#include "MString.h"

/////////////////////////////////////////////////////////////////////////////////////////
// CMBaseString 

CNilMStringData CMBaseString::m_nil;

CMStringData* CMBaseString::Allocate( int nChars, int nCharSize )
{
	CMStringData* pData;
	nChars++; // nil char
	size_t nDataBytes = nCharSize * nChars;
	size_t nTotalSize = nDataBytes + sizeof(CMStringData);

	pData = static_cast<CMStringData*>(malloc(nTotalSize));
	if (pData == NULL)
		return NULL;

	pData->nRefs = 1;
	pData->nAllocLength = nChars - 1;
	pData->nDataLength = 0;
	return pData;
}

void CMBaseString::Free(CMStringData* pData)
{
	free(pData);
}

CMStringData* CMBaseString::Realloc(CMStringData* pData, int nChars, int nCharSize)
{
	CMStringData* pNewData;
	nChars++; // nil char
	ULONG nDataBytes = nCharSize * nChars;
	ULONG nTotalSize = nDataBytes + sizeof(CMStringData);

	pNewData = static_cast<CMStringData*>(realloc(pData, nTotalSize));
	if (pNewData == NULL)
		return NULL;

	pNewData->nAllocLength = nChars - 1;
	return pNewData;
}

CMStringData* CMBaseString::GetNilString()
{
	m_nil.AddRef();
	return &m_nil;
}

/////////////////////////////////////////////////////////////////////////////////////////
// CMStringData

void* CMStringData::data()
{
	return (this + 1);
}

void CMStringData::AddRef()
{
	InterlockedIncrement(&nRefs);
}

bool CMStringData::IsLocked() const
{
	return nRefs < 0;
}

bool CMStringData::IsShared() const
{
	return (nRefs > 1);
}

void CMStringData::Lock()
{
	nRefs--;  // Locked buffers can't be shared, so no interlocked operation necessary
	if ( nRefs == 0 )
		nRefs = -1;
}

void CMStringData::Release()
{
	if (InterlockedDecrement(&nRefs) <= 0)
		CMBaseString::Free(this);
}

void CMStringData::Unlock()
{
	if (IsLocked())
	{
		nRefs++;  // Locked buffers can't be shared, so no interlocked operation necessary
		if (nRefs == 0)
			nRefs = 1;
	}
}

CNilMStringData::CNilMStringData()
{
	nRefs = 2;  // Never gets freed
	nDataLength = 0;
	nAllocLength = 0;
	achNil[0] = 0;
	achNil[1] = 0;
}

/////////////////////////////////////////////////////////////////////////////////////////
// ChTraitsCRT<wchar_t>

#if _MSC_VER < 1400
static HINSTANCE hCrt = NULL;

typedef int ( __cdecl *_vscprintf_func )( LPCSTR pszFormat, va_list args );
static _vscprintf_func _vscprintf_ptr = NULL;

typedef int ( __cdecl *_vscwprintf_func )( LPCWSTR pszFormat, va_list args );
static _vscwprintf_func _vscwprintf_ptr = NULL;

typedef int ( __cdecl *_vsnprintf_func )( char*, size_t, const char*, va_list );
static _vsnprintf_func _vsnprintf_ptr = NULL;

typedef int ( __cdecl *_vsnwprintf_func )( wchar_t *, size_t, const wchar_t *, va_list );
static _vsnwprintf_func _vsnwprintf_ptr = NULL;

typedef int ( __cdecl *vswprintf_func )( wchar_t *, size_t, const wchar_t *, va_list );
static vswprintf_func vswprintf_ptr = NULL;

typedef int ( __cdecl *vsprintf_func )( char*, size_t, const char*, va_list );
static vsprintf_func vsprintf_ptr = NULL;

static void checkCrt( void )
{
	if ( hCrt == NULL ) {
		hCrt = GetModuleHandleA( "msvcrt.dll" );
		_vscprintf_ptr = (_vscprintf_func)GetProcAddress( hCrt, "_vscprintf" );
		_vscwprintf_ptr = (_vscwprintf_func)GetProcAddress( hCrt, "_vscwprintf" );
		_vsnprintf_ptr = (_vsnprintf_func)GetProcAddress( hCrt, "_vsnprintf" );
		_vsnwprintf_ptr = (_vsnwprintf_func)GetProcAddress( hCrt, "_vsnwprintf" );
		vswprintf_ptr = (vswprintf_func)GetProcAddress( hCrt, "vswprintf" );
		vsprintf_ptr = (vsprintf_func)GetProcAddress( hCrt, "vsprintf" );
}	}
#endif

int __stdcall ChTraitsCRT<wchar_t>::GetFormattedLength( LPCWSTR pszFormat, va_list args )
{
	#if _MSC_VER < 1400
		checkCrt();

		if ( _vscwprintf_ptr != NULL )
			return _vscwprintf_ptr( pszFormat, args );

		WCHAR buf[ 4000 ];
		return vswprintf_ptr( buf, SIZEOF(buf), pszFormat, args );
	#else
		return _vscwprintf( pszFormat, args );
	#endif
}

int __stdcall ChTraitsCRT<wchar_t>::Format( LPWSTR pszBuffer, size_t nLength, LPCWSTR pszFormat, va_list args)
{
	#if _MSC_VER < 1400
		checkCrt();

		if ( _vsnwprintf_ptr != NULL )
			return _vsnwprintf_ptr( pszBuffer, nLength, pszFormat, args );

		return vswprintf_ptr( pszBuffer, nLength, pszFormat, args );
	#else
		return _vsnwprintf( pszBuffer, nLength, pszFormat, args );
	#endif
}

/////////////////////////////////////////////////////////////////////////////////////////
// ChTraitsCRT<char>

int __stdcall ChTraitsCRT<char>::GetFormattedLength( LPCSTR pszFormat, va_list args )
{
	#if _MSC_VER < 1400
		checkCrt();

		if ( _vscprintf_ptr != NULL )
			return _vscprintf_ptr( pszFormat, args );

		char buf[4000];
		return vsprintf_ptr( buf, sizeof(buf), pszFormat, args );
	#else
		return _vscprintf( pszFormat, args );
	#endif
}

int __stdcall ChTraitsCRT<char>::Format( LPSTR pszBuffer, size_t nlength, LPCSTR pszFormat, va_list args )
{
	#if _MSC_VER < 1400
		checkCrt();

		return _vsnprintf( pszBuffer, nlength, pszFormat, args );
	#else
		return vsprintf_s( pszBuffer, nlength, pszFormat, args );
	#endif
}