// Copyright 2019-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto.  Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.

#include <nvhpc/pstl_config.hpp>

#if _NVHPC_STDPAR_USE_GCC

  // We shouldn't get here unless all three of these conditions are true.
  // Double check them anyway as a sanity check.

  #if __GNUC__ < 9
    #error Using GCC's parallel algorithm implementation requires building against the standard library from GCC 9 or newer
  #endif
  #if !__has_include(<tbb/tbb.h>)
    #error The header file <tbb/tbb.h> was not found.  Usig GCC's parallel algorithm implementation requires that TBB be available.
  #endif
  #if __cplusplus < 201703L
    #error Using GCC's parallel algorithm implementation requires C++17. Use the option -std=c++17 to enable C++17 mode.
  #endif

  // Use GCC's multicore implementation of parallel algorithms.
  #include_next <execution>

#else

// Use NVIDIA's implementation of parallel algorithms.

#ifndef _NVCOMPILER_EXECUTION_HEADER_
#define _NVCOMPILER_EXECUTION_HEADER_

#if __pgnu_vsn < 60000
  #error The header file "<execution>" requires GCC 6 or newer
#elif __cplusplus < 201103L
  #error The header file "<execution>" requires at least C++17 mode
#elif __cplusplus < 201703L && !defined(STDPAR_IGNORE_DEPRECATED_CPP_DIALECT)
  #warning Use of the header "<execution>" in C++11 or C++14 mode is deprecated; C++17 mode will be required in a future release; define STDPAR_IGNORE_DEPRECATED_CPP_DIALECT to suppress this warning
#endif

#include <type_traits>

#if !defined(__cpp_lib_execution)
  #define __cpp_lib_execution 201902L
#endif

namespace nv {
namespace execution {

// For ADL to work correctly, the standard execution policies need to be
// defined in namespace nv::execution.  The standard names for these types are
// type aliases in std::execution.
struct [[nv::__seq_execution_policy]] sequenced_policy { };
struct [[nv::__unseq_execution_policy]] unsequenced_policy { };
struct [[nv::__par_execution_policy]] parallel_policy { };
struct [[nv::__par_unseq_execution_policy]] parallel_unsequenced_policy { };

// Custom execution policies.
struct [[nv::__par_execution_policy]] parallel_openacc_policy { };
struct [[nv::__par_execution_policy]] parallel_required_policy { };
struct [[nv::__par_execution_policy]] parallel_on_cpu_policy { };
struct [[nv::__par_execution_policy]] parallel_on_gpu_policy { };
struct [[nv::__openacc_execution_policy]] parallel_required_openacc_policy { };
struct [[nv::__par_execution_policy]] parallel_required_on_cpu_policy { };
struct [[nv::__par_execution_policy]] parallel_required_on_gpu_policy { };

// Temporary to ease transition to new naming scheme
#if __cplusplus >= 201703L
inline constexpr parallel_required_openacc_policy openacc_par{};
#else
static parallel_required_openacc_policy openacc_par{};
#endif

// Custom execution policy modifiers

// openacc
constexpr auto openacc(sequenced_policy) -> sequenced_policy {
  return {};
}
constexpr auto openacc(unsequenced_policy) -> sequenced_policy = delete;
constexpr auto openacc(parallel_policy) -> parallel_openacc_policy {
  return {};
}
constexpr auto openacc(parallel_unsequenced_policy) -> parallel_openacc_policy {
  return {};
}
constexpr auto openacc(parallel_openacc_policy) -> parallel_openacc_policy {
  return {};
}
constexpr auto openacc(parallel_required_policy)
                                          -> parallel_required_openacc_policy {
  return {};
}
constexpr auto openacc(parallel_on_cpu_policy) -> sequenced_policy = delete;
constexpr auto openacc(parallel_on_gpu_policy) -> parallel_openacc_policy {
  return {};
}
constexpr auto openacc(parallel_required_openacc_policy)
                                          -> parallel_required_openacc_policy {
  return {};
}
constexpr auto openacc(parallel_required_on_cpu_policy) -> sequenced_policy
                                                                      = delete;
constexpr auto openacc(parallel_required_on_gpu_policy)
                                          -> parallel_required_openacc_policy {
  return {};
}

// required
constexpr auto required(sequenced_policy) -> sequenced_policy {
  return {};
}
constexpr auto required(unsequenced_policy) -> sequenced_policy = delete;
constexpr auto required(parallel_policy) -> parallel_required_policy {
  return {};
}
constexpr auto required(parallel_unsequenced_policy)
                                                  -> parallel_required_policy {
  return {};
}
constexpr auto required(parallel_openacc_policy)
                                          -> parallel_required_openacc_policy {
  return {};
}
constexpr auto required(parallel_required_policy) -> parallel_required_policy {
  return {};
}
constexpr auto required(parallel_on_cpu_policy)
                                           -> parallel_required_on_cpu_policy {
  return {};
}
constexpr auto required(parallel_on_gpu_policy)
                                           -> parallel_required_on_gpu_policy {
  return {};
}
constexpr auto required(parallel_required_openacc_policy)
                                          -> parallel_required_openacc_policy {
  return {};
}
constexpr auto required(parallel_required_on_cpu_policy)
                                           -> parallel_required_on_cpu_policy {
  return {};
}
constexpr auto required(parallel_required_on_gpu_policy)
                                           -> parallel_required_on_gpu_policy {
  return {};
}

// on_cpu
constexpr auto on_cpu(sequenced_policy) -> sequenced_policy {
  return {};
}
constexpr auto on_cpu(unsequenced_policy) -> unsequenced_policy {
  return {};
}
constexpr auto on_cpu(parallel_policy) -> parallel_on_cpu_policy {
  return {};
}
constexpr auto on_cpu(parallel_unsequenced_policy) -> parallel_on_cpu_policy {
  return {};
}
constexpr auto on_cpu(parallel_openacc_policy) -> sequenced_policy = delete;
constexpr auto on_cpu(parallel_required_policy)
                                           -> parallel_required_on_cpu_policy {
  return {};
}
constexpr auto on_cpu(parallel_on_cpu_policy) -> parallel_on_cpu_policy {
  return {};
}
constexpr auto on_cpu(parallel_on_gpu_policy) -> sequenced_policy = delete;
constexpr auto on_cpu(parallel_required_openacc_policy) -> sequenced_policy
                                                                      = delete;
constexpr auto on_cpu(parallel_required_on_cpu_policy)
                                           -> parallel_required_on_cpu_policy {
  return {};
}
constexpr auto on_cpu(parallel_required_on_gpu_policy) -> sequenced_policy
                                                                      = delete;

// on_gpu
constexpr auto on_gpu(sequenced_policy) -> sequenced_policy = delete;
constexpr auto on_gpu(unsequenced_policy) -> sequenced_policy = delete;
constexpr auto on_gpu(parallel_policy) -> parallel_on_gpu_policy {
  return {};
}
constexpr auto on_gpu(parallel_unsequenced_policy) -> parallel_on_gpu_policy {
  return {};
}
constexpr auto on_gpu(parallel_openacc_policy) -> parallel_openacc_policy {
  return {};
}
constexpr auto on_gpu(parallel_required_policy)
                                           -> parallel_required_on_gpu_policy {
  return {};
}
constexpr auto on_gpu(parallel_on_cpu_policy) -> sequenced_policy = delete;
constexpr auto on_gpu(parallel_on_gpu_policy) -> parallel_on_gpu_policy {
  return {};
}
constexpr auto on_gpu(parallel_required_openacc_policy)
                                          -> parallel_required_openacc_policy {
  return {};
}
constexpr auto on_gpu(parallel_required_on_cpu_policy) -> sequenced_policy
                                                                      = delete;
constexpr auto on_gpu(parallel_required_on_gpu_policy)
                                           -> parallel_required_on_gpu_policy {
  return {};
}

} // namespace execution
} // namespace nv

