#!/usr/bin/python

# two_of_three.py
#       --copyright--                   Copyright 2010 (C) Tranzoa, Co. All rights reserved.    Warranty: You're free and on your own here. This code is not necessarily up-to-date or of public quality.
#       --url--                         http://www.tranzoa.net/tzpython/
#       --email--                       pycode is the name to send to. tranzoa.com is the place to send to.
#       --bodstamps--
#       May 1, 2010             bar
#       May 2, 2010             bar
#       May 4, 2010             experiment with M of N
#       May 7, 2010             optimize some find stuff and do 5/4
#       November 29, 2011       bar     pyflake cleanup
#       --eodstamps--
##      \file
#
#
#       Split a file in to 3 files, each a half of the file's size.
#
#       Any two of the 3 files can be combined to rebuild the original file.
#
#       The split files are each half the original data file length (plus a couple bytes).
#       The idea is that this logic encodes the original data using parity bits.
#           During decoding, the parity bits are combined and the original byte's data is looked up.
#           There is a separate lookup table (in essence) for each of the 6 possible combinations of pairs of the 3 split files.
#           A leading byte decodes uniquely for the proper combination so the logic can auto-handle any pairs of split files without any outside knowledge.
#           The partity bit masks are arbitrary, though, they were figured out by the "find_em()" logic.
#           There are plenty of other parity masks that would work.
#
#       The split files contain leading nibbles that makes the leading byte.
#           The leading byte is not data.
#           It is used to identify which of the two-of-three splits the split data is coming from.
#           It's value is computed at startup as a function of the parity mask that this logic uses.
#           It's a value that only comes out right when the splits are joined correctly.
#
#       The split files contain trailing nibble(s) that create 1 or two bytes. If the last byte is a \1 then the last two bytes are tossed. If a zero, then just the zero is tossed.
#
#       Note:
#
#           M of N would need non-table-lookup decoding if N*M were too big for a table size. See below for the logic. It would still be very fast.
#
#


import  array
import  os
import  sys

import  replace_file


BIT_FIDDLE_LIMIT    = 0x10000

BUFFER_SIZE = 1024 * 1024
BUFFER_SIZE = BUFFER_SIZE & ~1                  # here for enforcement and reminder

SPLIT_EXT   = ".2o3"

BITS        = 8
BITS_BY_TWO = BITS / 2
LIM         = 1 << BITS



def bit_cnt(n) :
    cnt     =  0
    while n :
        cnt += 1
        n  &=  (n - 1)
    return(cnt)
bit_cnts    = [ bit_cnt(n)  for n in range(BIT_FIDDLE_LIMIT) ]
parity      = [ (n & 1)     for n in bit_cnts   ]


def mask_parity(n, mask) :
    p   = 0
    for i in xrange(len(mask) - 1, -1, -1) :
        p <<= 1
        p  |= (parity[n & mask[i]])

    return(p)


def masks_str(masks) :
    return(str([ [ ('0x%02x' % n) for n in a ] for a in masks ]).replace("'", ""))


PAIRS       = [ [ 0, 1 ], [ 0, 2 ], [ 1, 2 ], [ 1, 0 ], [ 2, 0 ], [ 2, 1 ] ]

P_TO_VS     = [ [ [ -1, ], [0]*LIM, [0]*LIM ], [ [0]*LIM, [ -1, ], [0]*LIM ], [ [0]*LIM, [0]*LIM, [ -1, ] ], ]
VERSION_B   = -1

