diff options
Diffstat (limited to 'ggml-sycl/dpct/helper.hpp')
-rw-r--r-- | ggml-sycl/dpct/helper.hpp | 414 |
1 files changed, 185 insertions, 229 deletions
diff --git a/ggml-sycl/dpct/helper.hpp b/ggml-sycl/dpct/helper.hpp index 017fd6ee..1ff29721 100644 --- a/ggml-sycl/dpct/helper.hpp +++ b/ggml-sycl/dpct/helper.hpp @@ -588,266 +588,222 @@ namespace dpct out = prop; } - /// dpct device extension - class device_ext : public sycl::device - { - typedef std::mutex mutex_type; - - public: - device_ext() : sycl::device(), _ctx(*this) {} - ~device_ext() - { - std::lock_guard<mutex_type> lock(m_mutex); - clear_queues(); - } - device_ext(const sycl::device &base) : sycl::device(base), _ctx(*this) - { - std::lock_guard<mutex_type> lock(m_mutex); - init_queues(); - } - - int is_native_atomic_supported() { return 0; } - int get_major_version() const - { - return dpct::get_major_version(*this); - } - - int get_minor_version() const - { - return dpct::get_minor_version(*this); - } - - int get_max_compute_units() const - { - return get_device_info().get_max_compute_units(); - } - - /// Return the maximum clock frequency of this device in KHz. - int get_max_clock_frequency() const - { - return get_device_info().get_max_clock_frequency(); - } - - int get_integrated() const { return get_device_info().get_integrated(); } - - int get_max_sub_group_size() const - { - return get_device_info().get_max_sub_group_size(); - } - - int get_max_register_size_per_work_group() const - { - return get_device_info().get_max_register_size_per_work_group(); - } - - int get_max_work_group_size() const - { - return get_device_info().get_max_work_group_size(); - } - - int get_mem_base_addr_align() const - { - return get_info<sycl::info::device::mem_base_addr_align>(); - } - - size_t get_global_mem_size() const - { - return get_device_info().get_global_mem_size(); - } - - size_t get_max_mem_alloc_size() const - { - return get_device_info().get_max_mem_alloc_size(); - } - - /// Get the number of bytes of free and total memory on the SYCL device. - /// \param [out] free_memory The number of bytes of free memory on the SYCL device. - /// \param [out] total_memory The number of bytes of total memory on the SYCL device. - void get_memory_info(size_t &free_memory, size_t &total_memory) - { - total_memory = get_device_info().get_global_mem_size(); - const char *warning_info = "get_memory_info: [warning] ext_intel_free_memory is not " - "supported (export/set ZES_ENABLE_SYSMAN=1 to support), " - "use total memory as free memory"; + /// dpct device extension + class device_ext : public sycl::device { + typedef std::mutex mutex_type; + + public: + device_ext() : sycl::device() {} + ~device_ext() { + std::lock_guard<mutex_type> lock(m_mutex); + clear_queues(); + } + device_ext(const sycl::device &base) : sycl::device(base) { + std::lock_guard<mutex_type> lock(m_mutex); + init_queues(); + } + + int is_native_atomic_supported() { return 0; } + int get_major_version() const { return dpct::get_major_version(*this); } + + int get_minor_version() const { return dpct::get_minor_version(*this); } + + int get_max_compute_units() const { + return get_device_info().get_max_compute_units(); + } + + /// Return the maximum clock frequency of this device in KHz. + int get_max_clock_frequency() const { + return get_device_info().get_max_clock_frequency(); + } + + int get_integrated() const { return get_device_info().get_integrated(); } + + int get_max_sub_group_size() const { + return get_device_info().get_max_sub_group_size(); + } + + int get_max_register_size_per_work_group() const { + return get_device_info().get_max_register_size_per_work_group(); + } + + int get_max_work_group_size() const { + return get_device_info().get_max_work_group_size(); + } + + int get_mem_base_addr_align() const { + return get_info<sycl::info::device::mem_base_addr_align>(); + } + + size_t get_global_mem_size() const { + return get_device_info().get_global_mem_size(); + } + + size_t get_max_mem_alloc_size() const { + return get_device_info().get_max_mem_alloc_size(); + } + + /// Get the number of bytes of free and total memory on the SYCL device. + /// \param [out] free_memory The number of bytes of free memory on the + /// SYCL device. \param [out] total_memory The number of bytes of total + /// memory on the SYCL device. + void get_memory_info(size_t &free_memory, size_t &total_memory) { + total_memory = get_device_info().get_global_mem_size(); + const char *warning_info = + "get_memory_info: [warning] ext_intel_free_memory is not " + "supported (export/set ZES_ENABLE_SYSMAN=1 to support), " + "use total memory as free memory"; #if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105) - if (!has(sycl::aspect::ext_intel_free_memory)) - { - std::cerr << warning_info << std::endl; - free_memory = total_memory; - } - else - { - free_memory = get_info<sycl::ext::intel::info::device::free_memory>(); - } + if (!has(sycl::aspect::ext_intel_free_memory)) { + std::cerr << warning_info << std::endl; + free_memory = total_memory; + } else { + free_memory = get_info<sycl::ext::intel::info::device::free_memory>(); + } #else - std::cerr << warning_info << std::endl; - free_memory = total_memory; + std::cerr << warning_info << std::endl; + free_memory = total_memory; #if defined(_MSC_VER) && !defined(__clang__) #pragma message("Querying the number of bytes of free memory is not supported") #else #warning "Querying the number of bytes of free memory is not supported" #endif #endif - } + } - void get_device_info(device_info &out) const - { - dpct::get_device_info(out, *this); - } + void get_device_info(device_info &out) const { + dpct::get_device_info(out, *this); + } - device_info get_device_info() const - { - device_info prop; - dpct::get_device_info(prop, *this); - return prop; - } + device_info get_device_info() const { + device_info prop; + dpct::get_device_info(prop, *this); + return prop; + } - void reset() - { - std::lock_guard<mutex_type> lock(m_mutex); - clear_queues(); - init_queues(); - } + void reset() { + std::lock_guard<mutex_type> lock(m_mutex); + clear_queues(); + init_queues(); + } - sycl::queue &in_order_queue() { return *_q_in_order; } + sycl::queue &in_order_queue() { return _q_in_order; } - sycl::queue &out_of_order_queue() { return *_q_out_of_order; } + sycl::queue &out_of_order_queue() { return _q_out_of_order; } - sycl::queue &default_queue() - { - return in_order_queue(); - } + sycl::queue &default_queue() { return in_order_queue(); } - void queues_wait_and_throw() - { - std::unique_lock<mutex_type> lock(m_mutex); - std::vector<std::shared_ptr<sycl::queue>> current_queues( - _queues); - lock.unlock(); - for (const auto &q : current_queues) - { - q->wait_and_throw(); - } - // Guard the destruct of current_queues to make sure the ref count is safe. - lock.lock(); + void queues_wait_and_throw() { + std::unique_lock<mutex_type> lock(m_mutex); + lock.unlock(); + for (auto &q : _queues) { + q.wait_and_throw(); } + // Guard the destruct of current_queues to make sure the ref count is + // safe. + lock.lock(); + } - sycl::queue *create_queue(bool enable_exception_handler = false) - { - return create_in_order_queue(enable_exception_handler); - } + sycl::queue create_queue(bool enable_exception_handler = false) { + return create_in_order_queue(enable_exception_handler); + } - sycl::queue *create_queue(sycl::context context, sycl::device device, - bool enable_exception_handler = false) { - return create_in_order_queue(context, device, enable_exception_handler); - } + sycl::queue create_queue(sycl::device device, + bool enable_exception_handler = false) { + return create_in_order_queue(device, enable_exception_handler); + } - sycl::queue *create_in_order_queue(bool enable_exception_handler = false) { - std::lock_guard<mutex_type> lock(m_mutex); - return create_queue_impl(enable_exception_handler, - sycl::property::queue::in_order()); - } + sycl::queue create_in_order_queue(bool enable_exception_handler = false) { + std::lock_guard<mutex_type> lock(m_mutex); + return create_queue_impl(enable_exception_handler, + sycl::property::queue::in_order()); + } - sycl::queue *create_in_order_queue(sycl::context context, sycl::device device, + sycl::queue create_in_order_queue(sycl::device device, bool enable_exception_handler = false) { - std::lock_guard<mutex_type> lock(m_mutex); - return create_queue_impl(context, device, enable_exception_handler, - sycl::property::queue::in_order()); - } - - sycl::queue *create_out_of_order_queue(bool enable_exception_handler = false) { - std::lock_guard<mutex_type> lock(m_mutex); - return create_queue_impl(enable_exception_handler); - } - - void destroy_queue(sycl::queue *&queue) - { - std::lock_guard<mutex_type> lock(m_mutex); - _queues.erase(std::remove_if(_queues.begin(), _queues.end(), - [=](const std::shared_ptr<sycl::queue> &q) -> bool - { - return q.get() == queue; - }), - _queues.end()); - queue = nullptr; - } - void set_saved_queue(sycl::queue *q) - { - std::lock_guard<mutex_type> lock(m_mutex); - _saved_queue = q; - } - sycl::queue *get_saved_queue() const - { - std::lock_guard<mutex_type> lock(m_mutex); - return _saved_queue; - } - sycl::context get_context() const { return _ctx; } - - private: - void clear_queues() - { - _queues.clear(); - _q_in_order = _q_out_of_order = _saved_queue = nullptr; - } - - void init_queues() - { - _q_in_order = create_queue_impl(true, sycl::property::queue::in_order()); - _q_out_of_order = create_queue_impl(true); - _saved_queue = &default_queue(); + std::lock_guard<mutex_type> lock(m_mutex); + return create_queue_impl(device, enable_exception_handler, + sycl::property::queue::in_order()); + } + + sycl::queue create_out_of_order_queue( + bool enable_exception_handler = false) { + std::lock_guard<mutex_type> lock(m_mutex); + return create_queue_impl(enable_exception_handler); + } + + void destroy_queue(sycl::queue queue) { + std::lock_guard<mutex_type> lock(m_mutex); + _queues.clear(); + } + void set_saved_queue(sycl::queue q) { + std::lock_guard<mutex_type> lock(m_mutex); + _saved_queue = q; + } + sycl::queue get_saved_queue() const { + std::lock_guard<mutex_type> lock(m_mutex); + return _saved_queue; + } + + private: + void clear_queues() { _queues.clear(); } + + void init_queues() { + _q_in_order = + create_queue_impl(true, sycl::property::queue::in_order()); + _q_out_of_order = create_queue_impl(true); + _saved_queue = default_queue(); + } + + /// Caller should acquire resource \p m_mutex before calling this + /// function. + template <class... Properties> + sycl::queue create_queue_impl(bool enable_exception_handler, + Properties... properties) { + sycl::async_handler eh = {}; + if (enable_exception_handler) { + eh = exception_handler; } - - /// Caller should acquire resource \p m_mutex before calling this function. - template <class... Properties> - sycl::queue *create_queue_impl(bool enable_exception_handler, - Properties... properties) - { - sycl::async_handler eh = {}; - if (enable_exception_handler) - { - eh = exception_handler; - } - _queues.push_back(std::make_shared<sycl::queue>( - _ctx, *this, eh, - sycl::property_list( + auto q = sycl::queue(*this, eh, + sycl::property_list( #ifdef DPCT_PROFILING_ENABLED - sycl::property::queue::enable_profiling(), + sycl::property::queue::enable_profiling(), #endif - properties...))); + properties...)); + _queues.push_back(q); - return _queues.back().get(); - } + return _queues.back(); + } - template <class... Properties> - sycl::queue *create_queue_impl(sycl::context context, sycl::device device, + template <class... Properties> + sycl::queue create_queue_impl(sycl::device device, bool enable_exception_handler, Properties... properties) { - sycl::async_handler eh = {}; - if (enable_exception_handler) { - eh = exception_handler; - } - _queues.push_back(std::make_shared<sycl::queue>( - context, device, eh, - sycl::property_list( - #ifdef DPCT_PROFILING_ENABLED - sycl::property::queue::enable_profiling(), - #endif - properties...))); - - return _queues.back().get(); + sycl::async_handler eh = {}; + if (enable_exception_handler) { + eh = exception_handler; } - - void get_version(int &major, int &minor) const - { - detail::get_version(*this, major, minor); - } - sycl::queue *_q_in_order, *_q_out_of_order; - sycl::queue *_saved_queue; - sycl::context _ctx; - std::vector<std::shared_ptr<sycl::queue>> _queues; - mutable mutex_type m_mutex; + _queues.push_back( + sycl::queue(device, eh, + sycl::property_list( +#ifdef DPCT_PROFILING_ENABLED + sycl::property::queue::enable_profiling(), +#endif + properties...))); + + return _queues.back(); + } + + void get_version(int &major, int &minor) const { + detail::get_version(*this, major, minor); + } + sycl::queue _q_in_order, _q_out_of_order; + sycl::queue _saved_queue; + std::vector<sycl::queue> _queues; + mutable mutex_type m_mutex; }; + /// device manager class dev_mgr { |