//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
#  error "CUDA synchronization primitives are only supported for sm_70 and up."
#endif

#ifndef _CUDA_BARRIER
#define _CUDA_BARRIER

#include "atomic"
#include "cstddef"

#include "detail/__config"

#include "detail/__pragma_push"

#include "detail/libcxx/include/barrier"

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA

template<std::size_t _Alignment>
struct aligned_size_t {
    static constexpr std::size_t align = _Alignment;
    std::size_t value;
    __host__ __device__
    explicit aligned_size_t(size_t __s) : value(__s) { }
    __host__ __device__
    operator size_t() const { return value; }
};

template<thread_scope _Sco, class _CompletionF = std::__empty_completion>
class barrier : public std::__barrier_base<_CompletionF, _Sco> {
    template<thread_scope>
    friend class pipeline;

    using std::__barrier_base<_CompletionF, _Sco>::__try_wait;

public:
    barrier() = default;

    barrier(const barrier &) = delete;
    barrier & operator=(const barrier &) = delete;

    _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR
    barrier(std::ptrdiff_t __expected, _CompletionF __completion = _CompletionF())
        : std::__barrier_base<_CompletionF, _Sco>(__expected, __completion) {
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    friend void init(barrier * __b, std::ptrdiff_t __expected) {
        new (__b) barrier(__expected);
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    friend void init(barrier * __b, std::ptrdiff_t __expected, _CompletionF __completion) {
        new (__b) barrier(__expected, __completion);
    }
};

struct __block_scope_barrier_base {};

_LIBCUDACXX_END_NAMESPACE_CUDA

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA_DEVICE

__device__
inline std::uint64_t * barrier_native_handle(barrier<thread_scope_block> & b);

_LIBCUDACXX_END_NAMESPACE_CUDA_DEVICE

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA

template<>
class barrier<thread_scope_block, std::__empty_completion> : public __block_scope_barrier_base {
    using __barrier_base = std::__barrier_base<std::__empty_completion, (int)thread_scope_block>;
    __barrier_base __barrier;

    __device__
    friend inline std::uint64_t * device::_LIBCUDACXX_CUDA_ABI_NAMESPACE::barrier_native_handle(barrier<thread_scope_block> & b);

public:
    using arrival_token = typename __barrier_base::arrival_token;

private:
    struct __poll_tester {
        barrier const* __this;
        arrival_token __phase;

        _LIBCUDACXX_INLINE_VISIBILITY
        __poll_tester(barrier const* __this_, arrival_token&& __phase_)
          : __this(__this_)
          , __phase(_CUDA_VSTD::move(__phase_))
        {}

        inline _LIBCUDACXX_INLINE_VISIBILITY
        bool operator()() const
        {
            return __this->__try_wait(__phase);
        }
    };

    _LIBCUDACXX_INLINE_VISIBILITY
    bool __try_wait(arrival_token __phase) const {
#if __CUDA_ARCH__ >= 800
        if (__isShared(&__barrier)) {
            int __ready = 0;
            asm volatile ("{\n\t"
                    ".reg .pred p;\n\t"
                    "mbarrier.test_wait.shared.b64 p, [%1], %2;\n\t"
                    "selp.b32 %0, 1, 0, p;\n\t"
                    "}"
                : "=r"(__ready)
                : "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__barrier))), "l"(__phase)
                : "memory");
            return __ready;
        }
        else
#endif
        {
            return __barrier.__try_wait(std::move(__phase));
        }
    }

    template<thread_scope>
    friend class pipeline;

public:
    barrier() = default;

    barrier(const barrier &) = delete;
    barrier & operator=(const barrier &) = delete;

    _LIBCUDACXX_INLINE_VISIBILITY
    barrier(std::ptrdiff_t __expected, std::__empty_completion __completion = std::__empty_completion()) {
        static_assert(_LIBCUDACXX_OFFSET_IS_ZERO(barrier<thread_scope_block>, __barrier), "fatal error: bad barrier layout");
        init(this, __expected, __completion);
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    ~barrier() {
#if __CUDA_ARCH__ >= 800
        if (__isShared(&__barrier)) {
            asm volatile ("mbarrier.inval.shared.b64 [%0];"
                :: "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__barrier)))
                : "memory");
        }
#endif
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    friend void init(barrier * __b, std::ptrdiff_t __expected, std::__empty_completion __completion = std::__empty_completion()) {
#if __CUDA_ARCH__ >= 800
        if (__isShared(&__b->__barrier)) {
            asm volatile ("mbarrier.init.shared.b64 [%0], %1;"
                :: "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__b->__barrier))),
                    "r"(static_cast<std::uint32_t>(__expected))
                : "memory");
        }
        else
#endif
        {
            new (&__b->__barrier) __barrier_base(__expected);
        }
    }

    _LIBCUDACXX_NODISCARD_ATTRIBUTE _LIBCUDACXX_INLINE_VISIBILITY
    arrival_token arrive(std::ptrdiff_t __update = 1)
    {
#if __CUDA_ARCH__
        if (__isShared(&__barrier)) {
            arrival_token __token;
#if __CUDA_ARCH__ >= 800
            if (__update > 1) {
                asm volatile ("mbarrier.arrive.noComplete.shared.b64 %0, [%1], %2;"
                    : "=l"(__token)
                    : "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__barrier))),
                        "r"(static_cast<std::uint32_t>(__update - 1))
                    : "memory");
            }
            asm volatile ("mbarrier.arrive.shared.b64 %0, [%1];"
                : "=l"(__token)
                : "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__barrier)))
                : "memory");
#else
            unsigned int __activeA = __match_any_sync(__activemask(), __update);
            unsigned int __activeB = __match_any_sync(__activemask(), reinterpret_cast<std::uintptr_t>(&__barrier));
            unsigned int __active = __activeA & __activeB;
            int __inc = __popc(__active) * __update;

            unsigned __laneid;
            asm volatile ("mov.u32 %0, %laneid;" : "=r"(__laneid));
            int __leader = __ffs(__active) - 1;

            if(__leader == __laneid)
            {
                __token = __barrier.arrive(__inc);
            }
            __token = __shfl_sync(__active, __token, __leader);
#endif
            return __token;
        }
        else
#endif
        {
            return __barrier.arrive(__update);
        }
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    void wait(arrival_token && __phase) const
    {
        _CUDA_VSTD::__libcpp_thread_poll_with_backoff(__poll_tester(this, _CUDA_VSTD::move(__phase)));
    }

    inline _LIBCUDACXX_INLINE_VISIBILITY
    void arrive_and_wait()
    {
        wait(arrive());
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    void arrive_and_drop()
    {
#if __CUDA_ARCH__ >= 800
        if (__isShared(&__barrier)) {
            asm volatile ("mbarrier.arrive_drop.shared.b64 _, [%0];"
                :: "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__barrier)))
                : "memory");
        }
        else
#endif
        {
            __barrier.arrive_and_drop();
        }
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    static constexpr ptrdiff_t max() noexcept
    {
        return (1 << 20) - 1;
    }
};

_LIBCUDACXX_END_NAMESPACE_CUDA

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA_DEVICE

__device__
inline std::uint64_t * barrier_native_handle(barrier<thread_scope_block> & b) {
    return reinterpret_cast<std::uint64_t *>(&b.__barrier);
}

_LIBCUDACXX_END_NAMESPACE_CUDA_DEVICE

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA

template<>
class barrier<thread_scope_thread, std::__empty_completion> : private barrier<thread_scope_block> {
    using __base = barrier<thread_scope_block>;

public:
    using __base::__base;

    _LIBCUDACXX_INLINE_VISIBILITY
    friend void init(barrier * __b, std::ptrdiff_t __expected, std::__empty_completion __completion = std::__empty_completion()) {
        init(static_cast<__base *>(__b), __expected, __completion);
    }

    using __base::arrive;
    using __base::wait;
    using __base::arrive_and_wait;
    using __base::arrive_and_drop;
    using __base::max;
};

template<std::size_t _Alignment>
_LIBCUDACXX_INLINE_VISIBILITY
inline void __strided_memcpy(char * __destination, char const * __source, std::size_t __total_size, std::size_t __rank, std::size_t __stride = 1) {
    if (__stride == 1) {
        memcpy(__destination, __source, __total_size);
    }
    else {
        for (std::size_t __offset = __rank * _Alignment; __offset < __total_size; __offset += __stride * _Alignment) {
            memcpy(__destination + __offset, __source + __offset, _Alignment);
        }
    }
}

#if __CUDA_ARCH__ >= 800
template<std::size_t _Alignment, bool _Large = (_Alignment > 16)>
struct __memcpy_async_impl {
    __device__ static inline bool __copy(char * __destination, char const * __source, std::size_t __total_size, std::size_t __rank, std::size_t __stride) {
        __strided_memcpy<_Alignment>(__destination, __source, __total_size, __rank, __stride);
        return false;
    }
};

template<>
struct __memcpy_async_impl<4, false> {
    __device__ static inline bool __copy(char * __destination, char const * __source, std::size_t __total_size, std::size_t __rank, std::size_t __stride) {
        for (std::size_t __offset = __rank * 4; __offset < __total_size; __offset += __stride * 4) {
            asm volatile ("cp.async.ca.shared.global [%0], [%1], 4;"
                :: "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(__destination + __offset))),
                    "l"(__source + __offset)
                : "memory");
        }
        return true;
    }
};

template<>
struct __memcpy_async_impl<8, false> {
    __device__ static inline bool __copy(char * __destination, char const * __source, std::size_t __total_size, std::size_t __rank, std::size_t __stride) {
        for (std::size_t __offset = __rank * 8; __offset < __total_size; __offset += __stride * 8) {
            asm volatile ("cp.async.ca.shared.global [%0], [%1], 8;"
                :: "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(__destination + __offset))),
                    "l"(__source + __offset)
                : "memory");
        }
        return true;
    }
};

template<>
struct __memcpy_async_impl<16, false> {
    __device__ static inline bool __copy(char * __destination, char const * __source, std::size_t __total_size, std::size_t __rank, std::size_t __stride) {
        for (std::size_t __offset = __rank * 16; __offset < __total_size; __offset += __stride * 16) {
            asm volatile ("cp.async.ca.shared.global [%0], [%1], 16;"
                :: "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(__destination + __offset))),
                    "l"(__source + __offset)
                : "memory");
        }
        return true;
    }
};

template<std::size_t _Alignment>
struct __memcpy_async_impl<_Alignment, true> : public __memcpy_async_impl<16, false> { };
#endif

template<thread_scope _Sco, typename _CompF, bool _Is_mbarrier = (_Sco >= thread_scope_block) && std::is_same<_CompF, std::__empty_completion>::value>
_LIBCUDACXX_INLINE_VISIBILITY
inline void __memcpy_async_synchronize(barrier<_Sco, _CompF> & __barrier, bool __is_async) {
#if __CUDA_ARCH__ >= 800
    if (__is_async) {
        if (_Is_mbarrier && __isShared(&__barrier)) {
            asm volatile ("cp.async.mbarrier.arrive.shared.b64 [%0];"
                :: "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__barrier)))
                : "memory");
        }
        else {
            asm volatile ("cp.async.wait_all;"
                ::: "memory");
        }
    }
#endif
}

template<std::size_t _Native_alignment, typename _Group, typename _Sync>
_LIBCUDACXX_INLINE_VISIBILITY
void inline __memcpy_async(_Group const & __group, char * __destination, char const * __source, std::size_t __size, _Sync & __sync) {
    bool __is_async = false;

#if __CUDA_ARCH__ >= 800
    __is_async = __isShared(__destination) && __isGlobal(__source);

    if (__is_async) {
        if (_Native_alignment < 4) {
            auto __source_address = reinterpret_cast<std::uintptr_t>(__source);
            auto __destination_address = reinterpret_cast<std::uintptr_t>(__destination);

            // Lowest bit set will tell us what the common alignment of the three values is.
            auto _Alignment = __ffs(__source_address | __destination_address | __size);

            switch (_Alignment) {
                default: __is_async = __memcpy_async_impl<16>::__copy(__destination, __source, __size, __group.thread_rank(), __group.size()); break;
                case 4: __is_async = __memcpy_async_impl<8>::__copy(__destination, __source, __size, __group.thread_rank(), __group.size()); break;
                case 3: __is_async = __memcpy_async_impl<4>::__copy(__destination, __source, __size, __group.thread_rank(), __group.size()); break;
                case 2: // fallthrough
                case 1: __is_async = __memcpy_async_impl<1>::__copy(__destination, __source, __size, __group.thread_rank(), __group.size()); break;
            }
        }
        else {
            __is_async = __memcpy_async_impl<_Native_alignment>::__copy(__destination, __source, __size, __group.thread_rank(), __group.size());
        }
    }
    else
#endif
    {
        __strided_memcpy<_Native_alignment>(__destination, __source, __size, __group.thread_rank(), __group.size());
    }

    __memcpy_async_synchronize(__sync, __is_async);
}

struct __single_thread_group {
    _LIBCUDACXX_INLINE_VISIBILITY
    void sync() const {}
    _LIBCUDACXX_INLINE_VISIBILITY
    constexpr std::size_t size() const { return 1; };
    _LIBCUDACXX_INLINE_VISIBILITY
    constexpr std::size_t thread_rank() const { return 0; };
};

template<typename _Group, class _Tp, thread_scope _Sco, typename _CompF>
_LIBCUDACXX_INLINE_VISIBILITY
void memcpy_async(_Group const & __group, _Tp * __destination, _Tp const * __source, std::size_t __size, barrier<_Sco, _CompF> & __barrier) {
    // When compiling with NVCC and GCC 4.8, certain user defined types that _are_ trivially copyable are
    // incorrectly classified as not trivially copyable. Remove this assertion to allow for their usage with
    // memcpy_async when compiling with GCC 4.8.
    // FIXME: remove the #if once GCC 4.8 is no longer supported.
#if !defined(_LIBCUDACXX_COMPILER_GCC) || _GNUC_VER > 408
    static_assert(std::is_trivially_copyable<_Tp>::value, "memcpy_async requires a trivially copyable type");
#endif

    __memcpy_async<alignof(_Tp)>(__group, reinterpret_cast<char *>(__destination), reinterpret_cast<char const *>(__source), __size, __barrier);
}

template<typename _Group, class _Tp, std::size_t _Alignment, thread_scope _Sco, typename _CompF, std::size_t _Larger_alignment = (alignof(_Tp) > _Alignment) ? alignof(_Tp) : _Alignment>
_LIBCUDACXX_INLINE_VISIBILITY
void memcpy_async(_Group const & __group, _Tp * __destination, _Tp const * __source, aligned_size_t<_Alignment> __size, barrier<_Sco, _CompF> & __barrier) {
    // When compiling with NVCC and GCC 4.8, certain user defined types that _are_ trivially copyable are
    // incorrectly classified as not trivially copyable. Remove this assertion to allow for their usage with
    // memcpy_async when compiling with GCC 4.8.
    // FIXME: remove the #if once GCC 4.8 is no longer supported.
#if !defined(_LIBCUDACXX_COMPILER_GCC) || _GNUC_VER > 408
    static_assert(std::is_trivially_copyable<_Tp>::value, "memcpy_async requires a trivially copyable type");
#endif

    __memcpy_async<_Larger_alignment>(__group, reinterpret_cast<char *>(__destination), reinterpret_cast<char const *>(__source), __size, __barrier);
}

template<class _Tp, typename _Size, thread_scope _Sco, typename _CompF>
_LIBCUDACXX_INLINE_VISIBILITY
void memcpy_async(_Tp * __destination, _Tp const * __source, _Size __size, barrier<_Sco, _CompF> & __barrier) {
    memcpy_async(__single_thread_group{}, __destination, __source, __size, __barrier);
}

template<typename _Group, thread_scope _Sco, typename _CompF>
_LIBCUDACXX_INLINE_VISIBILITY
void memcpy_async(_Group const & __group, void * __destination, void const * __source, std::size_t __size, barrier<_Sco, _CompF> & __barrier) {
    __memcpy_async<1>(__group, reinterpret_cast<char *>(__destination), reinterpret_cast<char const *>(__source), __size, __barrier);
}

template<typename _Group, std::size_t _Alignment, thread_scope _Sco, typename _CompF>
_LIBCUDACXX_INLINE_VISIBILITY
void memcpy_async(_Group const & __group, void * __destination, void const * __source, aligned_size_t<_Alignment> __size, barrier<_Sco, _CompF> & __barrier) {
    __memcpy_async<_Alignment>(__group, reinterpret_cast<char *>(__destination), reinterpret_cast<char const *>(__source), __size, __barrier);
}

template<typename _Size, thread_scope _Sco, typename _CompF>
_LIBCUDACXX_INLINE_VISIBILITY
void memcpy_async(void * __destination, void const * __source, _Size __size, barrier<_Sco, _CompF> & __barrier) {
    memcpy_async(__single_thread_group{}, __destination, __source, __size, __barrier);
}

_LIBCUDACXX_END_NAMESPACE_CUDA

#include "detail/__pragma_pop"

#endif //_CUDA_BARRIER
