# -*- coding: utf-8 -*-
# 2009-12-08 Moritz

from obspy.arclink import Client
from obspy.core import UTCDateTime
import obspy, os, pickle
import matplotlib.pyplot as plt
from obspy.signal import highpass, lowpass, pazToFreqResp, seisSim
import numpy as np
plt.close('all')


# Initialize Client
cl = Client()
t = UTCDateTime("2008,133,6:30:48.640") #year day 133

#
# Download the data
#
file_list = []
channel = {'FURT':'EHN', 'FUR':'BHN'}
for network, station in zip(['BW','GR'], ['FURT', 'FUR']):
    mseed_file = "%s.%s.%s.%s.D.%s" % (network ,station,'',channel[station],t.strftime("%Y.%j"))
    paz_file = "%s.%s.%s.%s.D.paz" % (network ,station,'',channel[station])
    file_list.append([mseed_file, paz_file])
    # Download data, if not already there
    if not os.path.exists(mseed_file):
        cl.saveWaveform(mseed_file, network, station, "", channel[station], t, t+7200)
    # Download paz, if not already there
    if not os.path.exists(paz_file):
        paz = cl.getPAZ(network, station, "", channel[station], t, t+1800)
        pickle.dump(paz, open(paz_file,'wb'))


# Cutoff for highpass to compare
f0 = .2

#
# Process FURT, LE3D, EHN=200Hz
#
tr1 = obspy.read(file_list[0][0])[0]
tr1.trim(t, t+7200)
tr1.data = tr1.data.astype('float32')
# Downsample from 200Hz to 20Hz ==> take every 10th point only
tr1.data = lowpass(tr1.data, 10, df=tr1.stats.sampling_rate, corners=2)
tr1.data = tr1.data[::10]*2.0 #do not know why I need this factor 2
df, npts = 20, len(tr1.data) 
paz1 = pickle.load(open(file_list[0][1],'rb'))
data1 = seisSim(tr1.data, df, paz1, 
                inst_sim=None, water_level=300.0)
data1 /= paz1['sensitivity'] 
data1 = highpass(data1, f0, df=df)


#
# Process FUR, STS-2, BHN=20Hz
#
tr2 = obspy.read(file_list[1][0])[0]
tr2.trim(t, t+7200)
tr2.data = tr2.data.astype('float32')
df, npts = tr2.stats.sampling_rate, tr2.stats.npts
paz2 = pickle.load(open(file_list[1][1],'rb'))
data2 = seisSim(tr2.data, df, paz2, 
                inst_sim=None, water_level=300.0)
data2 /= paz2['sensitivity']
data2 = highpass(data2, f0, df=df)

#
# Just correct sensitivity for the original data
#
tr1.data = highpass(tr1.data/paz1['sensitivity'], f0, df=df)
tr2.data = highpass(tr2.data/paz2['sensitivity'], f0, df=df)

#
# The plotting part
#
# Time series
time = np.arange(0,npts,dtype='f')/df
plt.figure(1)
ax = plt.subplot(211)
plt.plot(time, tr2.data, 'r-', label='GR.FUR STS2')
plt.plot(time, tr1.data, 'k-', label='BW.FURT LE3D')
plt.legend()
plt.ylabel("Original")
plt.subplot(212, sharex=ax)
plt.plot(time, data2, 'r-', label='GR.FUR STS2')
plt.plot(time, data1, 'k-', label='BW.FURT LE3D')
plt.ylabel("Corrected")
plt.xlabel("Time [s]")
plt.legend()
plt.suptitle("China earthquake, 2008-05-12")
#
# Frequency response
plt.figure(2)
h1, f1 = pazToFreqResp(paz1['poles'], paz1['zeros'], paz1['gain'], 1./df, 2**20, freq=True)
h2, f2 = pazToFreqResp(paz2['poles'], paz2['zeros'], paz2['gain'], 1./df, 2**20, freq=True)
plt.loglog(f1,abs(h1), 'k-', label='BW.FURT LE3D')
plt.loglog(f2,abs(h2), 'r-', label='GR.FUR STS2')
plt.axvline(f0, label='Highpass, f0')
plt.xlabel("Frequency [Hz]")
plt.grid()
plt.legend(loc='lower right')

plt.show()
