gitMergeCommon.pyon commit [PATCH] Introduce a 'die' function. (654291a)
   1import sys, re, os, traceback
   2from sets import Set
   3
   4if sys.version_info[0] < 2 or \
   5       (sys.version_info[0] == 2 and sys.version_info[1] < 4):
   6    print 'Python version 2.4 required, found', \
   7          str(sys.version_info[0])+'.'+str(sys.version_info[1])+'.'+ \
   8          str(sys.version_info[2])
   9    sys.exit(1)
  10
  11import subprocess
  12
  13def die(*args):
  14    printList(args, sys.stderr)
  15    sys.exit(2)
  16
  17# Debugging machinery
  18# -------------------
  19
  20DEBUG = 0
  21functionsToDebug = Set()
  22
  23def addDebug(func):
  24    if type(func) == str:
  25        functionsToDebug.add(func)
  26    else:
  27        functionsToDebug.add(func.func_name)
  28
  29def debug(*args):
  30    if DEBUG:
  31        funcName = traceback.extract_stack()[-2][2]
  32        if funcName in functionsToDebug:
  33            printList(args)
  34
  35def printList(list, file=sys.stdout):
  36    for x in list:
  37        file.write(str(x))
  38        file.write(' ')
  39    file.write('\n')
  40
  41# Program execution
  42# -----------------
  43
  44class ProgramError(Exception):
  45    def __init__(self, progStr, error):
  46        self.progStr = progStr
  47        self.error = error
  48
  49addDebug('runProgram')
  50def runProgram(prog, input=None, returnCode=False, env=None, pipeOutput=True):
  51    debug('runProgram prog:', str(prog), 'input:', str(input))
  52    if type(prog) is str:
  53        progStr = prog
  54    else:
  55        progStr = ' '.join(prog)
  56    
  57    try:
  58        if pipeOutput:
  59            stderr = subprocess.STDOUT
  60            stdout = subprocess.PIPE
  61        else:
  62            stderr = None
  63            stdout = None
  64        pop = subprocess.Popen(prog,
  65                               shell = type(prog) is str,
  66                               stderr=stderr,
  67                               stdout=stdout,
  68                               stdin=subprocess.PIPE,
  69                               env=env)
  70    except OSError, e:
  71        debug('strerror:', e.strerror)
  72        raise ProgramError(progStr, e.strerror)
  73
  74    if input != None:
  75        pop.stdin.write(input)
  76    pop.stdin.close()
  77
  78    if pipeOutput:
  79        out = pop.stdout.read()
  80    else:
  81        out = ''
  82
  83    code = pop.wait()
  84    if returnCode:
  85        ret = [out, code]
  86    else:
  87        ret = out
  88    if code != 0 and not returnCode:
  89        debug('error output:', out)
  90        debug('prog:', prog)
  91        raise ProgramError(progStr, out)
  92#    debug('output:', out.replace('\0', '\n'))
  93    return ret
  94
  95# Code for computing common ancestors
  96# -----------------------------------
  97
  98currentId = 0
  99def getUniqueId():
 100    global currentId
 101    currentId += 1
 102    return currentId
 103
 104# The 'virtual' commit objects have SHAs which are integers
 105shaRE = re.compile('^[0-9a-f]{40}$')
 106def isSha(obj):
 107    return (type(obj) is str and bool(shaRE.match(obj))) or \
 108           (type(obj) is int and obj >= 1)
 109
 110class Commit:
 111    def __init__(self, sha, parents, tree=None):
 112        self.parents = parents
 113        self.firstLineMsg = None
 114        self.children = []
 115
 116        if tree:
 117            tree = tree.rstrip()
 118            assert(isSha(tree))
 119        self._tree = tree
 120
 121        if not sha:
 122            self.sha = getUniqueId()
 123            self.virtual = True
 124            self.firstLineMsg = 'virtual commit'
 125            assert(isSha(tree))
 126        else:
 127            self.virtual = False
 128            self.sha = sha.rstrip()
 129        assert(isSha(self.sha))
 130
 131    def tree(self):
 132        self.getInfo()
 133        assert(self._tree != None)
 134        return self._tree
 135
 136    def shortInfo(self):
 137        self.getInfo()
 138        return str(self.sha) + ' ' + self.firstLineMsg
 139
 140    def __str__(self):
 141        return self.shortInfo()
 142
 143    def getInfo(self):
 144        if self.virtual or self.firstLineMsg != None:
 145            return
 146        else:
 147            info = runProgram(['git-cat-file', 'commit', self.sha])
 148            info = info.split('\n')
 149            msg = False
 150            for l in info:
 151                if msg:
 152                    self.firstLineMsg = l
 153                    break
 154                else:
 155                    if l.startswith('tree'):
 156                        self._tree = l[5:].rstrip()
 157                    elif l == '':
 158                        msg = True
 159
 160class Graph:
 161    def __init__(self):
 162        self.commits = []
 163        self.shaMap = {}
 164
 165    def addNode(self, node):
 166        assert(isinstance(node, Commit))
 167        self.shaMap[node.sha] = node
 168        self.commits.append(node)
 169        for p in node.parents:
 170            p.children.append(node)
 171        return node
 172
 173    def reachableNodes(self, n1, n2):
 174        res = {}
 175        def traverse(n):
 176            res[n] = True
 177            for p in n.parents:
 178                traverse(p)
 179
 180        traverse(n1)
 181        traverse(n2)
 182        return res
 183
 184    def fixParents(self, node):
 185        for x in range(0, len(node.parents)):
 186            node.parents[x] = self.shaMap[node.parents[x]]
 187
 188# addDebug('buildGraph')
 189def buildGraph(heads):
 190    debug('buildGraph heads:', heads)
 191    for h in heads:
 192        assert(isSha(h))
 193
 194    g = Graph()
 195
 196    out = runProgram(['git-rev-list', '--parents'] + heads)
 197    for l in out.split('\n'):
 198        if l == '':
 199            continue
 200        shas = l.split(' ')
 201
 202        # This is a hack, we temporarily use the 'parents' attribute
 203        # to contain a list of SHA1:s. They are later replaced by proper
 204        # Commit objects.
 205        c = Commit(shas[0], shas[1:])
 206
 207        g.commits.append(c)
 208        g.shaMap[c.sha] = c
 209
 210    for c in g.commits:
 211        g.fixParents(c)
 212
 213    for c in g.commits:
 214        for p in c.parents:
 215            p.children.append(c)
 216    return g
 217
 218# Write the empty tree to the object database and return its SHA1
 219def writeEmptyTree():
 220    tmpIndex = os.environ['GIT_DIR'] + '/merge-tmp-index'
 221    def delTmpIndex():
 222        try:
 223            os.unlink(tmpIndex)
 224        except OSError:
 225            pass
 226    delTmpIndex()
 227    newEnv = os.environ.copy()
 228    newEnv['GIT_INDEX_FILE'] = tmpIndex
 229    res = runProgram(['git-write-tree'], env=newEnv).rstrip()
 230    delTmpIndex()
 231    return res
 232
 233def addCommonRoot(graph):
 234    roots = []
 235    for c in graph.commits:
 236        if len(c.parents) == 0:
 237            roots.append(c)
 238
 239    superRoot = Commit(sha=None, parents=[], tree=writeEmptyTree())
 240    graph.addNode(superRoot)
 241    for r in roots:
 242        r.parents = [superRoot]
 243    superRoot.children = roots
 244    return superRoot
 245
 246def getCommonAncestors(graph, commit1, commit2):
 247    '''Find the common ancestors for commit1 and commit2'''
 248    assert(isinstance(commit1, Commit) and isinstance(commit2, Commit))
 249
 250    def traverse(start, set):
 251        stack = [start]
 252        while len(stack) > 0:
 253            el = stack.pop()
 254            set.add(el)
 255            for p in el.parents:
 256                if p not in set:
 257                    stack.append(p)
 258    h1Set = Set()
 259    h2Set = Set()
 260    traverse(commit1, h1Set)
 261    traverse(commit2, h2Set)
 262    shared = h1Set.intersection(h2Set)
 263
 264    if len(shared) == 0:
 265        shared = [addCommonRoot(graph)]
 266        
 267    res = Set()
 268
 269    for s in shared:
 270        if len([c for c in s.children if c in shared]) == 0:
 271            res.add(s)
 272    return list(res)