#!/home/conda/feedstock_root/build_artifacts/meds_1756643425244/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeh/bin/python
from __future__ import print_function
try:
    xrange
except:
    xrange=range
    raw_input=input

import os
import numpy
import meds

from argparse import ArgumentParser
parser=ArgumentParser(__doc__)

parser.add_argument('filename',help="MEDS file name")

parser.add_argument('--types',help="types to show")

parser.add_argument('--start',type=int,help="start with this object")
parser.add_argument('--end',type=int,help="end with this object")

parser.add_argument('--index-file',
                    help="read indices from this file, not yet implemented")

parser.add_argument('--save',action='store_true',
                    help="save plot files for each object into cwd")
parser.add_argument('--show',action='store_true',
                    help="show each plot, with a prompt to the user")

parser.add_argument('--html-file',
                    help="write out an html file displaying the images")

parser.add_argument('--combine',action='store_true',
                    help="combine the images into a mosiac")

parser.add_argument('--has-coadd',action='store_true',
                    help="first cutout is the coadd")

def view(image, **keys): 
    """
    View the image and return the biggles plot object

    If show=False just return the plot object.

    Values below zero are clipped, so pre-subtract any background as
    necessary.

    parameters
    ----------
    image or r,g,b: ndarray
        The image(s) as a 2-d array or 
    type: string
        Type of plot, 'dens', 'cont', 'dens-cont'.  (cont is short
        for contour, dens for density).  Default is 'dens'.
    nonlinear:
        Non-linear scale for an asinh scaling.  If not sent a linear scale is
        used. See the asinh_scale() function.  For asinh scaling you must scale
        your image so that it does not saturate.
    autoscale:
        For linear scaling, re-scale the image so that the maximum value is
        1.0.  This guarantees all pixels are shown without saturation.  

        For asinh scaling you are responsible for performing such scalings
        Subtract this value from the image and clip values below zero.
    xlabel: string
        Label for the X axis
    ylabel: string
        Label for the Y axis
    show: bool
        Set False to not show the image in an X window.
    levels: int
        Number of levels for contours. Default 8
    ccolor:
        color for contour, default 'white' unless type is dens-cont
        then grey.
    transpose: bool
        Transpose the image.  Default False.
    file: string
        Write the image to the intput file name.  .png will be written
        as a png file, else an eps file.
    dims:
        [width,height] for png output, default 800x800
    """
    import biggles

    # we need to transpose for biggles to display properly
    trans = keys.get('transpose',True)

    im=scale_image(image, **keys)

    if trans:
        im = im.transpose()

    type=keys.get('type','dens')

    plt = biggles.FramedPlot()
    if 'title' in keys:
        plt.title=keys['title']

    x, y, ranges = _extract_data_ranges(im.shape, **keys)

    def_contour_color='black'
    if 'dens' in type:
        d = biggles.Density(im, ranges)
        plt.add(d)
        def_contour_color='grey'

    if 'cont' in type:
        levels=keys.get('levels',8)
        if levels is not None:
            if 'zrange' in keys:
                zrange=keys['zrange']
            else:
                zrange=numpy.percentile(im, [1.0,100.0])

            ccolor = keys.get('ccolor',def_contour_color)
            c = biggles.Contours(im,x=x,y=y,color=ccolor,zrange=zrange)
            c.levels = levels
            plt.add(c)

    # make sure the pixels look square
    plt.aspect_ratio = im.shape[1]/float(im.shape[0])

    if 'xlabel' in keys:
        plt.xlabel=keys['xlabel']
    if 'ylabel' in keys:
        plt.ylabel=keys['ylabel']

    _writefile_maybe(plt, **keys)
    _show_maybe(plt, **keys)

    return plt


