#!/usr/bin/python

# markov_text.py
#       --copyright--                   Copyright 2010 (C) Tranzoa, Co. All rights reserved.    Warranty: You're free and on your own here. This code is not necessarily up-to-date or of public quality.
#       --url--                         http://www.tranzoa.net/tzpython/
#       --email--                       pycode is the name to send to. tranzoa.com is the place to send to.
#       --bodstamps--
#       February 28, 2011       bar
#       May 27, 2012            bar     doxygen namespace
#       June 2, 2013            bar     fix interpretor name
#       August 2, 2015          bar     make a bug more clear
#       --eodstamps--
##      \file
#       \namespace              tzpython.markov_text
#
#
#       Output words that sorta seem right if you don't pay attention.
#
#


import  random
import  re

import  tzlib


DEPTH       =        7              # max number of words we put in history items
MAX_CNT     =  2000000              # total number of history items we can know
READY_COUNT =    10000              # number of word-instance counts needed to be is_ready()


class   a_markov(object) :
    def __init__(me, n = 1, depth = DEPTH) :
        me.n        = min(DEPTH, max(1, n or 1)) - 1
        me.depth    = max(depth, me.n + 1)

        me.all      = [ {}, {}, {}, ]           # [ { history_string : { generate_word : cnt } },  ]
        me.wrds     = [ "" ] * me.depth
        me.cnt      = 0
        me.dcnt     = 0
        me.wcnt     = 0
        me.owrds    = []


    def is_ready(me) :
        """ Return the total, known word-instance-counts. """

        cnt = sum([ sum([ sum(wd.values()) for wd in wdd.values() ]) for wdd in me.all ])
        return(((cnt > READY_COUNT) and cnt) or 0)


    def need_drop(me) :
        """ Return True if we should forget something to avoid using too much memory. """
        return(me.cnt >= MAX_CNT)


    def is_full_memory(me) :
        """
            Return True if at this exact time memory is full.
                Note: If True, a call to learn_word() will make this False.
        """
        return(me.need_drop() and len(me.all[0]))



    def drop_if_needed(me) :
        """ Forget something when memory is overloaded. """

        if  not me.need_drop() :
            return(False)

        me.dcnt    += 1
        d           = sum([ len(wd) for wd in me.all[0].values() ])
        me.cnt     -= d
        me.all.pop(0)                           # forget the oldest history
        me.all.append({})                       # and start learning a fresh batch of word sequences

        return(True)


    def learn_word(me, w) :
        """ Learn the next word from the input. 'w' should not be ''. """

        me.drop_if_needed()

        me.wcnt    += 1
        for i in xrange(len(me.wrds) - me.n) :
            s       = " ".join(me.wrds[i:])
            wdd     = me.all[-1]
            if  not s in wdd :
                wdd[s]      = {}
            wd              = wdd[s]            # learn in the latest memory section
            if  not w in wd :                   # unknown generate_word?
                me.cnt     += 1                 # track how many items we know
                wd[w]       = 0                 # remember this generate_word
            wd[w]  += 1                         # and bump the count for the generate_word
        me.wrds     = me.wrds[1:] + [ w ]       # keep the history up to date


    def generate_word(me) :
        """ Return the next word. Don't call this until is_ready() is true (> 0). """

        if  not len(me.owrds) :
            cnt = me.is_ready()
            if  not cnt :

                return("")

            wi  = random.randint(0, me.cnt - 1) # choose a word at random from all the generatable words we know
            for wdd in me.all :
                sm  = sum([ sum(wd.values()) for wd in wdd.values() ])
                if  wi  < sm :
                    for s, wd in wdd.items() :
                        wsm = sum(wd.values())
                        if  wi  < wsm :
                            for w, c in wd.items() :
                                if  wi < c :
                                    me.owrds    = s.split(" ")[1:] + [ w ]      # note: "words" may be spaces or nothing, so owrds may not be long enough

                                    return(w)

                                wi -= c
                            pass
                        wi -= wsm
                    pass
                wi         -=  sm
            pass

        else :

            wda         = []
            # sa          = []
            for i in xrange(len(me.owrds)) :
                s       = " ".join(me.owrds[i:])
                wda    += [ wdd[s] for wdd in me.all if s in wdd ]      # get a list of all learned words that match the current output
                # sa     += [     s  for wdd in me.all if s in wdd ]
            if  not len(wda) :
                # print
                # print "@@@@ reset", me.owrds
                # print
                del(me.owrds[random.randint(0, len(me.owrds) - 1)])     # bail at random

                return(me.generate_word())

            cnt = sum([ sum(wd.values()) for wd in wda ])

            wi  = random.randint(0, cnt - 1)                            # choose a word at random from the ones that can come next

            for si, wd in enumerate(wda) :
                wsm = sum(wd.values())
                if  wi  < wsm :
                    for w, c in wd.items() :
                        if  wi < c :
                            if  len(me.owrds) < me.depth :
                                me.owrds.append(w)
                            else :
                                me.owrds    = me.owrds[1:] + [ w ]

                            # print "@@@@", cnt, w
                            return(w)
                            # return("{%s}%s" % (sa[si], w))

                        wi -= c

                    raise ValueError("LITTLE BUG!")

                wi         -= wsm
            pass

        raise ValueError("BUG!")

    #   a_markov


