summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-sycl/dpct/helper.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-sycl/dpct/helper.hpp')
-rw-r--r--ggml/src/ggml-sycl/dpct/helper.hpp21
1 files changed, 17 insertions, 4 deletions
diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp
index 4aaa76bf..fe4a8f74 100644
--- a/ggml/src/ggml-sycl/dpct/helper.hpp
+++ b/ggml/src/ggml-sycl/dpct/helper.hpp
@@ -874,7 +874,7 @@ namespace dpct
inline std::string get_preferred_gpu_platform_name() {
std::string result;
- std::string filter = "level-zero";
+ std::string filter = "";
char* env = getenv("ONEAPI_DEVICE_SELECTOR");
if (env) {
if (std::strstr(env, "level_zero")) {
@@ -892,11 +892,24 @@ namespace dpct
else {
throw std::runtime_error("invalid device filter: " + std::string(env));
}
+ } else {
+ auto default_device = sycl::device(sycl::default_selector_v);
+ auto default_platform_name = default_device.get_platform().get_info<sycl::info::platform::name>();
+
+ if (std::strstr(default_platform_name.c_str(), "Level-Zero") || default_device.is_cpu()) {
+ filter = "level-zero";
+ }
+ else if (std::strstr(default_platform_name.c_str(), "CUDA")) {
+ filter = "cuda";
+ }
+ else if (std::strstr(default_platform_name.c_str(), "HIP")) {
+ filter = "hip";
+ }
}
- auto plaform_list = sycl::platform::get_platforms();
+ auto platform_list = sycl::platform::get_platforms();
- for (const auto& platform : plaform_list) {
+ for (const auto& platform : platform_list) {
auto devices = platform.get_devices();
auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
return d.is_gpu();
@@ -975,7 +988,7 @@ namespace dpct
if (backend == "opencl:cpu") return 4;
if (backend == "opencl:acc") return 5;
printf("convert_backend_index: can't handle backend=%s\n", backend.c_str());
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
static bool compare_backend(std::string &backend1, std::string &backend2) {
return convert_backend_index(backend1) < convert_backend_index(backend2);