def view_profile(image, **keys): 
    """
    View the image as a radial profile vs radius from the center.

    If show=False just return the plot object.

    Values below zero are clipped, so pre-subtract any background as
    necessary.

    parameters
    ----------
    image or r,g,b: ndarray
        The image(s) as a 2-d array or 
    cen: [c1,c2], optional
        Optional center
    show: bool
        Set False to not show the image in an X window.
    file: string
        Write the image to the intput file name.  .png will be written
        as a png file, else an eps file.
    width, height: integers
        Size for output
    **kw:
        keywords for the FramedPlot and for the output image dimensions
        width
    """
    import biggles

    plt=keys.pop('plt',None)
    show=keys.pop('show',True)
    cen=keys.pop('cen',None)

    if cen is None:
        cen=(numpy.array(image.shape)-1.0)/2.0
    else:
        assert len(cen)==2,"cen must have two elements"

    rows,cols=numpy.mgrid[
        0:image.shape[0],
        0:image.shape[1],
    ]

    rows = rows.astype('f8')-cen[0]
    cols = cols.astype('f8')-cen[1]

    r = numpy.sqrt(rows**2 + cols**2).ravel()
    s = r.argsort()
    r=r[s]
    imravel = image.ravel()[s]

    #pts = biggles.Points(r, imravel, size=size, type=type, **keys)
    keys['visible']=False
    plt = biggles.plot(r, imravel, plt=plt, **keys)

    _writefile_maybe(plt, **keys)
    _show_maybe(plt, show=show,**keys)

    return plt



def _writefile_maybe(plt, **keys):
    if 'file' in keys and keys['file'] is not None:
        file=os.path.expandvars(keys['file'])
        file=os.path.expanduser(file)
        if '.png' in file:
            dims=keys.get('dims',[800,800])
            plt.write_img(dims[0],dims[1],file)
        else:
            plt.write_eps(file)

def _show_maybe(plt, **keys):
    show=keys.get('show',None)
    if show is None:
        if 'file' in keys and keys['file'] is not None:
            # don't show anything if show not explicitly
            # sent and we are writing a file
            show=False
        else:
            show=True
    else:
        show=keys['show']

    if show:
        dims=keys.get('dims',None)
        if dims is None:
            width=keys.get('width',None)
            height=keys.get('height',None)
            if width is None:
                dims=[800,800]
            else:
                dims=[width,height]

        plt.show(width=dims[0], height=dims[1])


def _extract_data_ranges(imshape, **keys):
    if 'xdr' in keys and 'ydr' in keys:
        xdr=keys['xdr']
        ydr=keys['ydr']
        ranges = ((xdr[0], ydr[0]), (xdr[1], ydr[1]))

        x=numpy.linspace(xdr[0],xdr[1],imshape[0])
        y=numpy.linspace(ydr[0],ydr[1],imshape[1])
    else:
        # this is a difference from Contours which can be alarming
        ranges = ((-0.5, -0.5), (imshape[0]-0.5, imshape[1]-0.5))
        x=None
        y=None

    return x, y, ranges
 
def make_combined_mosaic(imlist):
    """
    only works if all images are same size

    Also should be "sky subtracted" for best
    effect when the grid is not fully packed
    """
    nimage=len(imlist)

    grid=Grid(nimage)
    shape=imlist[0].shape

    imtot=numpy.zeros( (grid.nrow*shape[0], grid.ncol*shape[1]) )

    for i in xrange(nimage):
        im=imlist[i]

        row,col = grid(i)

        rstart = row*shape[0]
        rend   = (row+1)*shape[0]

        cstart = col*shape[1]
        cend   = (col+1)*shape[1]

        imtot[rstart:rend, cstart:cend] = im

    imtot=numpy.flipud(imtot)
    return imtot


def view_mosaic(imlist, titles=None, combine=False, **keys):
    import biggles

    title=keys.pop('title',None)

    if combine:
        imtot=make_combined_mosaic(imlist)
        return view(imtot, **keys)

    nimage=len(imlist)

    grid = Grid(nimage)

    tab=biggles.Table(grid.nrow,grid.ncol)
    tab.title=title

    tkeys={}
    tkeys.update(keys)
    tkeys['show']=False

    for i in xrange(nimage):
        im=imlist[i]

        row,col = grid(i)

        if titles is not None:
            title=titles[i]
        else:
            title=None

        implt=view(im, title=title, **tkeys)
        tab[row,col] = implt

    # aspect is ysize/xsize
    aspect=keys.get('aspect',None)
    if aspect is None:
        aspect = float(grid.nrow)/grid.ncol

    tab.aspect_ratio=aspect

    _writefile_maybe(tab, **keys)
    _show_maybe(tab, **keys)

    return tab