def do_it(fo, ifile_name, count = 0, encoding = None, n = 0, depth = DEPTH, flow = 0, lc = False) :
    encoding    = encoding  or 'utf8'
    count       = count     or 0
    flow        = flow      or 0

    me          = a_markov(n = n, depth = depth)
    f           = False

    fi          = open(ifile_name, "rt")
    while True  :
        ln      = fi.readline()
        if  not ln :    break

        ln      = ln.lstrip()
        if  (not len(ln)) or (ln[0] != ';') :
            wa  = [ w for w in re.split(r"\s+", ln.decode(encoding)) if len(w) ]

            if  me.is_full_memory() :   break
            if  not flow :
                me.learn_word("\n")

            if  (flow == 1) and (not len(wa)) :
                if  not f :
                    f   = True
                    if  me.is_full_memory() :   break
                    me.learn_word("\n\n")
                pass
            else    :
                f   = False
                for w in wa :
                    if  me.is_full_memory() :   break
                    if  lc :
                        w   = w.lower()
                    me.learn_word(w)
                pass
            pass
        pass
    fi.close()

    while count != 0 :
        w   = me.generate_word()
        if  w.strip() :
            fo.write((" " + w).encode(encoding))
        else :
            fo.write(w.encode(encoding))
        count  -= 1
    pass




help_str    = """
%s (options) input_file_with_words
Options:
    --output        file_name       Write to given output file.
    --count         N               Output N words and stop.
    --n             N               Learn chains N or longer           (default: 1)
    --depth         N               Maximum chain length to learn      (default: %u
    --lower_case                    Do things all in lower case.
    --flow                          Don't do lines and paragraphs.
    --paragraphs                    Do paragraphs but not lines.
    --encoding      encoding        Set the encoding of the input file (default: %s)

Print words that are sorta like the input file at a small scale.
"Words" are space-delimited.
Input files' lines beginning with (white-space)semi-colon are ignored.
If the input file appears to contain lines of text, then output lines of text.
If the input file appears to contain paragraphs (blank lines between text),
   then output such blank lines.
"""

#
#
#
if __name__ == '__main__' :
    import  os
    import  sys

    import  TZCommandLineAtFile
    import  output_files


    program_name    = sys.argv.pop(0)
    TZCommandLineAtFile.expand_at_sign_command_line_files(sys.argv)

    ofile_name      = ""
    encoding        = 'utf8'
    count           = -1
    lc              = False
    flow            = 0
    depth           = DEPTH
    n               = 1

    while True :
        oi  = tzlib.array_find(sys.argv, [ "--help", "-h", "-?", "/?", "/h", "/H" ] )
        if  oi < 0 :    break
        del sys.argv[oi]
        print help_str % ( os.path.basename(program_name), depth, encoding, )

        sys.exit(254)


    while True :
        oi  = tzlib.array_find(sys.argv, [ "--lower_case", "--lowercase", "--lower-case", "--lower", "--low", "--lc", "-l", ] )
        if  oi < 0 :    break
        del sys.argv[oi]
        lc              = True

    while True :
        oi  = tzlib.array_find(sys.argv, [ "--count", "-c", ] )
        if  oi < 0 :    break
        del sys.argv[oi]
        count           = int(sys.argv.pop(oi))

    while True :
        oi  = tzlib.array_find(sys.argv, [ "--n", "-n", ] )
        if  oi < 0 :    break
        del sys.argv[oi]
        n               = int(sys.argv.pop(oi))

    while True :
        oi  = tzlib.array_find(sys.argv, [ "--depth", "-d", ] )
        if  oi < 0 :    break
        del sys.argv[oi]
        depth           = int(sys.argv.pop(oi))

    while True :
        oi  = tzlib.array_find(sys.argv, [ "--encoding", "--enc", "-e", ] )
        if  oi < 0 :    break
        del sys.argv[oi]
        encoding        = sys.argv.pop(oi)

    while True :
        oi  = tzlib.array_find(sys.argv, [ "--flow", "-f", ] )
        if  oi < 0 :    break
        del sys.argv[oi]
        flow            = 2

    while True :
        oi  = tzlib.array_find(sys.argv, [ "--paragraphs", "--paragraph", "--para", "-p", ] )
        if  oi < 0 :    break
        del sys.argv[oi]
        flow            = 1

    while True :
        oi  = tzlib.array_find(sys.argv, [ "--output", "--out", "-o", ] )
        if  oi < 0 :    break
        del sys.argv[oi]
        ofile_name      = sys.argv.pop(oi)


    if  not len(sys.argv) :
        print >>sys.stderr, "No input file name given!"
        sys.exit(102)

    if  len(sys.argv) > 1 :
        print >>sys.stderr, "I don't understand %s!" % str(sys.argv[1:])
        sys.exit(102)

    ifile_name  = sys.argv.pop(0)

    fo      = sys.stdout
    if  ofile_name :
        fo  = output_files.a_file(ofile_name)

    do_it(fo, ifile_name, count = count, encoding = encoding, n = n, depth = depth, flow = flow, lc = lc)

    fo.write("\n\n")

    if  ofile_name :
        fo.close()

    pass

#
#
#
# eof
