contrib/python-zstandard/tests/test_train_dictionary.py
changeset 31796 e0dc40530c5a
parent 30895 c32454d69b85
child 37495 b1fb341d8a61
--- a/contrib/python-zstandard/tests/test_train_dictionary.py	Sat Apr 01 13:43:52 2017 -0700
+++ b/contrib/python-zstandard/tests/test_train_dictionary.py	Sat Apr 01 15:24:03 2017 -0700
@@ -48,3 +48,63 @@
 
         data = d.as_bytes()
         self.assertEqual(data[0:4], b'\x37\xa4\x30\xec')
+
+    def test_set_dict_id(self):
+        samples = []
+        for i in range(128):
+            samples.append(b'foo' * 64)
+            samples.append(b'foobar' * 64)
+
+        d = zstd.train_dictionary(8192, samples, dict_id=42)
+        self.assertEqual(d.dict_id(), 42)
+
+
+@make_cffi
+class TestTrainCoverDictionary(unittest.TestCase):
+    def test_no_args(self):
+        with self.assertRaises(TypeError):
+            zstd.train_cover_dictionary()
+
+    def test_bad_args(self):
+        with self.assertRaises(TypeError):
+            zstd.train_cover_dictionary(8192, u'foo')
+
+        with self.assertRaises(ValueError):
+            zstd.train_cover_dictionary(8192, [u'foo'])
+
+    def test_basic(self):
+        samples = []
+        for i in range(128):
+            samples.append(b'foo' * 64)
+            samples.append(b'foobar' * 64)
+
+        d = zstd.train_cover_dictionary(8192, samples, k=64, d=16)
+        self.assertIsInstance(d.dict_id(), int_type)
+
+        data = d.as_bytes()
+        self.assertEqual(data[0:4], b'\x37\xa4\x30\xec')
+
+        self.assertEqual(d.k, 64)
+        self.assertEqual(d.d, 16)
+
+    def test_set_dict_id(self):
+        samples = []
+        for i in range(128):
+            samples.append(b'foo' * 64)
+            samples.append(b'foobar' * 64)
+
+        d = zstd.train_cover_dictionary(8192, samples, k=64, d=16,
+                                        dict_id=42)
+        self.assertEqual(d.dict_id(), 42)
+
+    def test_optimize(self):
+        samples = []
+        for i in range(128):
+            samples.append(b'foo' * 64)
+            samples.append(b'foobar' * 64)
+
+        d = zstd.train_cover_dictionary(8192, samples, optimize=True,
+                                        threads=-1, steps=1, d=16)
+
+        self.assertEqual(d.k, 16)
+        self.assertEqual(d.d, 16)