class Grid(object):
    """
    represent plots in a grid.  The grid is chosen
    based on the number of plots
    
    example
    -------
    grid=Grid(n)

    for i in xrange(n):
        row,col = grid(i)

        # equivalently grid.get_rowcol(i)

        plot_table[row,col] = plot(...)
    """
    def __init__(self, nplot):
        self.set_grid(nplot)
    
    def set_grid(self, nplot):
        """
        set the grid given the number of plots
        """
        from math import sqrt

        self.nplot=nplot

        # first check some special cases
        if nplot==8:
            self.nrow, self.ncol = 2,4
        else:

            sq=int(sqrt(nplot))
            if nplot==sq*sq:
                self.nrow, self.ncol = sq,sq
            elif nplot <= sq*(sq+1):
                self.nrow, self.ncol = sq,sq+1
            else:
                self.nrow, self.ncol = sq+1,sq+1

        self.nplot_tot=self.nrow*self.ncol

    def get_rowcol(self, index):
        """
        get the grid position given the number of plots

        move along columns first

        parameters
        ----------
        index: int
            Index in the grid

        example
        -------
        nplot=7
        grid=Grid(nplot)
        arr=biggles.FramedArray(grid.nrow, grid.ncol)

        for i in xrange(nplot):
            row,col=grid.get_rowcol(nplot, i)
            arr[row,col].add( ... )
        """

        imax=self.nplot_tot-1
        if index > imax:
            raise ValueError("index too large %d > %d" % (index,imax))

        row = index//self.ncol
        col = index % self.ncol

        return row,col

    def __call__(self, index):
        return self.get_rowcol(index)

def get_grid(nplot):
    """
    get nrow,ncol given the number of plots

    parameters
    ----------
    nplot: int
        Number of plots in the grid
    """
    from math import sqrt
    sq=int(sqrt(nplot))
    if nplot==sq*sq:
        return (sq,sq)
    elif nplot <= sq*(sq+1):
        return (sq,sq+1)
    else:
        return (sq+1,sq+1)

def get_grid_rowcol(nplot, index):
    """
    get the grid position given the number of plots

    move along columns first

    example
    -------
    nplot=7
    nrow, ncol = get_grid(nplot)
    arr=biggles.FramedArray(nrow, ncol)

    for i in xrange(nplot):
        row,col=get_grid_rowcol(nplot, i)
        arr[row,col].add( ... )
    """

    nrow, ncol = get_grid(nplot)
    row = index/ncol
    col = index % ncol

    return row,col


def bytescale(im):
    """ 
    The input should be between [0,1]

    output is [0,255] in a unsigned byte array
    """
    imout = (im*255).astype('u1')
    return imout

def write_image(filename, image, **keys):
    """
    Write an image to an image file.

    the image must be bytescaled between [0,255] and by of type 'u1'.  See
    the scale_image() and bytescale() functions.

    The file type, compression scheme, etc are determined by the file name.
    Extra keywords such as quality for jpegs are passed along.
    """
    from PIL import Image

    pim=Image.fromarray(image)

    fname=os.path.expandvars(filename)
    fname=os.path.expanduser(fname)
    pim.save(fname, **keys)


def multiview(image, **keys):
    """
    View the image and also some cross-sections through it.  Good for
    postage stamp type images
    """

    import biggles

    cen = keys.get('cen', None)
    if cen is None:
        # use the middle as center
        cen = [
            int(round( (image.shape[0]-1)/2. )),
            int(round( (image.shape[1]-1)/2. )),
        ]

    keys2 = copy.copy(keys)
    keys2['show'] = False
    keys2['file'] = None

    imp = view(image, **keys2)

    # cross-section across rows
    imrows = image[:, cen[1]]
    imcols = image[cen[0], :]
    
    # in case xdr, ydr were sent
    x, y, ranges = _extract_data_ranges(image.shape, **keys)
    if x is not None:
        x0 = ranges[0][0]
        y0 = ranges[0][1]
        xbinsize=x[1]-x[0]
        ybinsize=y[1]-y[0]
    else:
        x0=0
        y0=0
        xbinsize=1
        ybinsize=1


    crossplt = biggles.FramedPlot()
    hrows = biggles.Histogram(imrows, x0=y0, binsize=ybinsize, color='blue')
    hrows.label = 'Center rows'
    hcols = biggles.Histogram(imcols, x0=x0, binsize=xbinsize, color='red')
    hcols.label = 'Center columns'

    key = biggles.PlotKey(0.1, 0.9, [hrows, hcols])

    crossplt.add(hrows, hcols, key)
    crossplt.aspect_ratio=1
    yr = crossplt._limits1().yrange()
    yrange = (yr[0], yr[1]*1.2)
    crossplt.yrange = yrange


    tab = biggles.Table( 1, 2 )

    tab[0,0] = imp
    tab[0,1] = crossplt


    _writefile_maybe(tab, **keys)
    _show_maybe(tab, **keys)
    return tab

