mercurial/hgweb/request.py
changeset 36857 da4e2f87167d
parent 36856 1f7d9024674c
child 36858 01f6bba64424
--- a/mercurial/hgweb/request.py	Sat Mar 10 10:56:10 2018 -0800
+++ b/mercurial/hgweb/request.py	Sat Mar 10 11:06:13 2018 -0800
@@ -61,7 +61,10 @@
 
 @attr.s(frozen=True)
 class parsedrequest(object):
-    """Represents a parsed WSGI request / static HTTP request parameters."""
+    """Represents a parsed WSGI request.
+
+    Contains both parsed parameters as well as a handle on the input stream.
+    """
 
     # Request method.
     method = attr.ib()
@@ -91,8 +94,10 @@
     # wsgiref.headers.Headers instance. Operates like a dict with case
     # insensitive keys.
     headers = attr.ib()
+    # Request body input stream.
+    bodyfh = attr.ib()
 
-def parserequestfromenv(env):
+def parserequestfromenv(env, bodyfh):
     """Parse URL components from environment variables.
 
     WSGI defines request attributes via environment variables. This function
@@ -209,6 +214,12 @@
     if 'CONTENT_LENGTH' in env and 'HTTP_CONTENT_LENGTH' not in env:
         headers['Content-Length'] = env['CONTENT_LENGTH']
 
+    # TODO do this once we remove wsgirequest.inp, otherwise we could have
+    # multiple readers from the underlying input stream.
+    #bodyfh = env['wsgi.input']
+    #if 'Content-Length' in headers:
+    #    bodyfh = util.cappedreader(bodyfh, int(headers['Content-Length']))
+
     return parsedrequest(method=env['REQUEST_METHOD'],
                          url=fullurl, baseurl=baseurl,
                          advertisedurl=advertisedfullurl,
@@ -219,7 +230,8 @@
                          querystring=querystring,
                          querystringlist=querystringlist,
                          querystringdict=querystringdict,
-                         headers=headers)
+                         headers=headers,
+                         bodyfh=bodyfh)
 
 class wsgirequest(object):
     """Higher-level API for a WSGI request.
@@ -233,28 +245,27 @@
         if (version < (1, 0)) or (version >= (2, 0)):
             raise RuntimeError("Unknown and unsupported WSGI version %d.%d"
                                % version)
-        self.inp = wsgienv[r'wsgi.input']
+
+        inp = wsgienv[r'wsgi.input']
 
         if r'HTTP_CONTENT_LENGTH' in wsgienv:
-            self.inp = util.cappedreader(self.inp,
-                                         int(wsgienv[r'HTTP_CONTENT_LENGTH']))
+            inp = util.cappedreader(inp, int(wsgienv[r'HTTP_CONTENT_LENGTH']))
         elif r'CONTENT_LENGTH' in wsgienv:
-            self.inp = util.cappedreader(self.inp,
-                                         int(wsgienv[r'CONTENT_LENGTH']))
+            inp = util.cappedreader(inp, int(wsgienv[r'CONTENT_LENGTH']))
 
         self.err = wsgienv[r'wsgi.errors']
         self.threaded = wsgienv[r'wsgi.multithread']
         self.multiprocess = wsgienv[r'wsgi.multiprocess']
         self.run_once = wsgienv[r'wsgi.run_once']
         self.env = wsgienv
-        self.form = normalize(cgi.parse(self.inp,
+        self.form = normalize(cgi.parse(inp,
                                         self.env,
                                         keep_blank_values=1))
         self._start_response = start_response
         self.server_write = None
         self.headers = []
 
-        self.req = parserequestfromenv(wsgienv)
+        self.req = parserequestfromenv(wsgienv, inp)
 
     def respond(self, status, type, filename=None, body=None):
         if not isinstance(type, str):
@@ -315,7 +326,7 @@
                 # input stream doesn't overrun the actual request. So there's
                 # no guarantee that reading until EOF won't corrupt the stream
                 # state.
-                if not isinstance(self.inp, util.cappedreader):
+                if not isinstance(self.req.bodyfh, util.cappedreader):
                     close = True
                 else:
                     # We /could/ only drain certain HTTP response codes. But 200
@@ -329,9 +340,9 @@
                 self.headers.append((r'Connection', r'Close'))
 
             if drain:
-                assert isinstance(self.inp, util.cappedreader)
+                assert isinstance(self.req.bodyfh, util.cappedreader)
                 while True:
-                    chunk = self.inp.read(32768)
+                    chunk = self.req.bodyfh.read(32768)
                     if not chunk:
                         break