#!/usr/bin/python

# get_yahoo_historical_csv.py
#       --copyright--                   Copyright 2007 (C) Tranzoa, Co. All rights reserved.    Warranty: You're free and on your own here. This code is not necessarily up-to-date or of public quality.
#       --url--                         http://www.tranzoa.net/tzpython/
#       --email--                       pycode is the name to send to. tranzoa.com is the place to send to.
#       --bodstamps--
#       February 10, 2006       bar
#       February 18, 2006       bar     got rid of underscore in get_request
#                                       create get_update_csv  routine
#                                       create make_today_date routine
#                                       create make_date_str   routine
#                                       create find_price      routine
#       February 19, 2006       bar     more routines
#       February 25, 2006       bar     split test
#       March 2, 2006           bar     subtract_pricing and constants
#       March 3, 2006           bar     put scoping resolution on a_csv constants
#       March 22, 2006          bar     rename and put in \tzpython
#       April 20, 2006          bar     optional param to make_today_date()
#                                       shuffle the symbols before doing the snags
#       April 23, 2006          bar     --delay
#                                       use timeout and show_info
#                                       --update_files
#       February 6, 2007        bar     fix the values' regx to handle the new, better way the csv data comes back from the server
#       February 7, 2007        bar     fix a typo from yesterday
#       February 9, 2007        bar     better split detection logic
#                                       get_partial_csv()
#       February 27, 2007       bar     find_date_i and binary search
#                                       fix bug in numeric_date
#       April 23, 2007          bar     find_date_price_i
#                                       find_date_vals
#       November 18, 2007       bar     turn on doxygen
#       November 27, 2007       bar     insert boilerplate copyright
#       May 17, 2008            bar     email adr
#       July 1, 2008            bar     allow zero length files to trigger a fetch at a very high level (don't fuss about a .csv file that's zero length, that is)
#       August 29, 2008         bar     basestring instead of StringType because of unicode strings and others
#       August 9, 2010          bar     allow getting symbol not known without --refresh
#       May 27, 2012            bar     doxygen namespace
#       October 31, 2012        bar     divide by zero
#       --eodstamps--
##      \file
#       \namespace              tzpython.get_yahoo_historical_csv
#
#
#       Get or update the Yahoo historical prices .csv file for given symbol(s).
#
#       Note:
#           It appears that the numbers rather arbitrarily change over time.
#           And they don't match the charts (ref PCAR Aug 2006 to Dec 2006)
#
#

import  copy
import  os.path
import  re
import  time
import  urllib
import  urllib2

import  replace_file
import  tzlib
import  url_getter


opener            = urllib2.build_opener()
opener.addheaders = []                  # get rid of 'User-agent' the only way that seems to work (yes, I tried lower-casing 'Agent')
urllib2.install_opener(opener)



def get_request(req, timeout = None) :
    f   = ""
    if  True :
        req.add_header('User-Agent',       'Mozilla/5.0 (Windows; U; Win98; en-US; rv:1.5) Gecko/20031007')
        req.add_header('Accept',           'text/xml,application/xml,application/xhtml+xml,text/html;q=0.9,text/plain;q=0.8,image/png,image/jpeg,image/gif;q=0.2,*/*;q=0.1')
        req.add_header('Accept-Language',  'en-us,en;q=0.5')
        # req.add_header('Accept-Encoding',  'gzip,deflate')
        req.add_header('Accept-Charset',   'ISO-8859-1,utf-8;q=0.7,*;q=0.7')
        # req.add_header('Keep-Alive',       '300')
        # req.add_header('Connection',       'keep-alive')

    elif False :

        req.add_header('User-Agent',       'TZYBrowser/00.01 Graph/01.00 Text/01.00 Gen/01.00')
        req.add_header('Accept',           'text/html,image/png,image/jpeg,image/gif,image/bmp,image/jpg')
        req.add_header('Accept-Language',  'en-us,en')
        req.add_header('Accept-Charset',   'ISO-8859-1,utf-8')

    f   = url_getter.url_open_read_with_timeout(req, timeout)

    return(f)




known_symbols   =   {}




