from mpl_toolkits.basemap import Basemap
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
import sys

class Geobasemap(Basemap):
   """
   Goebasemap: A class inherited from Basemap
            Contains all the functions of Basemap
            Which contains the function
            Basemap.drawplates() as well


   """

   def drawcoasts(self,age=None, output_form='line',\
         PolygCol='k'):
      """
      Goebasemap.coasts :
         Plots the coastline polygons given a specific age for the Earth
   
      ===========       ===========================================
      Keyword           Description
      ===========       ===========================================
      age               The age at which the plate polygons should be plotted 
      output_form       gets one of the two forms of 'line' or 'dots'
                           respectively for a line or dotted plates
      """
      if age==None:
         sys.exit('Ouch! geodynamic_map.drawplates needs the age of the plate')
      fid = open('./Coast_200Ma/Coast_'+str(age)+'.00Ma.xy','r')
      flagA=0
      flagB=0
      line = fid.readline()
      lon = []
      lat = []
      lats= []
      lons= []
      while (len(line) != 0):
         if (line[0] != '>'):
            if (flagB==1):
               lons.append(lon)
               lats.append(lat)
               lon = []
               lat = []
               flagB=0
            flagA=1
            temp = line
            temp =temp.split()
            if (len(lon) !=0):
               if (abs(lon[-1] - float(temp[0])) > 150):
                  lons.append(lon)
                  lats.append(lat)
                  lon = []
                  lat = []
            lon.append(float(temp[0]))
            lat.append(float(temp[1]))
         else:
            if (flagA==1):
               flagB=1
         line = fid.readline()
      fid.close()
      for i in range(len(lons)):
         x, y = self(lons[i],lats[i])
         if output_form=='line':
            self.plot(x,y,linewidth=1.5,color=PolygCol)
            if i==0:
               self.plot(x,y,linewidth=1.5,color=PolygCol)
         elif output_form=='dot':
             self.scatter(x,y,2,marker='o',color=PolygCol)
             if i==0:
               self.scatter(x,y,2,marker='o',color=PolygCol)
         else:
            sys.exit('0o0oLALA! Type of output_form is not defined: %s' % output_for)



   def drawplates(self,age=None, output_form='line', recons_type='Plates_200Ma',\
         PolygCol='k'):
      """
      Goebasemap.drawplates :
         Plots the plate polygons given a specific age for the Earth
   
      ===========       ===========================================
      Keyword           Description
      ===========       ===========================================
      age               The age at which the plate polygons should be plotted 
      output_form       gets one of the two forms of 'line' or 'dots'
                           respectively for a line or dotted plates
      """
      if age==None:
         sys.exit('Ouch! geodynamic_map.drawplates needs the age of the plate')
      fid = open('./'+recons_type+'/topology_'+str(age)+'.00Ma.xy','r')
      flagA=0
      flagB=0
      line = fid.readline()
      lon = []
      lat = []
      lats= []
      lons= []
      while (len(line) != 0):
         if (line[0] != '>'):
            if (flagB==1):
               lons.append(lon)
               lats.append(lat)
               lon = []
               lat = []
               flagB=0
            flagA=1
            temp = line
            temp =temp.split()
            if (len(lon) !=0):
               if (abs(lon[-1] - float(temp[0])) > 150):
                  lons.append(lon)
                  lats.append(lat)
                  lon = []
                  lat = []
            lon.append(float(temp[0]))
            lat.append(float(temp[1]))
         else:
            if (flagA==1):
               flagB=1
         line = fid.readline()
      fid.close()
      for i in range(len(lons)):
         x, y = self(lons[i],lats[i])
         if output_form=='line':
            self.plot(x,y,linewidth=1.5,color=PolygCol)
            if i==0:
               self.plot(x,y,linewidth=1.5,color=PolygCol, label=recons_type)
         elif output_form=='dot':
             self.scatter(x,y,2,marker='o',color=PolygCol)
             if i==0:
               self.scatter(x,y,2,marker='o',color=PolygCol,label=recons_type)
         else:
            sys.exit('0o0oLALA! Type of output_form is not defined: %s' % output_for)