def init_em() :
    global  VERSION_B

    m0  = a_two_of_three(0)
    m1  = a_two_of_three(1)
    m2  = a_two_of_three(2)
    m0.p_his    = [ p << BITS_BY_TWO for p in m0.p_los ]
    m1.p_his    = [ p << BITS_BY_TWO for p in m1.p_los ]
    m2.p_his    = [ p << BITS_BY_TWO for p in m2.p_los ]

    na  = range(LIM)

    for n in na :
        P_TO_VS[0][1][m0.p_his[n] | m1.p_los[n]]    = n
        P_TO_VS[0][2][m0.p_his[n] | m2.p_los[n]]    = n
        P_TO_VS[1][2][m1.p_his[n] | m2.p_los[n]]    = n

        P_TO_VS[1][0][m1.p_los[n] | m0.p_his[n]]    = n

        P_TO_VS[2][0][m2.p_los[n] | m0.p_his[n]]    = n
        P_TO_VS[2][1][m2.p_los[n] | m1.p_his[n]]    = n

    ma          = [ a_two_of_three(0), a_two_of_three(1), a_two_of_three(2) ]
    for b in xrange(LIM) :
        bb      = [ b, b ]
        aa      = [ m.encode(bb) for m in ma ]

        if  ma[0].decode(aa[0], ma[1], aa[1]) != bb :
            raise(ValueError("Ooops! Bug 0-1!"))
        if  ma[0].decode(aa[0], ma[2], aa[2]) != bb :
            raise(ValueError("Ooops! Bug 0-2!"))
        if  ma[1].decode(aa[1], ma[2], aa[2]) != bb :
            raise(ValueError("Ooops! Bug 1-2!"))


        bad = False
        for pair in PAIRS :
            a0  = ma[pair[0]].encode(bb)
            a1  = ma[pair[1]].encode(bb)
            for opair in PAIRS :
                if  opair != pair :
                    if  ma[opair[0]].decode(a0, ma[opair[1]], a1) == bb :
                        # print b, pair, opair
                        bad = True
                        break
                    pass
                pass
            if  bad :
                break
            pass
        if  not bad :
            VERSION_B   = b
            break
        pass

    if  VERSION_B < 0 :
        raise(ValueError("Cannot find a suitable VERSION_B!"))

    # print "VERSION_B", VERSION_B
    pass


class   a_two_of_three(object) :

    def __init__(me, which) :

        me.which    = which
        me.masks    = [ [
                            0xe8,
                            0xd4,
                            0xd8,
                            0xa9,
                        ],
                        [
                            0x65,
                            0x96,
                            0x15,
                            0x2e,
                        ],
                        [
                            0x6a,
                            0x9a,
                            0x89,
                            0x78,
                        ]
                      ][which]
        me.p_los    = [ mask_parity(n, me.masks) for n in xrange(LIM) ]


    def encode(me, bytes_or_string) :
        if  isinstance(bytes_or_string, basestring) :
            bytes_or_string   = array.array('B', bytes_or_string)

        return([ me.p_los[n] for n in bytes_or_string ])


    def encode_to_string(me, bytes_or_string) :
        if  len(bytes_or_string) & 1 :
            raise(IndexError("Encoded data must be an even number of bytes!"))

        a   = me.encode(bytes_or_string)
        a   = "".join([ chr((a[n] << 4) | a[n + 1]) for n in xrange(0, len(a), 2) ])

        return(a)


    def decode(me, my_nibbles, om, om_nibbles) :
        p2va    = P_TO_VS[me.which][om.which]
        los     = ((me.which < om.which) and om_nibbles) or my_nibbles
        his     = ((me.which > om.which) and om_nibbles) or my_nibbles
        return([ p2va[(his[n] << BITS_BY_TWO) | los[n]] for n in xrange(len(my_nibbles)) ])


    def decode_bytes(me, my_bytes, om, om_bytes) :
        msk = (1 << BITS_BY_TWO) - 1
        myn = []
        for b in my_bytes :
            myn.append((b >> BITS_BY_TWO) & msk)
            myn.append( b                 & msk)
        omn = []
        for b in om_bytes :
            omn.append((b >> BITS_BY_TWO) & msk)
            omn.append( b                 & msk)

        return(me.decode(myn, om, omn))


    def decode_strings(me, my_str, om, om_str) :
        return(me.decode_bytes([ ord(c) for c in my_str ], om, [ ord(c) for c in om_str ]))


    @staticmethod
    def decoded_string(ba) :
        return("".join([ chr(c) for c in ba ]))


    pass        #   a_two_of_three


init_em()



def hi_bit(n) :
    b   = 0
    while (n) :
        b  += 1
        n   = n >> 1
    return(b)

hi_bits = [ hi_bit(n) for n in range(BIT_FIDDLE_LIMIT) ]



