"""
The most simple FD algorithm for acoustic waves you ever find!
"""

import numpy as np
import matplotlib.pyplot as plt
plt.close('all')
plt.ion() # set interactive mode

nx=200   # grid points in x
nz=200   # grid points in z
nt=1000  # number of time steps
dx=10.   # grid increment in x
dt=.001  # Time step
c0=3000. # velocity (can be an array)
isx=nx/2 # source index x
isz=nz/2 # source index z
ist=100  # shifting of source time function
f0=100.  # dominant frequency of source (Hz)
T=1/f0   # dominant period

# Receiver locations
irx = np.array([60,80,100,120,140])
irz = np.array([5,5,5,5,5])
nr = len(irx)
seis = np.zeros((nr,nt))

# Initialize pressure at different time steps
p = np.zeros((nz,nx))
pold = np.zeros((nz,nx))
pnew = np.zeros((nz,nx))
pxx  = np.zeros((nz,nx))
pzz  = np.zeros((nz,nx))

# Initialize velocity model
c = np.zeros((nz,nx))
c = c + c0
# this is an example for a fault zone
#c[:,nx/2-5:nx/2+5] *= .8

# Source time function Gaussian, nt + 1 as we loose the last one by diff
src = np.empty(nt+1)
for it in xrange(nt):
   src[it] = np.exp(-1.0/T**2*((it-ist)*dt)**2)
# Take the first derivative
src = np.diff(src)/dt
src[nt-1] = 0

# Initialize animated plot
image = plt.imshow(pnew, interpolation='nearest', animated=True, vmin=src.min(),vmax = src.max())
for ir in xrange(nr):
    plt.text(irx[ir],irz[ir],'+')
plt.text(isx,isz,'o')
plt.colorbar()
plt.xlabel('ix')
plt.ylabel('iz')

# Time extrapolation
for it in xrange(nt):
    # calculate partial derivatives, be carefull around the boundaries
    for i in xrange(1, nx-1):
        pzz[:,i] = p[:,i+1] - 2*p[:,i] + p[:,i-1]
    for j in xrange(1, nz-1):
        pxx[j,:] = p[j-1,:] - 2*p[j,:] + p[j+1,:]
    pxx /= dx**2
    pzz /= dx**2
        
    # Time extrapolation
    pnew=2*p-pold+dt**2*c**2*(pxx+pzz);   
    pnew[isz,isx]=pnew[isz,isx]+src[it];
        
    # Plot every 20th iteration
    print it
    if it % 20 == 0:
        image.set_data(pnew)
        plt.draw()

    pold = p
    p = pnew
        
    # Save seismograms
    ir = np.arange(nr)
    seis[ir,it]=p[irz[ir],irx[ir]]



#
# Plot the source time function and the seismograms
#
plt.figure(2)
time = np.arange(nt)*dt
plt.plot(time ,src)
plt.xlabel('Time (s) ')
plt.ylabel('Source amplitude ')

 
plt.figure(3)
ymax = seis.ravel().max()
for ir in xrange(nr):
    plt.plot(time, seis[ir,:] + ymax*ir)
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')
      
plt.figure(3)
ymax = seis.ravel().max()
for ir in xrange(nr):
    plt.plot(time, seis[ir,:] + ymax*ir)
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')

plt.figure(4)
plt.title('Velocity Model')
plt.imshow(c)
plt.xlabel('ix')
plt.ylabel('iz')
plt.colorbar()
      
## Here it would be good to save the data to a file readible by python