#-----------------------------------------------
def geog_plot(fileadd,titletxt='PLOT', proj='moll',minlat=-90, \
   maxlat = +90, minlon=-180,maxlon=+180,nx=360*1.0, ny=180*1.0, \
   MAXV = None, MINV =None, colorlabel='m', show_fg=True, save_fg=False, \
   agenum=0, agenum2=None, plate_plot=False,\
   coast_plot=True, palletcolor='bwr', fileout=None ):

   """
   geog_plot :

      Plots a given gmt.xyz format file on a given projection

==========|========|=================================
Keyword   |Default |      Description
==========|========|=================================
fileadd   | ----   | Name of the file that should be read in
titletxt  |'PLOT'  | Text appearing as the title of the Plot
proj      |'moll'  | Type of projection, Since we have to adjust other parameters to 
          |        |  the projection, up to now two types of 'moll', 'mill' implemented
minlat    |-90     | minimum geographical latitude of the plot
maxlat    |+90     | maximum geographical latitude of the plot 
minlon    |-180    | minimum geographical longitude of the plot
maxlon    |+180    | maximum geographical longitude of the plot
nx        |360     | Number of longitudinal cells
ny        |180     | Number of latitudinal cells
MAXV      |None    | maximum value of the colorbar (None sets it automatically)     
MINV      |None    | minumum value of the colorbar (None sets it automatically)
          |        | Attention: if both of them are None, we set them to:
          |        |          min(max(A),min(|A|))
colorlabel|'m'     | label on the color bar
show_fg   |True    | If the figure should be displayed in the End
save_fg   |False   | If the figure should be saved with the name: "fileadd"
plate_plot|True    | If the plate boundaries should be plotted
coast_plot|True    | If the coastlines should be plotted
agenum    |0       | What is the geological age of the plot(for plate stage, etc.) 
agenum2   |None    | If be other than None, another stage of the plates/coastline will be 
          |        |  plotted for the sake of comparison
palletcolor| 'bw'  | Should be adjusted to give the desired colorbar
"""

   import numpy as np
   import os
   

   fi_read = np.loadtxt(os.path.join(fileadd))
   if fileout==None:
      fileout=fileadd
   lon_lat_dat = np.array(fi_read)

   fig = plt.figure(figsize=(16,12),num=1,dpi=100)
   ax = fig.add_axes([0.05,0.05,0.9,0.9])
   if proj.lower() in ['robin', 'moll']:
       m = Geobasemap(projection=proj.lower(), lat_0=0, lon_0=0, resolution='i')
   elif proj.lower() == 'mill':
       m = Geobasemap(projection='mill', resolution='i')
   else:
       sys.exit('%s is not implemented yet!' % proj)
#------
   tran_lons = np.linspace(minlon, maxlon, nx+1) 
   tran_lats = np.linspace(minlat, maxlat, ny+1)
   mid_lon, mid_lat = np.meshgrid(tran_lons,tran_lats)
   lon, lat = m(mid_lon,mid_lat)
   data = mlab.griddata(lon_lat_dat[:,0],lon_lat_dat[:,1],\
      lon_lat_dat[:,2],tran_lons,tran_lats, interp='linear')
#------

   if (coast_plot==True):
      m.drawcoasts(age=agenum, PolygCol='k')

   if (plate_plot==True and coast_plot==False):
      m.drawplates(age=agenum, PolygCol='k')
      if (agenum2 != None):
         m.drawplates(age=agenum2, PolygCol='g')
#------

   parallels = np.arange(-90., 91, 90.)
   m.drawparallels(parallels)
   meridians = np.arange(-180., 181., 60.)
   m.drawmeridians(meridians)
#------

   if (MAXV==None) and (MINV==None):
      MAXV = np.min([np.max(data),abs(np.min(data))])
      MINV = -1*MAXV

   calvs=np.linspace(MINV,MAXV,63)

   palette = plt.cm.get_cmap(palletcolor)
   im1 = m.contourf(lon, lat, data,calvs,\
           cmap=palette,extend="both"\
           )
   cb = m.colorbar(im1, "bottom", size="2%", pad='2%')
   cb.set_label(colorlabel)
   ax.set_title(titletxt)
   if save_fg:
      plt.savefig(fileout+'.png',facecolor=fig.get_facecolor(), transparent=True)
      print('File Saved '+ fileout+'.png ')
   if show_fg:
      plt.show()
   plt.close(1)
  
