#!/bin/bash
#
# This shell script (nvcc_wrapper) wraps both the host compiler and
# NVCC, if you are building legacy C or C++ code with CUDA enabled.
# The script remedies some differences between the interface of NVCC
# and that of the host compiler, in particular for linking.
# It also means that a legacy code doesn't need separate .cu files;
# it can just use .cpp files.
#
# Default settings: change those according to your machine.  For
# example, you may have have two different wrappers with either icpc
# or g++ as their back-end compiler.  The defaults can be overwritten
# by using the usual arguments (e.g., -arch=sm_30 -ccbin icpc).

default_arch="sm_35"
#default_arch="sm_50"

#
# The default C++ compiler.
#
host_compiler=${NVCC_WRAPPER_DEFAULT_COMPILER:-"g++"}
#host_compiler="icpc"
#host_compiler="/usr/local/gcc/4.8.3/bin/g++"
#host_compiler="/usr/local/gcc/4.9.1/bin/g++"

#
# Internal variables
#

# C++ files
cpp_files=""

# Host compiler arguments
xcompiler_args=""

# Cuda (NVCC) only arguments
cuda_args=""

# Arguments for both NVCC and Host compiler
shared_args=""

# Linker arguments
xlinker_args=""

# Object files passable to NVCC
object_files=""

# Link objects for the host linker only
object_files_xlinker=""

# Shared libraries with version numbers are not handled correctly by NVCC
shared_versioned_libraries_host=""
shared_versioned_libraries=""

# Does the User set the architecture 
arch_set=0

# Does the user overwrite the host compiler
ccbin_set=0

#Error code of compilation
error_code=0

# Do a dry run without actually compiling
dry_run=0

# Skip NVCC compilation and use host compiler directly
host_only=0

# Enable workaround for CUDA 6.5 for pragma ident 
replace_pragma_ident=0

# Mark first host compiler argument
first_xcompiler_arg=1

temp_dir=${TMPDIR:-/tmp}

# Check if we have an optimization argument already
optimization_applied=0

#echo "Arguments: $# $@"

while [ $# -gt 0 ]
do
  case $1 in
  #show the executed command
  --show|--nvcc-wrapper-show)
    dry_run=1
    ;;
  #run host compilation only
  --host-only)
    host_only=1
    ;;
  #replace '#pragma ident' with '#ident' this is needed to compile OpenMPI due to a configure script bug and a non standardized behaviour of pragma with macros
  --replace-pragma-ident)
    replace_pragma_ident=1
    ;;
  #handle source files to be compiled as cuda files
  *.cpp|*.cxx|*.cc|*.C|*.c++|*.cu)
    cpp_files="$cpp_files $1"
    ;;
   # Ensure we only have one optimization flag because NVCC doesn't allow muliple
  -O*)
    if [ $optimization_applied -eq 1 ]; then
       echo "nvcc_wrapper - *warning* you have set multiple optimization flags (-O*), only the first is used because nvcc can only accept a single optimization setting."
    else
       shared_args="$shared_args $1"
       optimization_applied=1
    fi
    ;;
  #Handle shared args (valid for both nvcc and the host compiler)
  -D*|-c|-I*|-L*|-l*|-g|--help|--version|-E|-M|-shared)
    shared_args="$shared_args $1"
    ;;
  #Handle shared args that have an argument
  -o|-MT)
    shared_args="$shared_args $1 $2"
    shift
    ;;
  #Handle known nvcc args
  -gencode*|--dryrun|--verbose|--keep|--keep-dir*|-G|--relocatable-device-code*|-lineinfo|-expt-extended-lambda|--resource-usage|-Xptxas*)
    cuda_args="$cuda_args $1"
    ;;
  #Handle more known nvcc args
  --expt-extended-lambda|--expt-relaxed-constexpr)
    cuda_args="$cuda_args $1"
    ;;
  #Handle known nvcc args that have an argument
  -rdc|-maxrregcount|--default-stream)
    cuda_args="$cuda_args $1 $2"
    shift
    ;;
  #Handle c++11 setting
  --std=c++11|-std=c++11)
    shared_args="$shared_args $1"
    ;;
  #strip of -std=c++98 due to nvcc warnings and Tribits will place both -std=c++11 and -std=c++98
  -std=c++98|--std=c++98)
    ;;
  #strip of pedantic because it produces endless warnings about #LINE added by the preprocessor
  -pedantic|-Wpedantic|-ansi)
    ;;
  #strip of -Woverloaded-virtual to avoid "cc1: warning: command line option ‘-Woverloaded-virtual’ is valid for C++/ObjC++ but not for C"
  -Woverloaded-virtual)
    ;;
  #strip -Xcompiler because we add it
  -Xcompiler)
    if [ $first_xcompiler_arg -eq 1 ]; then
      xcompiler_args="$2"
      first_xcompiler_arg=0
    else
      xcompiler_args="$xcompiler_args,$2"
    fi
    shift
    ;;
  #strip of "-x cu" because we add that
  -x)
    if [[ $2 != "cu" ]]; then
      if [ $first_xcompiler_arg -eq 1 ]; then
        xcompiler_args="-x,$2"
        first_xcompiler_arg=0
      else
        xcompiler_args="$xcompiler_args,-x,$2"
      fi
    fi
    shift
    ;;
  #Handle -ccbin (if its not set we can set it to a default value)
  -ccbin)
    cuda_args="$cuda_args $1 $2"
    ccbin_set=1
    host_compiler=$2
    shift
    ;;
  #Handle -arch argument (if its not set use a default
  -arch*)
    cuda_args="$cuda_args $1"
    arch_set=1
    ;;
  #Handle -Xcudafe argument
  -Xcudafe)
    cuda_args="$cuda_args -Xcudafe $2"
    shift
    ;;
  #Handle args that should be sent to the linker
  -Wl*)
    xlinker_args="$xlinker_args -Xlinker ${1:4:${#1}}"
    host_linker_args="$host_linker_args ${1:4:${#1}}"
    ;;
  #Handle object files: -x cu applies to all input files, so give them to linker, except if only linking
  *.a|*.so|*.o|*.obj)
    object_files="$object_files $1"
    object_files_xlinker="$object_files_xlinker -Xlinker $1"
    ;;
  #Handle object files which always need to use "-Xlinker": -x cu applies to all input files, so give them to linker, except if only linking
  @*|*.dylib)
    object_files="$object_files -Xlinker $1"
    object_files_xlinker="$object_files_xlinker -Xlinker $1"
    ;;
  #Handle shared libraries with *.so.* names which nvcc can't do.
  *.so.*)
    shared_versioned_libraries_host="$shared_versioned_libraries_host $1"
    shared_versioned_libraries="$shared_versioned_libraries -Xlinker $1"
  ;;
  #All other args are sent to the host compiler
  *)
    if [ $first_xcompiler_arg -eq 1 ]; then
      xcompiler_args=$1
      first_xcompiler_arg=0
    else 
      xcompiler_args="$xcompiler_args,$1"
    fi
    ;;
  esac

  shift