def compare_images(im1, im2, **keys):
    import biggles

    show=keys.get('show',True)
    skysig=keys.get('skysig',None)
    dof=keys.get('dof',None)
    cross_sections=keys.get('cross_sections',True)
    ymin=keys.get('min',None)
    ymax=keys.get('max',None)

    color1=keys.get('color1','blue')
    color2=keys.get('color2','orange')
    colordiff=keys.get('colordiff','red')

    nrow=2
    if cross_sections:
        ncol=3
    else:
        ncol=2

    label1=keys.get('label1','im1')
    label2=keys.get('label2','im2')

    cen=keys.get('cen',None)
    if cen is None:
        cen = [(im1.shape[0]-1)/2., (im1.shape[1]-1)/2.]

    labelres='%s-%s' % (label2,label1)

    biggles.configure( 'default', 'fontsize_min', 1.)

    if im1.shape != im2.shape:
        raise ValueError("images must be the same shape")


    resid = im2-im1

    # will only be used if type is contour
    tab=biggles.Table(nrow,ncol)
    if 'title' in keys:
        tab.title=keys['title']

    tkeys=copy.deepcopy(keys)
    tkeys['show']=False
    tkeys['file']=None
    im1plt=view(im1, **tkeys)
    im2plt=view(im2, **tkeys)

    tkeys['nonlinear']=None
    # this has no effect
    tkeys['min'] = resid.min()
    tkeys['max'] = resid.max()
    residplt=view(resid, **tkeys)

    if skysig is not None:
        if dof is None:
            dof=im1.size
        chi2per = (resid**2).sum()/skysig**2/dof
        lab = biggles.PlotLabel(0.1,0.1,
                                r'$\chi^2/dof$: %0.2f' % chi2per,
                                color='red',
                                halign='left')
    else:
        if dof is None:
            dof=im1.size
        chi2per = (resid**2).sum()/dof
        lab = biggles.PlotLabel(0.1,0.1,
                                r'$\chi^2/npix$: %.3e' % chi2per,
                                color='red',
                                halign='left')
    residplt.add(lab)

    im1plt.title=label1
    im2plt.title=label2
    residplt.title=labelres


    # cross-sections
    if cross_sections:
        cen0=int(cen[0])
        cen1=int(cen[1])
        im1rows = im1[:,cen1]
        im1cols = im1[cen0,:]
        im2rows = im2[:,cen1]
        im2cols = im2[cen0,:]
        resrows = resid[:,cen1]
        rescols = resid[cen0,:]

        him1rows = biggles.Histogram(im1rows, color=color1)
        him1cols = biggles.Histogram(im1cols, color=color1)
        him2rows = biggles.Histogram(im2rows, color=color2)
        him2cols = biggles.Histogram(im2cols, color=color2)
        hresrows = biggles.Histogram(resrows, color=colordiff)
        hrescols = biggles.Histogram(rescols, color=colordiff)

        him1rows.label = label1
        him2rows.label = label2
        hresrows.label = labelres
        key = biggles.PlotKey(0.1,0.9,[him1rows,him2rows,hresrows]) 

        rplt=biggles.FramedPlot()
        rplt.add( him1rows, him2rows, hresrows,key )
        rplt.xlabel = 'Center Rows'

        cplt=biggles.FramedPlot()
        cplt.add( him1cols, him2cols, hrescols )
        cplt.xlabel = 'Center Columns'

        rplt.aspect_ratio=1
        cplt.aspect_ratio=1

        tab[0,0] = im1plt
        tab[0,1] = im2plt
        tab[0,2] = residplt
        tab[1,0] = rplt
        tab[1,1] = cplt
    else:
        tab[0,0] = im1plt
        tab[0,1] = im2plt
        tab[1,0] = residplt

    _writefile_maybe(tab, **keys)
    _show_maybe(tab, **keys)

    return tab

