contrib/python-zstandard/c-ext/compressiondict.c
changeset 31796 e0dc40530c5a
parent 30895 c32454d69b85
child 37495 b1fb341d8a61
--- a/contrib/python-zstandard/c-ext/compressiondict.c	Sat Apr 01 13:43:52 2017 -0700
+++ b/contrib/python-zstandard/c-ext/compressiondict.c	Sat Apr 01 15:24:03 2017 -0700
@@ -11,46 +11,48 @@
 extern PyObject* ZstdError;
 
 ZstdCompressionDict* train_dictionary(PyObject* self, PyObject* args, PyObject* kwargs) {
-	static char *kwlist[] = { "dict_size", "samples", "parameters", NULL };
+	static char* kwlist[] = {
+		"dict_size",
+		"samples",
+		"selectivity",
+		"level",
+		"notifications",
+		"dict_id",
+		NULL
+	};
 	size_t capacity;
 	PyObject* samples;
 	Py_ssize_t samplesLen;
-	PyObject* parameters = NULL;
+	unsigned  selectivity = 0;
+	int level = 0;
+	unsigned notifications = 0;
+	unsigned dictID = 0;
 	ZDICT_params_t zparams;
 	Py_ssize_t sampleIndex;
 	Py_ssize_t sampleSize;
 	PyObject* sampleItem;
 	size_t zresult;
-	void* sampleBuffer;
+	void* sampleBuffer = NULL;
 	void* sampleOffset;
 	size_t samplesSize = 0;
-	size_t* sampleSizes;
-	void* dict;
-	ZstdCompressionDict* result;
+	size_t* sampleSizes = NULL;
+	void* dict = NULL;
+	ZstdCompressionDict* result = NULL;
 
-	if (!PyArg_ParseTupleAndKeywords(args, kwargs, "nO!|O!:train_dictionary",
+	if (!PyArg_ParseTupleAndKeywords(args, kwargs, "nO!|IiII:train_dictionary",
 		kwlist,
 		&capacity,
 		&PyList_Type, &samples,
-		(PyObject*)&DictParametersType, &parameters)) {
+		&selectivity, &level, &notifications, &dictID)) {
 		return NULL;
 	}
 
-	/* Validate parameters first since it is easiest. */
-	zparams.selectivityLevel = 0;
-	zparams.compressionLevel = 0;
-	zparams.notificationLevel = 0;
-	zparams.dictID = 0;
-	zparams.reserved[0] = 0;
-	zparams.reserved[1] = 0;
+	memset(&zparams, 0, sizeof(zparams));
 
-	if (parameters) {
-		/* TODO validate data ranges */
-		zparams.selectivityLevel = PyLong_AsUnsignedLong(PyTuple_GetItem(parameters, 0));
-		zparams.compressionLevel = PyLong_AsLong(PyTuple_GetItem(parameters, 1));
-		zparams.notificationLevel = PyLong_AsUnsignedLong(PyTuple_GetItem(parameters, 2));
-		zparams.dictID = PyLong_AsUnsignedLong(PyTuple_GetItem(parameters, 3));
-	}
+	zparams.selectivityLevel = selectivity;
+	zparams.compressionLevel = level;
+	zparams.notificationLevel = notifications;
+	zparams.dictID = dictID;
 
 	/* Figure out the size of the raw samples */
 	samplesLen = PyList_Size(samples);
@@ -68,13 +70,12 @@
 	sampleBuffer = PyMem_Malloc(samplesSize);
 	if (!sampleBuffer) {
 		PyErr_NoMemory();
-		return NULL;
+		goto finally;
 	}
 	sampleSizes = PyMem_Malloc(samplesLen * sizeof(size_t));
 	if (!sampleSizes) {
-		PyMem_Free(sampleBuffer);
 		PyErr_NoMemory();
-		return NULL;
+		goto finally;
 	}
 
 	sampleOffset = sampleBuffer;
@@ -89,33 +90,168 @@
 
 	dict = PyMem_Malloc(capacity);
 	if (!dict) {
-		PyMem_Free(sampleSizes);
-		PyMem_Free(sampleBuffer);
 		PyErr_NoMemory();
-		return NULL;
+		goto finally;
 	}
 
+	/* TODO consider using dup2() to redirect zstd's stderr writing to a buffer */
+	Py_BEGIN_ALLOW_THREADS
 	zresult = ZDICT_trainFromBuffer_advanced(dict, capacity,
 		sampleBuffer, sampleSizes, (unsigned int)samplesLen,
 		zparams);
+	Py_END_ALLOW_THREADS
 	if (ZDICT_isError(zresult)) {
 		PyErr_Format(ZstdError, "Cannot train dict: %s", ZDICT_getErrorName(zresult));
 		PyMem_Free(dict);
-		PyMem_Free(sampleSizes);
-		PyMem_Free(sampleBuffer);
-		return NULL;
+		goto finally;
 	}
 
 	result = PyObject_New(ZstdCompressionDict, &ZstdCompressionDictType);
 	if (!result) {
-		return NULL;
+		goto finally;
 	}
 
 	result->dictData = dict;
 	result->dictSize = zresult;
+	result->d = 0;
+	result->k = 0;
+
+finally:
+	PyMem_Free(sampleBuffer);
+	PyMem_Free(sampleSizes);
+
 	return result;
 }
 
+ZstdCompressionDict* train_cover_dictionary(PyObject* self, PyObject* args, PyObject* kwargs) {
+	static char* kwlist[] = {
+		"dict_size",
+		"samples",
+		"k",
+		"d",
+		"notifications",
+		"dict_id",
+		"level",
+		"optimize",
+		"steps",
+		"threads",
+		NULL
+	};
+
+	size_t capacity;
+	PyObject* samples;
+	unsigned k = 0;
+	unsigned d = 0;
+	unsigned notifications = 0;
+	unsigned dictID = 0;
+	int level = 0;
+	PyObject* optimize = NULL;
+	unsigned steps = 0;
+	int threads = 0;
+	COVER_params_t params;
+	Py_ssize_t samplesLen;
+	Py_ssize_t i;
+	size_t samplesSize = 0;
+	void* sampleBuffer = NULL;
+	size_t* sampleSizes = NULL;
+	void* sampleOffset;
+	Py_ssize_t sampleSize;
+	void* dict = NULL;
+	size_t zresult;
+	ZstdCompressionDict* result = NULL;
+
+	if (!PyArg_ParseTupleAndKeywords(args, kwargs, "nO!|IIIIiOIi:train_cover_dictionary",
+		kwlist, &capacity, &PyList_Type, &samples,
+		&k, &d, &notifications, &dictID, &level, &optimize, &steps, &threads)) {
+		return NULL;
+	}
+
+	if (threads < 0) {
+		threads = cpu_count();
+	}
+
+	memset(&params, 0, sizeof(params));
+	params.k = k;
+	params.d = d;
+	params.steps = steps;
+	params.nbThreads = threads;
+	params.notificationLevel = notifications;
+	params.dictID = dictID;
+	params.compressionLevel = level;
+
+	/* Figure out total size of input samples. */
+	samplesLen = PyList_Size(samples);
+	for (i = 0; i < samplesLen; i++) {
+		PyObject* sampleItem = PyList_GET_ITEM(samples, i);
+
+		if (!PyBytes_Check(sampleItem)) {
+			PyErr_SetString(PyExc_ValueError, "samples must be bytes");
+			return NULL;
+		}
+		samplesSize += PyBytes_GET_SIZE(sampleItem);
+	}
+
+	sampleBuffer = PyMem_Malloc(samplesSize);
+	if (!sampleBuffer) {
+		PyErr_NoMemory();
+		goto finally;
+	}
+
+	sampleSizes = PyMem_Malloc(samplesLen * sizeof(size_t));
+	if (!sampleSizes) {
+		PyErr_NoMemory();
+		goto finally;
+	}
+
+	sampleOffset = sampleBuffer;
+	for (i = 0; i < samplesLen; i++) {
+		PyObject* sampleItem = PyList_GET_ITEM(samples, i);
+		sampleSize = PyBytes_GET_SIZE(sampleItem);
+		sampleSizes[i] = sampleSize;
+		memcpy(sampleOffset, PyBytes_AS_STRING(sampleItem), sampleSize);
+		sampleOffset = (char*)sampleOffset + sampleSize;
+	}
+
+	dict = PyMem_Malloc(capacity);
+	if (!dict) {
+		PyErr_NoMemory();
+		goto finally;
+	}
+
+	Py_BEGIN_ALLOW_THREADS
+	if (optimize && PyObject_IsTrue(optimize)) {
+		zresult = COVER_optimizeTrainFromBuffer(dict, capacity,
+			sampleBuffer, sampleSizes, (unsigned)samplesLen, &params);
+	}
+	else {
+		zresult = COVER_trainFromBuffer(dict, capacity,
+			sampleBuffer, sampleSizes, (unsigned)samplesLen, params);
+	}
+	Py_END_ALLOW_THREADS
+
+	if (ZDICT_isError(zresult)) {
+		PyMem_Free(dict);
+		PyErr_Format(ZstdError, "cannot train dict: %s", ZDICT_getErrorName(zresult));
+		goto finally;
+	}
+
+	result = PyObject_New(ZstdCompressionDict, &ZstdCompressionDictType);
+	if (!result) {
+		PyMem_Free(dict);
+		goto finally;
+	}
+
+	result->dictData = dict;
+	result->dictSize = zresult;
+	result->d = params.d;
+	result->k = params.k;
+
+finally:
+	PyMem_Free(sampleBuffer);
+	PyMem_Free(sampleSizes);
+
+	return result;
+}
 
 PyDoc_STRVAR(ZstdCompressionDict__doc__,
 "ZstdCompressionDict(data) - Represents a computed compression dictionary\n"
@@ -180,6 +316,14 @@
 	{ NULL, NULL }
 };
 
+static PyMemberDef ZstdCompressionDict_members[] = {
+	{ "k", T_UINT, offsetof(ZstdCompressionDict, k), READONLY,
+	  "segment size" },
+	{ "d", T_UINT, offsetof(ZstdCompressionDict, d), READONLY,
+	  "dmer size" },
+	{ NULL }
+};
+
 static Py_ssize_t ZstdCompressionDict_length(ZstdCompressionDict* self) {
 	return self->dictSize;
 }
@@ -224,7 +368,7 @@
 	0,                              /* tp_iter */
 	0,                              /* tp_iternext */
 	ZstdCompressionDict_methods,    /* tp_methods */
-	0,                              /* tp_members */
+	ZstdCompressionDict_members,    /* tp_members */
 	0,                              /* tp_getset */
 	0,                              /* tp_base */
 	0,                              /* tp_dict */