done

#Add default host compiler if necessary
if [ $ccbin_set -ne 1 ]; then
  cuda_args="$cuda_args -ccbin $host_compiler"
fi

#Add architecture command
if [ $arch_set -ne 1 ]; then
  cuda_args="$cuda_args -arch=$default_arch"
fi

#Compose compilation command
nvcc_command="nvcc $cuda_args $shared_args $xlinker_args $shared_versioned_libraries"
if [ $first_xcompiler_arg -eq 0 ]; then
  nvcc_command="$nvcc_command -Xcompiler $xcompiler_args"
fi

#Compose host only command
host_command="$host_compiler $shared_args $xcompiler_args $host_linker_args $shared_versioned_libraries_host"

#nvcc does not accept '#pragma ident SOME_MACRO_STRING' but it does accept '#ident SOME_MACRO_STRING'
if [ $replace_pragma_ident -eq 1 ]; then
  cpp_files2=""
  for file in $cpp_files
  do
    var=`grep pragma ${file} | grep ident | grep "#"`
    if [ "${#var}" -gt 0 ]
    then
      sed 's/#[\ \t]*pragma[\ \t]*ident/#ident/g' $file > $temp_dir/nvcc_wrapper_tmp_$file
      cpp_files2="$cpp_files2 $temp_dir/nvcc_wrapper_tmp_$file"
    else
      cpp_files2="$cpp_files2 $file"
    fi
  done
  cpp_files=$cpp_files2
  #echo $cpp_files
fi

if [ "$cpp_files" ]; then
  nvcc_command="$nvcc_command $object_files_xlinker -x cu $cpp_files"
else
  nvcc_command="$nvcc_command $object_files"
fi

if [ "$cpp_files" ]; then
  host_command="$host_command $object_files $cpp_files"
else
  host_command="$host_command $object_files"
fi

#Print command for dryrun
if [ $dry_run -eq 1 ]; then
  if [ $host_only -eq 1 ]; then
    echo $host_command
  else
    echo $nvcc_command
  fi
  exit 0
fi

#Run compilation command
if [ $host_only -eq 1 ]; then
  $host_command
else
  $nvcc_command
fi
error_code=$?

#Report error code
exit $error_code
