// -*- C++ -*-
//===--------------------------- semaphore --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX_SEMAPHORE
#define _LIBCUDACXX_SEMAPHORE

/*
    semaphore synopsis

namespace std {

template<ptrdiff_t least_max_value = implementation-defined>
class counting_semaphore
{
public:
static constexpr ptrdiff_t max() noexcept;

constexpr explicit counting_semaphore(ptrdiff_t desired);
~counting_semaphore();

counting_semaphore(const counting_semaphore&) = delete;
counting_semaphore& operator=(const counting_semaphore&) = delete;

void release(ptrdiff_t __update = 1);
void acquire();
bool try_acquire() noexcept;
template<class Rep, class Period>
    bool try_acquire_for(const chrono::duration<Rep, Period>& __rel_time);
template<class Clock, class Duration>
    bool try_acquire_until(const chrono::time_point<Clock, Duration>& __abs_time);

private:
ptrdiff_t counter; // exposition only
};

using binary_semaphore = counting_semaphore<1>;

}

*/

#ifndef __cuda_std__
#include <__config>
#include <__threading_support>
#include <atomic>
#include <cassert>
#include <__pragma_push>
#endif

#if defined(_LIBCUDACXX_USE_PRAGMA_GCC_SYSTEM_HEADER)
#pragma GCC system_header
#endif

#ifdef _LIBCUDACXX_HAS_NO_THREADS
# error <semaphore> is not supported on this single threaded system
#endif