def image_read_text(fname):
    """
    Read the simple text image format:
        nrows ncols
        im00 im01 im02...
        im10 im11 im12...
    """

    with open(fname) as fobj:
        ls=fobj.readline().split()
        nrows=int(ls[0])
        ncols=int(ls[1])


        image=numpy.fromfile(fobj, 
                             sep=' ', 
                             count=(nrows*ncols), 
                             dtype='f8').reshape(nrows,ncols)
    return image


def _get_max_image(im1, im2, im3):
    maximage=im1.copy()

    w=where(im2 > maximage)
    if w[0].size > 1:
        maximage[w] = im2[w]

    w=where(im3 > maximage)
    if w[0].size > 1:
        maximage[w] = im3[w]

    return maximage

def _fix_hard_satur(r, g, b, satval):
    """
    Clip to satval but preserve the color
    """

    # make sure you send scales such that this occurs at 
    # a reasonable place for your images

    maximage=_get_max_image(r,g,b)

    w=where(maximage > satval)
    if w[0].size > 1:
        # this preserves color
        fac=satval/maximage[w]
        r[w] *= fac
        g[w] *= fac
        b[w] *= fac
        maximage[w]=satval

    return maximage

def _fix_rgb_satur(r,g,b,fac):
    """
    Fix the factor so we don't saturate the
    RGB image (> 1)

    maximage is the
    """
    maximage=_get_max_image(r,g,b)

    w=where( (r*fac > 1) | (g*fac > 1) | (b*fac > 1) )
    if w[0].size > 1:
        # this preserves color
        fac[w]=1.0/maximage[w]


def get_color_image(imr, img, imb, **keys):
    """
    Create a color image.

    The idea here is that, after applying the asinh scaling, the color image
    should basically be between [0,1] for all filters.  Any place where a value
    is > 1 the intensity will be scaled back in all but the brightest filter
    but color preserved.

    In other words, you develaop a set of pre-scalings the images so that after
    multiplying by

        asinh(I/nonlinear)/(I/nonlinear)

    the numbers will be mostly between [0,1].  You can send scales using the
    scale= keyword

    It can actually be good to have some color saturation so don't be too
    agressive.  You'll have to play with the numbers for each image.

    Note also the image is clipped at zero.

    TODO:
        Implement a "saturation" level in the raw image values as in
        djs_rgb_make.  Even better, implement an outside function to do this.
    """

    nonlinear=keys.get('nonlinear',1.0)
    scales=keys.get('scales',None)
    satval=keys.get('satval',None)

    r = imr.astype('f4')
    g = img.astype('f4')
    b = imb.astype('f4')

    r.clip(0.,r.max(),r)
    g.clip(0.,g.max(),g)
    b.clip(0.,b.max(),b)

    if scales is not None:
        r *= scales[0]
        g *= scales[1]
        b *= scales[2]

    if satval is not None:
        # note using rescaled images so the satval
        # means the same thing (e.g. in terms of real flux)
        maximage=_fix_hard_satur(r,g,b,satval)


    # average images and divide by the nonlinear factor
    fac=1./nonlinear/3.
    I = fac*(r + g + b)

    # make sure we don't divide by zero
    # due to clipping, average value is zero only if all are zero
    w=where(I <= 0)
    if w[0].size > 0:
        I[w] = 1./3. # value doesn't matter images are zero

    f = arcsinh(I)/I

    # limit to values < 1
    # make sure you send scales such that this occurs at 
    # a reasonable place for your images
    _fix_rgb_satur(r,g,b,f)

    R = r*f
    G = g*f
    B = b*f

    st=R.shape
    colorim=zeros( (st[0], st[1], 3) )

    colorim[:,:,0] = R[:,:]
    colorim[:,:,1] = G[:,:]
    colorim[:,:,2] = B[:,:]

    return colorim





def scale_image(im, **keys):
    nonlinear=keys.get('nonlinear',None)
    if nonlinear is not None:
        return asinh_scale(im, nonlinear)
    else:
        return linear_scale(im, **keys)

def linear_scale(im, **keys):
    autoscale=keys.get('autoscale',True)

    I=im.astype('f4')

    if autoscale:
        maxval=I.max()
        I  *= (1.0/I.max())

    I.clip(0.0, 1.0, I)

    return I