def find_em() :
    import  copy
    import  itertools
    import  random
    import  time

    bt  = time.time()

    dbs = [ [
                0x08,
                0x04,
                0x42,
                0x20,
            ],
            [
                0x01,
                0x02,
                0x80,
                0x10,
            ],
            [
                0xa0,
                0x05,
                0x40,
                0x18,
            ]
          ]

    dbs = [ [
                0xbc,
                0xba,
                0x65,
                0x9d,
            ],
            [
                0xd7,
                0x66,
                0x28,
                0x13,
            ],
            [
                0x2a,
                0x98,
                0x62,
                0x7c,
            ]
          ]
    dbs = [ [
                0x94,
                0xaa,
                0x65,
                0x9d,
            ],
            [
                0xd7,
                0x66,
                0x38,
                0x13,
            ],
            [
                0x6a,
                0x98,
                0x62,
                0x7e,
            ]
          ]

    dbs = [ [
                0xe8,
                0xd4,
                0xd8,
                0xa9,
            ],
            [
                0x65,
                0x96,
                0x15,
                0x2e,
            ],
            [
                0x6a,
                0x9a,
                0x89,
                0x78,
            ]
          ]

    dbs     = [[0x80, 0x40, 0x0a, 0x18], [0x10, 0x0c, 0x20, 0x01], [0x82, 0x22, 0x04, 0x51], [0x60, 0x02, 0x8c, 0x05]]          # 4/2
    dbs     = [[0x400, 0x240, 0x802, 0xa0], [0x404, 0x300, 0x10, 0x02], [0x08, 0x30, 0x100, 0x801], [0x0c, 0x80, 0x01, 0x40]]   # 4/3
    dbs     = [[0x208, 0x27, 0x111, 0x143, 0x104], [0x40, 0x182, 0x183, 0x39a, 0x204], [0x302, 0x81, 0x64, 0x140, 0x91], [0x8a, 0x144, 0x61, 0x206, 0x18], [0x08, 0x42, 0x21, 0x04, 0x83]]                          # 5/2
    dbs     = [[0x18, 0x2081, 0x600, 0x120, 0x4080], [0x1000, 0x48, 0x4004, 0x900, 0x21], [0x100, 0x1400, 0x4040, 0x2000, 0x12], [0x42, 0x2200, 0x1004, 0x880, 0x83], [0x208, 0x14, 0x80, 0x802, 0x420]]            # 5/3
    dbs     = [[0x24000, 0x08, 0x9000, 0x400, 0x04], [0x10000, 0x40100, 0x40, 0x20, 0x81], [0x800, 0x01, 0xa0000, 0x8200, 0x402], [0x80000, 0x12, 0x200, 0x40800, 0x2004], [0x30, 0x108, 0x2080, 0x4040, 0x11000]]  # 5/4
    dbs     = [[0xa8, 0x50, 0x05, 0x68c, 0xa1, 0x812], [0x440, 0x282, 0x833, 0x644, 0x29, 0x910], [0x103, 0x186, 0xa80, 0x8c5, 0x888, 0x640], [0x110, 0x421, 0x07, 0x08, 0x850, 0x8c0], [0x112, 0x20, 0x69c, 0x454, 0x1a, 0x04], [0x804, 0x234, 0x25, 0x02, 0x904, 0x400]]  # 6/2

    N       = 6
    M       = 3

    bits    = N * M

    dbs     = [ [ 0 for j in range(N) ] for i in range(N) ]

    # abits = bits / 2.5                    # a possible goal for the number of set bits in the parity masks

    do_more = False


    def do_it(dbs, ii, show = 0) :

        mc  = len(dbs[ii[0]])

        if  show :
            print "do_it %s" % str(ii)

        vs  = {}
        bc  = 0

        for n in range(1 << (mc * M)) :
            v   = 0
            for i  in ii :
                v <<= mc
                v  |= mask_parity(n, dbs[i])
            vv  = vs.get(v, [ n, 0 ])
            vv[1]      += 1
            vs[v]       = vv
            if  vs[v][1] > 1 :
                if  show > 1 :
                    print "Ambiguous %2d %02x (%2d %02x) %02x %02d" % ( n, n, vs[v][0], vs[v][0], v, vs[v][1], )
                pass
            pass

        bc      = sum([ hi_bits[vs[v][1] - 1] for v in vs.keys() ])

        if  bc and show :
            print "bad count bits needed", bc
            print

        return(bc)


    def unk_bits(masks)             :
        """ This routine simulates decoding, in effect. That is, is goes through the logic that non-table-lookup decoding must do. Would need to be used if the lookup tables are too big. """

        masks                       = list(masks)           # the masks are all the parity masks for the M "sites" - this routine wants to reduce them each to be for 1 bit of original data
        while True                  :
            fnd                     = False
            iia                     = itertools.combinations(range(len(masks)), 2)
            while True              :
                f                   = False
                for ii in iia       :
                    i               = ii[0]
                    j               = ii[1]
                    mi              = masks[i]
                    mj              = masks[j]
                    m               = mi ^ mj
                    mc              = bit_cnts[m]
                    ff              = False
                    if  bit_cnts[mi] > mc   :
                        masks[i]    = m                     # the masks together are an improvement over the 1st mask, so replace the 1st mask with the better version
                        ff          = True
                    elif bit_cnts[mj] > mc  :
                        masks[j]    = m                     # the masks together are an improvement over the 2nd mask, so replace the 2nd mask with the better version
                        ff          = True
                    if  ff :
                        f           = True
                        fnd         = True
                    pass
                if  not f           :
                    break
                pass

            if  do_more and (len(masks) > 3) :
                iia                     = itertools.combinations(range(len(masks)), 3)
                while True              :
                    f                   = False
                    for ii in iia       :
                        i               = ii[0]
                        j               = ii[1]
                        k               = ii[2]
                        mi              = masks[i]
                        mj              = masks[j]
                        mk              = masks[k]
                        m               = mi ^ mj ^ mk
                        mc              = bit_cnts[m]
                        if  (m & mi) and bit_cnts[mi] > mc  :
                            masks[i]    = m                     # the masks together are an improvement over the 1st mask, so replace the 1st mask with the better version
                            f           = True
                            fnd         = True
                        elif (m & mj) and bit_cnts[mj] > mc :
                            masks[j]    = m                     # the masks together are an improvement over the 2nd mask, so replace the 2nd mask with the better version
                            f           = True
                            fnd         = True
                        elif (m & mk) and bit_cnts[mk] > mc :
                            masks[k]    = m                     # the masks together are an improvement over the 3rd mask, so replace the 3rd mask with the better version
                            f           = True
                            fnd         = True
                        pass
                    if  not f           :
                        break
                    pass
                pass

            if  not fnd :
                break
            pass

        return(masks)


    def count_unk_bits(masks) :
        masks   = unk_bits(masks)

        return(sum([ bit_cnts[m or (BIT_FIDDLE_LIMIT - 1)] for m in masks ]) - len(masks))


    if  False :
        ms  = []
        for m in ms :
            print count_unk_bits(m), m
        pass


    def try_it(dbs, ii, show = 0) :
        masks   = []
        for i in ii :
            masks  += dbs[i]

        return(count_unk_bits(masks))


    def get_bad_masks(dbs) :
        iia     = list(itertools.combinations(range(len(dbs)), M))
        bmsks   = [ [ 0, ] * len(dbs[0]) for ii in range(len(dbs)) ]

        for ii in iia :
            masks   = []
            for i in ii :
                masks  += dbs[i]
            masks       = unk_bits(masks)
            for i in xrange(len(masks)) :
                pass
            bmsks       = [ masks[i] | bmsks[i] for i in range(len(masks)) ]

        return(bmsks)



    def eval_dbs(dbs) :
        ba  = [ try_it(dbs, ii) for ii in list(itertools.combinations(range(len(dbs)), M)) ]
        # bc  = (min(ba) * 10000000) + (sum(ba) * 256) + sum([ sum([ abs(abits - bit_cnts[m]) for m in ma ]) for ma in dbs ])
        bc  = (min(ba) * 100000000000) + (sum(ba) * 100000) + sum([ sum([             bit_cnts[m] * bit_cnts[m]  for m in ma ]) for ma in dbs ])
        return(bc)


    def flip_best_bit(dbs) :
        bba = dbs
        bbc = 100000000000000

        ia  = range(len(dbs))
        random.shuffle(ia)
        ja  = range(len(dbs[0]))
        random.shuffle(ja)
        ka  = range(bits)
        random.shuffle(ka)

        for i in ia :
            for j in ja :
                for k in ka :

                    v   = dbs[i][j] ^ (1 << k)
                    if  v :
                        dbs[i][j]   = v

                        bc  = eval_dbs(dbs)
                        if  bbc > bc :
                            bbc = bc
                            bba = copy.deepcopy(dbs)
                            # print bc, i, j, k, ia, ja, ka, masks_str(dbs)
                        elif not bc :
                            # print i, j, masks_str(dbs)
                            pass
                        dbs[i][j]  ^= (1 << k)
                    pass
                pass
            pass

        return(bba, bbc)


    def ding(dbs, c) :
        for n in range(c) :
            i           = random.randint(0, len(dbs   ) - 1)
            j           = random.randint(0, len(dbs[0]) - 1)
            k           = random.randint(0, bits - 1)
            dbs[i][j]  ^= (1 << k)
        pass


    if  True :
        bc  = 0
        for ii in itertools.combinations(range(len(dbs)), M) :
            bc += try_it(dbs, ii, show = 1)
        print "starting bc", bc

    bbba    = copy.deepcopy(dbs)
    bbbc    = eval_dbs(bbba)
    print "starting best", bbbc

    start_time  = time.time()
    bt          = start_time
    best_time   = 7
    see_time    = start_time
    bcnt        = 0
    while True :

        bba = copy.deepcopy(dbs)
        bbc = eval_dbs(bba)

        bc                  = bbc
        while True :
            ( nba, nbc )    = flip_best_bit(dbs)
            if  nbc        >= bc :
                break
            dbs             = nba
            bc              = nbc
            if  bc  < bbbc  :
                bt          = time.time()
                best_time   = bt - start_time
                start_time  = bt
                bbba        = copy.deepcopy(dbs)
                bbbc        = bc
                bcnt        = 0
                print "best", bbbc, masks_str(bbba)
            if  time.time() - see_time > 15 :
                see_time    = time.time()
                print "----------- best", bbbc, masks_str(bbba)
                bcnt       += 1
                if  bcnt    > 2 :
                    do_more = True
                    print "Doing more"
                pass
            pass


        if  bc  < bbbc      :
            bt              = time.time()
            best_time       = bt - start_time
            start_time      = bt
            bbba            = copy.deepcopy(dbs)
            bbbc            = bc
            bcnt            = 0
            print "best", bbbc, masks_str(bbba)

            ding(dbs, 5)

        else                :
            if  bc < bbc    :
                print bc, masks_str(dbs)
                pass

            ok      = False
            for i in range(len(dbs)):
                vls     = {}
                for j in range(len(dbs[0])) :
                    v               = dbs[i][j]
                    while True      :
                        if  v and (v not in vls) :
                            break
                        v           = random.randint(1, (1 << bits) - 1)
                        ok          = True

                    dbs[i][j]       = v
                    vls[v]          = True
                pass

            if  (not ok) or (not random.randint(0, 5)) :
                ding(dbs, len(dbs) * len(dbs[0]) * len(dbs[0]))

            if  (time.time() - bt > best_time * 2) or (bcnt > 11) :
                bcnt            = 0
                best_time      *= 1.5
                bt              = time.time()
                print "bbc+", bbc
                ding(dbs, int(len(dbs) * len(dbs) * len(dbs) * len(dbs[0]) / 7))
            pass
        pass
    pass