def clean_symbol(sym) :
    """
        Clean up the company symbol as best we can.
    """

    if  sym :
        sym   = re.sub(r"[^a-zA-Z\.\^\-]", "", sym).strip().upper()

    return(sym)



def cmp_vals_dates(vals1, vals2) :

    r   = vals1[0] - vals2[0]
    if  r == 0 :
        r  =    vals1[1] - vals2[1]
        if  r == 0 :
            r  =    vals1[2] - vals2[2]
        pass

    return(r)




month_names         = [
                        "Jan",
                        "Feb",
                        "Mar",
                        "Apr",
                        "May",
                        "Jun",
                        "Jul",
                        "Aug",
                        "Sep",
                        "Oct",
                        "Nov",
                        "Dec"
                      ]


def make_date_str(ymd_array) :
    return("%u-%s-%02u" % ( ymd_array[a_csv.DAY], month_names[ymd_array[a_csv.MONTH] - 1], ymd_array[a_csv.YEAR] % 100 ) )


def make_today_date(t = None) :
    if  t == None :
        t  = time.time()
    t = time.localtime(t)

    return( [ t[0], t[1], t[2] ] )


def numeric_date(ymd_array) :
    return(ymd_array[a_csv.YEAR] * 384 + ymd_array[a_csv.MONTH] * 32 + ymd_array[a_csv.DAY])


def unix_date(ymd_array) :
    if  ymd_array[a_csv.YEAR] < 1970 :
        return(0.0)

    return(time.mktime( [ ymd_array[a_csv.YEAR], ymd_array[a_csv.MONTH], ymd_array[a_csv.DAY], 0, 0, 0, 0, 0, -1 ] ))



