#!/usr/bin/env python
#
# Copyright (C) 2009-2021 ABINIT Group (Damien Caliste)
#
# This file is part of the ABINIT software package. For license information,
# please see the COPYING file in the top-level directory of the ABINIT source
# distribution.
#
from __future__ import print_function, division, absolute_import #, unicode_literals

import sys, os, re

py2 = sys.version_info[0] <= 2
if py2:
    import cPickle as pickle
    #import pickle as pickle
else:
    import pickle

parserdir = os.path.join("bindings", "parser")
iovarsdir = os.path.join("src", "57_iovars")

# dtset.pickle has been generated by abilint...
dtset = pickle.load(open(os.path.join(parserdir, "dtset.pickle"), "rb"))

# Create the flags for the C/F90 bindings.
# The variables contain the following description:
#  C  : the ids as enum for each attributes of dtset.
#       this is directly put in ab7_invars.h.
#  Cs : the type of each attribute, put in a separated file,
#       later included in ab7_invars_c.c.
#  Py : the dictionnary of all available attributes,
#       put in a separated file, later included in ab7_invars_py.c.
#  F90: the ids as enum for each attributes of dtset. Put in a
#       separated file, later included in ab7_invars_f90.f90.
comment = "This file has been automatically generated, do not modify."
F90 = "! %s\n" % comment
C   = "/* %s */\n" % comment
C  += "typedef enum\n"
C  += "{\n"
Cs  = "/* %s */\n" % comment
Cs += "#ifndef AB7_INVARS_C_H\n"
Cs += "#define AB7_INVARS_C_H\n"
Cs += "\n"
Py  = "static PyObject* DictIds = NULL;\n"
Py += "static void _init_dict_ids(void)\n"
Py += "{\n"
Py += "  DictIds = PyDict_New();\n"
Py += "\n"
Cd  = ""
Ct  = ""
n = 0
for (key, descr) in dtset.items():
  attType = "_OTHER"
  if (descr[-1].find("(") < 0):
    if (descr[0] == "integer"):
      attType = "_INT_SCALAR"
    elif (descr[0] == "real(dp)"):
      attType = "_DOUBLE_SCALAR"
  else:
    if (descr[0].startswith("integer")):
      attType = "_INT_ARRAY"
    elif (descr[0].startswith("real(dp)")):
      attType = "_DOUBLE_ARRAY"
  if (attType is not None):
    Cd  += "  AB7_INVARS_%-15s,  /* %-15s */\n" % (key.upper(), attType)
    Ct  += "           %-15s,  /* %-15s */\n" % (attType, key)
    F90 += "integer, parameter, public :: ab7_invars_%-15s = %5d\n" % (key, n)  
    Py  += "  PyDict_SetItemString(DictIds, \"%s\", Py_BuildValue(\"ii\", AB7_INVARS_%s, %s));\n" % (key, key.upper(), attType)
    n += 1
C  += Cd
C  += "  AB7_INVARS_N_IDS\n"
C  += "} Ab7InvarsIds;\n"
C  += "\n"
Cs += "static Ab7InvarsTypes ab7_invars_types[] = {\n"
Cs += Ct
Cs += "           %-15s\n" % "_OTHER"
Cs += "};\n"
Cs += "\n"
Cs += "#endif\n"
Py += "}\n"

src = open(os.path.join(parserdir, "ab7_invars.h.tmpl"), "rt").read()
src = re.sub("@ATTRIBUTES@", C, src)
out = open(os.path.join(parserdir, "ab7_invars.h"), "wt")
out.write(src)
out.close()

header = open(os.path.join(parserdir, "ab7_invars_c.h"), "w")
header.write(Cs)
header.close()
header = open(os.path.join(iovarsdir, "ab7_invars_f90.inc"), "w")
header.write(F90)
header.close()
header = open(os.path.join(parserdir, "ab7_invars_py.h"), "w")
header.write(Py)
header.close()

# Create the access routines
int_sca = ""
flt_sca = ""
int_arr = ""
flt_arr = ""
shp_arr = ""
for (key, descr) in dtset.items():
  if (descr[-1].find("(") < 0):
    out  = "    case (ab7_invars_%s)\n" % key
    out += "      value = token%%dtsets(idtset)%%%s\n" % key
    if (descr[0] == "integer"):
      int_sca += out
    elif (descr[0] == "real(dp)"):
      flt_sca += out
  else:
    out  = "    case (ab7_invars_%s)\n" % key
    out += "      n_dt = product(shape(token%%dtsets(idtset)%%%s))\n" % key
    out += "      if (n_dt /= n) then\n"
    out += "        errno = AB7_ERROR_INVARS_SIZE\n"
    out += "      else\n"
    out += "        values = reshape(token%%dtsets(idtset)%%%s, (/ n_dt /))\n" % key
    out += "      end if\n"
    if (descr[0].startswith("integer")):
      int_arr += out
    elif (descr[0].startswith("real(dp)")):
      flt_arr += out
    shp_arr += "    case (ab7_invars_%s)\n" % key
    shp_arr += "      ndim = size(shape(token%%dtsets(idtset)%%%s))\n" % key
    shp_arr += "      dims(1:ndim) = shape(token%%dtsets(idtset)%%%s)\n" % key
    #if ("pointer" in descr[0]):
    #  print("nullify(dtset%%%s)" % key)