def test_em() :
    import  random

    ex  = 0

    m0  = a_two_of_three(0)
    m1  = a_two_of_three(1)
    m2  = a_two_of_three(2)

    da  = range(LIM)
    random.shuffle(da)

    m0.out  = m0.encode(da)
    m1.out  = m1.encode(da)
    m2.out  = m2.encode(da)

    r       = m0.decode(m0.out, m1, m1.out)
    if  r != da :
        print "0 1", r, da
        ex  = 101

    r       = m1.decode(m1.out, m2, m2.out)
    if  r != da :
        print "1 2", r, da
        ex  = 112

    r       = m0.decode(m0.out, m1, m1.out)
    if  r != da :
        print "0 2", r, da
        ex  = 102

    r       = m1.decode(m1.out, m0, m0.out)
    if  r != da :
        print "1 0", r, da
        ex  = 110

    r       = m2.decode(m2.out, m0, m0.out)
    if  r != da :
        print "2 0", r, da
        ex  = 120

    r       = m2.decode(m2.out, m1, m1.out)
    if  r != da :
        print "2 1", r, da
        ex  = 121

    if  not ex :
        print "Test OK."

    sys.exit(ex)


def split_file(ifn, of0, of1, of2, buff_size = None) :
    buff_size   = buff_size or BUFFER_SIZE

    iip = os.path.abspath(ifn)
    oip = os.path.abspath(of0)
    if  iip == oip :
        raise(ValueError("Splitting input to same output [%s]==[%s]!" % ( iip, oip ) ))
    oip = os.path.abspath(of1)
    if  iip == oip :
        raise(ValueError("Splitting input to same output [%s]==[%s]!" % ( iip, oip ) ))
    oip = os.path.abspath(of2)
    if  iip == oip :
        raise(ValueError("Splitting input to same output [%s]==[%s]!" % ( iip, oip ) ))

    m0  = a_two_of_three(0)
    m1  = a_two_of_three(1)
    m2  = a_two_of_three(2)

    fi  = open(ifn, "rb")

    tn0 = of0 + ".tmp"
    tf0 = open(tn0, "wb")
    tn1 = of1 + ".tmp"
    tf1 = open(tn1, "wb")
    tn2 = of2 + ".tmp"
    tf2 = open(tn2, "wb")

    buf = fi.read(buff_size - 1)
    eof = True
    if  buf :
        buf = chr(VERSION_B) + buf;                 # put a zero at the front, so that the decoder can figure out a legit way to disambiguate the split files (and insures that the decoder is compatible with the encoder)
        while buf :
            if  len(buf) & 1 :
                eof         = False
                buf        += '\0'                  # even out the buffer length with an indicator to not whack the last full byte decoded
                buff_size  += 1                     # and make us break out of the loop
            a   = array.array('B', buf)
            tf0.write(m0.encode_to_string(a))
            tf1.write(m1.encode_to_string(a))
            tf2.write(m2.encode_to_string(a))
            if  len(buf)    < buff_size :
                break
            buf = fi.read(buff_size)

        if  eof :
            a   = array.array('B', '\1\1')          # we've not terminated the output, so we'll need a full byte at the end (either nibble telling the decoder to lop off a byte from the decoded output)
            tf0.write(m0.encode_to_string(a))
            tf1.write(m1.encode_to_string(a))
            tf2.write(m2.encode_to_string(a))
        pass

    fi.close()

    tf0.close()
    tf1.close()
    tf2.close()
    replace_file.replace_file(of0, tn0, of0 + ".bak")
    replace_file.replace_file(of1, tn1, of1 + ".bak")
    replace_file.replace_file(of2, tn2, of2 + ".bak")