namespace std {

namespace execution {

// The four standard execution policies.

using sequenced_policy = ::nv::execution::sequenced_policy;
using unsequenced_policy = ::nv::execution::unsequenced_policy;
using parallel_policy = ::nv::execution::parallel_policy;
using parallel_unsequenced_policy
                                = ::nv::execution::parallel_unsequenced_policy;

#if __cplusplus >= 201703L
inline constexpr sequenced_policy seq{};
inline constexpr unsequenced_policy unseq{};
inline constexpr parallel_policy par{};
inline constexpr parallel_unsequenced_policy par_unseq{};
#else
static sequenced_policy seq{};
static unsequenced_policy unseq{};
static parallel_policy par{};
static parallel_unsequenced_policy par_unseq{};
#endif

} // namespace execution

// The std::is_execution_policy type trait.

template <class _T>
struct is_execution_policy : false_type { };

template<>
struct is_execution_policy<execution::sequenced_policy> : true_type { };
template<>
struct is_execution_policy<execution::unsequenced_policy> : true_type { };
template<>
struct is_execution_policy<execution::parallel_policy> : true_type { };
template<>
struct is_execution_policy<execution::parallel_unsequenced_policy>
    : true_type { };
template<>
struct is_execution_policy<::nv::execution::parallel_openacc_policy>
    : true_type { };
template<>
struct is_execution_policy<::nv::execution::parallel_required_policy>
    : true_type { };
template<>
struct is_execution_policy<::nv::execution::parallel_on_cpu_policy>
    : true_type { };
template<>
struct is_execution_policy<::nv::execution::parallel_on_gpu_policy>
    : true_type { };
template<>
struct is_execution_policy<::nv::execution::parallel_required_openacc_policy>
    : true_type { };
template<>
struct is_execution_policy<::nv::execution::parallel_required_on_cpu_policy>
    : true_type { };
template<>
struct is_execution_policy<::nv::execution::parallel_required_on_gpu_policy>
    : true_type { };

#if __cplusplus >= 201703L
template <class _T>
inline constexpr bool is_execution_policy_v = is_execution_policy<_T>::value;
#endif

} // namespace std

#endif

#if !defined(__NVCOMPILER_PROCESSING_THRUST_INCLUDES)
  #if defined(_NVCOMPILER_ALGORITHM_HEADER_) && \
      !defined(_NVCOMPILER_ALGORITHM_EXECUTION_HEADER_)
    #include <nvhpc/algorithm_execution.hpp>
  #endif
  #if defined(_NVCOMPILER_NUMERIC_HEADER_) && \
      !defined(_NVCOMPILER_NUMERIC_EXECUTION_HEADER_)
    #include <nvhpc/numeric_execution.hpp>
  #endif
  #if defined(_NVCOMPILER_MEMORY_HEADER_) && \
      !defined(_NVCOMPILER_MEMORY_EXECUTION_HEADER_)
    #include <nvhpc/memory_execution.hpp>
  #endif
#endif

#endif