def asinh_scale(im, nonlinear):
    """
    Scale the image using and asinh stretch

        I = image*f
        f=asinh(image/nonlinear)/(image/nonlinear)

    Values greater than 1.0 after this scaling will be shown as white, so you
    are responsible for pre-scaling your image.  Values < 0 are clipped.

    parameters
    ----------
    image:
        The image.
    nonlinear: keyword
        The non-linear scale.
    """
    from numpy import arcsinh
    I=im.astype('f4')

    I *= (1./nonlinear)

    # make sure we don't divide by zero
    w=where(I <= 0)
    if w[0].size > 0:
        I[w] = 1. # value doesn't matter since images is zero

    f = arcsinh(I)/I

    imout = im*f

    imout.clip(0.0, 1.0, imout)

    return imout



def rebin(im, factor, dtype=None):
    """
    Rebin the image so there are fewer pixels.  The pixels are simply
    averaged.
    """
    factor=int(factor)
    s = im.shape
    if ( (s[0] % factor) != 0
            or (s[1] % factor) != 0):
        raise ValueError("shape in each dim (%d,%d) must be "
                   "divisible by factor (%d)" % (s[0],s[1],factor))

    newshape=array(s)/factor
    if dtype is None:
        a=im
    else:
        a=im.astype(dtype)

    return a.reshape(newshape[0],factor,newshape[1],factor,).sum(1).sum(2)/factor/factor

def boost( a, factor):
    """
    Resize an array to larger shape, simply duplicating values.
    """
    from numpy import mgrid
    
    factor=int(factor)
    if factor < 1:
        raise ValueError("boost factor must be >= 1")

    newshape=array(a.shape)*factor

    slices = [ slice(0,old, float(old)/new) for old,new in zip(a.shape,newshape) ]
    coordinates = mgrid[slices]
    indices = coordinates.astype('i')   #choose the biggest smaller integer index
    return a[tuple(indices)]

def expand(image, new_dims, padval=0, verbose=False):
    """

    Expand an image to the specified size.  The extra pixels are set to the
    padval value.  Note this does not change the pixel scale, just adds new
    pixels.  See boost to change the pixel scale.
    
    If the new dims are all less than the existing image, the original image is
    returned.

    """

    sz = image.shape
    if new_dims[0] > sz[0] or new_dims[1] > sz[1]:
        srow = max(sz[0], new_dims[0])
        scol = max(sz[1], new_dims[1])

        if verbose:
            print("  expanding image from",sz,"to:",[srow,scol])

        new_image = numpy.empty( (srow, scol), dtype=image.dtype )
        new_image[:] = padval

        new_image[0:sz[0], 0:sz[1]] = image[0:sz[0], 0:sz[1]]
        
        return new_image
    else:
        return image