def join_file(if0, if1, ofn, buff_size = None) :
    buff_size   = buff_size or BUFFER_SIZE

    oip = os.path.abspath(ofn)
    iip = os.path.abspath(if0)
    if  iip == oip :
        raise(ValueError("Joining input to same output [%s]==[%s]!" % ( iip, oip ) ))
    iip = os.path.abspath(if1)
    if  iip == oip :
        raise(ValueError("Joining input to same output [%s]==[%s]!" % ( iip, oip ) ))

    sz0 = os.path.getsize(if0)
    sz1 = os.path.getsize(if1)
    if  sz0 != sz1 :
        raise(ValueError("Input files not the same size [%u != %u]!" % ( sz0, sz1 ) ))

    tfn = ofn + ".tmp"

    if  sz0 :
        fi0 = open(if0, "rb")
        fi1 = open(if1, "rb")

        a0  = [ ord(fi0.read(1)), ]
        a1  = [ ord(fi1.read(1)), ]

        for pair in PAIRS :
            m0  = a_two_of_three(pair[0])
            m1  = a_two_of_three(pair[1])
            za  = m0.decode_bytes(a0, m1, a1)
            # print z
            z   = za[0]
            if  z == VERSION_B :
                # print "Found pair", pair
                break
            pass
        if  z != VERSION_B  :
            raise(ValueError("Input is not two splits from same source using this two_of_three!"))

        fo  = open(tfn, "wb")

        bs  = buff_size / 2
        ob  = za[1:]
        while True :
            xa  = ob[-2:]
            ob  = ob[:-2]
            if  ob :
                fo.write(a_two_of_three.decoded_string(ob))

            a0  = fi0.read(bs)
            if  not len(a0) :
                break

            a1  = fi1.read(bs)

            ob  = xa + m0.decode_strings(a0, m1, a1)

        if  not xa[-1] :
            fo.write(a_two_of_three.decoded_string(xa[0:1]))

        fi0.close()
        fi1.close()

        fo.close()
        replace_file.replace_file(ofn, tfn, ofn + ".bak")

        if  xa[-1] not in [ 0, 1 ] :
            raise(ValueError("Data was written but it is incorrect!"))

        pass
    else :
        tzlib.write_whole_binary_file(tfn, "")
        replace_file.replace_file(ofn, tfn, ofn + ".bak")
    pass



