#!/usr/bin/env python


# Make the corresponding MRtrix3 Python libraries available
import inspect, os, sys
lib_folder = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))), os.pardir, 'lib'))
if not os.path.isdir(lib_folder):
  sys.stderr.write('Unable to locate MRtrix3 Python libraries')
  sys.exit(1)
sys.path.insert(0, lib_folder)


from mrtrix3 import app, file, image, path, run #pylint: disable=redefined-builtin

from mrtrix3.path import allindir

def abspath(*arg):
  return os.path.abspath(os.path.join(*arg))

def relpath(*arg):
  return os.path.relpath(os.path.join(*arg), app.workingDir)



class Input(object):
  def __init__(self, filename, prefix, directory, mask_filename = '', mask_directory = ''):
    self.filename = filename
    self.prefix = prefix
    self.directory = directory
    self.mask_filename = mask_filename
    self.mask_directory = mask_directory

app.init('David Raffelt (david.raffelt@florey.edu.au)',
         'Performs a global DWI intensity normalisation on a group of subjects using the median b=0 white matter value as the reference')
app.cmdline.addDescription('The white matter mask is estimated from a population average FA template then warped back to each subject to perform the intensity normalisation. Note that bias field correction should be performed prior to this step.')
app.cmdline.add_argument('input_dir', help='The input directory containing all DWI images')
app.cmdline.add_argument('mask_dir', help='Input directory containing brain masks, corresponding to one per input image (with the same file name prefix)')
app.cmdline.add_argument('output_dir', help='The output directory containing all of the intensity normalised DWI images')
app.cmdline.add_argument('fa_template', help='The output population specific FA template, which is threshold to estimate a white matter mask')
app.cmdline.add_argument('wm_mask', help='The output white matter mask (in template space), used to estimate the median b=0 white matter value for normalisation')
options = app.cmdline.add_argument_group('Options for the dwiintensitynorm script')
options.add_argument('-fa_threshold', default='0.4', help='The threshold applied to the Fractional Anisotropy group template used to derive an approximate white matter mask (default: 0.4)')
app.parse()

app.args.input_dir = relpath(app.args.input_dir)
inputDir = app.args.input_dir
if not os.path.exists(inputDir):
  app.error('input directory not found')
inFiles = allindir(inputDir, dir_path=False)
if len(inFiles) <= 1:
  app.console('not enough images found in input directory. More than one image is needed to perform a group-wise intensity normalisation')
else:
  app.console('performing global intensity normalisation on ' + str(len(inFiles)) + ' input images')

app.args.mask_dir = relpath(app.args.mask_dir)
maskDir = app.args.mask_dir
if not os.path.exists(maskDir):
  app.error('mask directory not found')
maskFiles = allindir(maskDir, dir_path=False)
if len(maskFiles) != len(inFiles):
  app.error('the number of images in the mask directory does not equal the number of images in the input directory')
maskCommonPostfix = path.commonPostfix(maskFiles)
maskPrefixes = []
for m in maskFiles:
  maskPrefixes.append(m.split(maskCommonPostfix)[0])

commonPostfix = path.commonPostfix(inFiles)
input_list = []
for i in inFiles:
  subj_prefix = i.split(commonPostfix)[0]
  if subj_prefix not in maskPrefixes:
    app.error ('no matching mask image was found for input image ' + i)
  image.check3DNonunity(os.path.join(path.fromUser(inputDir, False), i))
  index = maskPrefixes.index(subj_prefix)
  input_list.append(Input(i, subj_prefix, path.fromUser(inputDir, False), maskFiles[index], path.fromUser(maskDir, False)))

app.checkOutputPath(app.args.fa_template)
app.checkOutputPath(app.args.wm_mask)
app.checkOutputPath(app.args.output_dir)
file.makeDir(app.args.output_dir)

app.makeTempDir()

maskTempDir = os.path.join(app.tempDir, os.path.basename(os.path.normpath(maskDir)))
run.command('cp -R -L ' + maskDir + ' ' + maskTempDir)

app.gotoTempDir()

file.makeDir('fa')
progress = app.progressBar('Computing FA images', len(input_list))
for i in input_list:
  run.command('dwi2tensor ' + abspath(i.directory, i.filename) + ' -mask ' + abspath(i.mask_directory, i.mask_filename) + ' - | tensor2metric - -fa ' + os.path.join('fa', i.prefix + '.mif'))
  progress.increment()
progress.done()

app.console('Generating FA population template')
run.command('population_template fa -mask_dir ' + maskTempDir + ' fa_template.mif -type rigid_affine_nonlinear -rigid_scale 0.25,0.5,0.8,1.0 -affine_scale 0.7,0.8,1.0,1.0 -nl_scale 0.5,0.75,1.0,1.0,1.0 -nl_niter 5,5,5,5,5 -tempdir population_template -linear_no_pause -nocleanup')

app.console('Generating WM mask in template space')
run.command('mrthreshold fa_template.mif -abs ' +  app.args.fa_threshold + ' template_wm_mask.mif')

progress = app.progressBar('Intensity normalising subject images', len(input_list))
file.makeDir('wm_mask_warped')
for i in input_list:
  run.command('mrtransform template_wm_mask.mif -interp nearest -warp_full ' + os.path.join('population_template', 'warps', i.prefix + '.mif') + ' ' + os.path.join('wm_mask_warped', i.prefix + '.mif') + ' -from 2 -template ' + os.path.join('fa', i.prefix + '.mif'))
  run.command('dwinormalise ' + abspath(i.directory, i.filename) + ' ' + os.path.join('wm_mask_warped', i.prefix + '.mif') + ' ' + path.fromUser(os.path.join(app.args.output_dir, i.filename), True) + (' -force' if app.forceOverwrite else ''))
  progress.increment()
progress.done()

app.console('Exporting template images to user locations')
run.command('mrconvert template_wm_mask.mif ' + path.fromUser(app.args.wm_mask, True) + (' -force' if app.forceOverwrite else ''))
run.command('mrconvert fa_template.mif ' + path.fromUser(app.args.fa_template, True) + (' -force' if app.forceOverwrite else ''))

app.complete()