#if _LIBCUDACXX_STD_VER < 11
# error <semaphore> is requires C++11 or later
#endif

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template<int _Sco, ptrdiff_t __least_max_value>
class __atomic_semaphore_base
{
    _LIBCUDACXX_INLINE_VISIBILITY
    bool __fetch_sub_if_slow(ptrdiff_t __old)
    {
        while (__old != 0) {
            if (__count.compare_exchange_weak(__old, __old - 1, memory_order_acquire, memory_order_relaxed))
                return true;
        }
        return false;
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    bool __fetch_sub_if()
    {
        ptrdiff_t __old = __count.load(memory_order_acquire);
        if (__old == 0)
            return false;
        if(__count.compare_exchange_weak(__old, __old - 1, memory_order_acquire, memory_order_relaxed))
            return true;
        return __fetch_sub_if_slow(__old); // fail only if not __available
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    void __wait_slow()
    {
        while (1) {
            ptrdiff_t const __old = __count.load(memory_order_acquire);
            if(__old != 0)
                break;
            __count.wait(__old, memory_order_relaxed);
        }
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    bool __acquire_slow_timed(chrono::nanoseconds const& __rel_time)
    {
        return __libcpp_thread_poll_with_backoff([this]() {
            ptrdiff_t const __old = __count.load(memory_order_acquire);
            return __old != 0 && __fetch_sub_if_slow(__old);
        }, __rel_time);
    }
    __atomic_base<ptrdiff_t, _Sco> __count;

public:
    _LIBCUDACXX_INLINE_VISIBILITY
    static constexpr ptrdiff_t max() noexcept
    {
        return numeric_limits<ptrdiff_t>::max();
    }

    _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR
    __atomic_semaphore_base(ptrdiff_t __count) noexcept : __count(__count) { }

    ~__atomic_semaphore_base() = default;

    __atomic_semaphore_base(__atomic_semaphore_base const&) = delete;
    __atomic_semaphore_base& operator=(__atomic_semaphore_base const&) = delete;

    _LIBCUDACXX_INLINE_VISIBILITY
    void release(ptrdiff_t __update = 1)
    {
        __count.fetch_add(__update, memory_order_release);
        if(__update > 1)
            __count.notify_all();
        else
            __count.notify_one();
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    void acquire()
    {
        while (!try_acquire())
            __wait_slow();
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    bool try_acquire() noexcept
    {
        return __fetch_sub_if();
    }

    template <class Clock, class Duration>
    _LIBCUDACXX_INLINE_VISIBILITY
    bool try_acquire_until(chrono::time_point<Clock, Duration> const& __abs_time)
    {
        if (try_acquire())
            return true;
        else
            return __acquire_slow_timed(__abs_time - Clock::now());
    }

    template <class Rep, class Period>
    _LIBCUDACXX_INLINE_VISIBILITY
    bool try_acquire_for(chrono::duration<Rep, Period> const& __rel_time)
    {

        if (try_acquire())
            return true;
        else
            return __acquire_slow_timed(__rel_time);
    }
};

#ifndef _LIBCUDACXX_USE_NATIVE_SEMAPHORES

template<int _Sco>
class __atomic_semaphore_base<_Sco, 1> {

    _LIBCUDACXX_INLINE_VISIBILITY
    bool __acquire_slow_timed(chrono::nanoseconds const& __rel_time)
    {
        return __libcpp_thread_poll_with_backoff([this]() {
            return try_acquire();
        }, __rel_time);
    }
    __atomic_base<int, _Sco> __available;

public:
    _LIBCUDACXX_INLINE_VISIBILITY
    static constexpr ptrdiff_t max() noexcept { return 1; }

    _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR
    __atomic_semaphore_base(ptrdiff_t __available) : __available(__available) { }

    ~__atomic_semaphore_base() = default;

    __atomic_semaphore_base(__atomic_semaphore_base const&) = delete;
    __atomic_semaphore_base& operator=(__atomic_semaphore_base const&) = delete;

    _LIBCUDACXX_INLINE_VISIBILITY
    void release(ptrdiff_t __update = 1)
    {
        assert(__update == 1);
        __available.store(1, memory_order_release);
        __available.notify_one();
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    void acquire()
    {
        while (!try_acquire())
            __available.wait(0, memory_order_relaxed);
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    bool try_acquire() noexcept
    {
        return 1 == __available.exchange(0, memory_order_acquire);
    }

    template <class Clock, class Duration>
    _LIBCUDACXX_INLINE_VISIBILITY
    bool try_acquire_until(chrono::time_point<Clock, Duration> const& __abs_time)
    {
        if (try_acquire())
            return true;
        else
            return __acquire_slow_timed(__abs_time - Clock::now());
    }

    template <class Rep, class Period>
    _LIBCUDACXX_INLINE_VISIBILITY
    bool try_acquire_for(chrono::duration<Rep, Period> const& __rel_time)
    {
        if (try_acquire())
            return true;
        else
            return __acquire_slow_timed(__rel_time);
    }
};

#else

template<int _Sco>
class __sem_semaphore_base {

    _LIBCUDACXX_INLINE_VISIBILITY
    bool __backfill(bool __success)
    {
#ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_BACK_BUFFER
        if(__success) {
            auto const __back_amount = __backbuffer.fetch_sub(2, memory_order_acquire);
            bool const __post_one = __back_amount > 0;
            bool const __post_two = __back_amount > 1;
            auto const __success = (!__post_one || __libcpp_semaphore_post(&__semaphore)) &&
                                 (!__post_two || __libcpp_semaphore_post(&__semaphore));
            assert(__success);
            if(!__post_one || !__post_two)
                __backbuffer.fetch_add(!__post_one ? 2 : 1, memory_order_relaxed);
        }
#endif
        return __success;
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    bool __try_acquire_fast()
    {
#ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_FRONT_BUFFER

        ptrdiff_t __old;
        __libcpp_thread_poll_with_backoff([&]() {
            __old = __frontbuffer.load(memory_order_relaxed);
            return 0 != (__old >> 32);
        }, chrono::microseconds(5));

        // always steal if you can
        while(__old >> 32)
            if(__frontbuffer.compare_exchange_weak(__old, __old - (1ll << 32), memory_order_acquire))
                return true;
        // record we're waiting
        __old = __frontbuffer.fetch_add(1ll, memory_order_release);
        // ALWAYS steal if you can!
        while(__old >> 32)
            if(__frontbuffer.compare_exchange_weak(__old, __old - (1ll << 32), memory_order_acquire))
                break;
        // not going to wait after all
        if(__old >> 32)
            return __try_done(true);
#endif
        // the wait has begun...
        return false;
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    bool __try_done(bool __success)
    {
#ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_FRONT_BUFFER
        // record we're NOT waiting
        __frontbuffer.fetch_sub(1ll, memory_order_release);
#endif
        return __backfill(__success);
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    void __release_slow(ptrdiff_t __post_amount)
    {
    #ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_BACK_BUFFER
        bool const __post_one = __post_amount > 0;
        bool const __post_two = __post_amount > 1;
        if(__post_amount > 2)
            __backbuffer.fetch_add(__post_amount - 2, memory_order_acq_rel);
        auto const __success = (!__post_one || __libcpp_semaphore_post(&__semaphore)) &&
                             (!__post_two || __libcpp_semaphore_post(&__semaphore));
        assert(__success);
    #else
        for(; __post_amount; --__post_amount) {
            auto const __success = __libcpp_semaphore_post(&__semaphore);
            assert(__success);
        }
    #endif
    }

    __libcpp_semaphore_t __semaphore;
#ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_FRONT_BUFFER
    __atomic_base<ptrdiff_t, _Sco> __frontbuffer;
#endif
#ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_BACK_BUFFER
    __atomic_base<ptrdiff_t, _Sco> __backbuffer;
#endif

public:
    static constexpr ptrdiff_t max() noexcept {
        return _LIBCUDACXX_SEMAPHORE_MAX;
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    __sem_semaphore_base(ptrdiff_t __count = 0) : __semaphore()
#ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_FRONT_BUFFER
    , __frontbuffer(__count << 32)
#endif
#ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_BACK_BUFFER
    , __backbuffer(0)
#endif
    {
        assert(__count <= max());
        auto const __success =
#ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_FRONT_BUFFER
            __libcpp_semaphore_init(&__semaphore, 0);
#else
            __libcpp_semaphore_init(&__semaphore, __count);
#endif
        assert(__success);
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    ~__sem_semaphore_base() {
#ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_FRONT_BUFFER
        assert(0 == (__frontbuffer.load(memory_order_relaxed) & ~0u));
#endif
        auto const __success = __libcpp_semaphore_destroy(&__semaphore);
        assert(__success);
    }

    __sem_semaphore_base(const __sem_semaphore_base&) = delete;
    __sem_semaphore_base& operator=(const __sem_semaphore_base&) = delete;

    _LIBCUDACXX_INLINE_VISIBILITY
    void release(ptrdiff_t __update = 1)
    {
#ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_FRONT_BUFFER
        // boldly assume the semaphore is taken but uncontended
        ptrdiff_t __old = 0;
        // try to fast-release as long as it's uncontended
        while(0 == (__old & ~0ul))
            if(__frontbuffer.compare_exchange_weak(__old, __old + (__update << 32), memory_order_acq_rel))
                return;
#endif
        // slow-release it is
        __release_slow(__update);
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    void acquire()
    {
        if(!__try_acquire_fast())
            __try_done(__libcpp_semaphore_wait(&__semaphore));
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    bool try_acquire() noexcept
    {
        return try_acquire_for(chrono::nanoseconds(0));
    }

    template <class Clock, class Duration>
    _LIBCUDACXX_INLINE_VISIBILITY
    bool try_acquire_until(chrono::time_point<Clock, Duration> const& __abs_time)
    {
        auto const current = max(Clock::now(), __abs_time);
        return try_acquire_for(chrono::duration_cast<chrono::nanoseconds>(__abs_time - current));
    }

    template <class Rep, class Period>
    _LIBCUDACXX_INLINE_VISIBILITY
    bool try_acquire_for(chrono::duration<Rep, Period> const& __rel_time)
    {
        return __try_acquire_fast() ||
               __try_done(__libcpp_semaphore_wait_timed(&__semaphore, __rel_time));
    }
};

#endif //_LIBCUDACXX_HAS_NO_SEMAPHORES

template<ptrdiff_t __least_max_value, int _Sco>
using __semaphore_base =
#ifdef _LIBCUDACXX_USE_NATIVE_SEMAPHORES
    typename conditional<__least_max_value <= __sem_semaphore_base<_Sco>::max(),
                        __sem_semaphore_base<_Sco>,
                        __atomic_semaphore_base<_Sco, __least_max_value>>::type
#else
    __atomic_semaphore_base<_Sco, __least_max_value>
#endif
    ;

template<ptrdiff_t __least_max_value = INT_MAX>
class counting_semaphore : public __semaphore_base<__least_max_value, 0>
{
    static_assert(__least_max_value <= __semaphore_base<__least_max_value, 0>::max(), "");
public:
    _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR
    counting_semaphore(ptrdiff_t __count = 0) : __semaphore_base<__least_max_value, 0>(__count) { }
    ~counting_semaphore() = default;

    counting_semaphore(const counting_semaphore&) = delete;
    counting_semaphore& operator=(const counting_semaphore&) = delete;
};

using binary_semaphore = counting_semaphore<1>;

_LIBCUDACXX_END_NAMESPACE_STD

#ifndef __cuda_std__
#include <__pragma_pop>
#endif //__cuda_std__

#endif //_LIBCUDACXX_SEMAPHORE