def test_files() :
    import  random

    dfn     = "two_of_three.tst"
    ofn0    = dfn + "_1.tst"
    ofn1    = dfn + "_2.tst"
    ofn2    = dfn + "_3.tst"
    dfno    = dfn + "_joined.tst"
    for i in range(1131) :
        buff_size   = random.choice([ 2, 8, 10, 20, 32, 1000, 1024, random.randint(2, 100), random.randint(2, 100), random.randint(2, 100), random.randint(2, 100), ])
        buff_size  &= ~1

        ln  = random.choice([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 64, 89, 256, 512, random.randint(2, buff_size * 3), random.randint(2, buff_size * 3), random.randint(2, buff_size * 3), random.randint(2, buff_size * 3), random.randint(2, buff_size * 3), ])

        # print "ln=%-5u buff_size=%-5u" % ( ln, buff_size )

        data    = [ random.randint(0, 255) for i in xrange(ln) ]

        ds      = "".join([ chr(c) for c in data ])

        tzlib.write_whole_binary_file(dfn, ds)

        split_file(dfn, ofn0, ofn1, ofn2, buff_size = buff_size)
        ofns    = [ ofn0, ofn1, ofn2 ]
        random.shuffle(ofns)
        join_file(ofns[0], ofns[1], dfno, buff_size = buff_size)

        dsr     = tzlib.read_whole_binary_file(dfno)
        if  ds != dsr :
            print "Mismatch %u %s %s!" % ( ln, ofns[0], ofns[1] )
            sys.exit(197)

        pass

    if  os.path.isfile(dfn) :
        os.remove(dfn)
    if  os.path.isfile(dfno) :
        os.remove(dfno)
    if  os.path.isfile(ofn0) :
        os.remove(ofn0)
    if  os.path.isfile(ofn1) :
        os.remove(ofn1)
    if  os.path.isfile(ofn2) :
        os.remove(ofn2)

    print "Files OK"