class a_csv :

    YEAR                =   0
    MONTH               =   1
    DAY                 =   2
    TKR_OPENING_PRICE   =   3
    TKR_HIGH_PRICE      =   4
    TKR_LOW_PRICE       =   5
    TKR_CLOSE_PRICE     =   6           # ticker closing price at the time, not adjusted for splits
    VOLUME              =   7
    CLOSING_PRICE       =   8           # split-adjusted closing price


    li_re               =   re.compile(r"^(\d{1,2})-([A-Z]{3})-(\d\d),([\d\.]+),([\d\.]+),([\d\.]+),([\d\.]+),([\d\.]+),([\d\.]+)$", re.MULTILINE + re.DOTALL + re.IGNORECASE)
    new_li_re           =   re.compile(r"^(\d\d\d\d)-(\d\d)-(\d\d),([\d\.]+),([\d\.]+),([\d\.]+),([\d\.]+),([\d\.]+),([\d\.]+)$",    re.MULTILINE + re.DOTALL + re.IGNORECASE)

    month_names_to_num  = {
                            "jan" :  1,
                            "feb" :  2,
                            "mar" :  3,
                            "apr" :  4,
                            "may" :  5,
                            "jun" :  6,
                            "jul" :  7,
                            "aug" :  8,
                            "sep" :  9,
                            "oct" : 10,
                            "nov" : 11,
                            "dec" : 12
                          }

    def cmp_dates(me, w1, w2) :
        return(cmp_vals_dates(me.vals[w1], me.vals[w2]))




    def __init__(me, csv_text, sym = "") :

        me.sym  = sym

        me.hdr  = re.match(r"([^\r\n]+)", csv_text).group(1)

        if  not me.hdr :
            if  show_info and (show_info > 1) :
                print csv_text
            err = "No header in csv for %s of length " % ( me.sym, len(csv_text) )
            raise err

        me.vals =   a_csv.li_re.findall(csv_text)

        if  len(me.vals) :
            for i in xrange(len(me.vals)) :
                ivals   = copy.copy(me.vals[i])

                vals    =   []

                vals.append(int(ivals[2]))                                          # year
                if  vals[0] < 28 :  vals[0] += 100
                vals[a_csv.YEAR]            +=  1900

                vals.append(a_csv.month_names_to_num[ivals[a_csv.MONTH].lower()])   # month

                vals.append(int(ivals[0]))                                          # day of month

                vals.append(float(ivals[3]))
                vals.append(float(ivals[4]))
                vals.append(float(ivals[5]))
                vals.append(float(ivals[6]))
                vals.append(long(ivals[7]))
                vals.append(float(ivals[8]))

                me.vals[i]   =  vals

        else :
            me.vals = a_csv.new_li_re.findall(csv_text)
            if  not len(me.vals) :
                if  show_info and (show_info > 1) :
                    print csv_text
                err = "No values found in csv for %s of length %u " % ( me.sym, len(csv_text) )
                raise err

            for i in xrange(len(me.vals)) :
                ivals   = copy.copy(me.vals[i])

                vals    =   []

                vals.append(int(ivals[0]))                  # year
                vals.append(int(ivals[a_csv.MONTH]))        # month
                vals.append(int(ivals[2]))                  # day of month

                vals.append(float(ivals[3]))                # open
                vals.append(float(ivals[4]))                # high
                vals.append(float(ivals[5]))                # low
                vals.append(float(ivals[6]))                # close
                vals.append(long(ivals[7]))                 # volume
                vals.append(float(ivals[8]))                # adjusted close

                me.vals[i]   =  vals                        # replace the regx results with numbers in date-sortable order

            pass



        me.vals.sort()

        if  len(me.vals) > 0 :
            if  me.cmp_dates(len(me.vals) - 1, 0) < 0 :
                err     = "Wrong times in a_csv" + str(me.vals[0]) + " " + str(me.vals[len(me.vals) - 1])
                raise   err
            pass

        pass



    def is_split(me, i) :
        """
            Has there been a split between the 'i' vals and the 'i'th - 1 vals?

            \returns
                True if there's been a split or unsplit
        """

        if  0 < i < len(me.vals) :
            if  not me.vals[i - 1][a_csv.CLOSING_PRICE] :
                return(True)

            p   = me.vals[i - 1][a_csv.TKR_CLOSE_PRICE] / me.vals[i - 1][a_csv.CLOSING_PRICE]
            d   = abs(me.vals[i][a_csv.TKR_CLOSE_PRICE] - (p * me.vals[i][a_csv.CLOSING_PRICE]))
            if  d > 0.02 :
                # print "split", d, p, str(me.vals[i])

                return(True)                   # tell caller that there was some jump in values (probably because of a split or unsplit)
            pass

        return(False)




    def append(me, csv) :
        if  csv :
            if  isinstance(csv, basestring) :
                csv =   a_csv(csv)

            #
            #   Merge the new items in to ours
            #
            i       =   j   =   0
            while (i < len(me.vals)) and (j < len(csv.vals)) :
                d   = cmp_vals_dates(me.vals[i], csv.vals[j])
                if  d >= 0 :
                    if  d > 0 :
                        me.vals.insert(i, csv.vals[j])
                        if  me.is_split(i) :
                            return(False)
                        if  me.is_split(i + 1) :
                            return(False)
                        pass
                    j   += 1
                i       += 1

            pi           = len(me.vals)

            #
            #   Now append the new ones to ours
            #
            if  j < len(csv.vals) :
                me.vals     += csv.vals[j:]

                if  me.is_split(pi) :
                    return(False);

                pass
            pass


        return(True)


    def find_date_i(me, ymd_array) :
        """
            Find the 'vals' index that's at the given date or the next trading day, if the given date isn't in 'vals'.
            If the date is past the end of 'vals', then return the max(0, len(me.vals) - 1).
        """

        dt  = numeric_date(ymd_array)
        i   = 0
        lo  = 0
        hi  = len(me.vals)
        while lo < hi :
            i       = (lo + hi) / 2
            if  dt  > numeric_date(me.vals[i]) :
                i  += 1
                lo  = i
            else :
                hi  = i
            pass

        return(min(i, len(me.vals) - 1))



    def find_date_price_i(me, ymd_array) :
        """
            Find the price's 'vals' index as we consider it to be.
        """

        i       = me.find_date_i(ymd_array)

        if  numeric_date(me.vals[i]) >= numeric_date(ymd_array) :
            i   = max(i - 1, 0)                         # return the previous day's closing price if there was a previous trading day

        return(i)



    def find_price(me, ymd_array) :
        """
            Find the split adjusted closing price for the last trading day previous to this day.
        """

        i   = me.find_date_price_i(ymd_array)

        return(me.vals[i][a_csv.CLOSING_PRICE])


    def find_date_vals(me, ymd_array) :
        """
            Return the date ('vals') array that has this date's closing price.
        """

        i   = me.find_date_price_i(ymd_array)

        return(me.vals[i])




    def subtract(me, csv, which_value = CLOSING_PRICE) :

        if  not csv :
            me.vals = []
        else        :
            if  isinstance(csv, basestring) :
                csv =   a_csv(csv)

            ovals   = []
            i       =   j   =   0
            if  (i < len(me.vals)) and (j < len(csv.vals)) :

                op  = None

                while (i < len(me.vals)) and (j < len(csv.vals)) :
                    d   = cmp_vals_dates(me.vals[i], csv.vals[j])
                    if  d != 0 :
                        if  d > 0 :
                            j   += 1
                        else :
                            i   += 1
                        pass
                    else :
                        if  op == None :
                            op  =  me.vals[i][which_value]                  # original prices
                            sp  = csv.vals[j][which_value]

                            ovals   = [ me.vals[i] ]                        # and the whole day's info is copied to the new values
                        else :
                            vls = copy.copy(me.vals[i])

                            d1  = (        vls[which_value] - op) / op
                            d2  = (csv.vals[j][which_value] - sp) / sp

                            vls[which_value] = op + (op * (d1 - d2))

                            ovals.append(vls)

                        i  += 1
                        j  += 1
                    pass
                pass

            me.vals = ovals

        pass







    def write_csv_file(me, fname) :
        tname = fname + ".tmp"
        bname = fname + ".bak"

        fo  = open(tname, "wb")
        print >> fo, me.hdr
        for i in range(len(me.vals) - 1, -1, -1) :
            vals  = me.vals[i]
            print >>fo, "%s,%.2f,%.2f,%.2f,%.2f,%lu,%.2f" % ( make_date_str(vals[0:3]),
                                                              vals[a_csv.TKR_OPENING_PRICE],
                                                              vals[a_csv.TKR_HIGH_PRICE],
                                                              vals[a_csv.TKR_LOW_PRICE],
                                                              vals[a_csv.TKR_CLOSE_PRICE],
                                                              vals[a_csv.VOLUME],
                                                              vals[a_csv.CLOSING_PRICE]
                                                            )
        fo.close()
        del(fo)

        replace_file.replace_file(fname, tname, bname)






