contrib/python-zstandard/tests/test_decompressor.py
changeset 40121 73fef626dae3
parent 37495 b1fb341d8a61
child 42070 675775c33ab6
--- a/contrib/python-zstandard/tests/test_decompressor.py	Tue Sep 25 20:55:03 2018 +0900
+++ b/contrib/python-zstandard/tests/test_decompressor.py	Mon Oct 08 16:27:40 2018 -0700
@@ -293,10 +293,6 @@
     def test_context_manager(self):
         dctx = zstd.ZstdDecompressor()
 
-        reader = dctx.stream_reader(b'foo')
-        with self.assertRaisesRegexp(zstd.ZstdError, 'read\(\) must be called from an active'):
-            reader.read(1)
-
         with dctx.stream_reader(b'foo') as reader:
             with self.assertRaisesRegexp(ValueError, 'cannot __enter__ multiple times'):
                 with reader as reader2:
@@ -331,17 +327,23 @@
         dctx = zstd.ZstdDecompressor()
 
         with dctx.stream_reader(b'foo') as reader:
+            self.assertFalse(reader.closed)
             self.assertTrue(reader.readable())
             self.assertFalse(reader.writable())
             self.assertTrue(reader.seekable())
             self.assertFalse(reader.isatty())
+            self.assertFalse(reader.closed)
             self.assertIsNone(reader.flush())
+            self.assertFalse(reader.closed)
+
+        self.assertTrue(reader.closed)
 
     def test_read_closed(self):
         dctx = zstd.ZstdDecompressor()
 
         with dctx.stream_reader(b'foo') as reader:
             reader.close()
+            self.assertTrue(reader.closed)
             with self.assertRaisesRegexp(ValueError, 'stream is closed'):
                 reader.read(1)
 
@@ -372,10 +374,10 @@
             self.assertEqual(reader.tell(), len(source))
 
             # Read after EOF should return empty bytes.
-            self.assertEqual(reader.read(), b'')
+            self.assertEqual(reader.read(1), b'')
             self.assertEqual(reader.tell(), len(result))
 
-        self.assertTrue(reader.closed())
+        self.assertTrue(reader.closed)
 
     def test_read_buffer_small_chunks(self):
         cctx = zstd.ZstdCompressor()
@@ -408,8 +410,11 @@
             chunk = reader.read(8192)
             self.assertEqual(chunk, source)
             self.assertEqual(reader.tell(), len(source))
-            self.assertEqual(reader.read(), b'')
+            self.assertEqual(reader.read(1), b'')
             self.assertEqual(reader.tell(), len(source))
+            self.assertFalse(reader.closed)
+
+        self.assertTrue(reader.closed)
 
     def test_read_stream_small_chunks(self):
         cctx = zstd.ZstdCompressor()
@@ -440,7 +445,9 @@
             while reader.read(16):
                 pass
 
-        with self.assertRaisesRegexp(zstd.ZstdError, 'read\(\) must be called from an active'):
+        self.assertTrue(reader.closed)
+
+        with self.assertRaisesRegexp(ValueError, 'stream is closed'):
             reader.read(10)
 
     def test_illegal_seeks(self):
@@ -474,8 +481,7 @@
             with self.assertRaisesRegexp(ValueError, 'stream is closed'):
                 reader.seek(4, os.SEEK_SET)
 
-        with self.assertRaisesRegexp(
-            zstd.ZstdError, 'seek\(\) must be called from an active context'):
+        with self.assertRaisesRegexp(ValueError, 'stream is closed'):
             reader.seek(0)
 
     def test_seek(self):
@@ -492,6 +498,39 @@
             reader.seek(4, os.SEEK_CUR)
             self.assertEqual(reader.read(2), b'ar')
 
+    def test_no_context_manager(self):
+        source = b'foobar' * 60
+        cctx = zstd.ZstdCompressor()
+        frame = cctx.compress(source)
+
+        dctx = zstd.ZstdDecompressor()
+        reader = dctx.stream_reader(frame)
+
+        self.assertEqual(reader.read(6), b'foobar')
+        self.assertEqual(reader.read(18), b'foobar' * 3)
+        self.assertFalse(reader.closed)
+
+        # Calling close prevents subsequent use.
+        reader.close()
+        self.assertTrue(reader.closed)
+
+        with self.assertRaisesRegexp(ValueError, 'stream is closed'):
+            reader.read(6)
+
+    def test_read_after_error(self):
+        source = io.BytesIO(b'')
+        dctx = zstd.ZstdDecompressor()
+
+        reader = dctx.stream_reader(source)
+
+        with reader:
+            with self.assertRaises(TypeError):
+                reader.read()
+
+        with reader:
+            with self.assertRaisesRegexp(ValueError, 'stream is closed'):
+                reader.read(100)
+
 
 @make_cffi
 class TestDecompressor_decompressobj(unittest.TestCase):