help_str    = """
%s (--split) file                                   (output_file_base_name)
%s (--split) file                                   split_output_file_1 split_output_file_2 split_output_file_3
%s (--join)  input_file_base_name                   (output_file)
%s (--join)  split_input_file_X split_input_file_Y  (output_file)
%s           file_not.2o3                           split_output_file_base_name
%s           file.2o3           file.2o3            (joined_output_file)

Options:

    --buffer    copy_buffer_size                    How many bytes to read from the input at a time (default: %u)

Split a file in to 3 files, each 1/2 the length of the input file (plus 1 byte).

Or join any two of the 3 split files to make the original file.

"""


#
#
#
if __name__ == '__main__' :

    import  TZCommandLineAtFile
    import  tzlib


    program_name    = sys.argv.pop(0)
    TZCommandLineAtFile.expand_at_sign_command_line_files(sys.argv)


    buff_size   = BUFFER_SIZE
    split   = None


    if  tzlib.array_find(sys.argv, [ "--help", "-h", "-?", "/?", ] ) >= 0 :
        pn  = os.path.basename(program_name)
        print help_str % ( pn, pn, pn, pn, pn, pn, buff_size, )
        sys.exit(254)


    while True :
        oi  = tzlib.array_find(sys.argv, [ "--find", ] )
        if  oi < 0 :    break
        del sys.argv[oi]
        find_em()

    while True :
        oi  = tzlib.array_find(sys.argv, [ "--test", ] )
        if  oi < 0 :    break
        del sys.argv[oi]
        test_em()

    while True :
        oi  = tzlib.array_find(sys.argv, [ "--test_files", "--testfiles", "--test-files", ] )
        if  oi < 0 :    break
        del sys.argv[oi]
        test_files()
        sys.exit(0)


    while True :
        oi  = tzlib.array_find(sys.argv, [ "--split", "-s", ] )
        if  oi < 0 :    break
        del sys.argv[oi]
        split       = True

    while True :
        oi  = tzlib.array_find(sys.argv, [ "--join", "-j", ] )
        if  oi < 0 :    break
        del sys.argv[oi]
        split       = False

    while True :
        oi  = tzlib.array_find(sys.argv, [ "--buffer", "--buffer_size", "--buffer-size", "--buffersize", "-b", ] )
        if  oi < 0 :    break
        del sys.argv[oi]
        buff_size   = int(sys.argv.pop(oi))
        if  buff_size < 2 :
            print "Buffer size must be 2 or greater!"
            sys.exit(105)
        buff_size  &= ~1


    if  not sys.argv :
        print "--help for help."
        sys.exit(250)

    ifiles  = []
    ofiles  = []

    ifiles.append(sys.argv.pop(0))

    if  len(sys.argv) > 3   :
        print "Too many file names!"
        sys.exit(101)

    if  len(sys.argv) == 3  :
        if  split == False  :
            print "I can't join 3 files to 1 file!"
            sys.exit(102)

        while sys.argv :
            ofiles.append(sys.argv.pop(0))

        split   = True

    elif len(sys.argv) == 2 :
        if  split == True :
            print "I can't split from or to two files!"
            sys.exit(103)

        ifiles.append(sys.argv.pop(0))
        ofiles.append(sys.argv.pop(0))

        split   = False

    elif len(sys.argv) :
        fn  = sys.argv.pop(0)
        if  (split == False) or (fn.lower().endswith(SPLIT_EXT)) or (os.path.isfile(fn) and (os.path.getsize(ifiles[0]) == os.path.getsize(fn))) :
            ifiles.append(fn)
            split   = False
        else :
            split   = True
            ofiles.append(fn)
        pass

    elif split == None :
        if  ifiles[0].lower().endswith(SPLIT_EXT) :
            split   = False
        else        :
            split   = os.path.isfile(ifiles[0])
        pass


    if  split :
        if  len(ifiles) != 1 :
            print "Arrgggh. Not just one file to split %s !" % ifiles
            sys.exit(149)

        ifn = ifiles[0]

        if  ifn.lower().endswith(SPLIT_EXT) :
            print "I don't want to split files with my extension, like %s !" % ifn
            sys.exit(104)

        if  not len(ofiles) :
            ofiles.append(os.path.splitext(ifn)[0])

        if  len(ofiles) == 1 :
            ofn             = ofiles[0]
            ( bfn, ext )    = os.path.splitext(ofn)
            if  ext.lower() != SPLIT_EXT :
                ext         =  SPLIT_EXT
                bfn         = ofn

            ofiles          = [
                                    bfn + "_1" + ext,
                                    bfn + "_2" + ext,
                                    bfn + "_3" + ext,
                              ]
            pass

        split_file(ifn, ofiles[0], ofiles[1], ofiles[2], buff_size = buff_size)

    else    :
        for fn in ofiles :                                                          # note: actually, ofiles will only contain 1 file name
            if  os.path.isfile(fn) and fn.lower().endswith(SPLIT_EXT) :
                print "I don't want to join files to create %s !" % fn
                sys.exit(104)
            pass

        fna = [ "_1" + SPLIT_EXT, "_2" + SPLIT_EXT, "_3" + SPLIT_EXT ]
        if  len(ifiles) == 1 :
            ifn = ifiles.pop(0)
            if  os.path.isfile(ifn) :
                ifiles.append(ifn)
                for exn in fna :
                    if  ifn.lower().endswith(exn) :
                        for xn in fna :
                            if  xn != exn :
                                tfn = ifn[0 : -len(exn)] + xn
                                if  os.path.isfile(tfn) :
                                    ifiles.append(tfn)
                                    break
                                pass
                            pass
                        pass
                    if  len(ifiles) != 1 :
                        break
                    pass
                pass
            else :
                for fn in fna :
                    tfn = ifn + fn
                    print "tfn", tfn
                    if  os.path.isfile(tfn) :
                        ifiles.append(tfn)
                        if  len(ifiles) == 2 :
                            break
                        pass
                    pass
                pass

            if  not len(ifiles) :
                ifiles.append(ifn)
            pass

        if  len(ifiles) != 2 :
            print "Not two files to join %s !" % str(ifiles)
            sys.exit(106)

        if  not len(ofiles) :
            ifn = ifiles[0]
            if  not ifn.lower().endswith(SPLIT_EXT) :
                ifn = ifiles[1]
            if  not ifn.lower().endswith(SPLIT_EXT) :
                print "Cannot figure out what to name the output file from %s !" % ifiles
                sys.exit(107)

            for fn in fna :
                if  ifn.lower().endswith(fn) :
                    ofiles.append(ifn[0 : -len(fn)])
                    break
                pass
            if  not len(ofiles) :
                ofiles.append(ifn[0 : -len(SPLIT_EXT)])

            pass

        if  len(ofiles) != 1 :
            print "Not one file to join to %s !" % str(ofiles)
            sys.exit(108)

        join_file(ifiles[0], ifiles[1], ofiles[0], buff_size = buff_size)

    pass


#
#
#
# eof