def read_csv_file(fname, sym = "") :
    csv = None
    if  fname :
        try :
            fi  = open(fname, "r")
        except IOError :
            return(None)

        csv = fi.read()
        fi.close()

        csv = a_csv(csv, sym)

    return(csv)





def get_partial_csv(osym, from_t = None, to_t = None, timeout = None, show_info = False) :
    """
        Get a .csv object from a given date through a given date.
    """

    sym  = clean_symbol(osym)

    if  sym :

        if  from_t :
            start_month = from_t[1] - 1
            start_day   = from_t[2]
            start_year  = from_t[0]
        else :
            start_month = 0
            start_day   = 1
            start_year  = 1928

        if  to_t :
            end_month   = to_t[1] - 1
            end_day     = to_t[2]
            end_year    = to_t[0]
        else :
            t           = time.gmtime()
            end_month   = t[1] - 1
            end_day     = t[2]
            end_year    = t[0]

        url_sym         = urllib.quote(sym)

        url             = "http://ichart.finance.yahoo.com/table.csv?s=%s&d=%u&e=%u&f=%u&g=d&a=%u&b=%u&c=%u&ignore=.csv" % ( url_sym, end_month, end_day, end_year, start_month, start_day, start_year )

        if  show_info :
            print "Getting .csv for", sym, " ", url

        csv             = get_request(urllib2.Request(url), timeout)

        if  csv :
            return(a_csv(csv, osym))

        if  show_info :
            print "No .csv retrieved for", sym
        pass

    return(None)