subs = open(os.path.join(iovarsdir, "ab7_invars_f90_get.f90"), "w")
subs.write("! %s\n" % comment)
subs.write("""
  subroutine ab7_invars_get_integer(dtsetsId, value, id, idtset, errno)
    integer, intent(in) :: dtsetsId
    integer, intent(in) :: id, idtset
    integer, intent(out) :: value
    integer, intent(out) :: errno

    type(dtsets_list), pointer :: token

    call get_token(token, dtsetsId)
    if (.not. associated(token)) then
      errno = AB7_ERROR_OBJ
      return
    end if
    if (idtset < 0 .or. idtset > size(token%dtsets)) then
      errno = AB7_ERROR_INVARS_ID
      return
    end if
    
    errno = AB7_NO_ERROR
    select case (id)
""")
subs.write(int_sca)
subs.write("""
    case default
      errno = AB7_ERROR_INVARS_ATT
    end select
  end subroutine ab7_invars_get_integer

  subroutine ab7_invars_get_real(dtsetsId, value, id, idtset, errno)
    integer, intent(in) :: dtsetsId
    integer, intent(in) :: id, idtset
    real(dp), intent(out) :: value
    integer, intent(out) :: errno

    type(dtsets_list), pointer :: token

    call get_token(token, dtsetsId)
    if (.not. associated(token)) then
      errno = AB7_ERROR_OBJ
      return
    end if
    if (idtset < 0 .or. idtset > size(token%dtsets)) then
      errno = AB7_ERROR_INVARS_ID
      return
    end if
    
    errno = AB7_NO_ERROR
    select case (id)
""")
subs.write(flt_sca)
subs.write("""
    case default
      errno = AB7_ERROR_INVARS_ATT
    end select
  end subroutine ab7_invars_get_real

  subroutine ab7_invars_get_integer_array(dtsetsId, values, n, id, idtset, errno)
    integer, intent(in) :: dtsetsId
    integer, intent(in) :: id, idtset, n
    integer, intent(out) :: values(n)
    integer, intent(out) :: errno

    type(dtsets_list), pointer :: token
    integer :: n_dt

    call get_token(token, dtsetsId)
    if (.not. associated(token)) then
      errno = AB7_ERROR_OBJ
      return
    end if
    if (idtset < 0 .or. idtset > size(token%dtsets)) then
      errno = AB7_ERROR_INVARS_ID
      return
    end if
    
    errno = AB7_NO_ERROR
    select case (id)
""")
subs.write(int_arr)
subs.write("""
    case default
      errno = AB7_ERROR_INVARS_ATT
    end select
  end subroutine ab7_invars_get_integer_array

  subroutine ab7_invars_get_real_array(dtsetsId, values, n, id, idtset, errno)
    integer, intent(in) :: dtsetsId
    integer, intent(in) :: id, idtset, n
    real(dp), intent(out) :: values(n)
    integer, intent(out) :: errno

    type(dtsets_list), pointer :: token
    integer :: n_dt

    call get_token(token, dtsetsId)
    if (.not. associated(token)) then
      errno = AB7_ERROR_OBJ
      return
    end if
    if (idtset < 0 .or. idtset > size(token%dtsets)) then
      errno = AB7_ERROR_INVARS_ID
      return
    end if
    
    errno = AB7_NO_ERROR
    select case (id)
""")
subs.write(flt_arr)
subs.write("""
    case default
      errno = AB7_ERROR_INVARS_ATT
    end select
  end subroutine ab7_invars_get_real_array

  subroutine ab7_invars_get_shape(dtsetsId, dims, ndim, id, idtset, errno)
    integer, intent(in) :: dtsetsId
    integer, intent(in) :: id, idtset
    integer, intent(out) :: dims(7), ndim
    integer, intent(out) :: errno

    type(dtsets_list), pointer :: token

    call get_token(token, dtsetsId)
    if (.not. associated(token)) then
      errno = AB7_ERROR_OBJ
      return
    end if
    if (idtset < 0 .or. idtset > size(token%dtsets)) then
      errno = AB7_ERROR_INVARS_ID
      return
    end if
    
    errno = AB7_NO_ERROR
    select case (id)
""")
subs.write(shp_arr)
subs.write("""
    case default
      errno = AB7_ERROR_INVARS_ATT
    end select
  end subroutine ab7_invars_get_shape
""")
subs.close()
