#!/Users/ktietz/demo/mc3/conda-bld/neon_1629799682079/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeh/bin/python
# ----------------------------------------------------------------------------
# Copyright 2015-2016 Nervana Systems Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
"""
Main neon run script.

This script will parse a YAML formatted model configuration file, instantiate the model, and run
a fit.

To run this script, first activate the virtualenv and then run
"python <path to script>/neon <yaml config file> [options]"

Some of the commonly used arguments to this script are:

    -b | --backend [ gpu | cpu | mkl ] : backend to use (gpu backend works on Pascal,
                                         Maxwell and Kepler)

    -e | --epochs [ epochs ] : number of epochs to run fit

    -d | --datatype [ f16 | f32 | f64 ] : floating point data precision to use, note 64-bit
                                            float only works with cpu backend

    -l | --log [ LOGFILE ] : log file for messages, use the -v option to control logging level, if
                             this option is not set then log messages will only go to the screen

    -v[vv] | --verbose : verbosity level

Args that start with '--' (eg. --data_dir) can also be set in a config file (~/nervana/neon.cfg or
specified via -c) by using .ini or .yaml-style syntax (eg. data_dir=value). If an arg is specified
in more than one place, command-line values override config file values which override defaults.

For details of all the command line arguments run this script with the --help option.
"""
import sys
import yaml

from neon import logger
from neon.backends import gen_backend
from neon.callbacks.callbacks import Callbacks
import neon.data
from neon.models import Model
from neon.util.argparser import NeonArgparser
from neon.util.yaml_parse import create_objects


def parse_args():
    """
    Parse command line arguments

    Returns:
        tuple: Contains YAML elements, command line arguments, backend name,
               number of epochs and batch size.
    """
    # setup the arg parser
    parser = NeonArgparser(__doc__)
    parser.add_yaml_arg()

    # parse the cmd line args
    args = parser.parse_args(gen_be=False)

    # load yaml
    yaml_str = args.yaml_file.read()
    root_yaml = yaml.safe_load(yaml_str)

    batch_size = root_yaml['batchsize'] if 'batchsize' in root_yaml else args.batch_size

    if any("--backend" in ag or "-b" in ag for ag in sys.argv):
        # command line overrides yaml setting
        be_name = args.backend
    else:
        be_name = root_yaml['backend'] if 'backend' in root_yaml else 'cpu'

    # command line will override epochs in yaml
    # epochs has default in parser so check argv
    if any("--epochs" in ag or "-e" in ag for ag in sys.argv):
        num_epochs = args.epochs
    else:
        num_epochs = root_yaml['epochs'] if 'epochs' in root_yaml else 1

    gen_backend(backend=be_name,
                rng_seed=args.rng_seed,
                device_id=args.device_id,
                batch_size=batch_size,
                datatype=args.datatype,
                stochastic_round=args.rounding)

    return root_yaml, args, be_name, num_epochs, batch_size


def load_data(data_dir=".", backend_obj=None):
    """
    Load the specified dataset.

    Arguments:
        data_dir (str, optional): Local directory in which to check for and save newly downloaded
                                  datasets.  Defaults to current directory
    Returns:
        tuple: Contains data iterator objects for training and test datasets.
    """

    if root_yaml['dataset']['name'].upper() == 'I1K':
        try:
            import os
            import numpy as np
            from neon.data.dataloader_transformers import OneHot, TypeCast, BGRMeanSubtract
            from neon.data.aeon_shim import AeonDataLoader

            train_config = root_yaml['dataset']['train_config']
            test_config = root_yaml['dataset']['test_config']
            train_config['manifest_filename'] = os.path.join(data_dir,
                                                             'i1k-extracted',
                                                             'train-index.csv')

            test_config['manifest_filename'] = os.path.join(data_dir,
                                                            'i1k-extracted',
                                                            'val-index.csv')

            for config in (train_config, test_config):
                config['type'] = 'image,label'
                config['minibatch_size'] = backend_obj.bsz
                config['macrobatch_size'] = backend_obj.bsz * 12
                config['cache_directory'] = os.path.join(data_dir, 'i1k-cache')

            def wrap_dataloader(dl):
                dl = OneHot(dl, index=1, nclasses=1000)
                dl = TypeCast(dl, index=0, dtype=np.float32)
                dl = BGRMeanSubtract(dl, index=0)
                return dl

            train = wrap_dataloader(AeonDataLoader(train_config, backend_obj))
            test = wrap_dataloader(AeonDataLoader(test_config, backend_obj))

        except (OSError, IOError, ValueError) as err:
            logger.error(err)
            sys.exit(0)
    else:
        if 'dataset' not in root_yaml:
            raise ValueError('dataset not specified in configuration file')
        dataset = getattr(neon.data, root_yaml['dataset']['name'].upper())()
        dataiters = dataset.gen_iterators()
        train = dataiters['train']
        test = dataiters['valid']
    return train, test


if __name__ == "__main__":
    """
    Train and test the specified model.
    """
    root_yaml, args, be_name, num_epochs, batch_size = parse_args()
    model, cost, optim = create_objects(root_yaml,
                                        be_type=be_name,
                                        batch_size=batch_size,
                                        rng_seed=args.rng_seed,
                                        device_id=args.device_id,
                                        default_dtype=args.datatype,
                                        stochastic_rounding=args.rounding)

    if args.model_file:
        model.load_params(args.model_file)
    train, test = load_data(data_dir=args.data_dir, backend_obj=model.be)
    # configure callbacks
    callbacks = Callbacks(model, eval_set=test, **args.callback_args)

    if args.verbose > 2:
        from neon.util.display_information import display_platform_information
        from neon.util.display_information import display_cpu_information
        from neon.util.display_information import display_model_params
        display_platform_information()
        display_cpu_information()
        display_model_params(args, root_yaml)

    if (args.profile is True):
        from neon.benchmark import Benchmark

        inference = args.profile_inference
        model = Model(layers=model.layers, optimizer=optim)
        model.initialize(train, cost=cost)
        b = Benchmark(model)
        if (args.profiling_method == 'time'):
            res = b.time(train, inference=inference, niterations=args.profile_iterations)
            b.print_stats(res, nskip=args.profile_iter_skip)
    else:
        model.fit(train, optimizer=optim, num_epochs=num_epochs, cost=cost, callbacks=callbacks)