def get_csv(osym, current_csv, timeout = None, show_info = False) :
    """
        Get the .csv object for a given symbol. If there is already a current csv object, then just get updates to it.
    """

    sym  = clean_symbol(osym)

    if  sym :

        global  known_symbols

        if  current_csv and known_symbols.has_key(sym) :
            return(current_csv)                                             # return the input csv if the csv is in our cache (implying that it came from the cache rather than from a file just read)

        start       = None
        if  current_csv :
            start   = current_csv.vals[len(current_csv.vals) - 1]

        end         = time.gmtime()

        csv         = get_partial_csv(osym, start, end, timeout = timeout, show_info = show_info)

        if  csv :

            if  current_csv :
                if  not current_csv.append(csv) :
                    return(get_csv(osym, None, timeout, show_info))         # force a full update because there was a discontinuity in the numbers
                csv = current_csv

            known_symbols[sym] = csv                                        # cache the new csv

            return(csv)

        pass

    return(None)



def get_update_csv(osym, output_dir, refresh = False, check_base = False, timeout = None, show_info = False) :
    """
        For the given symbol, update (or entirely replace if
        'refresh' is TRUE) the historical data csv file in the
        given directory.
        If refresh is None and there is already a csv file, don't get anything, just go with the csv file we have.
    """

    global  known_symbols

    sym     = clean_symbol(osym)

    fname   = os.path.join(output_dir, tzlib.file_name_able(clean_symbol(osym).lower())) + ".csv"

    lv      = 0
    csv     = None
    if  not refresh :
        csv = known_symbols.get(sym, None)                                  # get old values from the cache
        if  not csv :
            try :
                if  not os.path.getsize(fname) :
                    csv = None                                                  # allow zero-length files to trigger us at a much higher level
                else :
                    csv = read_csv_file(fname, osym)                            # or, if they aren't there, from disk
                pass
            except ( OSError, IOError, ) :
                csv     = None
            pass
        if  csv :
            lv  = len(csv.vals)
        if  (refresh == None) and csv :
            known_symbols[sym] = csv
        pass

    if  (refresh != None) or (not csv) or check_base :

        if  csv and check_base :
            bcsv    = get_partial_csv(osym, [ 1928, 1, 1 ], csv.vals[0], timeout = timeout, show_info = show_info)
            if  (not bcsv) or (len(bcsv.vals) == 0) or (str(bcsv.vals[0]) != str(csv.vals[0])) :
                if  show_info :
                    if  bcsv :
                        print "Check base mismatch", sym, csv.vals[0], bcsv.vals[0]
                    else :
                        print "Check base mismatch", sym, csv.vals[0]
                csv = None
                lv  = 0
            pass

        csv         = get_csv(osym, csv, timeout = timeout, show_info = show_info)      # either get the whole history, or get the most recent history

    if  csv :
        known_symbols[sym] = csv                        # cache it

        if  len(csv.vals) != lv :
            csv.write_csv_file(fname)
        elif show_info :
            print "Nothing to add to", sym, "length of", lv
        pass
    elif     show_info :
        print "Cannot do anything for", sym

    return(csv)



def find_price(sym, ymd_array, output_dir, refresh = False) :

    p   = None

    csv = get_update_csv(sym, output_dir, refresh = refresh)
    if  csv :
        p   = csv.find_price(ymd_array)

    return(p)



def find_date_vals(sym, ymd_array, output_dir, refresh = False) :

    vals        = None

    csv = get_update_csv(sym, output_dir, refresh = refresh)
    if  csv :
        vals    = csv.find_date_vals(ymd_array)

    return(vals)




