Mercurial > public > mercurial-scm > hg-stable
diff hgext/remotefilelog/fileserverclient.py @ 40545:3a333a582d7b
remotefilelog: import pruned-down remotefilelog extension from hg-experimental
This is remotefilelog as of my recent patches for compatibility with
current tip of hg, minus support for old versions of Mercurial and
some FB-specific features like their treemanifest extension and
fetching linkrev data from a patched phabricator. The file extutil.py
moved from hgext3rd to remotefilelog.
This is not yet ready to be landed, consider it a preview for
now. Planned changes include:
* replace lz4 with zstd
* rename some capabilities, requirements and wireproto commands to mark
them as experimental
* consolidate bits of shallowutil with related functions (eg readfile)
I'm certainly open to other (small) changes, but my rough mission is
to land this largely as-is so we can use it as a model of the
functionality we need going forward for lazy-fetching of file contents
from a server.
# no-check-commit because of a few foo_bar functions
Differential Revision: https://phab.mercurial-scm.org/D4782
author | Augie Fackler <augie@google.com> |
---|---|
date | Thu, 27 Sep 2018 13:03:19 -0400 |
parents | |
children | 6d64e2abe8d3 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/hgext/remotefilelog/fileserverclient.py Thu Sep 27 13:03:19 2018 -0400 @@ -0,0 +1,648 @@ +# fileserverclient.py - client for communicating with the cache process +# +# Copyright 2013 Facebook, Inc. +# +# This software may be used and distributed according to the terms of the +# GNU General Public License version 2 or any later version. + +from __future__ import absolute_import + +import hashlib +import io +import os +import struct +import threading +import time + +from mercurial.i18n import _ +from mercurial.node import bin, hex, nullid +from mercurial import ( + error, + revlog, + sshpeer, + util, + wireprotov1peer, +) +from mercurial.utils import procutil + +from . import ( + constants, + contentstore, + lz4wrapper, + metadatastore, + shallowutil, + wirepack, +) + +_sshv1peer = sshpeer.sshv1peer + +# Statistics for debugging +fetchcost = 0 +fetches = 0 +fetched = 0 +fetchmisses = 0 + +_lfsmod = None +_downloading = _('downloading') + +def getcachekey(reponame, file, id): + pathhash = hashlib.sha1(file).hexdigest() + return os.path.join(reponame, pathhash[:2], pathhash[2:], id) + +def getlocalkey(file, id): + pathhash = hashlib.sha1(file).hexdigest() + return os.path.join(pathhash, id) + +def peersetup(ui, peer): + + class remotefilepeer(peer.__class__): + @wireprotov1peer.batchable + def getfile(self, file, node): + if not self.capable('getfile'): + raise error.Abort( + 'configured remotefile server does not support getfile') + f = wireprotov1peer.future() + yield {'file': file, 'node': node}, f + code, data = f.value.split('\0', 1) + if int(code): + raise error.LookupError(file, node, data) + yield data + + @wireprotov1peer.batchable + def getflogheads(self, path): + if not self.capable('getflogheads'): + raise error.Abort('configured remotefile server does not ' + 'support getflogheads') + f = wireprotov1peer.future() + yield {'path': path}, f + heads = f.value.split('\n') if f.value else [] + yield heads + + def _updatecallstreamopts(self, command, opts): + if command != 'getbundle': + return + if 'remotefilelog' not in self.capabilities(): + return + if not util.safehasattr(self, '_localrepo'): + return + if constants.REQUIREMENT not in self._localrepo.requirements: + return + + bundlecaps = opts.get('bundlecaps') + if bundlecaps: + bundlecaps = [bundlecaps] + else: + bundlecaps = [] + + # shallow, includepattern, and excludepattern are a hacky way of + # carrying over data from the local repo to this getbundle + # command. We need to do it this way because bundle1 getbundle + # doesn't provide any other place we can hook in to manipulate + # getbundle args before it goes across the wire. Once we get rid + # of bundle1, we can use bundle2's _pullbundle2extraprepare to + # do this more cleanly. + bundlecaps.append('remotefilelog') + if self._localrepo.includepattern: + patterns = '\0'.join(self._localrepo.includepattern) + includecap = "includepattern=" + patterns + bundlecaps.append(includecap) + if self._localrepo.excludepattern: + patterns = '\0'.join(self._localrepo.excludepattern) + excludecap = "excludepattern=" + patterns + bundlecaps.append(excludecap) + opts['bundlecaps'] = ','.join(bundlecaps) + + def _sendrequest(self, command, args, **opts): + self._updatecallstreamopts(command, args) + return super(remotefilepeer, self)._sendrequest(command, args, + **opts) + + def _callstream(self, command, **opts): + supertype = super(remotefilepeer, self) + if not util.safehasattr(supertype, '_sendrequest'): + self._updatecallstreamopts(command, opts) + return super(remotefilepeer, self)._callstream(command, **opts) + + peer.__class__ = remotefilepeer + +class cacheconnection(object): + """The connection for communicating with the remote cache. Performs + gets and sets by communicating with an external process that has the + cache-specific implementation. + """ + def __init__(self): + self.pipeo = self.pipei = self.pipee = None + self.subprocess = None + self.connected = False + + def connect(self, cachecommand): + if self.pipeo: + raise error.Abort(_("cache connection already open")) + self.pipei, self.pipeo, self.pipee, self.subprocess = \ + procutil.popen4(cachecommand) + self.connected = True + + def close(self): + def tryclose(pipe): + try: + pipe.close() + except Exception: + pass + if self.connected: + try: + self.pipei.write("exit\n") + except Exception: + pass + tryclose(self.pipei) + self.pipei = None + tryclose(self.pipeo) + self.pipeo = None + tryclose(self.pipee) + self.pipee = None + try: + # Wait for process to terminate, making sure to avoid deadlock. + # See https://docs.python.org/2/library/subprocess.html for + # warnings about wait() and deadlocking. + self.subprocess.communicate() + except Exception: + pass + self.subprocess = None + self.connected = False + + def request(self, request, flush=True): + if self.connected: + try: + self.pipei.write(request) + if flush: + self.pipei.flush() + except IOError: + self.close() + + def receiveline(self): + if not self.connected: + return None + try: + result = self.pipeo.readline()[:-1] + if not result: + self.close() + except IOError: + self.close() + + return result + +def _getfilesbatch( + remote, receivemissing, progresstick, missed, idmap, batchsize): + # Over http(s), iterbatch is a streamy method and we can start + # looking at results early. This means we send one (potentially + # large) request, but then we show nice progress as we process + # file results, rather than showing chunks of $batchsize in + # progress. + # + # Over ssh, iterbatch isn't streamy because batch() wasn't + # explicitly designed as a streaming method. In the future we + # should probably introduce a streambatch() method upstream and + # use that for this. + with remote.commandexecutor() as e: + futures = [] + for m in missed: + futures.append(e.callcommand('getfile', { + 'file': idmap[m], + 'node': m[-40:] + })) + + for i, m in enumerate(missed): + r = futures[i].result() + futures[i] = None # release memory + file_ = idmap[m] + node = m[-40:] + receivemissing(io.BytesIO('%d\n%s' % (len(r), r)), file_, node) + progresstick() + +def _getfiles_optimistic( + remote, receivemissing, progresstick, missed, idmap, step): + remote._callstream("getfiles") + i = 0 + pipeo = remote._pipeo + pipei = remote._pipei + while i < len(missed): + # issue a batch of requests + start = i + end = min(len(missed), start + step) + i = end + for missingid in missed[start:end]: + # issue new request + versionid = missingid[-40:] + file = idmap[missingid] + sshrequest = "%s%s\n" % (versionid, file) + pipeo.write(sshrequest) + pipeo.flush() + + # receive batch results + for missingid in missed[start:end]: + versionid = missingid[-40:] + file = idmap[missingid] + receivemissing(pipei, file, versionid) + progresstick() + + # End the command + pipeo.write('\n') + pipeo.flush() + +def _getfiles_threaded( + remote, receivemissing, progresstick, missed, idmap, step): + remote._callstream("getfiles") + pipeo = remote._pipeo + pipei = remote._pipei + + def writer(): + for missingid in missed: + versionid = missingid[-40:] + file = idmap[missingid] + sshrequest = "%s%s\n" % (versionid, file) + pipeo.write(sshrequest) + pipeo.flush() + writerthread = threading.Thread(target=writer) + writerthread.daemon = True + writerthread.start() + + for missingid in missed: + versionid = missingid[-40:] + file = idmap[missingid] + receivemissing(pipei, file, versionid) + progresstick() + + writerthread.join() + # End the command + pipeo.write('\n') + pipeo.flush() + +class fileserverclient(object): + """A client for requesting files from the remote file server. + """ + def __init__(self, repo): + ui = repo.ui + self.repo = repo + self.ui = ui + self.cacheprocess = ui.config("remotefilelog", "cacheprocess") + if self.cacheprocess: + self.cacheprocess = util.expandpath(self.cacheprocess) + + # This option causes remotefilelog to pass the full file path to the + # cacheprocess instead of a hashed key. + self.cacheprocesspasspath = ui.configbool( + "remotefilelog", "cacheprocess.includepath") + + self.debugoutput = ui.configbool("remotefilelog", "debug") + + self.remotecache = cacheconnection() + + def setstore(self, datastore, historystore, writedata, writehistory): + self.datastore = datastore + self.historystore = historystore + self.writedata = writedata + self.writehistory = writehistory + + def _connect(self): + return self.repo.connectionpool.get(self.repo.fallbackpath) + + def request(self, fileids): + """Takes a list of filename/node pairs and fetches them from the + server. Files are stored in the local cache. + A list of nodes that the server couldn't find is returned. + If the connection fails, an exception is raised. + """ + if not self.remotecache.connected: + self.connect() + cache = self.remotecache + writedata = self.writedata + + if self.ui.configbool('remotefilelog', 'fetchpacks'): + self.requestpack(fileids) + return + + repo = self.repo + count = len(fileids) + request = "get\n%d\n" % count + idmap = {} + reponame = repo.name + for file, id in fileids: + fullid = getcachekey(reponame, file, id) + if self.cacheprocesspasspath: + request += file + '\0' + request += fullid + "\n" + idmap[fullid] = file + + cache.request(request) + + total = count + self.ui.progress(_downloading, 0, total=count) + + missed = [] + count = 0 + while True: + missingid = cache.receiveline() + if not missingid: + missedset = set(missed) + for missingid in idmap.iterkeys(): + if not missingid in missedset: + missed.append(missingid) + self.ui.warn(_("warning: cache connection closed early - " + + "falling back to server\n")) + break + if missingid == "0": + break + if missingid.startswith("_hits_"): + # receive progress reports + parts = missingid.split("_") + count += int(parts[2]) + self.ui.progress(_downloading, count, total=total) + continue + + missed.append(missingid) + + global fetchmisses + fetchmisses += len(missed) + + count = [total - len(missed)] + fromcache = count[0] + self.ui.progress(_downloading, count[0], total=total) + self.ui.log("remotefilelog", "remote cache hit rate is %r of %r\n", + count[0], total, hit=count[0], total=total) + + oldumask = os.umask(0o002) + try: + # receive cache misses from master + if missed: + def progresstick(): + count[0] += 1 + self.ui.progress(_downloading, count[0], total=total) + # When verbose is true, sshpeer prints 'running ssh...' + # to stdout, which can interfere with some command + # outputs + verbose = self.ui.verbose + self.ui.verbose = False + try: + with self._connect() as conn: + remote = conn.peer + # TODO: deduplicate this with the constant in + # shallowrepo + if remote.capable("remotefilelog"): + if not isinstance(remote, _sshv1peer): + raise error.Abort('remotefilelog requires ssh ' + 'servers') + step = self.ui.configint('remotefilelog', + 'getfilesstep') + getfilestype = self.ui.config('remotefilelog', + 'getfilestype') + if getfilestype == 'threaded': + _getfiles = _getfiles_threaded + else: + _getfiles = _getfiles_optimistic + _getfiles(remote, self.receivemissing, progresstick, + missed, idmap, step) + elif remote.capable("getfile"): + if remote.capable('batch'): + batchdefault = 100 + else: + batchdefault = 10 + batchsize = self.ui.configint( + 'remotefilelog', 'batchsize', batchdefault) + _getfilesbatch( + remote, self.receivemissing, progresstick, + missed, idmap, batchsize) + else: + raise error.Abort("configured remotefilelog server" + " does not support remotefilelog") + + self.ui.log("remotefilefetchlog", + "Success\n", + fetched_files = count[0] - fromcache, + total_to_fetch = total - fromcache) + except Exception: + self.ui.log("remotefilefetchlog", + "Fail\n", + fetched_files = count[0] - fromcache, + total_to_fetch = total - fromcache) + raise + finally: + self.ui.verbose = verbose + # send to memcache + count[0] = len(missed) + request = "set\n%d\n%s\n" % (count[0], "\n".join(missed)) + cache.request(request) + + self.ui.progress(_downloading, None) + + # mark ourselves as a user of this cache + writedata.markrepo(self.repo.path) + finally: + os.umask(oldumask) + + def receivemissing(self, pipe, filename, node): + line = pipe.readline()[:-1] + if not line: + raise error.ResponseError(_("error downloading file contents:"), + _("connection closed early")) + size = int(line) + data = pipe.read(size) + if len(data) != size: + raise error.ResponseError(_("error downloading file contents:"), + _("only received %s of %s bytes") + % (len(data), size)) + + self.writedata.addremotefilelognode(filename, bin(node), + lz4wrapper.lz4decompress(data)) + + def requestpack(self, fileids): + """Requests the given file revisions from the server in a pack format. + + See `remotefilelogserver.getpack` for the file format. + """ + try: + with self._connect() as conn: + total = len(fileids) + rcvd = 0 + + remote = conn.peer + remote._callstream("getpackv1") + + self._sendpackrequest(remote, fileids) + + packpath = shallowutil.getcachepackpath( + self.repo, constants.FILEPACK_CATEGORY) + pipei = remote._pipei + receiveddata, receivedhistory = wirepack.receivepack( + self.repo.ui, pipei, packpath) + rcvd = len(receiveddata) + + self.ui.log("remotefilefetchlog", + "Success(pack)\n" if (rcvd==total) else "Fail(pack)\n", + fetched_files = rcvd, + total_to_fetch = total) + except Exception: + self.ui.log("remotefilefetchlog", + "Fail(pack)\n", + fetched_files = rcvd, + total_to_fetch = total) + raise + + def _sendpackrequest(self, remote, fileids): + """Formats and writes the given fileids to the remote as part of a + getpackv1 call. + """ + # Sort the requests by name, so we receive requests in batches by name + grouped = {} + for filename, node in fileids: + grouped.setdefault(filename, set()).add(node) + + # Issue request + pipeo = remote._pipeo + for filename, nodes in grouped.iteritems(): + filenamelen = struct.pack(constants.FILENAMESTRUCT, len(filename)) + countlen = struct.pack(constants.PACKREQUESTCOUNTSTRUCT, len(nodes)) + rawnodes = ''.join(bin(n) for n in nodes) + + pipeo.write('%s%s%s%s' % (filenamelen, filename, countlen, + rawnodes)) + pipeo.flush() + pipeo.write(struct.pack(constants.FILENAMESTRUCT, 0)) + pipeo.flush() + + def connect(self): + if self.cacheprocess: + cmd = "%s %s" % (self.cacheprocess, self.writedata._path) + self.remotecache.connect(cmd) + else: + # If no cache process is specified, we fake one that always + # returns cache misses. This enables tests to run easily + # and may eventually allow us to be a drop in replacement + # for the largefiles extension. + class simplecache(object): + def __init__(self): + self.missingids = [] + self.connected = True + + def close(self): + pass + + def request(self, value, flush=True): + lines = value.split("\n") + if lines[0] != "get": + return + self.missingids = lines[2:-1] + self.missingids.append('0') + + def receiveline(self): + if len(self.missingids) > 0: + return self.missingids.pop(0) + return None + + self.remotecache = simplecache() + + def close(self): + if fetches: + msg = ("%s files fetched over %d fetches - " + + "(%d misses, %0.2f%% hit ratio) over %0.2fs\n") % ( + fetched, + fetches, + fetchmisses, + float(fetched - fetchmisses) / float(fetched) * 100.0, + fetchcost) + if self.debugoutput: + self.ui.warn(msg) + self.ui.log("remotefilelog.prefetch", msg.replace("%", "%%"), + remotefilelogfetched=fetched, + remotefilelogfetches=fetches, + remotefilelogfetchmisses=fetchmisses, + remotefilelogfetchtime=fetchcost * 1000) + + if self.remotecache.connected: + self.remotecache.close() + + def prefetch(self, fileids, force=False, fetchdata=True, + fetchhistory=False): + """downloads the given file versions to the cache + """ + repo = self.repo + idstocheck = [] + for file, id in fileids: + # hack + # - we don't use .hgtags + # - workingctx produces ids with length 42, + # which we skip since they aren't in any cache + if (file == '.hgtags' or len(id) == 42 + or not repo.shallowmatch(file)): + continue + + idstocheck.append((file, bin(id))) + + datastore = self.datastore + historystore = self.historystore + if force: + datastore = contentstore.unioncontentstore(*repo.shareddatastores) + historystore = metadatastore.unionmetadatastore( + *repo.sharedhistorystores) + + missingids = set() + if fetchdata: + missingids.update(datastore.getmissing(idstocheck)) + if fetchhistory: + missingids.update(historystore.getmissing(idstocheck)) + + # partition missing nodes into nullid and not-nullid so we can + # warn about this filtering potentially shadowing bugs. + nullids = len([None for unused, id in missingids if id == nullid]) + if nullids: + missingids = [(f, id) for f, id in missingids if id != nullid] + repo.ui.develwarn( + ('remotefilelog not fetching %d null revs' + ' - this is likely hiding bugs' % nullids), + config='remotefilelog-ext') + if missingids: + global fetches, fetched, fetchcost + fetches += 1 + + # We want to be able to detect excess individual file downloads, so + # let's log that information for debugging. + if fetches >= 15 and fetches < 18: + if fetches == 15: + fetchwarning = self.ui.config('remotefilelog', + 'fetchwarning') + if fetchwarning: + self.ui.warn(fetchwarning + '\n') + self.logstacktrace() + missingids = [(file, hex(id)) for file, id in missingids] + fetched += len(missingids) + start = time.time() + missingids = self.request(missingids) + if missingids: + raise error.Abort(_("unable to download %d files") % + len(missingids)) + fetchcost += time.time() - start + self._lfsprefetch(fileids) + + def _lfsprefetch(self, fileids): + if not _lfsmod or not util.safehasattr( + self.repo.svfs, 'lfslocalblobstore'): + return + if not _lfsmod.wrapper.candownload(self.repo): + return + pointers = [] + store = self.repo.svfs.lfslocalblobstore + for file, id in fileids: + node = bin(id) + rlog = self.repo.file(file) + if rlog.flags(node) & revlog.REVIDX_EXTSTORED: + text = rlog.revision(node, raw=True) + p = _lfsmod.pointer.deserialize(text) + oid = p.oid() + if not store.has(oid): + pointers.append(p) + if len(pointers) > 0: + self.repo.svfs.lfsremoteblobstore.readbatch(pointers, store) + assert all(store.has(p.oid()) for p in pointers) + + def logstacktrace(self): + import traceback + self.ui.log('remotefilelog', 'excess remotefilelog fetching:\n%s\n', + ''.join(traceback.format_stack()))