class Viewer(dict):
    def __init__(self, filename, combine=False, cutout_types=None, has_coadd=False):
        import biggles
        biggles.configure( 'default', 'fontsize_min', 1.2)
        self.has_coadd=has_coadd
        self.filename=filename
        self.combine=combine
        self.m=meds.MEDS(filename)

        if cutout_types is None:
            cutout_types=['image']

        self.cutout_types=cutout_types

        bname=os.path.basename(self.filename)
        bname=bname.replace('.fits.fz','')
        bname=bname.replace('.fits','')

        self.bname=bname

    def go(self,
           save=False,
           show=True,
           start=None,
           end=None,
           indices=None,
           html_file=None,
           **kw):
        """
        Can send start,end or a list of indices
        """

        if html_file is not None:
            dohtml=True
            save=True
        else:
            dohtml=False

        if save == False and show==False:
            raise RuntimeError("you should either save or show or both!")

        if dohtml:
            html_fobj = self._start_html(html_file)

        indices = self._get_indices(start=start, end=end, indices=indices)

        ntot=indices.size
        print("viewing",ntot,"objects")

        for i,index in enumerate(indices):

            id = self.m['id'][index]
            print("index: %d id: %d" % (index,id))
            print("%d/%d %.1f%%" % (i+1,ntot,100*(i+1)/float(ntot)))

            if self.m['ncutout'][index] == 0:
                print("    No cutouts found")
                continue

            if dohtml:
                self._html_begin_line(html_fobj)

            for cutout_type in self.cutout_types:
                print("    ",cutout_type)

                #title='%s-%06d %s' % (self.bname,index,cutout_type)
                title='%06d %d %s' % (index,id,cutout_type)

                if save:
                    fname=self._get_png_filename(index,cutout_type)
                    print("    writing:",fname)
                else:
                    fname=None

                plt=self.plot_one(
                    index,
                    cutout_type,
                    show=show,
                    file=fname,
                    dims=(800,800),
                    title=title,
                    **kw
                )

                if dohtml:
                    self._html_add_item(html_fobj, fname)

            if dohtml:
                self._html_end_line(html_fobj)


            if show:
                if raw_input('hit a key (q to quit): ').lower()=='q':
                    return

        if dohtml:
            self._end_html(html_fobj)

    def plot_one(self, index, cutout_type, **kw):
        if cutout_type=='uberseg':
            imlist = self.m.get_uberseg_list(index)
        else:
            imlist = self.m.get_cutout_list(index, type=cutout_type)


        nonlinear=None
        if cutout_type in ['image']:
            nonlinear=0.1
            old_imlist=imlist
            imlist=[]
            wtlist = self.m.get_cutout_list(index, type='weight')
            for im,wt in zip(old_imlist,wtlist):

                wnan = numpy.where(numpy.isnan(im))
                if wnan[0].size > 0:
                    print("converting %d/%d nan" % (wnan[0].size, im.size))
                    im[wnan] = 0.0

                wtmax = wt.max()
                if wtmax > 0.0:
                    std = numpy.sqrt(1.0/wtmax)
                    im += std
                    im *= (1.0/im.max())

                imlist.append(im)

        elif cutout_type in ['psf']:
            nonlinear=0.1
            old_imlist=imlist
            imlist=[]
            for im in old_imlist:
                im = im.clip(min=1.0e-6)

                im = numpy.log10(im)
                im -= im.min()
                im *= (1.0/im.max())
                imlist.append(im)

        elif cutout_type in ['seg','bmask']:
            old_imlist=imlist
            imlist=[]
            for im in old_imlist:
                im += 1
                im = numpy.log10(im)

                imlist.append(im)

        if self.has_coadd:
            titles=[str(i) for i in self.m['file_id'][index,1:]]
            titles = ['coadd'] + titles
        else:
            titles=[str(i) for i in self.m['file_id'][index,:]]

        kw['titles']=titles
        plt=view_mosaic(imlist, combine=self.combine, **kw)
        return plt

    """
    def plot_one(self, index, cutout_type, **kw):
        image = self.m.get_mosaic(index, type=cutout_type)

        plt=view(image,  **kw)
        return plt
    """

    def _get_png_filename(self, index,cutout_type):
        return '%s-%06d-%s.png' % (self.bname,index,cutout_type)

    def _get_indices(self, start=None, end=None, indices=None):

        if indices is not None:
            ind_out = indices
        elif start is not None or end is not None:
            if start is None:
                start=0

            if end is None:
                end=self.m.size
            else:
                end=end+1

            ind_out = numpy.arange(start,end)

        else:
            ind_out = numpy.arange(self.m.size)

        return ind_out

    def _html_add_item(self, fobj, fname):
        fobj.write("<td><img src=\"%s\"></td>" % fname)

    def _html_begin_line(self, fobj):
        fobj.write("        <tr>")

    def _html_end_line(self, fobj):
        fobj.write("</tr>\n")

    def _start_html(self, html_file):
        fobj = open(html_file,'w')
        fobj.write(_html_head)
        return fobj

    def _end_html(self, fobj):
        fobj.write(_html_foot)


_html_head="""
<html>
    <body bgcolor=black>

    <table>
"""
_html_foot="""
    </table>
    </body>
</html>
"""


def load_indices(fname):
    return numpy.fromfile(fname, sep='\n', dtype='i8')
        
def main():
    args=parser.parse_args()

    types=args.types
    if types is not None:
        types = [t for t in types.split(',')]

    viewer=Viewer(
        args.filename,
        combine=args.combine,
        cutout_types=types,
        has_coadd=args.has_coadd,
    )

    ind = None
    if args.index_file is not None:
        ind = load_indices(args.index_file)

    viewer.go(
        save=args.save,
        show=args.show,
        start=args.start,
        end=args.end,
        indices=ind,
        html_file=args.html_file,
    )
    


main()
