#!/usr/bin/python

# tz_wave.py
#       --copyright--                   Copyright 2009 (C) Tranzoa, Co. All rights reserved.    Warranty: You're free and on your own here. This code is not necessarily up-to-date or of public quality.
#       --url--                         http://www.tranzoa.net/tzpython/
#       --email--                       pycode is the name to send to. tranzoa.com is the place to send to.
#       --bodstamps--
#       January 12, 2009        bar     spun from wave_futz.py, which is work to get this logic in to wave.py
#       January 13, 2009        bar     don't filter not wavs from writepcmsamples
#       January 23, 2009        bar     read_wave return value change to samples,sr,sw from sr,samples
#       February 8, 2009        bar     64-bit fix for array type (and tests by test_wave.py)
#       March 8, 2009           bar     bring the wave_futz.py changes back in here
#                                       unsuccessfully try to get around a leak in python 2.5.2 (and others?) in struct.unpack(), apparently
#       --eodstamps--
##      \file
#
#
#       Make wave.py usable.
#
#


import  wave

import  struct


##################################################################################
############### The following code is for insertion in to wave.py ################
##################################################################################


import  array                                       # NOTE: may as well import array at the file-scope level, not from inside particular routines

_array_fmts = [ None, 'B', 'h', None, 'l' ]         # NOTE: 'B' not 'b'.   8 bit samples are unsigned (This is a functional fix to wave.py!)

if  array.array('i').itemsize == 4:
    _array_fmts[-1] = 'i'                           # handle 64 bit versions (e.g. v2.5.2 under Ubuntu 64 H)
wave._array_fmts    = _array_fmts

if  array.array(wave._array_fmts[1]).itemsize != 1:
    raise TypeError
if  array.array(wave._array_fmts[2]).itemsize != 2:
    raise TypeError
if  array.array(wave._array_fmts[4]).itemsize != 4:
    raise TypeError



def readpcmsamples(self, nframes):
    """Return a list of lists of integers for raw, PCM samples.

    The number of these inner lists is given by the number of channels in
    the wave file. Each list contains the channel samples formatted as
    integers.

    Samples values are returned in one of these forms:
        8-bit:  0..255
        16-bit: -32768..32768
        32-bit: -2147483648..2147483647

    In each frame, there is one sample for each channel.
    """
    wav         = self.readframes(nframes)

    sampwidth   = self.getsampwidth()
    wav         = struct.unpack('<%d%s' % (len(wav) / sampwidth, wave._array_fmts[sampwidth]), wav)

    nc          = self.getnchannels()
    if nc > 1:
        wavs    = []
        for c in range(nc):
            wavs.append([wav[si] for si in range(c, len(wav), nc)])
        pass
    else:
        wavs    = [wav]

    return wavs


def writepcmsamples(self, wavs):
    """Write PCM samples to the wave file.

    wavs must follow the structure returned by readpcmsamples.
    """
    if self.getnchannels() != len(wavs):
        raise wave.Error("# of channels(%u) != # of passed sample channels(%u)" % (self.getnchannels(), len(wavs)))

    if  len(wavs) > 1:
        ln  = len(wavs[0])
        for w in wavs[1:] :
            if  len(w) != ln:
                raise wave.Error("PCM write with differing sample counts (%u:%u)" % (ln, len(w)))
            pass

        wav = []
        for w in zip(*wavs):
            wav    += w
        pass
    else:
        wav = wavs[0]

    ws = array.array(wave._array_fmts[self.getsampwidth()], wav)
    ws = ws.tostring()

    self.writeframes(ws)



################################################################
############### end of prospective wave.py code ################
################################################################



wave.Wave_write.writepcmsamples = writepcmsamples
wave.Wave_read.readpcmsamples   =  readpcmsamples



def destring_samples(samps) :
    if  samps :
        try :
            samps[0] + 1
        except TypeError :
            samps   = [ ord(s) for s in samps ]                         # in case samps is a string of, presumably, bytes, as it may be in the program this code comes from
        pass

    return(samps)


def remove_dc(samps) :
    if  samps :
        samps       = destring_samples(samps)

        avg         = float(sum([ s for s in samps ])) / len(samps)
        samps       = [ s - avg for s in samps ]                        # take out the DC by subtracting the simple average from all the samples (this code was used for an ADC board)

    return(samps)


def make_samples_full_scale(samps, sample_width) :
    if  samps :
        samps       = destring_samples(samps)

        mx_amp      = 127.0
        if  sample_width == 2 :
            mx_amp  = 32767.0
        elif sample_width == 4 :
            mx_amp  = 2147483647.0

        mx  = max(abs(min(samps)), max(samps))
        if  mx > 0  :
            mxm     = mx_amp / mx
            samps   = [ s * mxm for s in samps ]                        # scale the samples to full amplitude

    return(samps)


def clamp_samples(samps, sample_width) :
    if  samps :
        samps       = destring_samples(samps)

        mx_amp      = 127.0
        if  sample_width == 2 :
            mx_amp  = 32767.0
        elif sample_width == 4 :
            mx_amp  = 2147483647.0
        mn_amp      = -mx_amp - 1

        samps       = [ max(mn_amp, min(mx_amp, s)) for s in samps ]    # limit the samples to what they can be

    return(samps)



