#!/usr/bin/env python
import glob, sys, os
from numpy import array

def write_indmftpr():
    """
    Usage: init_dmftpr
    
    **case.struct file is required for this script.**

    An interactive script to generate the input file (case.indmftpr) for
    the dmftproj program.
    """
    orbitals = {"s" : [1, 0, 0, 0],
                "p" : [0, 1, 0, 0],
                "d" : [0, 0, 1, 0], 
                "f" : [0, 0, 0, 1]}
    corr_orbitals = {"s" : [2, 0, 0, 0],
                     "p" : [0, 2, 0, 0],
                     "d" : [0, 0, 2, 0], 
                     "f" : [0, 0, 0, 2]}
    dirname = os.getcwd().rpartition('/')[2]
    if os.path.isfile(dirname + ".indmftpr"):
       found = input("Previous {}.indmftpr detected! Continue? (y/n)\n".format(dirname))
       if found == "n": sys.exit(0)
    with open(dirname + ".indmftpr", "w") as out:
        print("Preparing dmftproj input file : {}\n".format(dirname + ".indmftpr"))
        if not os.path.isfile(dirname + ".struct"): print("Could not identify a case.struct file!"); sys.exit(-1);
        struct = open(glob.glob("*.struct")[0], "r").readlines()
        species = [line.split()[0] for line in struct if "NPT" in line]
        num_atoms = len(species)
        print("number of atoms = {} ({})\n".format(str(num_atoms), " ".join(species)))
        out.write(str(num_atoms)+"\n")
        mult = [line.split("=")[1].split()[0] for line in struct if "MULT" in line ]
        out.write(" ".join(mult)+"\n")
        out.write("3\n") 
        for atom in range(num_atoms):
            while True: # input choice of spherical harmonics for atom
                sph_harm=input("What flavor of spherical harmonics do you want to use for ATOM {} ({})? (cubic/complex/fromfile)\n".format(atom+1, species[atom]))
                if sph_harm in ["cubic", "complex", "fromfile"]:
                    out.write(sph_harm+"\n")
                    if sph_harm == "fromfile":
                        filename=input("name of file defining the basis?\n")
                        if len(filename) < 25: # name of file must be less than 25 characters
                            out.write(filename)
                        else:
                            print("{} is too long!".format(filename))
                            rename=input("Rename the file to: \n")
                            if os.path.isfile(filename):
                                os.rename(filename, rename)
                                out.write(rename+"\n")
                            else:
                                print("{} could not be found in the current directory!".format(filename))
                                sys.exit(1)
                    break

                else:
                    print("Did not recognize that input. Try again.")

            corr=input("Do you want to treat ATOM {} ({}) as correlated (y/n)?\n".format(atom+1, species[atom]))
            if corr == "y":
                proj=input("Specify the correlated orbital? (d,f)\n")
                while True:
                    non_corr=input("projectors for non-correlated orbitals? (type h for help)\n")
                    if non_corr == "h":
                        print("indicate orbital projectors using (s, p, d, or f). For multiple, combine them (sp, pd, spd, etc.)")
                    elif len(non_corr) > 0 and proj in non_corr:
                        print("Error: User can not choose orbital {} as both correlated and uncorrelated!".format(proj))
                    else:
                        projectors=array([0, 0, 0, 0])
                        projectors += array(corr_orbitals[proj])
                        if len(non_corr) > 0:
                            for p in non_corr: projectors += array(orbitals[p])
                        out.write(" ".join(list(map(str, projectors)))+"\n")
                        break
                to_write = "0 0 0 0\n"
                if proj == "d":
                    irrep=input("Split this orbital into it's irreps? (t2g/eg/n)\n")
                    if irrep == "t2g":
                        to_write = "0 0 2 0\n01\n"
                    elif irrep == "eg":
                        to_write = "0 0 2 0\n10\n"

                soc=input("Do you want to include soc? (y/n)\n")
                if soc == "y":
                    if proj=="d" and ( irrep == "t2g" or irrep == "eg" ):
                        print("Warning: For SOC, dmftproj will use the entire d-shell. Using entire d-shell!")
                        out.write("0 0 0 0\n")
                        out.write("1\n")
                    else:
                        out.write(to_write)
                        out.write("1\n")
                else:
                    out.write(to_write)
                    out.write("0\n")

            else: # still identify the projectors
                while True:
                    proj=input("Specify the projectors that you would like to include? (type h for help)\n")
                    if proj == "h":
                        print("indicate orbital projectors using (s, p, d, or f). For multiple, combine them (sp, pd, spd, etc.)")
                    else:
                        projectors = array([0, 0, 0, 0], dtype=int)
                        for p in proj: projectors += array(orbitals[p], dtype=int)
                        out.write(" ".join(list(map(str, projectors)))+"\n")
                        out.write("0 0 0 0\n")
                        break
        while True:
            window=input("Specify the projection window around eF (default unit is Ry, specify eV with -X.XX X.XX eV)\n")
            if float(window.split()[0]) < 0 and float(window.split()[1]) > 0:
                try:
                    eV2Ry=1.0/13.60566
                    if window.split()[2] == "ev" or window.split()[2] == "eV" or window.split()[2] == "Ev":
                        out.write("{0:0.2f} {1:0.2f}".format(float(window.split()[0])*eV2Ry, float(window.split()[1])*eV2Ry))
                except:
                    out.write(window)
                break
            else:
                print("The energy window ({}) does not contain the Fermi energy!".format(window))
    print("initialize {} file ok!".format(dirname + ".indmftpr"))


if __name__ == "__main__":
    write_indmftpr()
