#! /usr/bin/env python
from dipy.fixes import argparse as arg
import os
import warnings 
import dipy.io.pickles as pkl

import numpy as np
import scipy.io as sio

from nibabel import trackvis
import dipy.segment.quickbundles as qb

parser = arg.ArgumentParser(description='Segment a fiber-group into clusters using the QuickBundles algorithm (Garyfallidis et al. 2012)')

parser.add_argument('in_file', action='store', metavar='File', 
                    help='Trackvis fiber file (.trk) or Vistasoft fiber-group file (.mat)')

parser.add_argument('--dist_thr', action='store', metavar='Float',
                    help='Distance threshold (default: 30.)', default=30.)

parser.add_argument('--pts',  action='store', metavar='Int',
                    help='Points (default: 12)', default=12)
                    
parser.add_argument('--pkl_file', action='store', metavar='File',
		    help='Whether to save a QB object in pickle format. Provide full path to the file name (default: none)')

parser.add_argument('--out_file', action='store', metavar='File', 
                    help='Full path to the output file name (defaults to the same location as the original input file')


params = parser.parse_args()

if __name__ == "__main__":
    def gimme_clusters(streamlines, th, pts):
        bundles = qb.QuickBundles(streamlines,
                                  dist_thr=th,
                                  pts=pts)
        return bundles, bundles.clusters(), bundles.centroids
    
    fname_orig = params.in_file
    split = os.path.splitext(fname_orig)
    
    if params.in_file.endswith('.trk'):

        tracks, hdr = trackvis.read(params.in_file)
        streamlines = [fibers[0] for fibers in tracks]
        bundles, clusters, centroids = gimme_clusters(streamlines,
                                                      np.float(params.dist_thr),
                                                      np.int(params.pts))
        
        for c in clusters:
            new_streamlines = [(s, None, None) for s in
                                np.array(streamlines)[clusters[c]['indices']]]


            if params.out_file:
                this_split = params.out_file.split('.')
                out_fname_clust = this_split[0] + '_%s'%(c + 1) + '.trk'
            else:
                out_fname_clust = split[0] + '_%s'%(c + 1) + split[1]
                

            trackvis.write(out_fname_clust, new_streamlines, hdr)

        if params.out_file:
            this_split = params.out_file.split('.')
            out_fname_cent = this_split[0] + '_centroids' + '.trk'
        else:
            out_fname_cent = split[0] + '_centroids' + split[1]
            
        trackvis.write(out_fname_cent, [(s, None, None) for s in centroids], hdr)
        
    elif params.in_file.endswith('.mat'):
        # Once read this as a mat-object
        fg_mat = sio.loadmat(params.in_file, struct_as_record=False,
                             squeeze_me=True)
        # So that you can painlessly get the fiber-group:
        fg = fg_mat['fg']
        # And once as a non-struct, so that you can use this later to pass back
        # to savemat:
        fg_mat_clust = sio.loadmat(params.in_file, squeeze_me=True)
        fg_mat_cent = sio.loadmat(params.in_file, squeeze_me=True)
        name_orig = fg.name
        fibers = fg.fibers
        streamlines = np.array([np.array(ff).T for ff in fibers])
        bundles, clusters, centroids = gimme_clusters(streamlines,
                                                      np.float(params.dist_thr),
                                                      np.int(params.pts))
        # We will give each cluster a different random RGB color:
        cmap_arr = np.random.rand(len(clusters), 3)
        # We prepare a struct array to assign into with the original dtype:
        dt = [('name', 'O'),
              ('colorRgb', 'O'),
              ('thickness', 'O'),
              ('visible', 'O'),
              ('seeds', 'O'),
              ('seedRadius', 'O'),
              ('seedVoxelOffsets', 'O'),
              ('params', 'O'),
              ('fibers', 'O'),
              ('query_id', 'O')]
        cluster_arr = np.empty(len(clusters), dtype=dt)
        centroid_arr = np.empty(len(clusters), dtype=dt)
        
        for c in clusters:
            ff = streamlines[clusters[c]['indices']]
            new_fibers = np.empty(len(ff), dtype=object)
            for f in range(new_fibers.shape[0]):
                new_fibers[f] = ff[f].T
            
            cluster_arr[c]['fibers'] = new_fibers
            cluster_arr[c]['colorRgb']=\
                np.array(cmap_arr[c]).astype(np.float).reshape(1,3)
            cluster_arr[c]['name'] = str(name_orig) + '_%s'%(c + 1)
            centroid_arr[c]['fibers'] = centroids[c].T
            centroid_arr[c]['colorRgb']=\
                np.array(cmap_arr[c]).astype(np.float).reshape(1,3)

            for d in dt:
                if cluster_arr[c][d[0]] is None:
                    cluster_arr[c][d[0]] = fg_mat_clust['fg'][d[0]]
                if centroid_arr[c][d[0]] is None:
                    centroid_arr[c][d[0]] = fg_mat_cent['fg'][d[0]]

        fg_mat_clust['fg'] = cluster_arr
        fg_mat_cent['fg'] = centroid_arr
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            if params.out_file:
                out_fname_clust = params.out_file + '_cluster.mat'
                out_fname_cent = params.out_file + '_centroids.mat'
            else:
                out_fname_clust = split[0] + '_cluster.mat'
                out_fname_cent = split[0] + '_centroids.mat'

            sio.savemat(out_fname_cent, fg_mat_cent)
            sio.savemat(out_fname_clust, fg_mat_clust)
        
            
    else:
        raise ValueError('Unrecognized file-format')

    if params.pkl_file:
        pkl.save_pickle(params.pkl_file, bundles)