def pcm_samples(samps) :
    samps           = destring_samples(samps)
    return([ int(s)  for s in samps ])


def normalize_samples(samps, sample_width) :
    if  samps   :
        samps       = remove_dc(samps)
        samps       = clamp_samples(samps, sample_width)

    return(samps)




def write_wave(ofile_name, wavs, sample_rate, sample_width) :
    if  wavs :
        wavs        = [ pcm_samples(wav)            for wav in wavs ]

        z           = 0
        if  sample_width == 1 :
            z       = 128
            wavs    = [ [ s + 128 for s in wav ]    for wav in wavs ]   # convert 8 bit signed samples to unsigned 8 bit values .wav files expect

        mxln        = max([ len(wav)                for wav in wavs ])
        for wav in wavs :
            if  len(wav) < mxln :
                wav     += [ z ] * (mxln - len(wav))                    # make all the sample arrays the same length as the longest
            pass

        fo          = wave.open(ofile_name, "wb")
        fo.setsampwidth(sample_width)
        fo.setframerate(sample_rate)
        fo.setnchannels(len(wavs))
        fo.writepcmsamples(wavs)
        fo.close()
        del(fo)
    pass



def read_wave(ifile_name) :
    try :
        fi      = wave.open(ifile_name, "rb")
        sr      = fi.getframerate()
        sw      = fi.getsampwidth()
        wavs    = fi.readpcmsamples(fi.getnframes())
        fi.close()
        del(fi)

        if  sw == 1 :
            wavs    = [ [ s - 128 for s in wav ]    for wav in wavs ]

        return(wavs, sr, sw)
    except IOError :
        pass

    return(None, 0, 0)





#
#
#       Test code
#
#
#       February 8, 2009
#           The DCRC (Google: dcrc md5    or use md5sum) for these files is:
#
#           i:\projects\python >dcrc /5 *.wav
#               .   .Message.Digest . 5 .   .  Binary file
#            58eae0fd3f7bf7aac9303524ab8d5e2a  test_16_mono.wav                          93560  1-12-09  6:09A
#            6134ba9b2f1c0507f1485b185a671568  x_16_mono_filtered_from_8_bit.wav         93560  2-08-09  1:09P
#            48f8da402c85e762e6eced1204da011d  x_16_stereo_right_channel_backward.wav   187076  2-08-09  1:09P
#            fed50d2d43663a9f9c3f19911f9f0cc8  x_16_stereo_right_channel_filtered.wav   187076  2-08-09  1:09P
#            a7c38e10e41415991d2776e70bb88c39  x_8_mono.wav                              46802  2-08-09  1:09P
#            296cc91ccc967a0202536374a3a8e98d  x_8_mono_forward_backward_together.wav    46802  2-08-09  1:09P
#            555ba38228d44a37a6693dc1f7590536  x_8_stereo_left_channel_backward.wav      93560  2-08-09  1:09P
#            a895a656bc0e73e1172b01e4987a386c  Total for 7 files.
#
#
if __name__ == '__main__' :


    def write_our_wave(ofile_name, wavs, sample_rate, sample_width) :
        wavs    = [ make_samples_full_scale(remove_dc(sa), sample_width) for sa in wavs ]
        write_wave(ofile_name, wavs, sample_rate, sample_width)


    ( wavs, sr, w ) = read_wave("test_16_mono.wav")                                                 # actually, this file can be a stereo file, 8 or 16 bit, but we ignore the right channel
    if  wavs        :
        wv          = [ s for s in wavs[0] ]
        wv.reverse()


        write_our_wave("x_16_stereo_right_channel_backward.wav", [ wavs[0], wv      ], sr, 2)
        write_our_wave("x_8_mono.wav",                           [ wavs[0]          ], sr, 1)
        write_our_wave("x_8_stereo_left_channel_backward.wav",   [ wv,      wavs[0] ], sr, 1)


        ( wavs, sr, w ) = read_wave("x_16_stereo_right_channel_backward.wav")
        wav             = [ wavs[0][i] + wavs[1][i] for i in xrange(len(wavs[0])) ]                 # combine the two tracks to one mono
        write_our_wave("x_8_mono_forward_backward_together.wav", [ wav              ], sr, 1)       # note: write_our_wave() will scale the samples so it's ok that they are out of range


        ( wavs, sr, w ) = read_wave("x_8_mono.wav")
        fs              = 0.0
        wav             = []
        for s in wavs[0] :
            fs          = ((fs * 127) + s) / 128;
            wav.append(fs)                                                                          # cheeseball iir filter to mess with the sound
        write_our_wave("x_16_mono_filtered_from_8_bit.wav",      [ wav              ], sr, 2)       # note: write_our_wave() will scale the samples so they won't be really quiet. And it will convert them to ints


        ( wavs, sr, w ) = read_wave("x_8_stereo_left_channel_backward.wav")
        wv              = [ s for s in wavs[0] ]
        wv.reverse()
        fs              = 0.0
        wav             = []
        for s in wavs[1] :
            fs          = ((fs * 127) + s) / 128;
            wav.append(fs)                                                                          # cheeseball iir filter to mess with the sound
        write_our_wave("x_16_stereo_right_channel_filtered.wav", [ wv,      wav     ], sr, 2)

    pass


#
#
# eof