if  __name__ == '__main__' :
    import  glob
    import  random
    import  sys
    import  os

    if  len(sys.argv) < 2 :

        print   "Tell me a company symbol"

    else :

        import  TZCommandLineAtFile

        del(sys.argv[0])

        TZCommandLineAtFile.expand_at_sign_command_line_files(sys.argv)

        show_info   =   False
        timeout     =   None
        retries     =   0
        output_dir  =   ""
        refresh     =   False
        check_base  =   False
        wait_time   =   0
        update_all  =   False

        if  (tzlib.array_find(sys.argv, "--help") >= 0) or (tzlib.array_find(sys.argv, "-h") >= 0) or (tzlib.array_find(sys.argv, "-?") >= 0) :
            print """
get_historical_csv      Get historical data .csv file from Yahoo finance, by symbol (e.g. MSFT  ^DJI mmm).

--refresh               Do not append to current .csv file. Overwrite it.
--output_dir    dir     Directory name to put the .csv file(s) in to. (output file name is lower_case_symbol.cvs)
--refresh               Force getting the whole CSV information rather than appending new data.
--check_base            Force getting whole CSV if day-one differs.
--show_info             Print progress/debugging info.
--timeout       seconds Set the web-bit timeout.
--retries       cnt     Set the number of retries.
--delay         seconds How long to delay between each download.
--update_all            Update all csv files in 'output_dir'.

"""
            sys.exit(254)



        while True :
            oi  = tzlib.array_find(sys.argv, "--timeout")
            if  oi < 0 :    break
            del sys.argv[oi]
            timeout         = int(sys.argv.pop(oi))

        while True :
            oi  = tzlib.array_find(sys.argv, "--delay")
            if  oi < 0 :    break
            del sys.argv[oi]
            wait_time       = int(sys.argv.pop(oi))

        while True :
            oi  = tzlib.array_find(sys.argv, "--retries")
            if  oi < 0 :    break
            del sys.argv[oi]
            retries         = int(sys.argv.pop(oi))

        while True :
            oi  = tzlib.array_find(sys.argv, "--show_info")
            if  oi < 0 :    break
            del sys.argv[oi]
            if  not show_info :
                show_info   = 0
            show_info      += 1

        while True :
            oi  = tzlib.array_find(sys.argv, "--refresh")
            if  oi < 0 :    break
            del sys.argv[oi]
            refresh         = True

        while True :
            oi  = tzlib.array_find(sys.argv, "--check_base")
            if  oi < 0 :    break
            del sys.argv[oi]
            check_base      = True

        while True :
            oi  = tzlib.array_find(sys.argv, "--output_dir")
            if  oi < 0 :    break
            del sys.argv[oi]
            output_dir      =   sys.argv.pop(oi)

        while True :
            oi  = tzlib.array_find(sys.argv, "--update_all")
            if  oi < 0 :    break
            del sys.argv[oi]
            update_all      = True


        if  output_dir      :  output_dir   = os.path.normpath(output_dir)

        if  not output_dir  :  output_dir   = "."

        if  not os.path.isdir(output_dir)   :   os.makedirs(output_dir)

        syms    = sys.argv

        if  update_all :
            syms   +=   map(lambda s : s[len(output_dir) + 1:-4], glob.glob(os.path.join(output_dir, "*.csv")))

        for si in xrange(len(syms)) :
            syms[si] = syms[si].replace("_", "^")               # why doesn't map/lambda work????  anyway, this is the best we can do to unmap the file names back to a symbol

        syms    = tzlib.without_dupes(sys.argv)

        #
        #   Run through all the symbols wanted
        #
        random.shuffle(syms)

        ot  = time.time() - wait_time

        while len(syms) > 0 :

            wt  = random.random() * wait_time + wait_time / 2
            while time.time() < ot + wt :
                time.sleep(0.11)

            sym = syms.pop(0)

            if  sym[0:1] == '-' :
                print "Did you mean for", sym, "to be a command line parameter?"

            get_update_csv(sym, output_dir, refresh = refresh, check_base = check_base, timeout = timeout, show_info = show_info)

            ot  =   time.time()

            known_symbols   = {}

        pass

    pass


#
#
#


__ALL__ = [
            'clean_symbol',

            'cmp_vals_dates',
            'make_date_str',
            'make_today_date',
            'numeric_date',
            'unix_date',

            'a_csv',

            'read_csv_file',
            'get_partial_csv',
            'get_csv',

            'get_update_csv',

            'find_price',
            'find_date_vals',
          ]


#
#
#
# eof
