From 31efa893e9457c468ceaa648daea09b6de374e53 Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Sat, 8 Nov 2025 21:23:11 +0900 Subject: [PATCH 01/32] Add random orth transformation --- cpp/CMakeLists.txt | 1 + .../linear_transform/random_orthogonal.hpp | 477 ++++++++++++++++++ .../detail/random_orthogonal.cuh | 246 +++++++++ .../linear_transform/random_orthogonal.cu | 61 +++ cpp/tests/CMakeLists.txt | 2 +- .../random_orthogonal_transformation.cu | 219 ++++++++ cpp/tests/test_utils.cuh | 6 +- cpp/tests/test_utils.h | 18 +- 8 files changed, 1020 insertions(+), 10 deletions(-) create mode 100644 cpp/include/cuvs/preprocessing/linear_transform/random_orthogonal.hpp create mode 100644 cpp/src/preprocessing/linear_transform/detail/random_orthogonal.cuh create mode 100644 cpp/src/preprocessing/linear_transform/random_orthogonal.cu create mode 100644 cpp/tests/preprocessing/random_orthogonal_transformation.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index ababb22548..6b36b25bb5 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -513,6 +513,7 @@ if(NOT BUILD_CPU_ONLY) src/neighbors/vamana_serialize_float.cu src/neighbors/vamana_serialize_uint8.cu src/neighbors/vamana_serialize_int8.cu + src/preprocessing/linear_transform/random_orthogonal.cu src/preprocessing/quantize/scalar.cu src/preprocessing/quantize/binary.cu src/preprocessing/spectral/spectral_embedding.cu diff --git a/cpp/include/cuvs/preprocessing/linear_transform/random_orthogonal.hpp b/cpp/include/cuvs/preprocessing/linear_transform/random_orthogonal.hpp new file mode 100644 index 0000000000..fb6e302183 --- /dev/null +++ b/cpp/include/cuvs/preprocessing/linear_transform/random_orthogonal.hpp @@ -0,0 +1,477 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace cuvs::preprocessing::linear_transform::random_orthogonal { + +/** + * @defgroup scalar Scalar transformer utilities + * @{ + */ + +/** + * @brief transformer parameters. + */ +struct params { + /* + * random seed + */ + uint64_t seed = 0; +}; + +/** + * @brief Defines and stores the orthogonal matrix to apply + * + * The transformation is performed by a matrix multiplication + * + * @tparam T data element type + * + */ +template +struct transformer { + raft::device_matrix orthogonal_matrix; +}; + +/** + * @brief Initializes a random orthogonal transoformation to be used later for transforming the + * dataset. + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * cuvs::preprocessing::quantize::scalar::params params; + * auto transformer = cuvs::preprocessing::quantize::scalar::train(handle, params, + * dataset); + * @endcode + * + * @param[in] res raft resource + * @param[in] params configure scalar transformer, e.g. quantile + * @param[in] dataset a row-major matrix view on device + * + * @return transformer + */ +transformer train(raft::resources const& res, + const params params, + raft::device_matrix_view dataset); + +/** + * @brief Initializes a random orthogonal transoformation to be used later for transforming the + * dataset. + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * cuvs::preprocessing::quantize::scalar::params params; + * auto transformer = cuvs::preprocessing::quantize::scalar::train(handle, params, + * dataset); + * @endcode + * + * @param[in] res raft resource + * @param[in] params configure scalar transformer, e.g. quantile + * @param[in] dataset a row-major matrix view on host + * + * @return transformer + */ +transformer train(raft::resources const& res, + const params params, + raft::host_matrix_view dataset); + +/** + * @brief Applies quantization transform to given dataset + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * cuvs::preprocessing::quantize::scalar::params params; + * auto transformer = cuvs::preprocessing::quantize::scalar::train(handle, params, + * dataset); auto quantized_dataset = raft::make_device_matrix(handle, samples, + * features); cuvs::preprocessing::quantize::scalar::transform(handle, transformer, dataset, + * quantized_dataset.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] transformer a scalar transformer + * @param[in] dataset a row-major matrix view on device + * @param[out] out a row-major matrix view on device + * + */ +void transform(raft::resources const& res, + const transformer& transformer, + raft::device_matrix_view dataset, + raft::device_matrix_view out); + +/** + * @brief Applies quantization transform to given dataset + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * cuvs::preprocessing::quantize::scalar::params params; + * auto transformer = cuvs::preprocessing::quantize::scalar::train(handle, params, + * dataset); auto quantized_dataset = raft::make_host_matrix(samples, features); + * cuvs::preprocessing::quantize::scalar::transform(handle, transformer, dataset, + * quantized_dataset.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] transformer a scalar transformer + * @param[in] dataset a row-major matrix view on host + * @param[out] out a row-major matrix view on host + * + */ +void transform(raft::resources const& res, + const transformer& transformer, + raft::host_matrix_view dataset, + raft::host_matrix_view out); + +/** + * @brief Perform inverse quantization step on previously quantized dataset + * + * Note that depending on the chosen data types train dataset the conversion is + * not lossless. + * + * Usage example: + * @code{.cpp} + * auto quantized_dataset = raft::make_device_matrix(handle, samples, features); + * cuvs::preprocessing::quantize::scalar::transform(handle, transformer, dataset, + * quantized_dataset.view()); auto dataset_revert = raft::make_device_matrix(handle, samples, features); + * cuvs::preprocessing::quantize::scalar::inverse_transform(handle, transformer, + * dataset_revert.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] transformer a scalar transformer + * @param[in] dataset a row-major matrix view on device + * @param[out] out a row-major matrix view on device + * + */ +void inverse_transform(raft::resources const& res, + const transformer& transformer, + raft::device_matrix_view dataset, + raft::device_matrix_view out); + +/** + * @brief Perform inverse quantization step on previously quantized dataset + * + * Note that depending on the chosen data types train dataset the conversion is + * not lossless. + * + * Usage example: + * @code{.cpp} + * auto quantized_dataset = raft::make_host_matrix(samples, features); + * cuvs::preprocessing::quantize::scalar::transform(handle, transformer, dataset, + * quantized_dataset.view()); auto dataset_revert = raft::make_host_matrix(samples, + * features); cuvs::preprocessing::quantize::scalar::inverse_transform(handle, transformer, + * dataset_revert.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] transformer a scalar transformer + * @param[in] dataset a row-major matrix view on host + * @param[out] out a row-major matrix view on host + * + */ +void inverse_transform(raft::resources const& res, + const transformer& transformer, + raft::host_matrix_view dataset, + raft::host_matrix_view out); + +/** + * @brief Initializes a scalar transformer to be used later for quantizing the dataset. + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * cuvs::preprocessing::quantize::scalar::params params; + * auto transformer = cuvs::preprocessing::quantize::scalar::train(handle, params, + * dataset); + * @endcode + * + * @param[in] res raft resource + * @param[in] params configure scalar transformer, e.g. quantile + * @param[in] dataset a row-major matrix view on device + * + * @return transformer + */ +transformer train(raft::resources const& res, + const params params, + raft::device_matrix_view dataset); + +/** + * @brief Initializes a scalar transformer to be used later for quantizing the dataset. + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * cuvs::preprocessing::quantize::scalar::params params; + * auto transformer = cuvs::preprocessing::quantize::scalar::train(handle, params, + * dataset); + * @endcode + * + * @param[in] res raft resource + * @param[in] params configure scalar transformer, e.g. quantile + * @param[in] dataset a row-major matrix view on host + * + * @return transformer + */ +transformer train(raft::resources const& res, + const params params, + raft::host_matrix_view dataset); + +/** + * @brief Applies quantization transform to given dataset + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * cuvs::preprocessing::quantize::scalar::params params; + * auto transformer = cuvs::preprocessing::quantize::scalar::train(handle, params, + * dataset); auto quantized_dataset = raft::make_device_matrix(handle, samples, + * features); cuvs::preprocessing::quantize::scalar::transform(handle, transformer, dataset, + * quantized_dataset.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] transformer a scalar transformer + * @param[in] dataset a row-major matrix view on device + * @param[out] out a row-major matrix view on device + * + */ +void transform(raft::resources const& res, + const transformer& transformer, + raft::device_matrix_view dataset, + raft::device_matrix_view out); + +/** + * @brief Applies quantization transform to given dataset + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * cuvs::preprocessing::quantize::scalar::params params; + * auto transformer = cuvs::preprocessing::quantize::scalar::train(handle, params, + * dataset); auto quantized_dataset = raft::make_host_matrix(samples, features); + * cuvs::preprocessing::quantize::scalar::transform(handle, transformer, dataset, + * quantized_dataset.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] transformer a scalar transformer + * @param[in] dataset a row-major matrix view on host + * @param[out] out a row-major matrix view on host + * + */ +void transform(raft::resources const& res, + const transformer& transformer, + raft::host_matrix_view dataset, + raft::host_matrix_view out); + +/** + * @brief Perform inverse quantization step on previously quantized dataset + * + * Note that depending on the chosen data types train dataset the conversion is + * not lossless. + * + * Usage example: + * @code{.cpp} + * auto quantized_dataset = raft::make_device_matrix(handle, samples, features); + * cuvs::preprocessing::quantize::scalar::transform(handle, transformer, dataset, + * quantized_dataset.view()); auto dataset_revert = raft::make_device_matrix(handle, + * samples, features); cuvs::preprocessing::quantize::scalar::inverse_transform(handle, transformer, + * dataset_revert.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] transformer a scalar transformer + * @param[in] dataset a row-major matrix view on device + * @param[out] out a row-major matrix view on device + * + */ +void inverse_transform(raft::resources const& res, + const transformer& transformer, + raft::device_matrix_view dataset, + raft::device_matrix_view out); + +/** + * @brief Perform inverse quantization step on previously quantized dataset + * + * Note that depending on the chosen data types train dataset the conversion is + * not lossless. + * + * Usage example: + * @code{.cpp} + * auto quantized_dataset = raft::make_host_matrix(samples, features); + * cuvs::preprocessing::quantize::scalar::transform(handle, transformer, dataset, + * quantized_dataset.view()); auto dataset_revert = raft::make_host_matrix(samples, + * features); cuvs::preprocessing::quantize::scalar::inverse_transform(handle, transformer, + * dataset_revert.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] transformer a scalar transformer + * @param[in] dataset a row-major matrix view on host + * @param[out] out a row-major matrix view on host + * + */ +void inverse_transform(raft::resources const& res, + const transformer& transformer, + raft::host_matrix_view dataset, + raft::host_matrix_view out); + +/** + * @brief Initializes a scalar transformer to be used later for quantizing the dataset. + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * cuvs::preprocessing::quantize::scalar::params params; + * auto transformer = cuvs::preprocessing::quantize::scalar::train(handle, params, + * dataset); + * @endcode + * + * @param[in] res raft resource + * @param[in] params configure scalar transformer, e.g. quantile + * @param[in] dataset a row-major matrix view on device + * + * @return transformer + */ +transformer train(raft::resources const& res, + const params params, + raft::device_matrix_view dataset); + +/** + * @brief Initializes a scalar transformer to be used later for quantizing the dataset. + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * cuvs::preprocessing::quantize::scalar::params params; + * auto transformer = cuvs::preprocessing::quantize::scalar::train(handle, params, + * dataset); + * @endcode + * + * @param[in] res raft resource + * @param[in] params configure scalar transformer, e.g. quantile + * @param[in] dataset a row-major matrix view on host + * + * @return transformer + */ +transformer train(raft::resources const& res, + const params params, + raft::host_matrix_view dataset); + +/** + * @brief Applies quantization transform to given dataset + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * cuvs::preprocessing::quantize::scalar::params params; + * auto transformer = cuvs::preprocessing::quantize::scalar::train(handle, params, + * dataset); auto quantized_dataset = raft::make_device_matrix(handle, samples, + * features); cuvs::preprocessing::quantize::scalar::transform(handle, transformer, dataset, + * quantized_dataset.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] transformer a scalar transformer + * @param[in] dataset a row-major matrix view on device + * @param[out] out a row-major matrix view on device + * + */ +void transform(raft::resources const& res, + const transformer& transformer, + raft::device_matrix_view dataset, + raft::device_matrix_view out); + +/** + * @brief Applies quantization transform to given dataset + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * cuvs::preprocessing::quantize::scalar::params params; + * auto transformer = cuvs::preprocessing::quantize::scalar::train(handle, params, + * dataset); auto quantized_dataset = raft::make_host_matrix(samples, features); + * cuvs::preprocessing::quantize::scalar::transform(handle, transformer, dataset, + * quantized_dataset.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] transformer a scalar transformer + * @param[in] dataset a row-major matrix view on host + * @param[out] out a row-major matrix view on host + * + */ +void transform(raft::resources const& res, + const transformer& transformer, + raft::host_matrix_view dataset, + raft::host_matrix_view out); + +/** + * @brief Perform inverse quantization step on previously quantized dataset + * + * Note that depending on the chosen data types train dataset the conversion is + * not lossless. + * + * Usage example: + * @code{.cpp} + * auto quantized_dataset = raft::make_device_matrix(handle, samples, features); + * cuvs::preprocessing::quantize::scalar::transform(handle, transformer, dataset, + * quantized_dataset.view()); auto dataset_revert = raft::make_device_matrix(handle, + * samples, features); cuvs::preprocessing::quantize::scalar::inverse_transform(handle, transformer, + * dataset_revert.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] transformer a scalar transformer + * @param[in] dataset a row-major matrix view on device + * @param[out] out a row-major matrix view on device + * + */ +void inverse_transform(raft::resources const& res, + const transformer& transformer, + raft::device_matrix_view dataset, + raft::device_matrix_view out); + +/** + * @brief Perform inverse quantization step on previously quantized dataset + * + * Note that depending on the chosen data types train dataset the conversion is + * not lossless. + * + * Usage example: + * @code{.cpp} + * auto quantized_dataset = raft::make_host_matrix(samples, features); + * cuvs::preprocessing::quantize::scalar::transform(handle, transformer, dataset, + * quantized_dataset.view()); auto dataset_revert = raft::make_host_matrix(samples, + * features); cuvs::preprocessing::quantize::scalar::inverse_transform(handle, transformer, + * dataset_revert.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] transformer a scalar transformer + * @param[in] dataset a row-major matrix view on host + * @param[out] out a row-major matrix view on host + * + */ +void inverse_transform(raft::resources const& res, + const transformer& transformer, + raft::host_matrix_view dataset, + raft::host_matrix_view out); + +/** @} */ // end of group scalar + +} // namespace cuvs::preprocessing::linear_transform::random_orthogonal diff --git a/cpp/src/preprocessing/linear_transform/detail/random_orthogonal.cuh b/cpp/src/preprocessing/linear_transform/detail/random_orthogonal.cuh new file mode 100644 index 0000000000..bb9fc3eb82 --- /dev/null +++ b/cpp/src/preprocessing/linear_transform/detail/random_orthogonal.cuh @@ -0,0 +1,246 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace cuvs::preprocessing::linear_transform::detail { + +template +raft::device_matrix generate_random_orth(raft::resources const& res, + const int64_t dim, + const uint64_t seed) +{ + auto orthogonal_matrix = raft::make_device_matrix(res, dim, dim); + + raft::random::RngState rng(seed); + if constexpr (!std::is_same_v) { + auto rand_matrix = raft::make_device_matrix(res, dim, dim); + raft::random::normal(res, + rng, + rand_matrix.data_handle(), + rand_matrix.size(), + static_cast(0), + static_cast(1)); + raft::linalg::qrGetQ(res, + rand_matrix.data_handle(), + orthogonal_matrix.data_handle(), + dim, + dim, + raft::resource::get_cuda_stream(res)); + } else { + using compute_t = float; + auto rand_matrix = raft::make_device_matrix(res, dim, dim); + auto orthogonal_matrix_f32 = raft::make_device_matrix(res, dim, dim); + raft::random::normal(res, + rng, + rand_matrix.data_handle(), + rand_matrix.size(), + static_cast(0), + static_cast(1)); + raft::linalg::qrGetQ(res, + rand_matrix.data_handle(), + orthogonal_matrix_f32.data_handle(), + dim, + dim, + raft::resource::get_cuda_stream(res)); + + raft::linalg::map(res, + orthogonal_matrix.view(), + raft::cast_op{}, + raft::make_const_mdspan(orthogonal_matrix_f32.view())); + } + + return orthogonal_matrix; +} + +template +cuvs::preprocessing::linear_transform::random_orthogonal::transformer train( + raft::resources const& res, + const cuvs::preprocessing::linear_transform::random_orthogonal::params params, + raft::device_matrix_view dataset) +{ + cuvs::preprocessing::linear_transform::random_orthogonal::transformer transformer{ + .orthogonal_matrix = generate_random_orth(res, dataset.extent(1), params.seed)}; + + return transformer; +} + +template +cuvs::preprocessing::linear_transform::random_orthogonal::transformer train( + raft::resources const& res, + const cuvs::preprocessing::linear_transform::random_orthogonal::params params, + raft::host_matrix_view dataset) +{ + cuvs::preprocessing::linear_transform::random_orthogonal::transformer transformer{ + .orthogonal_matrix = generate_random_orth(res, dataset.extent(1), params.seed)}; + + return transformer; +} + +template +void transform( + raft::resources const& res, + const cuvs::preprocessing::linear_transform::random_orthogonal::transformer& transformer, + raft::device_matrix_view dataset, + raft::device_matrix_view out) +{ + RAFT_EXPECTS(dataset.extent(0) == out.extent(0), "Input and output dataset sizes mismatch."); + RAFT_EXPECTS(dataset.extent(1) == out.extent(1), "Input and output dataset dimensions mismatch."); + RAFT_EXPECTS(dataset.extent(1) == transformer.orthogonal_matrix.extent(1), + "Input dataset and transformer dimensions mismatch."); + RAFT_EXPECTS(transformer.orthogonal_matrix.extent(0) == transformer.orthogonal_matrix.extent(1), + "Transformer matrix must be square."); + + const auto src_begin = reinterpret_cast(dataset.data_handle()); + const auto src_end = src_begin + sizeof(T) * dataset.size(); + const auto dst_begin = reinterpret_cast(out.data_handle()); + const auto dst_end = dst_begin + sizeof(T) * out.size(); + + // overlapped && src range < dsr range + const auto overlap = !((src_end < dst_begin) || (dst_end < src_begin)); + + const auto dataset_dim = dataset.extent(1); + const auto orth_view = raft::make_device_matrix_view( + const_cast(transformer.orthogonal_matrix.data_handle()), dataset_dim, dataset_dim); + if (!overlap) { + // Remove `const` + auto dataset_view = raft::make_device_matrix_view( + const_cast(dataset.data_handle()), dataset.extent(0), dataset.extent(1)); + + raft::linalg::gemm(res, dataset_view, orth_view, out); + } else { + RAFT_EXPECTS(dst_begin <= src_begin, + "Must be dst_begin <= src_begin in the current implementation"); + + auto mr = raft::resource::get_workspace_resource(res); + + const auto gemm_chunk_size = + std::min(static_cast(dataset.extent(0)), + raft::resource::get_workspace_free_bytes(res) / (dataset.extent(1) * sizeof(T))); + auto neighbor_indices = raft::make_device_mdarray( + res, mr, raft::make_extents(gemm_chunk_size, dataset.extent(1))); + + raft::spatial::knn::detail::utils::batch_load_iterator dataset_chunk_set( + dataset.data_handle(), + dataset.extent(0), + dataset.extent(1), + gemm_chunk_size, + raft::resource::get_cuda_stream(res), + mr); + for (auto& batch : dataset_chunk_set) { + raft::linalg::gemm(res, + raft::make_device_matrix_view( + const_cast(batch.data()), batch.size(), dataset_dim), + orth_view, + out); + raft::copy_async(out.data_handle() + dataset_dim * batch.offset(), + batch.data(), + batch.size() * dataset_dim, + raft::resource::get_cuda_stream(res)); + } + } +} + +template +void transform( + raft::resources const& res, + const cuvs::preprocessing::linear_transform::random_orthogonal::transformer& transformer, + raft::host_matrix_view dataset, + raft::host_matrix_view out) +{ + RAFT_EXPECTS(dataset.extent(0) == out.extent(0), "Input and output dataset sizes mismatch."); + RAFT_EXPECTS(dataset.extent(1) == out.extent(1), "Input and output dataset dimensions mismatch."); + RAFT_EXPECTS(dataset.extent(1) == transformer.orthogonal_matrix.extent(1), + "Input dataset and transformer dimensions mismatch."); + RAFT_EXPECTS(transformer.orthogonal_matrix.extent(0) == transformer.orthogonal_matrix.extent(1), + "Transformer matrix must be square."); + + auto host_orth = raft::make_host_matrix(dataset.extent(1), dataset.extent(1)); + raft::copy(host_orth.data_handle(), + transformer.orthogonal_matrix.data_handle(), + host_orth.size(), + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + +#pragma omp parallel for collapse(2) + for (int64_t i = 0; i < dataset.extent(0); i++) { + for (int64_t j = 0; j < dataset.extent(1); j++) { + auto c = static_cast(0); + for (int64_t k = 0; k < dataset.extent(1); k++) { + c += dataset(i, k) * host_orth(k, j); + } + out(i, j) = c; + } + } +} + +template +void inverse_transform( + raft::resources const& res, + const cuvs::preprocessing::linear_transform::random_orthogonal::transformer& transformer, + raft::device_matrix_view dataset, + raft::device_matrix_view out) +{ + RAFT_EXPECTS(dataset.extent(0) == out.extent(0), "Input and output dataset sizes mismatch."); + RAFT_EXPECTS(dataset.extent(1) == out.extent(1), "Input and output dataset dimensions mismatch."); + RAFT_EXPECTS(dataset.extent(1) == transformer.orthogonal_matrix.extent(1), + "Input dataset and transformer dimensions mismatch."); + RAFT_EXPECTS(transformer.orthogonal_matrix.extent(0) == transformer.orthogonal_matrix.extent(1), + "Transformer matrix must be square."); + + // Remove `const` + auto dataset_view = raft::make_device_matrix_view( + const_cast(dataset.data_handle()), dataset.extent(0), dataset.extent(1)); + + // Transpose and remove `const` + const auto dataset_dim = dataset.extent(1); + const auto orth_T_view = raft::make_device_matrix_view( + const_cast(transformer.orthogonal_matrix.data_handle()), dataset_dim, dataset_dim); + + raft::linalg::gemm(res, dataset_view, orth_T_view, out); +} + +template +void inverse_transform( + raft::resources const& res, + const cuvs::preprocessing::linear_transform::random_orthogonal::transformer& transformer, + raft::host_matrix_view dataset, + raft::host_matrix_view out) +{ + RAFT_EXPECTS(dataset.extent(0) == out.extent(0), "Input and output dataset sizes mismatch."); + RAFT_EXPECTS(dataset.extent(1) == out.extent(1), "Input and output dataset dimensions mismatch."); + RAFT_EXPECTS(dataset.extent(1) == transformer.orthogonal_matrix.extent(1), + "Input dataset and transformer dimensions mismatch."); + RAFT_EXPECTS(transformer.orthogonal_matrix.extent(0) == transformer.orthogonal_matrix.extent(1), + "Transformer matrix must be square."); + + auto host_orth = raft::make_host_matrix(dataset.extent(1), dataset.extent(1)); + raft::copy(host_orth.data_handle(), + transformer.orthogonal_matrix.data_handle(), + host_orth.size(), + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + +#pragma omp parallel for collapse(2) + for (int64_t i = 0; i < dataset.extent(0); i++) { + for (int64_t j = 0; j < dataset.extent(1); j++) { + auto c = static_cast(0); + for (int64_t k = 0; k < dataset.extent(1); k++) { + c += dataset(i, k) * host_orth(j, k); + } + out(i, j) = c; + } + } +} + +} // namespace cuvs::preprocessing::linear_transform::detail diff --git a/cpp/src/preprocessing/linear_transform/random_orthogonal.cu b/cpp/src/preprocessing/linear_transform/random_orthogonal.cu new file mode 100644 index 0000000000..9ecae9af55 --- /dev/null +++ b/cpp/src/preprocessing/linear_transform/random_orthogonal.cu @@ -0,0 +1,61 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "./detail/random_orthogonal.cuh" + +#include + +namespace cuvs::preprocessing::linear_transform::random_orthogonal { + +#define CUVS_INST_TRANSFORMATION(T) \ + auto train(raft::resources const& res, \ + const params params, \ + raft::device_matrix_view dataset) -> transformer \ + { \ + return detail::train(res, params, dataset); \ + } \ + auto train(raft::resources const& res, \ + const params params, \ + raft::host_matrix_view dataset) -> transformer \ + { \ + return detail::train(res, params, dataset); \ + } \ + void transform(raft::resources const& res, \ + const transformer& transformer, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view out) \ + { \ + detail::transform(res, transformer, dataset, out); \ + } \ + void transform(raft::resources const& res, \ + const transformer& transformer, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view out) \ + { \ + detail::transform(res, transformer, dataset, out); \ + } \ + void inverse_transform(raft::resources const& res, \ + const transformer& transformer, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view out) \ + { \ + detail::inverse_transform(res, transformer, dataset, out); \ + } \ + void inverse_transform(raft::resources const& res, \ + const transformer& transformer, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view out) \ + { \ + detail::inverse_transform(res, transformer, dataset, out); \ + } \ + template struct transformer; + +CUVS_INST_TRANSFORMATION(double); +CUVS_INST_TRANSFORMATION(float); +CUVS_INST_TRANSFORMATION(half); + +#undef CUVS_INST_TRANSFORMATION + +} // namespace cuvs::preprocessing::linear_transform::random_orthogonal diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 56b53ef697..e6e391103c 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -296,7 +296,7 @@ ConfigureTest( ConfigureTest( NAME PREPROCESSING_TEST PATH preprocessing/scalar_quantization.cu preprocessing/binary_quantization.cu - preprocessing/spectral_embedding.cu + preprocessing/spectral_embedding.cu preprocessing/random_orthogonal_transformation.cu GPUS 1 PERCENT 100 ) diff --git a/cpp/tests/preprocessing/random_orthogonal_transformation.cu b/cpp/tests/preprocessing/random_orthogonal_transformation.cu new file mode 100644 index 0000000000..d6b27e4cdf --- /dev/null +++ b/cpp/tests/preprocessing/random_orthogonal_transformation.cu @@ -0,0 +1,219 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "../test_utils.cuh" +#include +#include +#include +#include +#include +#include +#include + +namespace cuvs::preprocessing::linear_transform::random_orthogonal { + +template +struct TransformationInputs { + int rows; + int cols; +}; + +template +constexpr double error_threshold_const = 0; +template <> +constexpr double error_threshold_const = 1e-14; +template <> +constexpr double error_threshold_const = 5e-6; +template <> +constexpr double error_threshold_const = 5e-3; +template +double error_threshold(const T dim) +{ + return error_threshold_const * sqrt(static_cast(dim)); +} + +template +std::ostream& operator<<(std::ostream& os, const TransformationInputs& inputs) +{ + return os << " rows:" << inputs.rows << " cols:" << inputs.cols; +} + +template +class RandomOrthogonalTransformation : public ::testing::TestWithParam> { + public: + RandomOrthogonalTransformation() + : params_(::testing::TestWithParam>::GetParam()), + stream(raft::resource::get_cuda_stream(handle)), + input_(0, stream) + { + } + + double getRelativeErrorStddev(const T* array_a, const T* array_b, size_t size, float quantile) + { + // relative error elementwise + rmm::device_uvector relative_error(size, stream); + raft::linalg::binaryOp( + relative_error.data(), + array_a, + array_b, + size, + [] __device__(double a, double b) { + return a != b ? (raft::abs(a - b) / raft::max(raft::abs(a), raft::abs(b))) : 0; + }, + stream); + + // sort by size --> remove largest errors to account for quantile chosen + thrust::sort(raft::resource::get_thrust_policy(handle), + relative_error.data(), + relative_error.data() + size); + int elements_to_consider = + std::ceil(double(params_.quantization_params.quantile) * double(size)); + + rmm::device_uvector mu(1, stream); + RAFT_CUDA_TRY(cudaMemsetAsync(mu.data(), 0, sizeof(double), stream)); + + rmm::device_uvector error_stddev(1, stream); + raft::stats::stddev(error_stddev.data(), + relative_error.data(), + mu.data(), + 1, + elements_to_consider, + false, + stream); + + double error_stddev_h; + raft::update_host(&error_stddev_h, error_stddev.data(), 1, stream); + raft::resource::sync_stream(handle, stream); + return error_stddev_h; + } + + protected: + void testRandomOrthogonalTransformation() + { + // dataset identical on host / device + auto dataset = raft::make_device_matrix_view( + (const T*)(input_.data()), rows_, cols_); + auto dataset_h = raft::make_host_matrix_view( + (const T*)(host_input_.data()), rows_, cols_); + + size_t print_size = std::min(input_.size(), 20ul); + + cuvs::preprocessing::linear_transform::random_orthogonal::params params; + auto transformer = + cuvs::preprocessing::linear_transform::random_orthogonal::train(handle, params, dataset); + + { + auto transformed_input_h = raft::make_host_matrix(rows_, cols_); + auto transformed_input_d = raft::make_device_matrix(handle, rows_, cols_); + cuvs::preprocessing::linear_transform::random_orthogonal::transform( + handle, transformer, dataset, transformed_input_d.view()); + cuvs::preprocessing::linear_transform::random_orthogonal::transform( + handle, transformer, dataset_h, transformed_input_h.view()); + + // test transform host/device equal + ASSERT_TRUE(devArrMatchHost(transformed_input_h.data_handle(), + transformed_input_d.data_handle(), + input_.size(), + cuvs::CompareApprox(error_threshold(cols_)), + stream)); + + auto transformed_input_h_const_view = raft::make_host_matrix_view( + transformed_input_h.data_handle(), rows_, cols_); + auto re_transformed_input_h = raft::make_host_matrix(rows_, cols_); + cuvs::preprocessing::linear_transform::random_orthogonal::inverse_transform( + handle, transformer, transformed_input_h_const_view, re_transformed_input_h.view()); + + auto transformed_input_d_const_view = raft::make_device_matrix_view( + transformed_input_d.data_handle(), rows_, cols_); + auto re_transformed_input_d = raft::make_device_matrix(handle, rows_, cols_); + cuvs::preprocessing::linear_transform::random_orthogonal::inverse_transform( + handle, transformer, transformed_input_d_const_view, re_transformed_input_d.view()); + + // test transform host/device equal + ASSERT_TRUE( + devArrMatchHost(re_transformed_input_h.data_handle(), + re_transformed_input_d.data_handle(), + input_.size(), + cuvs::CompareApprox(error_threshold(cols_) * 2 /*=transform+inv*/), + stream)); + ASSERT_TRUE( + devArrMatchHost(re_transformed_input_h.data_handle(), + dataset.data_handle(), + input_.size(), + cuvs::CompareApprox(error_threshold(cols_) * 2 /*=transform+inv*/), + stream)); + } + } + + void SetUp() override + { + rows_ = params_.rows; + cols_ = params_.cols; + + int n_elements = rows_ * cols_; + input_.resize(n_elements, stream); + host_input_.resize(n_elements); + + // random input + unsigned long long int seed = 1234ULL; + raft::random::RngState r(seed); + uniform(handle, r, input_.data(), input_.size(), static_cast(-8), static_cast(8)); + + raft::update_host(host_input_.data(), input_.data(), input_.size(), stream); + + raft::resource::sync_stream(handle, stream); + } + + private: + raft::resources handle; + cudaStream_t stream; + + TransformationInputs params_; + int rows_; + int cols_; + rmm::device_uvector input_; + std::vector host_input_; +}; + +template +const std::vector> inputs = {{1000, 1}, + {1000, 3}, + {1000, 13}, + {1000, 128}, + {1000, 199}, + {1000, 876}, + {1000, 1289}, + {10000, 67}, + {10000, 128}}; + +typedef RandomOrthogonalTransformation RandomOrthogonalTransformation_float_int8t; +TEST_P(RandomOrthogonalTransformation_float_int8t, RandomOrthogonalTransformationTest) +{ + this->testRandomOrthogonalTransformation(); +} + +typedef RandomOrthogonalTransformation RandomOrthogonalTransformation_double_int8t; +TEST_P(RandomOrthogonalTransformation_double_int8t, RandomOrthogonalTransformationTest) +{ + this->testRandomOrthogonalTransformation(); +} + +typedef RandomOrthogonalTransformation RandomOrthogonalTransformation_half_int8t; +TEST_P(RandomOrthogonalTransformation_half_int8t, RandomOrthogonalTransformationTest) +{ + this->testRandomOrthogonalTransformation(); +} + +INSTANTIATE_TEST_CASE_P(RandomOrthogonalTransformation, + RandomOrthogonalTransformation_float_int8t, + ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(RandomOrthogonalTransformation, + RandomOrthogonalTransformation_double_int8t, + ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(RandomOrthogonalTransformation, + RandomOrthogonalTransformation_half_int8t, + ::testing::ValuesIn(inputs)); + +} // namespace cuvs::preprocessing::linear_transform::random_orthogonal diff --git a/cpp/tests/test_utils.cuh b/cpp/tests/test_utils.cuh index 11b21f3647..92d9ef1856 100644 --- a/cpp/tests/test_utils.cuh +++ b/cpp/tests/test_utils.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2018-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2018-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -149,8 +149,8 @@ testing::AssertionResult devArrMatchHost( bool ok = true; auto fail = testing::AssertionFailure(); for (size_t i(0); i < size; ++i) { - auto exp = expected_h[i]; - auto act = act_h.get()[i]; + const auto exp = static_cast(expected_h[i]); + const auto act = static_cast(act_h.get()[i]); if (!eq_compare(exp, act)) { ok = false; fail << "actual=" << act << " != expected=" << exp << " @" << i << "; "; diff --git a/cpp/tests/test_utils.h b/cpp/tests/test_utils.h index c879b59912..60becbf261 100644 --- a/cpp/tests/test_utils.h +++ b/cpp/tests/test_utils.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2018-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2018-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -42,11 +42,16 @@ struct CompareApprox { CompareApprox(T eps_) : eps(eps_) {} bool operator()(const T& a, const T& b) const { - T diff = std::abs(a - b); - T m = std::max(std::abs(a), std::abs(b)); - T ratio = diff > eps ? diff / m : diff; + using compute_t = double; + const auto a_comp = static_cast(a); + const auto b_comp = static_cast(b); + const auto eps_comp = static_cast(eps); - return (ratio <= eps); + const auto diff = std::abs(a_comp - b_comp); + const auto m = std::max(std::abs(a_comp), std::abs(b_comp)); + const auto ratio = diff > eps_comp ? diff / m : diff; + + return (ratio <= eps_comp); } private: @@ -105,7 +110,8 @@ template testing::AssertionResult match(const T& expected, const T& actual, L eq_compare) { if (!eq_compare(expected, actual)) { - return testing::AssertionFailure() << "actual=" << actual << " != expected=" << expected; + return testing::AssertionFailure() << "actual=" << static_cast(actual) + << " != expected=" << static_cast(expected); } return testing::AssertionSuccess(); } From 13d99d74d1b45b016c7b8fb83d8740759a6451f5 Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Tue, 11 Nov 2025 11:13:43 +0900 Subject: [PATCH 02/32] Add docs --- .../linear_transform/random_orthogonal.hpp | 2 +- docs/source/cpp_api/preprocessing.rst | 1 + .../preprocessing_linear_transform.rst | 19 +++++++++++++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 docs/source/cpp_api/preprocessing_linear_transform.rst diff --git a/cpp/include/cuvs/preprocessing/linear_transform/random_orthogonal.hpp b/cpp/include/cuvs/preprocessing/linear_transform/random_orthogonal.hpp index fb6e302183..5790b6746c 100644 --- a/cpp/include/cuvs/preprocessing/linear_transform/random_orthogonal.hpp +++ b/cpp/include/cuvs/preprocessing/linear_transform/random_orthogonal.hpp @@ -16,7 +16,7 @@ namespace cuvs::preprocessing::linear_transform::random_orthogonal { /** - * @defgroup scalar Scalar transformer utilities + * @defgroup random_orthogonal Random orthogonal transformer utilities * @{ */ diff --git a/docs/source/cpp_api/preprocessing.rst b/docs/source/cpp_api/preprocessing.rst index 5fd8cfe778..e8599941e5 100644 --- a/docs/source/cpp_api/preprocessing.rst +++ b/docs/source/cpp_api/preprocessing.rst @@ -10,4 +10,5 @@ Preprocessing :caption: Contents: preprocessing_quantize.rst + preprocessing_linear_transform.rst preprocessing_spectral_embedding.rst diff --git a/docs/source/cpp_api/preprocessing_linear_transform.rst b/docs/source/cpp_api/preprocessing_linear_transform.rst new file mode 100644 index 0000000000..b039255331 --- /dev/null +++ b/docs/source/cpp_api/preprocessing_linear_transform.rst @@ -0,0 +1,19 @@ +Linear transformation +======== + +This page provides C++ class references for the publicly-exposed elements of the +`cuvs/preprocessing/linear_transform` package. + +.. role:: py(code) + :language: c++ + :class: highlight + +Random orthogonal +------ + +``#include `` + +namespace *cuvs::preprocessing::linear_transform::random_orthogonal* + +.. doxygengroup:: random_orthogonal + :project: cuvs From fc69fea83e88229134cfbbd9b6c81105ef8d714e Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Tue, 11 Nov 2025 14:57:20 +0900 Subject: [PATCH 03/32] Update linear trandform docs --- docs/source/cpp_api/preprocessing_linear_transform.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/cpp_api/preprocessing_linear_transform.rst b/docs/source/cpp_api/preprocessing_linear_transform.rst index b039255331..c8906dc3ce 100644 --- a/docs/source/cpp_api/preprocessing_linear_transform.rst +++ b/docs/source/cpp_api/preprocessing_linear_transform.rst @@ -1,5 +1,5 @@ Linear transformation -======== +===================== This page provides C++ class references for the publicly-exposed elements of the `cuvs/preprocessing/linear_transform` package. @@ -9,7 +9,7 @@ This page provides C++ class references for the publicly-exposed elements of the :class: highlight Random orthogonal ------- +----------------- ``#include `` From 32e9ea6e1d95eea41bb81a939561995aea2df870 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Thu, 13 Nov 2025 08:01:16 +0100 Subject: [PATCH 04/32] ANN_BENCH: integrate NVTX statistics (#1529) Add the aggregate reporting of NVTX ranges in the output of benchmark executable. ### Usage ```bash # Measure the CPU and GPU runtime of all NVTX ranges nsys launch --trace=cuda,nvtx # Measure only the CPU runtime of all NVTX ranges nsys launch --trace=nvtx # Do not measure/report any NVTX ranges # Do not measure/report any NVTX ranges within benchmark, but use nsys profiling as usual nsys profile ... ``` ### Implementation The PR adds a single module `nvtx_stats.hpp` to the benchmark executable; there are no changes to the library at all. The program leverages NVIDIA Nsight Systems CLI to collect and export NVTX statistics and then SQLite API to aggregate it into the benchmark state: 1. Detect if run via `nsys launch`; if so, call `nsys start` / `nsys stop` around benchmark loop; otherwise do nothing. 2. If the report is generated, read it and query all NVTX events and the GPU correlation data using SQLite 3. Aggregate the NVTX events by their short names (without arguments to reduce the number of columns) 4. Add them to the benchmark performance counters with the same averaging strategy as the global CPU/GPU runtime. ### Performance cost If the benchmark is **not** run using `nsys launch`, there's virtually zero overhead in the new functionality. Otherwise, there are overheads: 1. Usual nsys profiling overheads (minimized by disabling unused information via `nsys start` CLI internally). This affects the reported performance the same way as normal nsys profiling does (especially if cuda tracing is enabled). 2. One or more data collection/exporting events per benchmark case. These add some extra time to the benchmark time, but do not affect the counters (they are not the part of the benchmark loop) Closes https://github.com/rapidsai/cuvs/issues/1367 Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/cuvs/pull/1529 --- cpp/bench/ann/CMakeLists.txt | 4 +- cpp/bench/ann/src/common/benchmark.hpp | 5 +- cpp/bench/ann/src/common/nvtx_stats.hpp | 539 ++++++++++++++++++++++++ cpp/cmake/thirdparty/get_sqlite.cmake | 44 ++ 4 files changed, 590 insertions(+), 2 deletions(-) create mode 100644 cpp/bench/ann/src/common/nvtx_stats.hpp create mode 100644 cpp/cmake/thirdparty/get_sqlite.cmake diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index c3a09272a6..46bbd318d2 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -93,6 +93,7 @@ if(CUVS_ANN_BENCH_USE_HNSWLIB OR CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB) endif() include(cmake/thirdparty/get_nlohmann_json) +include(cmake/thirdparty/get_sqlite) if(CUVS_ANN_BENCH_USE_GGNN) include(cmake/thirdparty/get_ggnn) @@ -144,6 +145,7 @@ function(ConfigureAnnBench) ${BENCH_NAME} PRIVATE ${ConfigureAnnBench_LINKS} nlohmann_json::nlohmann_json + sqlite3 Threads::Threads $<$:CUDA::cudart_static> $ @@ -358,7 +360,7 @@ if(CUVS_ANN_BENCH_SINGLE_EXE) target_include_directories(ANN_BENCH PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_link_libraries( - ANN_BENCH PRIVATE raft::raft nlohmann_json::nlohmann_json benchmark::benchmark dl + ANN_BENCH PRIVATE raft::raft nlohmann_json::nlohmann_json sqlite3 benchmark::benchmark dl $<$:CUDA::nvtx3> ) set_target_properties( diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index b0c471cb29..22859e9ab8 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -7,6 +7,7 @@ #include "ann_types.hpp" #include "conf.hpp" #include "dataset.hpp" +#include "nvtx_stats.hpp" #include "util.hpp" #include @@ -138,6 +139,7 @@ void bench_build(::benchmark::State& state, cuda_timer gpu_timer{algo}; { + nvtx_stats nvtx_stats{state}; nvtx_case nvtx{state.name()}; /* Note: GPU timing @@ -293,6 +295,7 @@ void bench_search(::benchmark::State& state, auto* distances_ptr = reinterpret_cast(neighbors_ptr + result_elem_count); { + nvtx_stats nvtx_stats{state}; nvtx_case nvtx{state.name()}; std::unique_ptr> a{nullptr}; diff --git a/cpp/bench/ann/src/common/nvtx_stats.hpp b/cpp/bench/ann/src/common/nvtx_stats.hpp new file mode 100644 index 0000000000..f12776336b --- /dev/null +++ b/cpp/bench/ann/src/common/nvtx_stats.hpp @@ -0,0 +1,539 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "util.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cuvs::bench { + +namespace detail { + +// Strip parameters from NVTX range names for grouping +// Removes everything inside and including (), <>, and {} +inline std::string strip_nvtx_parameters(const std::string& name) +{ + if (name.empty()) { return ""; } + + std::string result; + result.reserve(name.size()); + int paren_depth = 0; + int angle_depth = 0; + int brace_depth = 0; + + for (char c : name) { + // Track nesting depth + if (c == '(') { + paren_depth++; + } else if (c == ')') { + paren_depth--; + } else if (c == '<') { + angle_depth++; + } else if (c == '>') { + angle_depth--; + } else if (c == '{') { + brace_depth++; + } else if (c == '}') { + brace_depth--; + } else if (paren_depth == 0 && angle_depth == 0 && brace_depth == 0) { + // Only add character if we're not inside any brackets + result += c; + } + } + + // Trim trailing whitespace + while (!result.empty() && std::isspace(result.back())) { + result.pop_back(); + } + + return result; +} + +// Extract CPU and GPU NVTX statistics with correlation +inline auto extract_cpu_gpu_stats(sqlite3* db, + int64_t algo_bench_domain_id, + const std::vector& activity_tables = {}) +{ + // Accumulate times by base_name + std::map> stats; + + // Query 1: NVTX events sorted by start, then end + const char* nvtx_query = + "SELECT start, end, globalTid, text " + "FROM NVTX_EVENTS " + "WHERE end IS NOT NULL " + " AND eventType = 59 " + " AND domainId != ? " + " AND text IS NOT NULL " + "ORDER BY start, end"; + + // Query 2: Runtime+GPU events sorted by start, then end + std::string runtime_query; + if (activity_tables.empty()) { + // Return empty result set with correct schema + runtime_query = + "SELECT NULL as start, NULL as end, NULL as globalTid, NULL as min_start, NULL as max_end " + "WHERE 1=0"; + } else { + // Build union query for GPU activities + std::string gpu_union = "("; + for (size_t i = 0; i < activity_tables.size(); ++i) { + if (i > 0) gpu_union += " UNION ALL "; + gpu_union += "SELECT correlationId, start, end FROM " + activity_tables[i]; + } + gpu_union += ")"; + + runtime_query = + "SELECT r.start, r.end, r.globalTid, MIN(ga.start), MAX(ga.end) " + "FROM CUPTI_ACTIVITY_KIND_RUNTIME r " + "INNER JOIN (" + + gpu_union + + ") ga ON r.correlationId = ga.correlationId " + "GROUP BY r.start, r.end, r.globalTid, r.correlationId " + "ORDER BY r.start, r.end"; + } + + sqlite3_stmt* nvtx_stmt = nullptr; + sqlite3_stmt* runtime_stmt = nullptr; + + if (sqlite3_prepare_v2(db, nvtx_query, -1, &nvtx_stmt, nullptr) != SQLITE_OK) { return stats; } + sqlite3_bind_int64(nvtx_stmt, 1, algo_bench_domain_id); + + if (sqlite3_prepare_v2(db, runtime_query.c_str(), -1, &runtime_stmt, nullptr) != SQLITE_OK) { + sqlite3_finalize(nvtx_stmt); + return stats; + } + + // Structure to hold runtime events in a sliding window queue + struct RuntimeEvent { + int64_t rt_start; + int64_t rt_end; + int64_t globalTid; + int64_t gpu_start; + int64_t gpu_end; + }; + std::deque runtime_queue; + bool runtime_exhausted = false; + + // Process each NVTX event + while (sqlite3_step(nvtx_stmt) == SQLITE_ROW) { + int64_t nvtx_start = sqlite3_column_int64(nvtx_stmt, 0); + int64_t nvtx_end = sqlite3_column_int64(nvtx_stmt, 1); + int64_t nvtx_tid = sqlite3_column_int64(nvtx_stmt, 2); + const char* name = reinterpret_cast(sqlite3_column_text(nvtx_stmt, 3)); + + // Apply strip_params in C++ instead of SQL + std::string base_name = name ? strip_nvtx_parameters(name) : ""; + if (base_name.empty()) continue; // Skip events with no name after stripping + + // Accumulate CPU time + auto& [cpu_ns, gpu_ns] = stats[base_name]; // Creates with zeros if not exists + cpu_ns += (nvtx_end - nvtx_start); + + // Remove runtime events that start before this NVTX event starts + // (they can't match this or any future NVTX events since NVTX is sorted) + while (!runtime_queue.empty() && runtime_queue.front().rt_start < nvtx_start) { + runtime_queue.pop_front(); + } + + // Load more runtime events into the queue until we have all that could match + while (!runtime_exhausted) { + // Peek: do we need more events? + if (!runtime_queue.empty() && runtime_queue.back().rt_start > nvtx_end) { + // We have enough events in queue for this NVTX + break; + } + + // Fetch next runtime event + if (sqlite3_step(runtime_stmt) == SQLITE_ROW) { + RuntimeEvent evt; + evt.rt_start = sqlite3_column_int64(runtime_stmt, 0); + evt.rt_end = sqlite3_column_int64(runtime_stmt, 1); + evt.globalTid = sqlite3_column_int64(runtime_stmt, 2); + evt.gpu_start = sqlite3_column_int64(runtime_stmt, 3); + evt.gpu_end = sqlite3_column_int64(runtime_stmt, 4); + runtime_queue.push_back(evt); + } else { + runtime_exhausted = true; + break; + } + } + + // Scan the queue to find all matching runtime events + int64_t gpu_min = INT64_MAX; + int64_t gpu_max = INT64_MIN; + bool found_any = false; + + for (const auto& rt : runtime_queue) { + // Stop if runtime event starts after NVTX event ends + if (rt.rt_start > nvtx_end) { break; } + + // Check if this runtime event is contained in the NVTX event + if (rt.globalTid == nvtx_tid && rt.rt_start >= nvtx_start && rt.rt_end <= nvtx_end) { + gpu_min = std::min(gpu_min, rt.gpu_start); + gpu_max = std::max(gpu_max, rt.gpu_end); + found_any = true; + } + } + + // Record GPU time + if (found_any && gpu_max > gpu_min) { gpu_ns += (gpu_max - gpu_min); } + } + + sqlite3_finalize(nvtx_stmt); + sqlite3_finalize(runtime_stmt); + return stats; // Return the accumulated stats map +} + +// Common setup: open database, register functions, find domain ID, discover GPU tables +inline std::tuple> setup_nvtx_database( + const std::string& sqlite_file) +{ + sqlite3* db = nullptr; + int rc = sqlite3_open_v2(sqlite_file.c_str(), &db, SQLITE_OPEN_READONLY, nullptr); + if (rc != SQLITE_OK) { + if (db) sqlite3_close(db); + return {nullptr, -1, {}}; + } + + // Find the domainId for "algo benchmark" domain to exclude it + const char* find_domain_sql = + "SELECT domainId FROM NVTX_EVENTS WHERE text = 'algo benchmark' LIMIT 1"; + sqlite3_stmt* domain_stmt = nullptr; + int domain_rc = sqlite3_prepare_v2(db, find_domain_sql, -1, &domain_stmt, nullptr); + + int64_t algo_bench_domain_id = -1; + if (domain_rc == SQLITE_OK && sqlite3_step(domain_stmt) == SQLITE_ROW) { + algo_bench_domain_id = sqlite3_column_int64(domain_stmt, 0); + } + if (domain_stmt) sqlite3_finalize(domain_stmt); + + // Check if GPU activity tables exist + const char* check_tables_sql = + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND " + "(name='CUPTI_ACTIVITY_KIND_RUNTIME' OR name='CUPTI_ACTIVITY_KIND_KERNEL')"; + sqlite3_stmt* check_stmt = nullptr; + bool has_gpu_tables = false; + if (sqlite3_prepare_v2(db, check_tables_sql, -1, &check_stmt, nullptr) == SQLITE_OK) { + if (sqlite3_step(check_stmt) == SQLITE_ROW) { + has_gpu_tables = sqlite3_column_int64(check_stmt, 0) == 2; + } + sqlite3_finalize(check_stmt); + } + + std::vector activity_tables; + if (has_gpu_tables) { + // Find all CUPTI_ACTIVITY_KIND_* tables (except RUNTIME) + const char* find_tables_sql = + "SELECT name FROM sqlite_master WHERE type='table' " + "AND name LIKE 'CUPTI_ACTIVITY_KIND_%' AND name != 'CUPTI_ACTIVITY_KIND_RUNTIME'"; + sqlite3_stmt* tables_stmt = nullptr; + + if (sqlite3_prepare_v2(db, find_tables_sql, -1, &tables_stmt, nullptr) == SQLITE_OK) { + while (sqlite3_step(tables_stmt) == SQLITE_ROW) { + const char* table_name = reinterpret_cast(sqlite3_column_text(tables_stmt, 0)); + if (table_name) { + // Check if this table has start, end, and correlationId columns + std::string check_cols_sql = "PRAGMA table_info(" + std::string(table_name) + ")"; + sqlite3_stmt* cols_stmt = nullptr; + bool has_start = false, has_end = false, has_corr = false; + + if (sqlite3_prepare_v2(db, check_cols_sql.c_str(), -1, &cols_stmt, nullptr) == + SQLITE_OK) { + while (sqlite3_step(cols_stmt) == SQLITE_ROW) { + const char* col_name = + reinterpret_cast(sqlite3_column_text(cols_stmt, 1)); + if (col_name) { + if (strcmp(col_name, "start") == 0) has_start = true; + if (strcmp(col_name, "end") == 0) has_end = true; + if (strcmp(col_name, "correlationId") == 0) has_corr = true; + } + } + sqlite3_finalize(cols_stmt); + } + + if (has_start && has_end && has_corr) { activity_tables.push_back(table_name); } + } + } + sqlite3_finalize(tables_stmt); + } + } + + return {db, algo_bench_domain_id, activity_tables}; +} + +// Extract NVTX statistics from SQLite database +// Returns pair of (cpu_times, gpu_times) maps with times in seconds +inline auto extract_nvtx_stats_from_sqlite(const std::string& sqlite_file) +{ + auto [db, algo_bench_domain_id, activity_tables] = setup_nvtx_database(sqlite_file); + if (!db) { return std::map>{}; } + + // Extract CPU and GPU stats (works even if no GPU tables available) + auto stats = extract_cpu_gpu_stats(db, algo_bench_domain_id, activity_tables); + + sqlite3_close(db); + return stats; +} + +// Filter out the ranges with less than min_time_ratio of the max detected range time. +inline auto filter_stats(const std::map>& stats, + double min_time_ratio) + -> std::pair, std::map> +{ + std::map cpu_times; + std::map gpu_times; + int64_t cpu_threshold = 0; + int64_t gpu_threshold = 0; + for (const auto& [_, times] : stats) { + auto [cpu_time, gpu_time] = times; + cpu_threshold = std::max(cpu_threshold, static_cast(cpu_time * min_time_ratio)); + gpu_threshold = std::max(gpu_threshold, static_cast(gpu_time * min_time_ratio)); + } + for (const auto& [name, times] : stats) { + auto [cpu_time, gpu_time] = times; + if (cpu_time > cpu_threshold) { cpu_times[name] = static_cast(cpu_time) / 1.0e9; } + if (gpu_time > gpu_threshold) { gpu_times[name] = static_cast(gpu_time) / 1.0e9; } + } + return {cpu_times, gpu_times}; +} + +// Get process name from PID +inline std::string get_process_name(pid_t pid) +{ + std::ifstream comm_file("/proc/" + std::to_string(pid) + "/comm"); + std::string name; + if (comm_file.is_open()) { std::getline(comm_file, name); } + return name; +} + +// Get process executable path from PID +inline std::string get_process_exe_path(pid_t pid) +{ + char buffer[PATH_MAX]; + ssize_t len = + readlink(("/proc/" + std::to_string(pid) + "/exe").c_str(), buffer, sizeof(buffer) - 1); + if (len != -1) { + buffer[len] = '\0'; + return std::string(buffer); + } + return ""; +} + +// Get parent PID from a given PID +inline pid_t get_parent_pid(pid_t pid) +{ + std::ifstream stat_file("/proc/" + std::to_string(pid) + "/stat"); + if (stat_file.is_open()) { + std::string line; + std::getline(stat_file, line); + + // stat file format: pid (comm) state ppid ... + size_t last_paren = line.rfind(')'); + if (last_paren != std::string::npos) { + std::istringstream iss(line.substr(last_paren + 1)); + char state; + pid_t ppid; + iss >> state >> ppid; + return ppid; + } + } + return 0; +} + +// Check if a process has 'launch' in its command line +inline bool has_launch_arg(pid_t pid) +{ + std::ifstream cmdline_file("/proc/" + std::to_string(pid) + "/cmdline"); + if (cmdline_file.is_open()) { + std::string arg; + while (std::getline(cmdline_file, arg, '\0')) { + if (arg == "launch") { return true; } + } + } + return false; +} + +// Detect if the program was launched by nsys with 'launch' subcommand +// by walking up the process tree +inline std::optional detect_nsys_launch() +{ + pid_t parent_pid = getppid(); + + // Walk up the process tree (max 10 levels) + for (int depth = 0; depth < 10 && parent_pid > 1; ++depth) { + std::string parent_name = get_process_name(parent_pid); + + // Check if this process is nsys + if (parent_name.find("nsys") != std::string::npos) { + // Verify it has the 'launch' argument + if (has_launch_arg(parent_pid)) { + std::string nsys_exe = get_process_exe_path(parent_pid); + return nsys_exe.empty() ? std::optional(parent_name) + : std::optional(nsys_exe); + } + } + + // Move to the next parent + pid_t grandparent = get_parent_pid(parent_pid); + if (grandparent == 0 || grandparent == parent_pid) { break; } + parent_pid = grandparent; + } + + return std::nullopt; +} + +} // namespace detail + +struct nsys_launcher { + nsys_launcher() + { + std::lock_guard lock(mtx); + nsys_exe = detail::detect_nsys_launch(); + } + + bool is_enabled() const + { + std::lock_guard lock(mtx); + return nsys_exe.has_value(); + } + + bool start(const std::string& output_path) const + { + std::lock_guard lock(mtx); + if (nsys_exe.has_value()) { + std::string cmd = + nsys_exe.value() + " start --export=sqlite -o " + output_path + " >/dev/null 2>&1"; + auto res = system(cmd.c_str()); + if (res != 0) { + log_warn( + "Failed to start nsys: %s with error %d. Disabling profiler stats.", cmd.c_str(), res); + nsys_exe.reset(); + return false; + } + return true; + } + return false; + } + + bool stop() const + { + std::lock_guard lock(mtx); + if (nsys_exe.has_value()) { + std::string cmd = nsys_exe.value() + " stop >/dev/null 2>&1"; + auto res = system(cmd.c_str()); + if (res != 0) { + log_warn( + "Failed to start nsys: %s with error %d. Disabling profiler stats.", cmd.c_str(), res); + nsys_exe.reset(); + return false; + } + return true; + } + return false; + } + + private: + mutable std::mutex mtx; + mutable std::optional nsys_exe; +}; + +/** + * @brief Returns the nsys executable path if launched via 'nsys launch'. + * + * Detects if the program is running under 'nsys launch' by walking up the process tree. + * Returns std::nullopt for 'nsys profile' or other modes to avoid interference. + * + */ +inline const nsys_launcher& get_nsys_launcher() +{ + static const nsys_launcher nsys_launcher; + return nsys_launcher; +} + +struct nvtx_stats { + explicit nvtx_stats(::benchmark::State& state, + double min_time_ratio = 0.01, + bool debug_logs_enabled = false) + : state_(state), min_time_ratio(min_time_ratio), debug_logs_enabled(debug_logs_enabled) + { + if (state_.thread_index() != 0) { return; } + if (get_nsys_launcher().is_enabled()) { get_nsys_launcher().start(report_path); } + } + + ~nvtx_stats() + { + if (state_.thread_index() != 0) { return; } + if (!get_nsys_launcher().is_enabled()) { return; } + + // Stop nsys profiling + if (!get_nsys_launcher().stop()) { return; } + + auto sql_start = std::chrono::high_resolution_clock::now(); + if (debug_logs_enabled) { log_info("Extracting NVTX stats from SQLite database..."); } + + // Extract NVTX statistics from SQLite database + std::string sqlite_file = report_path + ".sqlite"; + std::string nsys_file = report_path + ".nsys-rep"; + + auto stats = detail::extract_nvtx_stats_from_sqlite(sqlite_file); + auto [cpu_times, gpu_times] = detail::filter_stats(stats, min_time_ratio); + + auto sql_end = std::chrono::high_resolution_clock::now(); + auto sql_duration = std::chrono::duration_cast(sql_end - sql_start); + if (debug_logs_enabled) { + log_info("NVTX stats SQL query took %d ms (%zu CPU ranges, %zu GPU ranges)", + static_cast(sql_duration.count()), + cpu_times.size(), + gpu_times.size()); + } + + // Insert counters into benchmark state + for (const auto& [range_name, cpu_time] : cpu_times) { + state_.counters.insert( + {{"CPU::" + range_name, {cpu_time, benchmark::Counter::kAvgIterations}}}); + } + + for (const auto& [range_name, gpu_time] : gpu_times) { + state_.counters.insert( + {{"GPU::" + range_name, {gpu_time, benchmark::Counter::kAvgIterations}}}); + } + + // Clean up generated files (ignore errors if files don't exist) + std::remove(sqlite_file.c_str()); + std::remove(nsys_file.c_str()); + } + + private: + std::string report_path = + std::filesystem::temp_directory_path() / + ("nvtx_stats_" + std::to_string(getpid()) + "_" + + std::to_string(std::chrono::steady_clock::now().time_since_epoch().count())); + ::benchmark::State& state_; + double min_time_ratio; + bool debug_logs_enabled; +}; + +}; // namespace cuvs::bench diff --git a/cpp/cmake/thirdparty/get_sqlite.cmake b/cpp/cmake/thirdparty/get_sqlite.cmake new file mode 100644 index 0000000000..0cda9d1157 --- /dev/null +++ b/cpp/cmake/thirdparty/get_sqlite.cmake @@ -0,0 +1,44 @@ +#============================================================================= +# cmake-format: off +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +# cmake-format: on +#============================================================================= + +function(find_and_configure_sqlite) + set(oneValueArgs VERSION YEAR) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + # SQLite amalgamation is distributed as a single .c file with a header + # We'll fetch it and create a static library + rapids_cpm_find(sqlite3 ${PKG_VERSION} + GLOBAL_TARGETS sqlite3 + CPM_ARGS + URL https://www.sqlite.org/${PKG_YEAR}/sqlite-amalgamation-${PKG_VERSION}.zip + DOWNLOAD_ONLY YES + ) + + if(sqlite3_ADDED) + message(VERBOSE "cuVS: Using SQLite3 amalgamation from ${sqlite3_SOURCE_DIR}") + + # Create a static library from the amalgamation + add_library(sqlite3 STATIC ${sqlite3_SOURCE_DIR}/sqlite3.c) + + target_include_directories(sqlite3 PUBLIC + $ + $ + ) + + target_link_libraries(sqlite3 PUBLIC dl) + + set_target_properties(sqlite3 PROPERTIES EXCLUDE_FROM_ALL ON) + else() + message(VERBOSE "cuVS: Using SQLite3 located in ${sqlite3_DIR}") + endif() + +endfunction() + +find_and_configure_sqlite( + VERSION 3470200 + YEAR 2024 +) From 90c57404e2a1ec3315061d1a2b912e82ff99918e Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 13 Nov 2025 11:45:04 -0800 Subject: [PATCH 05/32] Check stride information in from_dlpack (#1458) When converting from a DLManagedTensor to a mdspan in our c-api, we weren't checking the stride information on the dlmanaged tensor is the c-api. This caused invalid results when passing a strided matrix to functions like cuvsCagraBuild. Fix and add a unittest. Authors: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuvs/pull/1458 --- c/src/core/detail/interop.hpp | 72 +++++++++++++++++++---------------- c/tests/core/interop.cu | 55 ++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 32 deletions(-) diff --git a/c/src/core/detail/interop.hpp b/c/src/core/detail/interop.hpp index f41ae3d9a8..3b94feb78d 100644 --- a/c/src/core/detail/interop.hpp +++ b/c/src/core/detail/interop.hpp @@ -58,11 +58,51 @@ inline bool is_dlpack_host_compatible(DLTensor tensor) tensor.device.device_type == kDLCPU; } +inline bool is_f_contiguous(DLManagedTensor* managed_tensor) +{ + auto tensor = managed_tensor->dl_tensor; + + if (!tensor.strides) { return false; } + int64_t expected_stride = 1; + for (int64_t i = 0; i < tensor.ndim; ++i) { + if (tensor.strides[i] != expected_stride) { return false; } + expected_stride *= tensor.shape[i]; + } + + return true; +} + +inline bool is_c_contiguous(DLManagedTensor* managed_tensor) +{ + auto tensor = managed_tensor->dl_tensor; + + if (!tensor.strides) { + // no stride information indicates a row-major tensor according to the dlpack spec + return true; + } + + int64_t expected_stride = 1; + for (int64_t i = tensor.ndim - 1; i >= 0; --i) { + if (tensor.strides[i] != expected_stride) { return false; } + expected_stride *= tensor.shape[i]; + } + + return true; +} + template > inline MdspanType from_dlpack(DLManagedTensor* managed_tensor) { auto tensor = managed_tensor->dl_tensor; + if constexpr (std::is_same_v) { + RAFT_EXPECTS(is_c_contiguous(managed_tensor), "Expected a row-major matrix"); + } + + if constexpr (std::is_same_v) { + RAFT_EXPECTS(is_f_contiguous(managed_tensor), "Expected a col-major matrix"); + } + auto to_data_type = data_type_to_DLDataType(); RAFT_EXPECTS(to_data_type.code == tensor.dtype.code, "code mismatch between return mdspan (%i) and DLTensor (%i)", @@ -98,38 +138,6 @@ inline MdspanType from_dlpack(DLManagedTensor* managed_tensor) return MdspanType{reinterpret_cast(tensor.data), exts}; } -inline bool is_f_contiguous(DLManagedTensor* managed_tensor) -{ - auto tensor = managed_tensor->dl_tensor; - - if (!tensor.strides) { return false; } - int64_t expected_stride = 1; - for (int64_t i = 0; i < tensor.ndim; ++i) { - if (tensor.strides[i] != expected_stride) { return false; } - expected_stride *= tensor.shape[i]; - } - - return true; -} - -inline bool is_c_contiguous(DLManagedTensor* managed_tensor) -{ - auto tensor = managed_tensor->dl_tensor; - - if (!tensor.strides) { - // no stride information indicates a row-major tensor according to the dlpack spec - return true; - } - - int64_t expected_stride = 1; - for (int64_t i = tensor.ndim - 1; i >= 0; --i) { - if (tensor.strides[i] != expected_stride) { return false; } - expected_stride *= tensor.shape[i]; - } - - return true; -} - #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-function" static void free_dlmanaged_tensor_metadata(DLManagedTensor* tensor) diff --git a/c/tests/core/interop.cu b/c/tests/core/interop.cu index 4e193fe5eb..cee075866f 100644 --- a/c/tests/core/interop.cu +++ b/c/tests/core/interop.cu @@ -40,4 +40,59 @@ TEST(Interop, FromDLPack) ASSERT_EQ(out(1), data(1)); } +TEST(Interop, FromDLPackStrides) +{ + raft::resources res; + auto data = raft::make_host_matrix(res, 3, 2); + data(0, 0) = 1; + data(1, 0) = 2; + data(2, 0) = 3; + data(0, 1) = 4; + data(1, 1) = 5; + data(2, 1) = 6; + + auto device = DLDevice{kDLCPU}; + auto data_type = DLDataType{kDLFloat, 4 * 8, 1}; + auto shape = std::vector{3 ,2}; + + auto tensor = DLTensor{data.data_handle(), device, 2, data_type, shape.data()}; + auto managed_tensor = DLManagedTensor{tensor}; + + // converting a 2D dltensor to a 1D mspan should fail + using vector_mdspan_type = raft::host_mdspan>; + ASSERT_THROW(from_dlpack(&managed_tensor), raft::logic_error); + + // No stride information in the dltensor indicates row major + using mdspan_type = raft::host_matrix_view; + auto out = from_dlpack(&managed_tensor); + ASSERT_EQ(out.rank(), data.rank()); + ASSERT_EQ(out.extent(0), data.extent(0)); + ASSERT_EQ(out.extent(1), data.extent(1)); + for (int64_t row = 0; row < data.extent(0); row++) { + for (int64_t col = 0; col < data.extent(1); col++) { + ASSERT_EQ(out(row, col), data(row, col)); + } + } + + // asking for a col-major mdspan should fail if no strides are present + using colmajor_mdspan_type = raft::host_matrix_view; + ASSERT_THROW(from_dlpack(&managed_tensor), raft::logic_error); + + // Setting strides equal to row major should also work + auto strides = std::vector{2, 1}; + managed_tensor.dl_tensor.strides = strides.data(); + auto out_strided = from_dlpack(&managed_tensor); + ASSERT_EQ(out_strided.rank(), data.rank()); + + // Setting strides indicating col-major should be able to convert to a col-major + // mdspan + auto strides_colmajor = std::vector{1, 3}; + managed_tensor.dl_tensor.strides = strides_colmajor.data(); + auto out_colmajor = from_dlpack(&managed_tensor); + ASSERT_EQ(out_colmajor.rank(), data.rank()); + + // But shouldn't be able to convert to a row-major + ASSERT_THROW(from_dlpack(&managed_tensor), raft::logic_error); +} + } // namespace cuvs::core From a8813b137d04b34932decb036d14c9de23cca1b5 Mon Sep 17 00:00:00 2001 From: Nate Rock Date: Thu, 13 Nov 2025 15:52:24 -0600 Subject: [PATCH 06/32] refactored update-version.sh to handle new branching strategy (#1535) This PR supports handling the new main branch strategy outlined below: * [RSN 47 - Changes to RAPIDS branching strategy in 25.12](https://docs.rapids.ai/notices/rsn0047/) The `update-version.sh` script should now supports two modes controlled via `CLI` params or `ENV` vars: CLI arguments: `--run-context=main|release` ENV var `RAPIDS_RUN_CONTEXT=main|release` xref: https://github.com/rapidsai/build-planning/issues/224 Authors: - Nate Rock (https://github.com/rockhowse) Approvers: - Jake Awe (https://github.com/AyodeAwe) - Corey J. Nolet (https://github.com/cjnolet) - MithunR (https://github.com/mythrocks) URL: https://github.com/rapidsai/cuvs/pull/1535 --- README.md | 4 +- ci/release/update-version.sh | 94 ++++++++++++++++--- docs/source/developer_guide.md | 4 +- python/cuvs_bench/cuvs_bench/plot/__main__.py | 2 +- 4 files changed, 88 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index a54466aa88..5dba0cfc38 100755 --- a/README.md +++ b/README.md @@ -171,7 +171,7 @@ cuvsCagraIndexParamsDestroy(index_params); cuvsResourcesDestroy(res); ``` -For more code examples of the C APIs, including drop-in Cmake project templates, please refer to the [C examples](https://github.com/rapidsai/cuvs/tree/branch-25.12/examples/c) +For more code examples of the C APIs, including drop-in Cmake project templates, please refer to the [C examples](https://github.com/rapidsai/cuvs/tree/main/examples/c) ### Rust API @@ -234,7 +234,7 @@ fn cagra_example() -> Result<()> { } ``` -For more code examples of the Rust APIs, including a drop-in project templates, please refer to the [Rust examples](https://github.com/rapidsai/cuvs/tree/branch-25.12/examples/rust). +For more code examples of the Rust APIs, including a drop-in project templates, please refer to the [Rust examples](https://github.com/rapidsai/cuvs/tree/main/examples/rust). ## Contributing diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index 7d67a23d4e..9842281ed4 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -6,8 +6,66 @@ ######################## ## Usage -# bash update-version.sh +# Primary interface: ./ci/release/update-version.sh --run-context=main|release +# Fallback: Environment variable support for automation needs +# NOTE: Must be run from the root of the repository +# +# CLI args take precedence when both are provided +# If neither RUN_CONTEXT nor --run-context is provided, defaults to main +# +# Examples: +# ./ci/release/update-version.sh --run-context=main 25.12.00 +# ./ci/release/update-version.sh --run-context=release 25.12.00 +# RAPIDS_RUN_CONTEXT=main ./ci/release/update-version.sh 25.12.00 + +# Verify we're running from the repository root +if [[ ! -f "VERSION" ]] || [[ ! -f "ci/release/update-version.sh" ]] || [[ ! -d "python" ]]; then + echo "Error: This script must be run from the root of the cuvs repository" + echo "" + echo "Usage:" + echo " cd /path/to/cuvs" + echo " ./ci/release/update-version.sh --run-context=main|release " + echo "" + echo "Example:" + echo " ./ci/release/update-version.sh --run-context=main 25.12.00" + exit 1 +fi + +# Parse command line arguments +POSITIONAL_ARGS=() +while [[ $# -gt 0 ]]; do + case $1 in + --run-context=*) + CLI_RUN_CONTEXT="${1#*=}" + shift + ;; + *) + POSITIONAL_ARGS+=("$1") + shift + ;; + esac +done +# Restore positional parameters +set -- "${POSITIONAL_ARGS[@]}" + +# Determine RUN_CONTEXT with precedence: CLI > Environment > Default +if [[ -n "${CLI_RUN_CONTEXT:-}" ]]; then + RUN_CONTEXT="${CLI_RUN_CONTEXT}" + echo "Using run-context from CLI: ${RUN_CONTEXT}" +elif [[ -n "${RAPIDS_RUN_CONTEXT:-}" ]]; then + RUN_CONTEXT="${RAPIDS_RUN_CONTEXT}" + echo "Using RUN_CONTEXT from environment: ${RUN_CONTEXT}" +else + RUN_CONTEXT="main" + echo "Using default run-context: ${RUN_CONTEXT}" +fi + +# Validate RUN_CONTEXT +if [[ "${RUN_CONTEXT}" != "main" && "${RUN_CONTEXT}" != "release" ]]; then + echo "Error: Invalid run-context '${RUN_CONTEXT}'. Must be 'main' or 'release'" + exit 1 +fi # Format is YY.MM.PP - no leading 'v' or trailing 'a' NEXT_FULL_TAG=$1 @@ -28,7 +86,14 @@ NEXT_SHORT_TAG=${NEXT_MAJOR}.${NEXT_MINOR} NEXT_SHORT_TAG_PEP440=$(python -c "from packaging.version import Version; print(Version('${NEXT_SHORT_TAG}'))") PATCH_PEP440=$(python -c "from packaging.version import Version; print(Version('${NEXT_PATCH}'))") -echo "Preparing release $CURRENT_TAG => $NEXT_FULL_TAG" +# Determine branch name based on context +if [[ "${RUN_CONTEXT}" == "main" ]]; then + RAPIDS_BRANCH_NAME="main" + echo "Preparing development branch update ${CURRENT_TAG} => ${NEXT_FULL_TAG} (targeting main branch)" +elif [[ "${RUN_CONTEXT}" == "release" ]]; then + RAPIDS_BRANCH_NAME="release/${NEXT_SHORT_TAG}" + echo "Preparing release branch update ${CURRENT_TAG} => ${NEXT_FULL_TAG} (targeting release/${NEXT_SHORT_TAG} branch)" +fi # Inplace sed replace; workaround for Linux and Mac function sed_runner() { @@ -37,7 +102,7 @@ function sed_runner() { # Centralized version file update echo "${NEXT_FULL_TAG}" > VERSION -echo "branch-${NEXT_SHORT_TAG}" > RAPIDS_BRANCH +echo "${RAPIDS_BRANCH_NAME}" > RAPIDS_BRANCH DEPENDENCIES=( dask-cuda @@ -62,23 +127,30 @@ for FILE in python/*/pyproject.toml; do done done +# CI files - context-aware branch references for FILE in .github/workflows/*.yaml; do - sed_runner "/shared-workflows/ s/@.*/@branch-${NEXT_SHORT_TAG}/g" "${FILE}" + sed_runner "/shared-workflows/ s|@.*|@${RAPIDS_BRANCH_NAME}|g" "${FILE}" sed_runner "s/:[0-9]*\\.[0-9]*-/:${NEXT_SHORT_TAG}-/g" "${FILE}" done -sed_runner "/rapidsai\/raft/ s|branch-[0-9][0-9].[0-9][0-9]|branch-${NEXT_SHORT_TAG}|g" docs/source/developer_guide.md - -# Update cuvs-bench Docker image references +# Documentation and code references - context-aware +if [[ "${RUN_CONTEXT}" == "main" ]]; then + # In main context, keep documentation on main (no changes needed) + : +elif [[ "${RUN_CONTEXT}" == "release" ]]; then + # In release context, use release branch for documentation links (word boundaries to avoid partial matches) + sed_runner "/rapidsai\\/cuvs/ s|\\bmain\\b|release/${NEXT_SHORT_TAG}|g" docs/source/developer_guide.md + sed_runner "s|\\bmain\\b|release/${NEXT_SHORT_TAG}|g" README.md + sed_runner "s|\\bmain\\b|release/${NEXT_SHORT_TAG}|g" python/cuvs_bench/cuvs_bench/plot/__main__.py +fi + +# Update cuvs-bench Docker image references (version-only, not branch-related) sed_runner "s|rapidsai/cuvs-bench:[0-9][0-9].[0-9][0-9]|rapidsai/cuvs-bench:${NEXT_SHORT_TAG}|g" docs/source/cuvs_bench/index.rst +# Version references (not branch-related) sed_runner "s|=[0-9][0-9].[0-9][0-9]|=${NEXT_SHORT_TAG}|g" README.md -sed_runner "s|branch-[0-9][0-9].[0-9][0-9]|branch-${NEXT_SHORT_TAG}|g" README.md sed_runner "s|@v[0-9][0-9].[0-9][0-9].[0-9][0-9]|@v${NEXT_FULL_TAG}|g" examples/go/README.md -# references to license files -sed_runner "s|branch-[0-9][0-9].[0-9][0-9]|branch-${NEXT_SHORT_TAG}|g" python/cuvs_bench/cuvs_bench/plot/__main__.py - # rust can't handle leading 0's in the major/minor/patch version - remove NEXT_FULL_RUST_TAG=$(printf "%d.%d.%d" $((10#$NEXT_MAJOR)) $((10#$NEXT_MINOR)) $((10#$NEXT_PATCH))) sed_runner "s/version = \".*\"/version = \"${NEXT_FULL_RUST_TAG}\"/g" rust/Cargo.toml diff --git a/docs/source/developer_guide.md b/docs/source/developer_guide.md index 48eafc2f62..da50a44d27 100644 --- a/docs/source/developer_guide.md +++ b/docs/source/developer_guide.md @@ -187,7 +187,7 @@ RAFT relies on `clang-format` to enforce code style across all C++ and CUDA sour 1. Do not split empty functions/records/namespaces. 2. Two-space indentation everywhere, including the line continuations. 3. Disable reflowing of comments. - The reasons behind these deviations from the Google style guide are given in comments [here](https://github.com/rapidsai/raft/blob/branch-25.12/cpp/.clang-format). + The reasons behind these deviations from the Google style guide are given in comments [here](https://github.com/rapidsai/cuvs/blob/main/cpp/.clang-format). [`doxygen`](https://doxygen.nl/) is used as documentation generator and also as a documentation linter. In order to run doxygen as a linter on C++/CUDA code, run @@ -205,7 +205,7 @@ you can run `codespell -i 3 -w .` from the repository root directory. This will bring up an interactive prompt to select which spelling fixes to apply. ### #include style -[include_checker.py](https://github.com/rapidsai/raft/blob/branch-25.12/cpp/scripts/include_checker.py) is used to enforce the include style as follows: +[include_checker.py](https://github.com/rapidsai/cuvs/blob/main/cpp/scripts/include_checker.py) is used to enforce the include style as follows: 1. `#include "..."` should be used for referencing local files only. It is acceptable to be used for referencing files in a sub-folder/parent-folder of the same algorithm, but should never be used to include files in other algorithms or between algorithms and the primitives or other dependencies. 2. `#include <...>` should be used for referencing everything else diff --git a/python/cuvs_bench/cuvs_bench/plot/__main__.py b/python/cuvs_bench/cuvs_bench/plot/__main__.py index 19937caa16..6d9d9cb4cd 100644 --- a/python/cuvs_bench/cuvs_bench/plot/__main__.py +++ b/python/cuvs_bench/cuvs_bench/plot/__main__.py @@ -6,7 +6,7 @@ # 1: https://github.com/erikbern/ann-benchmarks/blob/main/plot.py # 2: https://github.com/erikbern/ann-benchmarks/blob/main/ann_benchmarks/plotting/utils.py # noqa: E501 # 3: https://github.com/erikbern/ann-benchmarks/blob/main/ann_benchmarks/plotting/metrics.py # noqa: E501 -# License: https://github.com/rapidsai/cuvs/blob/branch-25.12/thirdparty/LICENSES/LICENSE.ann-benchmark # noqa: E501 +# License: https://github.com/rapidsai/cuvs/blob/main/thirdparty/LICENSES/LICENSE.ann-benchmark # noqa: E501 import itertools import os From 006f9f1bab71e576bd1d0c835c88bfb3c7a16b9c Mon Sep 17 00:00:00 2001 From: Julian Miller Date: Thu, 13 Nov 2025 22:53:34 +0100 Subject: [PATCH 07/32] Add Augmented Core Extraction Algorithm (#1404) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR introduces **Augmented Core Extraction (ACE)**, an approach proposed by @anaruse for building CAGRA indices on very large datasets that exceed GPU memory capacity. ACE enables users to build high-quality approximate nearest neighbor search indices on datasets that would otherwise be impossible to process on a single GPU. The approach uses the host memory if large enough and falls back to the disk if required. This work is a collaboration: @anaruse, @tfeher, @achirkin, @mfoerste4 ## Algorithm Description 1. **Dataset Partitioning**: The dataset is partitioned using balanced k-means clustering on sampled data. Each vector is assigned to its two closest partition centroids (primary and augmented). The primary partitions are non-overlapping. The augmentation ensures that cross-partition edges are captured in the final graph. Partitions smaller than a minimum threshold are automatically merged with larger partitions to ensure computational efficiency and graph quality. Vectors from small partitions are reassigned to the nearest valid partitions. 2. **Per-Partition Graph Building**: For each partition, a sub-index is built independently (regular `build_knn_graph()` flow) with its primary vectors plus augmented vectors from neighboring partitions. 3. **Graph Combining**: The per-partition graphs are combined into a single unified CAGRA index. Merging is not needed since the primary partitions are non-overlapping. The in-memory variant remaps the local partition IDs to global dataset IDs to create a correct index. The disk variant stores the backward index mappings (`dataset_mapping.bin`), the reordered dataset (`reordered_dataset.bin`) and the optimized CAGRA graph (`cagra_graph.bin`) on disk. The index is then incomplete as show by `cuvs::neighbors::index::on_disk()`. The files are stored in `cuvs::neighbors::index::file_directory()`. The HNSW index serialization was provided by @mfoerste4 in #1410, which was merged here. This adds the `serialize_to_hnsw()` serialization routine that allows combination of dataset, graph, and mapping. The data will be combined on-the-fly while streamed from disk to disk while trying to minimize the required host memory. The host needs enough memory to hold the index though. ## Core Components - **`ace_build()`**: Main routine which users should call. - **`ace_get_partition_labels()`**: Performs balanced k-means clustering to assign each vector to two closest partitions while handling small partition merging. - **`ace_create_forward_and_backward_lists()`**: Creates bidirectional ID mappings between original dataset indices and reordered partition-local indices. - **`ace_set_index_params()`**: Set the index parameters based on the partition and augmented dataset to ensure an efficient KNN graph building. - **`ace_gather_partition_dataset()`**: In-memory only: gather the partition and augmented dataset. - **`ace_adjust_sub_graph_ids`**: In-memory only: Adjust ids in sub search graph and store them into the main search graph. - **`ace_adjust_final_graph_ids`**: In-memory only: Map graph neighbor IDs from reordered space back to original vector IDs. - **`ace_reorder_and_store_dataset`**: Disk only: Reorder the dataset based on partitions and store to disk. Uses write buffers to improve performance. - **`ace_load_partition_dataset_from_disk`**: Disk only: Load partition dataset and augmented dataset from disk. - **`file_descriptor` and `ace_read_large_file()` / `ace_write_large_file()`**: RAII file handle and chunked file I/O operations. - **CAGRA index changes**: Added `on_disk_` flag and `file_directory_` to the CAGRA index structure to support disk-backed indices. - **CAGRA parameter changes**: Added `ace_npartitions` and `ace_build_dir` to the CAGRA parameters for users to specify that ACE should be used and which directory should be used if required. ## Usage ### C++ API ```cpp #include using namespace cuvs::neighbors; // Configure index parameters cagra::index_params params; params.ace_npartitions = 10; // Number of partitions (unset or <= 1 to disable ACE) params.ace_build_dir = "/tmp/ace_build"; // Directory for intermediate files (should be a fast NVMe) params.graph_degree = 64; params.intermediate_graph_degree = 128; // Build ACE index (dataset can be on host memory) auto dataset = raft::make_host_matrix(n_rows, n_cols); // ... load dataset ... auto index = cagra::build_ace(res, params, dataset.view(), params.ace_npartitions); // Search works identically to standard CAGRA if the host has enough memory (index.on_disk() == false) cagra::search_params search_params; auto neighbors = raft::make_device_matrix(res, n_queries, k); auto distances = raft::make_device_matrix(res, n_queries, k); cagra::search(res, search_params, index, queries, neighbors.view(), distances.view()); ``` ### Storage Requirements 1. `cagra_graph.bin`: `n_rows * graph_degree * sizeof(IdxT)` 2. `dataset_mapping.bin`: `n_rows * sizeof(IdxT)` 2. `reordered_dataset.bin`: Size of the input dataset 3. `augmented_dataset.bin`: Size of the input dataset Authors: - Julian Miller (https://github.com/julianmi) - Anupam (https://github.com/aamijar) - Tarang Jain (https://github.com/tarang-jain) - Malte Förster (https://github.com/mfoerste4) - Jake Awe (https://github.com/AyodeAwe) - Bradley Dice (https://github.com/bdice) - Artem M. Chirkin (https://github.com/achirkin) - Jinsol Park (https://github.com/jinsolp) Approvers: - MithunR (https://github.com/mythrocks) - Robert Maynard (https://github.com/robertmaynard) - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuvs/pull/1404 --- c/include/cuvs/neighbors/cagra.h | 94 +- c/src/neighbors/cagra.cpp | 76 +- c/tests/neighbors/ann_cagra_c.cu | 223 +++ cpp/CMakeLists.txt | 3 + .../src/cuvs/cuvs_ann_bench_param_parser.h | 22 + cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib.cu | 20 +- .../ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h | 23 +- cpp/include/cuvs/neighbors/cagra.hpp | 328 +++- .../cuvs/neighbors/graph_build_types.hpp | 41 +- cpp/include/cuvs/neighbors/hnsw.hpp | 5 + cpp/include/cuvs/util/file_io.hpp | 243 +++ cpp/include/cuvs/util/host_memory.hpp | 25 + cpp/src/neighbors/cagra.cuh | 16 +- cpp/src/neighbors/cagra_build_float.cu | 40 +- cpp/src/neighbors/cagra_build_half.cu | 10 +- cpp/src/neighbors/cagra_build_int8.cu | 40 +- cpp/src/neighbors/cagra_build_uint8.cu | 40 +- cpp/src/neighbors/detail/cagra/add_nodes.cuh | 5 + .../neighbors/detail/cagra/cagra_build.cuh | 1436 ++++++++++++++++- .../neighbors/detail/cagra/cagra_search.cuh | 5 + .../detail/cagra/cagra_serialize.cuh | 9 + cpp/src/neighbors/detail/cagra/graph_core.cuh | 19 +- cpp/src/neighbors/detail/hnsw.hpp | 588 ++++++- cpp/src/util/file_io.cpp | 81 + cpp/src/util/host_memory.cpp | 29 + cpp/tests/CMakeLists.txt | 28 + cpp/tests/neighbors/ann_cagra_ace.cuh | 270 ++++ .../ann_cagra_ace/test_float_uint32_t.cu | 17 + .../ann_cagra_ace/test_half_uint32_t.cu | 17 + .../ann_cagra_ace/test_int8_t_uint32_t.cu | 17 + .../ann_cagra_ace/test_uint8_t_uint32_t.cu | 17 + docs/source/cuvs_bench/param_tuning.rst | 30 +- examples/cpp/CMakeLists.txt | 2 + examples/cpp/src/cagra_hnsw_ace_example.cu | 182 +++ .../com/nvidia/cuvs/CagraIndexParams.java | 32 +- .../java/com/nvidia/cuvs/CuVSAceParams.java | 184 +++ .../main/java/com/nvidia/cuvs/HnswIndex.java | 15 + .../com/nvidia/cuvs/spi/CuVSProvider.java | 11 + .../nvidia/cuvs/spi/UnsupportedProvider.java | 6 + .../nvidia/cuvs/internal/CagraIndexImpl.java | 33 + .../cuvs/internal/CuVSParamsHelper.java | 19 + .../nvidia/cuvs/internal/HnswIndexImpl.java | 115 ++ .../com/nvidia/cuvs/internal/common/Util.java | 18 + .../com/nvidia/cuvs/spi/JDKProvider.java | 6 + .../nvidia/cuvs/CagraAceBuildAndSearchIT.java | 243 +++ python/cuvs/cuvs/neighbors/cagra/__init__.py | 2 + python/cuvs/cuvs/neighbors/cagra/cagra.pxd | 15 +- python/cuvs/cuvs/neighbors/cagra/cagra.pyx | 215 ++- python/cuvs/cuvs/neighbors/hnsw/hnsw.pxd | 3 +- python/cuvs/cuvs/tests/test_cagra_ace.py | 173 ++ 50 files changed, 4944 insertions(+), 147 deletions(-) create mode 100644 cpp/include/cuvs/util/file_io.hpp create mode 100644 cpp/include/cuvs/util/host_memory.hpp create mode 100644 cpp/src/util/file_io.cpp create mode 100644 cpp/src/util/host_memory.cpp create mode 100644 cpp/tests/neighbors/ann_cagra_ace.cuh create mode 100644 cpp/tests/neighbors/ann_cagra_ace/test_float_uint32_t.cu create mode 100644 cpp/tests/neighbors/ann_cagra_ace/test_half_uint32_t.cu create mode 100644 cpp/tests/neighbors/ann_cagra_ace/test_int8_t_uint32_t.cu create mode 100644 cpp/tests/neighbors/ann_cagra_ace/test_uint8_t_uint32_t.cu create mode 100644 examples/cpp/src/cagra_hnsw_ace_example.cu create mode 100644 java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSAceParams.java create mode 100644 java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraAceBuildAndSearchIT.java create mode 100644 python/cuvs/cuvs/tests/test_cagra_ace.py diff --git a/c/include/cuvs/neighbors/cagra.h b/c/include/cuvs/neighbors/cagra.h index 58025ef359..487ada503d 100644 --- a/c/include/cuvs/neighbors/cagra.h +++ b/c/include/cuvs/neighbors/cagra.h @@ -34,7 +34,14 @@ enum cuvsCagraGraphBuildAlgo { /* Experimental, use NN-Descent to build all-neighbors knn graph */ NN_DESCENT = 2, /* Experimental, use iterative cagra search and optimize to build the knn graph */ - ITERATIVE_CAGRA_SEARCH = 3 + ITERATIVE_CAGRA_SEARCH = 3, + /** + * Experimental, use ACE (Augmented Core Extraction) to build the graph. ACE partitions the + * dataset into core and augmented partitions and builds a sub-index for each partition. This + * enables building indices for datasets too large to fit in GPU or host memory. + * See cuvsAceParams for more details about the ACE algorithm and its parameters. + */ + ACE = 4 }; /** @@ -118,6 +125,52 @@ struct cuvsIvfPqParams { typedef struct cuvsIvfPqParams* cuvsIvfPqParams_t; +/** + * Parameters for ACE (Augmented Core Extraction) graph build. + * ACE enables building indices for datasets too large to fit in GPU memory by: + * 1. Partitioning the dataset in core (closest) and augmented (second-closest) + * partitions using balanced k-means. + * 2. Building sub-indices for each partition independently + * 3. Concatenating sub-graphs into a final unified index + */ +struct cuvsAceParams { + /** + * Number of partitions for ACE (Augmented Core Extraction) partitioned build. + * + * Small values might improve recall but potentially degrade performance and + * increase memory usage. Partitions should not be too small to prevent issues + * in KNN graph construction. 100k - 5M vectors per partition is recommended + * depending on the available host and GPU memory. The partition size is on + * average 2 * (n_rows / npartitions) * dim * sizeof(T). 2 is because of the + * core and augmented vectors. Please account for imbalance in the partition + * sizes (up to 3x in our tests). + */ + size_t npartitions; + /** + * The index quality for the ACE build. + * + * Bigger values increase the index quality. At some point, increasing this will no longer + * improve the quality. + */ + size_t ef_construction; + /** + * Directory to store ACE build artifacts (e.g., KNN graph, optimized graph). + * + * Used when `use_disk` is true or when the graph does not fit in host and GPU + * memory. This should be the fastest disk in the system and hold enough space + * for twice the dataset, final graph, and label mapping. + */ + const char* build_dir; + /** + * Whether to use disk-based storage for ACE build. + * + * When true, enables disk-based operations for memory-efficient graph construction. + */ + bool use_disk; +}; + +typedef struct cuvsAceParams* cuvsAceParams_t; + /** * @brief Supplemental parameters to build CAGRA Index * @@ -140,9 +193,12 @@ struct cuvsCagraIndexParams { */ cuvsCagraCompressionParams_t compression; /** - * Optional: specify ivf pq params when `build_algo = IVF_PQ` + * Optional: specify graph build params based on build_algo + * - IVF_PQ: cuvsIvfPqParams_t + * - ACE: cuvsAceParams_t + * - Others: nullptr */ - cuvsIvfPqParams_t graph_build_params; + void* graph_build_params; }; typedef struct cuvsCagraIndexParams* cuvsCagraIndexParams_t; @@ -179,6 +235,38 @@ cuvsError_t cuvsCagraCompressionParamsCreate(cuvsCagraCompressionParams_t* param */ cuvsError_t cuvsCagraCompressionParamsDestroy(cuvsCagraCompressionParams_t params); +/** + * @brief Allocate ACE params, and populate with default values + * + * @param[in] params cuvsAceParams_t to allocate + * @return cuvsError_t + */ +cuvsError_t cuvsAceParamsCreate(cuvsAceParams_t* params); + +/** + * @brief De-allocate ACE params + * + * @param[in] params + * @return cuvsError_t + */ +cuvsError_t cuvsAceParamsDestroy(cuvsAceParams_t params); + +/** + * @brief Allocate ACE params, and populate with default values + * + * @param[in] params cuvsAceParams_t to allocate + * @return cuvsError_t + */ +cuvsError_t cuvsAceParamsCreate(cuvsAceParams_t* params); + +/** + * @brief De-allocate ACE params + * + * @param[in] params + * @return cuvsError_t + */ +cuvsError_t cuvsAceParamsDestroy(cuvsAceParams_t params); + /** * @brief Create CAGRA index parameters similar to an HNSW index * diff --git a/c/src/neighbors/cagra.cpp b/c/src/neighbors/cagra.cpp index b8242fe07a..611ef3e086 100644 --- a/c/src/neighbors/cagra.cpp +++ b/c/src/neighbors/cagra.cpp @@ -4,6 +4,7 @@ */ #include +#include #include #include @@ -17,6 +18,7 @@ #include #include #include +#include #include "../core/exceptions.hpp" #include "../core/interop.hpp" @@ -30,6 +32,7 @@ static void _set_graph_build_params( std::variant& out_params, cuvsCagraIndexParams& params, cuvsCagraGraphBuildAlgo algo, @@ -81,6 +84,18 @@ static void _set_graph_build_params( out_params = nn_params; break; } + case cuvsCagraGraphBuildAlgo::ACE: { + cuvs::neighbors::cagra::graph_build_params::ace_params ace_p; + if (params.graph_build_params) { + auto ace_params_c = static_cast(params.graph_build_params); + ace_p.npartitions = ace_params_c->npartitions; + ace_p.ef_construction = ace_params_c->ef_construction; + ace_p.build_dir = std::string(ace_params_c->build_dir); + ace_p.use_disk = ace_params_c->use_disk; + } + out_params = ace_p; + break; + } case cuvsCagraGraphBuildAlgo::ITERATIVE_CAGRA_SEARCH: { cuvs::neighbors::cagra::graph_build_params::iterative_search_params p; out_params = p; @@ -388,7 +403,19 @@ static void _populate_cagra_index_params_from_cpp(cuvsCagraIndexParams_t c_param std::get( cpp_params.graph_build_params); - _populate_c_ivf_pq_params(c_params->graph_build_params, ivf_pq_params); + _populate_c_ivf_pq_params(static_cast(c_params->graph_build_params), ivf_pq_params); + } else if (std::holds_alternative( + cpp_params.graph_build_params)) { + c_params->build_algo = ACE; + auto ace_params = + std::get( + cpp_params.graph_build_params); + cuvsAceParams* c_ace_params = new cuvsAceParams; + c_ace_params->npartitions = ace_params.npartitions; + c_ace_params->ef_construction = ace_params.ef_construction; + c_ace_params->build_dir = ace_params.build_dir.empty() ? nullptr : strdup(ace_params.build_dir.c_str()); + c_ace_params->use_disk = ace_params.use_disk; + c_params->graph_build_params = c_ace_params; } } @@ -700,7 +727,26 @@ extern "C" cuvsError_t cuvsCagraIndexParamsCreate(cuvsCagraIndexParams_t* params extern "C" cuvsError_t cuvsCagraIndexParamsDestroy(cuvsCagraIndexParams_t params) { return cuvs::core::translate_exceptions([=] { - delete params->graph_build_params; + // Delete graph_build_params based on the build algorithm type + if (params->graph_build_params != nullptr) { + switch (params->build_algo) { + case cuvsCagraGraphBuildAlgo::IVF_PQ: + delete static_cast(params->graph_build_params); + break; + case cuvsCagraGraphBuildAlgo::ACE: { + auto ace_params = static_cast(params->graph_build_params); + // Free the allocated build directory string + if (ace_params->build_dir) { free(const_cast(ace_params->build_dir)); } + delete ace_params; + break; + } + case cuvsCagraGraphBuildAlgo::AUTO_SELECT: + case cuvsCagraGraphBuildAlgo::NN_DESCENT: + case cuvsCagraGraphBuildAlgo::ITERATIVE_CAGRA_SEARCH: + // These algorithms don't have separate parameter structs + break; + } + } delete params; }); } @@ -724,6 +770,32 @@ extern "C" cuvsError_t cuvsCagraCompressionParamsDestroy(cuvsCagraCompressionPar return cuvs::core::translate_exceptions([=] { delete params; }); } +extern "C" cuvsError_t cuvsAceParamsCreate(cuvsAceParams_t* params) +{ + return cuvs::core::translate_exceptions([=] { + auto ps = cuvs::neighbors::cagra::graph_build_params::ace_params(); + + // Allocate and copy the build directory string + const char* build_dir = strdup(ps.build_dir.c_str()); + + *params = new cuvsAceParams{.npartitions = ps.npartitions, + .ef_construction = ps.ef_construction, + .build_dir = build_dir, + .use_disk = ps.use_disk}; + }); +} + +extern "C" cuvsError_t cuvsAceParamsDestroy(cuvsAceParams_t params) +{ + return cuvs::core::translate_exceptions([=] { + if (params) { + // Free the allocated build directory string + if (params->build_dir) { free(const_cast(params->build_dir)); } + delete params; + } + }); +} + extern "C" cuvsError_t cuvsCagraIndexParamsFromHnswParams(cuvsCagraIndexParams_t params, int64_t n_rows, int64_t dim, diff --git a/c/tests/neighbors/ann_cagra_c.cu b/c/tests/neighbors/ann_cagra_c.cu index ab46c8b877..31f0e79e80 100644 --- a/c/tests/neighbors/ann_cagra_c.cu +++ b/c/tests/neighbors/ann_cagra_c.cu @@ -10,7 +10,9 @@ #include #include +#include #include +#include #include #include @@ -44,6 +46,9 @@ float distances_exp[4] = {0.03878258, 0.12472608, 0.04776672, 0.15224178}; uint32_t neighbors_exp_filtered[4] = {3, 0, 3, 0}; float distances_exp_filtered[4] = {0.03878258, 0.12472608, 0.04776672, 0.59063464}; +std::vector neighbors_exp_disk = {3, 0, 3, 1}; +std::vector distances_exp_disk = {0.03878258, 0.12472608, 0.04776672, 0.15224178}; + TEST(CagraC, BuildSearch) { // create cuvsResources_t @@ -565,3 +570,221 @@ TEST(CagraC, BuildMergeSearch) cuvsCagraIndexDestroy(index_main); cuvsResourcesDestroy(res); } + +TEST(CagraC, BuildSearchACEMemory) +{ + // create cuvsResources_t + cuvsResources_t res; + cuvsResourcesCreate(&res); + cudaStream_t stream; + cuvsStreamGet(res, &stream); + + // create dataset DLTensor + DLManagedTensor dataset_tensor; + dataset_tensor.dl_tensor.data = dataset; + dataset_tensor.dl_tensor.device.device_type = kDLCPU; + dataset_tensor.dl_tensor.ndim = 2; + dataset_tensor.dl_tensor.dtype.code = kDLFloat; + dataset_tensor.dl_tensor.dtype.bits = 32; + dataset_tensor.dl_tensor.dtype.lanes = 1; + int64_t dataset_shape[2] = {4, 2}; + dataset_tensor.dl_tensor.shape = dataset_shape; + dataset_tensor.dl_tensor.strides = nullptr; + + // create index + cuvsCagraIndex_t index; + cuvsCagraIndexCreate(&index); + + // build index with ACE memory mode + cuvsCagraIndexParams_t build_params; + cuvsCagraIndexParamsCreate(&build_params); + build_params->build_algo = ACE; + + cuvsAceParams_t ace_params; + cuvsAceParamsCreate(&ace_params); + ace_params->npartitions = 2; + ace_params->ef_construction = 120; + ace_params->use_disk = false; + + build_params->graph_build_params = ace_params; + cuvsCagraBuild(res, build_params, &dataset_tensor, index); + + // create queries DLTensor + rmm::device_uvector queries_d(4 * 2, stream); + raft::copy(queries_d.data(), (float*)queries, 4 * 2, stream); + + DLManagedTensor queries_tensor; + queries_tensor.dl_tensor.data = queries_d.data(); + queries_tensor.dl_tensor.device.device_type = kDLCUDA; + queries_tensor.dl_tensor.ndim = 2; + queries_tensor.dl_tensor.dtype.code = kDLFloat; + queries_tensor.dl_tensor.dtype.bits = 32; + queries_tensor.dl_tensor.dtype.lanes = 1; + int64_t queries_shape[2] = {4, 2}; + queries_tensor.dl_tensor.shape = queries_shape; + queries_tensor.dl_tensor.strides = nullptr; + + // create neighbors DLTensor + rmm::device_uvector neighbors_d(4, stream); + + DLManagedTensor neighbors_tensor; + neighbors_tensor.dl_tensor.data = neighbors_d.data(); + neighbors_tensor.dl_tensor.device.device_type = kDLCUDA; + neighbors_tensor.dl_tensor.ndim = 2; + neighbors_tensor.dl_tensor.dtype.code = kDLUInt; + neighbors_tensor.dl_tensor.dtype.bits = 32; + neighbors_tensor.dl_tensor.dtype.lanes = 1; + int64_t neighbors_shape[2] = {4, 1}; + neighbors_tensor.dl_tensor.shape = neighbors_shape; + neighbors_tensor.dl_tensor.strides = nullptr; + + // create distances DLTensor + rmm::device_uvector distances_d(4, stream); + + DLManagedTensor distances_tensor; + distances_tensor.dl_tensor.data = distances_d.data(); + distances_tensor.dl_tensor.device.device_type = kDLCUDA; + distances_tensor.dl_tensor.ndim = 2; + distances_tensor.dl_tensor.dtype.code = kDLFloat; + distances_tensor.dl_tensor.dtype.bits = 32; + distances_tensor.dl_tensor.dtype.lanes = 1; + int64_t distances_shape[2] = {4, 1}; + distances_tensor.dl_tensor.shape = distances_shape; + distances_tensor.dl_tensor.strides = nullptr; + + cuvsFilter filter; + filter.type = NO_FILTER; + filter.addr = (uintptr_t)NULL; + + // search index + cuvsCagraSearchParams_t search_params; + cuvsCagraSearchParamsCreate(&search_params); + cuvsCagraSearch( + res, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor, filter); + + // verify output + ASSERT_TRUE( + cuvs::devArrMatchHost(neighbors_exp, neighbors_d.data(), 4, cuvs::Compare())); + ASSERT_TRUE(cuvs::devArrMatchHost( + distances_exp, distances_d.data(), 4, cuvs::CompareApprox(0.001f))); + + // de-allocate index and res + cuvsCagraSearchParamsDestroy(search_params); + cuvsCagraIndexParamsDestroy(build_params); + cuvsCagraIndexDestroy(index); + cuvsResourcesDestroy(res); +} + +TEST(CagraC, BuildSearchACEDisk) +{ + // create cuvsResources_t + cuvsResources_t res; + cuvsResourcesCreate(&res); + + // create dataset DLTensor + DLManagedTensor dataset_tensor; + dataset_tensor.dl_tensor.data = dataset; + dataset_tensor.dl_tensor.device.device_type = kDLCPU; + dataset_tensor.dl_tensor.ndim = 2; + dataset_tensor.dl_tensor.dtype.code = kDLFloat; + dataset_tensor.dl_tensor.dtype.bits = 32; + dataset_tensor.dl_tensor.dtype.lanes = 1; + int64_t dataset_shape[2] = {4, 2}; + dataset_tensor.dl_tensor.shape = dataset_shape; + dataset_tensor.dl_tensor.strides = nullptr; + + // create index + cuvsCagraIndex_t index; + cuvsCagraIndexCreate(&index); + + // build index with ACE memory mode + cuvsCagraIndexParams_t build_params; + cuvsCagraIndexParamsCreate(&build_params); + build_params->build_algo = ACE; + + cuvsAceParams_t ace_params; + cuvsAceParamsCreate(&ace_params); + ace_params->npartitions = 2; + ace_params->ef_construction = 120; + ace_params->use_disk = true; + ace_params->build_dir = strdup("/tmp/cagra_ace_test_disk"); + + build_params->graph_build_params = ace_params; + cuvsCagraBuild(res, build_params, &dataset_tensor, index); + + // Convert CAGRA index to HNSW (automatically serializes to disk for ACE) + cuvsHnswIndex_t hnsw_index_ser; + cuvsHnswIndexCreate(&hnsw_index_ser); + cuvsHnswIndexParams_t hnsw_params; + cuvsHnswIndexParamsCreate(&hnsw_params); + + cuvsHnswFromCagra(res, hnsw_params, index, hnsw_index_ser); + ASSERT_NE(hnsw_index_ser->addr, 0); + cuvsHnswIndexDestroy(hnsw_index_ser); + + DLManagedTensor queries_tensor; + queries_tensor.dl_tensor.data = queries; + queries_tensor.dl_tensor.device.device_type = kDLCPU; + queries_tensor.dl_tensor.ndim = 2; + queries_tensor.dl_tensor.dtype.code = kDLFloat; + queries_tensor.dl_tensor.dtype.bits = 32; + queries_tensor.dl_tensor.dtype.lanes = 1; + int64_t queries_shape[2] = {4, 2}; + queries_tensor.dl_tensor.shape = queries_shape; + queries_tensor.dl_tensor.strides = nullptr; + + // create neighbors DLTensor + std::vector neighbors(4); + + DLManagedTensor neighbors_tensor; + neighbors_tensor.dl_tensor.data = neighbors.data(); + neighbors_tensor.dl_tensor.device.device_type = kDLCPU; + neighbors_tensor.dl_tensor.ndim = 2; + neighbors_tensor.dl_tensor.dtype.code = kDLUInt; + neighbors_tensor.dl_tensor.dtype.bits = 64; + neighbors_tensor.dl_tensor.dtype.lanes = 1; + int64_t neighbors_shape[2] = {4, 1}; + neighbors_tensor.dl_tensor.shape = neighbors_shape; + neighbors_tensor.dl_tensor.strides = nullptr; + + // create distances DLTensor + std::vector distances(4); + + DLManagedTensor distances_tensor; + distances_tensor.dl_tensor.data = distances.data(); + distances_tensor.dl_tensor.device.device_type = kDLCPU; + distances_tensor.dl_tensor.ndim = 2; + distances_tensor.dl_tensor.dtype.code = kDLFloat; + distances_tensor.dl_tensor.dtype.bits = 32; + distances_tensor.dl_tensor.dtype.lanes = 1; + int64_t distances_shape[2] = {4, 1}; + distances_tensor.dl_tensor.shape = distances_shape; + distances_tensor.dl_tensor.strides = nullptr; + + // Deserialize the HNSW index from disk for search + cuvsHnswIndex_t hnsw_index; + cuvsHnswIndexCreate(&hnsw_index); + hnsw_index->dtype = index->dtype; + + // Use the actual dimension from the dataset + int dim = dataset_tensor.dl_tensor.shape[1]; + cuvsHnswDeserialize(res, hnsw_params, "/tmp/cagra_ace_test_disk/hnsw_index.bin", dim, L2Expanded, hnsw_index); + ASSERT_NE(hnsw_index->addr, 0); + + // Search the HNSW index + cuvsHnswSearchParams_t search_params; + cuvsHnswSearchParamsCreate(&search_params); + cuvsHnswSearch( + res, search_params, hnsw_index, &queries_tensor, &neighbors_tensor, &distances_tensor); + + // Verify output + ASSERT_TRUE(cuvs::hostVecMatch(neighbors_exp_disk, neighbors, cuvs::Compare())); + ASSERT_TRUE(cuvs::hostVecMatch(distances_exp_disk, distances, cuvs::CompareApprox(0.001f))); + + cuvsCagraIndexParamsDestroy(build_params); + cuvsCagraIndexDestroy(index); + cuvsHnswSearchParamsDestroy(search_params); + cuvsHnswIndexParamsDestroy(hnsw_params); + cuvsHnswIndexDestroy(hnsw_index); + cuvsResourcesDestroy(res); +} diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6b36b25bb5..31f7227ae4 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -351,6 +351,8 @@ if(NOT BUILD_CPU_ONLY) src/cluster/spectral.cu src/core/bitset.cu src/core/omp_wrapper.cpp + src/util/file_io.cpp + src/util/host_memory.cpp src/distance/detail/kernels/gram_matrix.cu src/distance/detail/kernels/kernel_factory.cu src/distance/detail/kernels/kernel_matrices.cu @@ -442,6 +444,7 @@ if(NOT BUILD_CPU_ONLY) src/neighbors/cagra_index_wrapper.cu src/neighbors/composite/index.cu src/neighbors/composite/merge.cpp + $<$:src/neighbors/cagra.cpp> $<$:src/neighbors/hnsw.cpp> src/neighbors/ivf_flat_index.cpp src/neighbors/ivf_flat/ivf_flat_build_extend_float_int64_t.cu diff --git a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h index 84c6f0628d..d9e7c0f41d 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h @@ -260,6 +260,11 @@ void parse_build_param(const nlohmann::json& conf, cuvs::neighbors::cagra::index params.graph_build_params)) { params.graph_build_params = cuvs::neighbors::graph_build_params::nn_descent_params{}; } + } else if (conf.at("graph_build_algo") == "ACE") { + if (!std::holds_alternative( + params.graph_build_params)) { + params.graph_build_params = cuvs::neighbors::graph_build_params::ace_params{}; + } } } @@ -267,12 +272,15 @@ void parse_build_param(const nlohmann::json& conf, cuvs::neighbors::cagra::index nlohmann::json ivf_pq_build_conf = collect_conf_with_prefix(conf, "ivf_pq_build_"); nlohmann::json ivf_pq_search_conf = collect_conf_with_prefix(conf, "ivf_pq_search_"); nlohmann::json nn_descent_conf = collect_conf_with_prefix(conf, "nn_descent_"); + nlohmann::json ace_conf = collect_conf_with_prefix(conf, "ace_"); if (std::holds_alternative(params.graph_build_params)) { if (!ivf_pq_build_conf.empty() || !ivf_pq_search_conf.empty()) { params.graph_build_params = cuvs::neighbors::graph_build_params::ivf_pq_params{}; } else if (!nn_descent_conf.empty()) { params.graph_build_params = cuvs::neighbors::graph_build_params::nn_descent_params{}; + } else if (!ace_conf.empty()) { + params.graph_build_params = cuvs::neighbors::graph_build_params::ace_params{}; } else { params.graph_build_params = cuvs::neighbors::graph_build_params::iterative_search_params{}; } @@ -328,6 +336,20 @@ void parse_build_param(const nlohmann::json& conf, cuvs::neighbors::cagra::graph_build_params::nn_descent_params( conf.value("intermediate_graph_degree", cagra_params.intermediate_graph_degree), dist_type); + } else if (conf.value("graph_build_algo", "") == "ACE") { + cagra_params.graph_build_params = cuvs::neighbors::cagra::graph_build_params::ace_params{}; + } + // Parse ACE parameters if provided + nlohmann::json ace_conf = collect_conf_with_prefix(conf, "ace_"); + if (!ace_conf.empty()) { + auto ace_params = cuvs::neighbors::cagra::graph_build_params::ace_params(); + if (ace_conf.contains("npartitions")) { ace_params.npartitions = ace_conf.at("npartitions"); } + if (ace_conf.contains("build_dir")) { ace_params.build_dir = ace_conf.at("build_dir"); } + if (ace_conf.contains("ef_construction")) { + ace_params.ef_construction = ace_conf.at("ef_construction"); + } + if (ace_conf.contains("use_disk")) { ace_params.use_disk = ace_conf.at("use_disk"); } + cagra_params.graph_build_params = ace_params; } ::parse_build_param(conf, cagra_params); return cagra_params; diff --git a/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib.cu b/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib.cu index b69b3946b2..113e79fa15 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib.cu @@ -48,9 +48,25 @@ auto parse_build_param(const nlohmann::json& conf) -> // to override them. cagra_params.cagra_params = [conf, hnsw_params](raft::matrix_extent extents, cuvs::distance::DistanceType dist_type) { - auto ps = cuvs::neighbors::hnsw::to_cagra_params( - extents, conf.at("M"), hnsw_params.ef_construction, dist_type); + auto ps = cuvs::neighbors::cagra::index_params::from_hnsw_params( + extents, + conf.at("M"), + hnsw_params.ef_construction, + cuvs::neighbors::cagra::hnsw_heuristic_type::SAME_GRAPH_FOOTPRINT, + dist_type); ps.metric = dist_type; + // Parse ACE parameters if provided + if (conf.contains("npartitions") || conf.contains("build_dir") || + conf.contains("ef_construction") || conf.contains("use_disk")) { + auto ace_params = cuvs::neighbors::cagra::graph_build_params::ace_params(); + if (conf.contains("npartitions")) { ace_params.npartitions = conf.at("npartitions"); } + if (conf.contains("build_dir")) { ace_params.build_dir = conf.at("build_dir"); } + if (conf.contains("ef_construction")) { + ace_params.ef_construction = conf.at("ef_construction"); + } + if (conf.contains("use_disk")) { ace_params.use_disk = conf.at("use_disk"); } + ps.graph_build_params = ace_params; + } // NB: above, we only provide the defaults. Below we parse the explicit parameters as usual. ::parse_build_param(conf, ps); return ps; diff --git a/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h index 1b237f82df..2f0c54e1bd 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h @@ -75,6 +75,8 @@ class cuvs_cagra_hnswlib : public algo, public algo_gpu { build_param build_param_; search_param search_param_; std::shared_ptr> hnsw_index_; + + bool cagra_ace_build_ = false; }; template @@ -110,6 +112,11 @@ void cuvs_cagra_hnswlib::build(const T* dataset, size_t nrow) // convert the index to HNSW format hnsw_index_ = cuvs::neighbors::hnsw::from_cagra( handle_, build_param_.hnsw_index_params, cagra_index, opt_dataset_view); + + // special treatment in save/serialize step + if (cagra_index.dataset_fd().has_value() && cagra_index.graph_fd().has_value()) { + cagra_ace_build_ = true; + } } template @@ -123,7 +130,21 @@ void cuvs_cagra_hnswlib::set_search_param(const search_param_base& para template void cuvs_cagra_hnswlib::save(const std::string& file) const { - cuvs::neighbors::hnsw::serialize(handle_, file, *(hnsw_index_.get())); + if (cagra_ace_build_) { + std::string index_filename = hnsw_index_->file_path(); + RAFT_EXPECTS(!index_filename.empty(), "HNSW index file path is not available."); + RAFT_EXPECTS(std::filesystem::exists(index_filename), + "Index file '%s' does not exist.", + index_filename.c_str()); + if (std::filesystem::exists(file)) { std::filesystem::remove(file); } + // might fail when using 2 different filesystems + std::error_code ec; + std::filesystem::rename(index_filename, file, ec); + RAFT_EXPECTS( + !ec, "Failed to rename index file '%s' to '%s'.", index_filename.c_str(), file.c_str()); + } else { + cuvs::neighbors::hnsw::serialize(handle_, file, *(hnsw_index_.get())); + } } template diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 6192b263c3..fb1b1549af 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -11,17 +11,24 @@ #include #include #include +#include #include #include #include #include #include #include +#include + +#include +#include +#include #include #include #include #include +#include #include namespace cuvs::neighbors::cagra { @@ -79,9 +86,9 @@ struct index_params : cuvs::neighbors::index_params { /** Parameters for graph building. * - * Set ivf_pq_params, nn_descent_params, or iterative_search_params to select the graph build - * algorithm and control their parameters. The default (std::monostate) is to use a heuristic - * to decide the algorithm and its parameters. + * Set ivf_pq_params, nn_descent_params, ace_params, or iterative_search_params to select the + * graph build algorithm and control their parameters. The default (std::monostate) is to use a + * heuristic to decide the algorithm and its parameters. * * @code{.cpp} * cagra::index_params params; @@ -93,7 +100,10 @@ struct index_params : cuvs::neighbors::index_params { * params.graph_build_params = * cagra::graph_build_params::nn_descent_params(params.intermediate_graph_degree); * - * // 3. Choose iterative graph building using CAGRA's search() and optimize() [Experimental] + * // 3. Choose ACE algorithm for graph construction + * params.graph_build_params = cagra::graph_build_params::ace_params(); + * + * // 4. Choose iterative graph building using CAGRA's search() and optimize() [Experimental] * params.graph_build_params = * cagra::graph_build_params::iterative_search_params(); * @endcode @@ -101,9 +111,9 @@ struct index_params : cuvs::neighbors::index_params { std::variant graph_build_params; - /** * Whether to use MST optimization to guarantee graph connectivity. */ @@ -363,15 +373,19 @@ struct index : cuvs::neighbors::index { [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { auto data_rows = dataset_->n_rows(); + if (dataset_fd_.has_value()) { return n_rows_; } return data_rows > 0 ? data_rows : graph_view_.extent(0); } /** Dimensionality of the data. */ - [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { return dataset_->dim(); } + [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t + { + return dataset_fd_.has_value() ? dim_ : dataset_->dim(); + } /** Graph degree */ [[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t { - return graph_view_.extent(1); + return dataset_fd_.has_value() ? graph_degree_ : graph_view_.extent(1); } [[nodiscard]] inline auto dataset() const noexcept @@ -406,6 +420,27 @@ struct index : cuvs::neighbors::index { : std::nullopt; } + /** Get the dataset file descriptor (for disk-backed index) */ + [[nodiscard]] inline auto dataset_fd() const noexcept + -> const std::optional& + { + return dataset_fd_; + } + + /** Get the graph file descriptor (for disk-backed index) */ + [[nodiscard]] inline auto graph_fd() const noexcept + -> const std::optional& + { + return graph_fd_; + } + + /** Get the mapping file descriptor (for disk-backed index) */ + [[nodiscard]] inline auto mapping_fd() const noexcept + -> const std::optional& + { + return mapping_fd_; + } + /** Dataset norms for cosine distance [size] */ [[nodiscard]] inline auto dataset_norms() const noexcept -> std::optional> @@ -677,6 +712,117 @@ struct index : cuvs::neighbors::index { raft::resource::get_cuda_stream(res)); } + /** + * Update the dataset from a disk file using a file descriptor. + * + * This method configures the index to use a disk-based dataset. + * The dataset file should contain a numpy header followed by vectors in row-major format. + * The number of rows and dimensionality are read from the numpy header. + * + * @param[in] res raft resources + * @param[in] fd File descriptor (will be moved into the index for lifetime management) + */ + void update_dataset(raft::resources const& res, cuvs::util::file_descriptor&& fd) + { + RAFT_EXPECTS(fd.is_valid(), "Invalid file descriptor provided for dataset"); + + auto stream = fd.make_istream(); + if (lseek(fd.get(), 0, SEEK_SET) == -1) { + RAFT_FAIL("Failed to seek to beginning of dataset file"); + } + auto header = raft::detail::numpy_serializer::read_header(stream); + RAFT_EXPECTS(header.shape.size() == 2, + "Dataset file should be 2D, got %zu dimensions", + header.shape.size()); + + n_rows_ = header.shape[0]; + dim_ = header.shape[1]; + + RAFT_LOG_DEBUG("ACE: Dataset has shape [%zu, %zu]", n_rows_, dim_); + + // Re-open the file descriptor in read-only mode for subsequent operations + dataset_fd_.emplace(std::move(fd)); + + dataset_ = std::make_unique>(0); + dataset_norms_.reset(); + } + + /** + * Update the graph from a disk file using a file descriptor. + * + * This method configures the index to use a disk-based graph. + * The graph file should contain a numpy header followed by neighbor indices in row-major format. + * The number of rows and graph degree are read from the numpy header. + * + * @param[in] res raft resources + * @param[in] fd File descriptor (will be moved into the index for lifetime management) + */ + void update_graph(raft::resources const& res, cuvs::util::file_descriptor&& fd) + { + RAFT_EXPECTS(fd.is_valid(), "Invalid file descriptor provided for graph"); + + auto stream = fd.make_istream(); + if (lseek(fd.get(), 0, SEEK_SET) == -1) { + RAFT_FAIL("Failed to seek to beginning of graph file"); + } + auto header = raft::detail::numpy_serializer::read_header(stream); + RAFT_EXPECTS( + header.shape.size() == 2, "Graph file should be 2D, got %zu dimensions", header.shape.size()); + + if (dataset_fd_.has_value() && n_rows_ != 0) { + RAFT_EXPECTS(n_rows_ == header.shape[0], + "Graph size (%zu) must match dataset size (%zu)", + header.shape[0], + n_rows_); + } + + n_rows_ = header.shape[0]; + graph_degree_ = header.shape[1]; + + RAFT_LOG_DEBUG("ACE: Graph has shape [%zu, %zu]", n_rows_, graph_degree_); + + // Re-open the file descriptor in read-only mode for subsequent operations + graph_fd_.emplace(std::move(fd)); + + graph_ = raft::make_device_matrix(res, 0, 0); + graph_view_ = graph_.view(); + } + + /** + * Update the dataset mapping from a disk file using a file descriptor. + * + * This method configures the index to use a disk-based dataset mapping. + * The mapping file should contain a numpy header followed by index mappings. + * + * @param[in] res raft resources + * @param[in] fd File descriptor (will be moved into the index for lifetime management) + */ + void update_mapping(raft::resources const& res, cuvs::util::file_descriptor&& fd) + { + RAFT_EXPECTS(fd.is_valid(), "Invalid file descriptor provided for mapping"); + + // Read header from file using ifstream + auto stream = fd.make_istream(); + if (lseek(fd.get(), 0, SEEK_SET) == -1) { + RAFT_FAIL("Failed to seek to beginning of mapping file"); + } + auto header = raft::detail::numpy_serializer::read_header(stream); + RAFT_EXPECTS(header.shape.size() == 1, + "Mapping file should be 1D, got %zu dimensions", + header.shape.size()); + + if (dataset_fd_.has_value() && n_rows_ != 0) { + RAFT_EXPECTS(header.shape[0] == n_rows_, + "Mapping size (%zu) must match dataset size (%zu)", + header.shape[0], + n_rows_); + } + + RAFT_LOG_DEBUG("ACE: Mapping has %zu elements", header.shape[0]); + + mapping_fd_.emplace(std::move(fd)); + } + private: cuvs::distance::DistanceType metric_; raft::device_matrix graph_; @@ -687,7 +833,15 @@ struct index : cuvs::neighbors::index { // only float distances supported at the moment std::optional> dataset_norms_; + // File descriptors for disk-backed index components (ACE disk mode) + std::optional dataset_fd_; + std::optional graph_fd_; + std::optional mapping_fd_; + void compute_dataset_norms_(raft::resources const& res); + size_t n_rows_ = 0; + size_t dim_ = 0; + size_t graph_degree_ = 0; }; /** * @} @@ -2866,6 +3020,166 @@ template auto distribute(const raft::resources& clique, const std::string& filename) -> cuvs::neighbors::mg_index, T, IdxT>; +/** + * @brief Build a kNN graph using IVF-PQ. + * + * The kNN graph is the first building block for CAGRA index. + * + * The output is a dense matrix that stores the neighbor indices for each point in the dataset. + * Each point has the same number of neighbors. + * + * See [cagra::build](#cagra::build) for an alternative method. + * + * The following distance metrics are supported: + * - L2Expanded + * - InnerProduct + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters based on shape of the dataset + * ivf_pq::index_params build_params = ivf_pq::index_params::from_dataset(dataset); + * ivf_pq::search_params search_params; + * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); + * // create knn graph + * cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params); + * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); + * cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view()); + * // Construct an index from dataset and optimized knn_graph + * auto index = cagra::index(res, build_params.metric(), dataset, + * optimized_graph.view()); + * @endcode + * + * @param[in] res raft resources + * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] + * @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree] + * @param[in] build_params ivf-pq parameters for graph build + */ +void build_knn_graph(raft::resources const& res, + raft::host_matrix_view dataset, + raft::host_matrix_view knn_graph, + cuvs::neighbors::cagra::graph_build_params::ivf_pq_params build_params); + +/** + * @brief Build a kNN graph using IVF-PQ. + * + * The kNN graph is the first building block for CAGRA index. + * + * The output is a dense matrix that stores the neighbor indices for each point in the dataset. + * Each point has the same number of neighbors. + * + * See [cagra::build](#cagra::build) for an alternative method. + * + * The following distance metrics are supported: + * - L2Expanded + * - InnerProduct + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters based on shape of the dataset + * ivf_pq::index_params build_params = ivf_pq::index_params::from_dataset(dataset); + * ivf_pq::search_params search_params; + * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); + * // create knn graph + * cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params); + * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); + * cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view()); + * // Construct an index from dataset and optimized knn_graph + * auto index = cagra::index(res, build_params.metric(), dataset, + * optimized_graph.view()); + * @endcode + * + * @param[in] res raft resources + * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] + * @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree] + * @param[in] build_params ivf-pq parameters for graph build + */ +void build_knn_graph(raft::resources const& res, + raft::host_matrix_view dataset, + raft::host_matrix_view knn_graph, + cuvs::neighbors::cagra::graph_build_params::ivf_pq_params build_params); + +/** + * @brief Build a kNN graph using IVF-PQ. + * + * The kNN graph is the first building block for CAGRA index. + * + * The output is a dense matrix that stores the neighbor indices for each point in the dataset. + * Each point has the same number of neighbors. + * + * See [cagra::build](#cagra::build) for an alternative method. + * + * The following distance metrics are supported: + * - L2Expanded + * - InnerProduct + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters based on shape of the dataset + * ivf_pq::index_params build_params = ivf_pq::index_params::from_dataset(dataset); + * ivf_pq::search_params search_params; + * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); + * // create knn graph + * cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params); + * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); + * cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view()); + * // Construct an index from dataset and optimized knn_graph + * auto index = cagra::index(res, build_params.metric(), dataset, + * optimized_graph.view()); + * @endcode + * + * @param[in] res raft resources + * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] + * @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree] + * @param[in] build_params ivf-pq parameters for graph build + */ +void build_knn_graph(raft::resources const& res, + raft::host_matrix_view dataset, + raft::host_matrix_view knn_graph, + cuvs::neighbors::cagra::graph_build_params::ivf_pq_params build_params); + +/** + * @brief Build a kNN graph using IVF-PQ. + * + * The kNN graph is the first building block for CAGRA index. + * + * The output is a dense matrix that stores the neighbor indices for each point in the dataset. + * Each point has the same number of neighbors. + * + * See [cagra::build](#cagra::build) for an alternative method. + * + * The following distance metrics are supported: + * - L2Expanded + * - InnerProduct + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters based on shape of the dataset + * ivf_pq::index_params build_params = ivf_pq::index_params::from_dataset(dataset); + * ivf_pq::search_params search_params; + * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); + * // create knn graph + * cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params); + * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); + * cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view()); + * // Construct an index from dataset and optimized knn_graph + * auto index = cagra::index(res, build_params.metric(), dataset, + * optimized_graph.view()); + * @endcode + * + * @param[in] res raft resources + * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] + * @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree] + * @param[in] build_params ivf-pq parameters for graph build + */ +void build_knn_graph(raft::resources const& res, + raft::host_matrix_view dataset, + raft::host_matrix_view knn_graph, + cuvs::neighbors::cagra::graph_build_params::ivf_pq_params build_params); + } // namespace cuvs::neighbors::cagra #include diff --git a/cpp/include/cuvs/neighbors/graph_build_types.hpp b/cpp/include/cuvs/neighbors/graph_build_types.hpp index 2f3f93b3f9..7f501240a2 100644 --- a/cpp/include/cuvs/neighbors/graph_build_types.hpp +++ b/cpp/include/cuvs/neighbors/graph_build_types.hpp @@ -16,7 +16,7 @@ namespace cuvs::neighbors { * @{ */ -enum GRAPH_BUILD_ALGO { BRUTE_FORCE = 0, IVF_PQ = 1, NN_DESCENT = 2 }; +enum GRAPH_BUILD_ALGO { BRUTE_FORCE = 0, IVF_PQ = 1, NN_DESCENT = 2, ACE = 3 }; namespace graph_build_params { @@ -94,6 +94,45 @@ struct brute_force_params { cuvs::neighbors::brute_force::search_params search_params; }; +/** Specialized parameters for ACE (Augmented Core Extraction) graph build */ +struct ace_params { + /** + * Number of partitions for ACE (Augmented Core Extraction) partitioned build. + * + * Small values might improve recall but potentially degrade performance and + * increase memory usage. Partitions should not be too small to prevent issues + * in KNN graph construction. 100k - 5M vectors per partition is recommended + * depending on the available host and GPU memory. The partition size is on + * average 2 * (n_rows / npartitions) * dim * sizeof(T). 2 is because of the + * core and augmented vectors. Please account for imbalance in the partition + * sizes (up to 3x in our tests). + */ + size_t npartitions = 1; + /** + * The index quality for the ACE build. + * + * Bigger values increase the index quality. At some point, increasing this will no longer improve + * the quality. + */ + size_t ef_construction = 120; + /** + * Directory to store ACE build artifacts (e.g., KNN graph, optimized graph). + * + * Used when `use_disk` is true or when the graph does not fit in host and GPU + * memory. This should be the fastest disk in the system and hold enough space + * for twice the dataset, final graph, and label mapping. + */ + std::string build_dir = "/tmp/ace_build"; + /** + * Whether to use disk-based storage for ACE build. + * + * When true, enables disk-based operations for memory-efficient graph construction. + */ + bool use_disk = false; + + ace_params() = default; +}; + // **** Experimental **** using iterative_search_params = cuvs::neighbors::search_params; } // namespace graph_build_params diff --git a/cpp/include/cuvs/neighbors/hnsw.hpp b/cpp/include/cuvs/neighbors/hnsw.hpp index 3ba9df63e2..c2bfb1993d 100644 --- a/cpp/include/cuvs/neighbors/hnsw.hpp +++ b/cpp/include/cuvs/neighbors/hnsw.hpp @@ -132,6 +132,11 @@ struct index : cuvs::neighbors::index { */ virtual void set_ef(int ef) const; + /** + @brief Get file path for disk-backed index + */ + virtual std::string file_path() const { return ""; } + private: int dim_; cuvs::distance::DistanceType metric_; diff --git a/cpp/include/cuvs/util/file_io.hpp b/cpp/include/cuvs/util/file_io.hpp new file mode 100644 index 0000000000..363b1b1ca0 --- /dev/null +++ b/cpp/include/cuvs/util/file_io.hpp @@ -0,0 +1,243 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace cuvs::util { +/** + * @brief Streambuf that reads from a POSIX file descriptor + */ +class fd_streambuf : public std::streambuf { + int fd_; + std::unique_ptr buffer_; + size_t buffer_size_; + + protected: + int_type underflow() override + { + if (gptr() < egptr()) return traits_type::to_int_type(*gptr()); + ssize_t n = ::read(fd_, buffer_.get(), buffer_size_); + if (n <= 0) return traits_type::eof(); + setg(buffer_.get(), buffer_.get(), buffer_.get() + n); + return traits_type::to_int_type(*gptr()); + } + + public: + explicit fd_streambuf(int fd, size_t buffer_size = 8192) + : fd_(fd), buffer_(new char[buffer_size]), buffer_size_(buffer_size) + { + setg(buffer_.get(), buffer_.get(), buffer_.get()); + } + + ~fd_streambuf() + { + if (fd_ != -1) ::close(fd_); + } + + fd_streambuf(const fd_streambuf&) = delete; + fd_streambuf& operator=(const fd_streambuf&) = delete; + fd_streambuf(fd_streambuf&&) noexcept = default; + fd_streambuf& operator=(fd_streambuf&&) noexcept = default; +}; + +/** + * @brief Istream that reads from a POSIX file descriptor + */ +class fd_istream : public std::istream { + fd_streambuf buf_; + + public: + explicit fd_istream(int fd) : std::istream(&buf_), buf_(fd) {} + + fd_istream(const fd_istream&) = delete; + fd_istream& operator=(const fd_istream&) = delete; + + fd_istream(fd_istream&& o) noexcept : std::istream(std::move(o)), buf_(std::move(o.buf_)) + { + rdbuf(&buf_); + } + + fd_istream& operator=(fd_istream&& o) noexcept + { + std::istream::operator=(std::move(o)); + buf_ = std::move(o.buf_); + rdbuf(&buf_); + return *this; + } +}; + +/** + * @brief RAII wrapper for POSIX file descriptors + * + * Manages file descriptor lifecycle with automatic cleanup. + * Non-copyable, move-only. + */ +class file_descriptor { + public: + explicit file_descriptor(int fd = -1) : fd_(fd) {} + + file_descriptor(const std::string& path, int flags, mode_t mode = 0644) + : fd_(open(path.c_str(), flags, mode)), path_(path) + { + if (fd_ == -1) { + RAFT_FAIL("Failed to open file: %s (errno: %d, %s)", path.c_str(), errno, strerror(errno)); + } + } + + file_descriptor(const file_descriptor&) = delete; + file_descriptor& operator=(const file_descriptor&) = delete; + + file_descriptor(file_descriptor&& other) noexcept + : fd_{std::exchange(other.fd_, -1)}, path_{std::move(other.path_)} + { + } + + file_descriptor& operator=(file_descriptor&& other) noexcept + { + std::swap(this->fd_, other.fd_); + std::swap(this->path_, other.path_); + return *this; + } + + ~file_descriptor() noexcept { close(); } + + [[nodiscard]] int get() const noexcept { return fd_; } + [[nodiscard]] bool is_valid() const noexcept { return fd_ != -1; } + + void close() noexcept + { + if (fd_ != -1) { + ::close(fd_); + fd_ = -1; + } + } + + [[nodiscard]] int release() noexcept + { + const int fd = fd_; + fd_ = -1; + return fd; + } + + [[nodiscard]] std::string get_path() const { return path_; } + + /** + * @brief Create an input stream from this file descriptor + * + * Creates an istream that reads directly from the file descriptor using POSIX read(). + * The original descriptor remains valid and unchanged (we duplicate it internally). + * Returns the stream by value (uses move semantics). + * + * @return fd_istream (movable istream) + */ + [[nodiscard]] fd_istream make_istream() const + { + RAFT_EXPECTS(is_valid(), "Invalid file descriptor"); + + // Duplicate the fd to avoid consuming the original + int dup_fd = dup(fd_); + RAFT_EXPECTS(dup_fd != -1, "Failed to duplicate file descriptor"); + + // Create stream that owns the duplicated fd + // Returned by value, uses move semantics + return fd_istream(dup_fd); + } + + private: + int fd_; + std::string path_; +}; + +/** + * @brief Read large file in chunks using pread + * + * Reads a file in chunks to avoid issues with very large reads. + * Uses pread for thread-safe, offset-based reading. + * + * @param fd File descriptor to read from + * @param dest_ptr Destination buffer + * @param total_bytes Total bytes to read + * @param file_offset Starting offset in file + */ +void read_large_file(const file_descriptor& fd, + void* dest_ptr, + const size_t total_bytes, + const uint64_t file_offset); + +/** + * @brief Write large file in chunks using pwrite + * + * Writes data to a file in chunks to avoid issues with very large writes. + * Uses pwrite for thread-safe, offset-based writing. + * + * @param fd File descriptor to write to + * @param data_ptr Source data buffer + * @param total_bytes Total bytes to write + * @param file_offset Starting offset in file + */ +void write_large_file(const file_descriptor& fd, + const void* data_ptr, + const size_t total_bytes, + const uint64_t file_offset); + +/** + * @brief Buffered output stream wrapper + * + * Wraps an std::ostream with a buffer to improve write performance by + * reducing the number of system calls. Automatically flushes on destruction. + * Non-copyable, non-movable. + */ +class buffered_ofstream { + public: + buffered_ofstream(std::ostream* os, size_t buffer_size) : os_(os), buffer_(buffer_size), pos_(0) + { + } + + ~buffered_ofstream() noexcept { flush(); } + + buffered_ofstream(const buffered_ofstream& res) = delete; + auto operator=(const buffered_ofstream& other) -> buffered_ofstream& = delete; + buffered_ofstream(buffered_ofstream&& other) = delete; + auto operator=(buffered_ofstream&& other) -> buffered_ofstream& = delete; + + void flush() + { + if (pos_ > 0) { + os_->write(reinterpret_cast(&buffer_.front()), pos_); + if (!os_->good()) { RAFT_FAIL("Error writing HNSW file!"); } + pos_ = 0; + } + } + + void write(const char* input, size_t size) + { + if (pos_ + size > buffer_.size()) { flush(); } + std::copy(input, input + size, &buffer_[pos_]); + pos_ += size; + } + + private: + std::vector buffer_; + std::ostream* os_; + size_t pos_; +}; + +} // namespace cuvs::util diff --git a/cpp/include/cuvs/util/host_memory.hpp b/cpp/include/cuvs/util/host_memory.hpp new file mode 100644 index 0000000000..7ca9da1687 --- /dev/null +++ b/cpp/include/cuvs/util/host_memory.hpp @@ -0,0 +1,25 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include + +#include +#include + +namespace cuvs::util { + +/** + * @brief Get available host memory from /proc/meminfo + * + * Queries the system for available memory by reading /proc/meminfo. + * This is useful for determining how much host memory can be used + * for buffering or temporary storage. + * + * @return Available memory in bytes + */ +size_t get_free_host_memory(); + +} // namespace cuvs::util diff --git a/cpp/src/neighbors/cagra.cuh b/cpp/src/neighbors/cagra.cuh index fdc5b2b7e7..cf65ef7c4d 100644 --- a/cpp/src/neighbors/cagra.cuh +++ b/cpp/src/neighbors/cagra.cuh @@ -118,12 +118,7 @@ void build_knn_graph( raft::mdspan, raft::row_major, accessor>( dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - cagra::detail::build_knn_graph(res, - dataset_internal, - knn_graph_internal, - ivf_pq_params.refinement_rate, - ivf_pq_params.build_params, - ivf_pq_params.search_params); + cagra::detail::build_knn_graph(res, dataset_internal, knn_graph_internal, ivf_pq_params); } /** @@ -278,6 +273,15 @@ index build( const index_params& params, raft::mdspan, raft::row_major, Accessor> dataset) { + // Check if ACE dispatch is requested via graph_build_params + if (std::holds_alternative(params.graph_build_params)) { + // ACE expects the dataset to be on host due to the large dataset size + RAFT_EXPECTS(raft::get_device_for_address(dataset.data_handle()) == -1, + "ACE: Dataset must be on host for ACE build"); + auto dataset_view = raft::make_host_matrix_view( + dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + return cuvs::neighbors::cagra::detail::build_ace(res, params, dataset_view); + } return cuvs::neighbors::cagra::detail::build(res, params, dataset); } diff --git a/cpp/src/neighbors/cagra_build_float.cu b/cpp/src/neighbors/cagra_build_float.cu index fe4c757b72..b3097f7647 100644 --- a/cpp/src/neighbors/cagra_build_float.cu +++ b/cpp/src/neighbors/cagra_build_float.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,21 +8,29 @@ namespace cuvs::neighbors::cagra { -#define RAFT_INST_CAGRA_BUILD(T, IdxT) \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::cagra::index_params& params, \ - raft::device_matrix_view dataset) \ - -> cuvs::neighbors::cagra::index \ - { \ - return cuvs::neighbors::cagra::build(handle, params, dataset); \ - } \ - \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::cagra::index_params& params, \ - raft::host_matrix_view dataset) \ - -> cuvs::neighbors::cagra::index \ - { \ - return cuvs::neighbors::cagra::build(handle, params, dataset); \ +#define RAFT_INST_CAGRA_BUILD(T, IdxT) \ + void build_knn_graph(raft::resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view knn_graph, \ + cuvs::neighbors::cagra::graph_build_params::ivf_pq_params params) \ + { \ + cuvs::neighbors::cagra::build_knn_graph(handle, dataset, knn_graph, params); \ + } \ + \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset) \ + -> cuvs::neighbors::cagra::index \ + { \ + return cuvs::neighbors::cagra::build(handle, params, dataset); \ + } \ + \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset) \ + -> cuvs::neighbors::cagra::index \ + { \ + return cuvs::neighbors::cagra::build(handle, params, dataset); \ } RAFT_INST_CAGRA_BUILD(float, uint32_t); diff --git a/cpp/src/neighbors/cagra_build_half.cu b/cpp/src/neighbors/cagra_build_half.cu index 10a995da45..dd57cb87cc 100644 --- a/cpp/src/neighbors/cagra_build_half.cu +++ b/cpp/src/neighbors/cagra_build_half.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -9,6 +9,14 @@ namespace cuvs::neighbors::cagra { +void build_knn_graph(raft::resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view knn_graph, + cuvs::neighbors::cagra::graph_build_params::ivf_pq_params params) +{ + cuvs::neighbors::cagra::build_knn_graph(handle, dataset, knn_graph, params); +} + cuvs::neighbors::cagra::index build( raft::resources const& handle, const cuvs::neighbors::cagra::index_params& params, diff --git a/cpp/src/neighbors/cagra_build_int8.cu b/cpp/src/neighbors/cagra_build_int8.cu index 291c3a4bae..d651790662 100644 --- a/cpp/src/neighbors/cagra_build_int8.cu +++ b/cpp/src/neighbors/cagra_build_int8.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,21 +8,29 @@ namespace cuvs::neighbors::cagra { -#define RAFT_INST_CAGRA_BUILD(T, IdxT) \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::cagra::index_params& params, \ - raft::device_matrix_view dataset) \ - -> cuvs::neighbors::cagra::index \ - { \ - return cuvs::neighbors::cagra::build(handle, params, dataset); \ - } \ - \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::cagra::index_params& params, \ - raft::host_matrix_view dataset) \ - -> cuvs::neighbors::cagra::index \ - { \ - return cuvs::neighbors::cagra::build(handle, params, dataset); \ +#define RAFT_INST_CAGRA_BUILD(T, IdxT) \ + void build_knn_graph(raft::resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view knn_graph, \ + cuvs::neighbors::cagra::graph_build_params::ivf_pq_params params) \ + { \ + cuvs::neighbors::cagra::build_knn_graph(handle, dataset, knn_graph, params); \ + } \ + \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset) \ + -> cuvs::neighbors::cagra::index \ + { \ + return cuvs::neighbors::cagra::build(handle, params, dataset); \ + } \ + \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset) \ + -> cuvs::neighbors::cagra::index \ + { \ + return cuvs::neighbors::cagra::build(handle, params, dataset); \ } RAFT_INST_CAGRA_BUILD(int8_t, uint32_t); diff --git a/cpp/src/neighbors/cagra_build_uint8.cu b/cpp/src/neighbors/cagra_build_uint8.cu index 6bba2814bf..a819675d9c 100644 --- a/cpp/src/neighbors/cagra_build_uint8.cu +++ b/cpp/src/neighbors/cagra_build_uint8.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,21 +8,29 @@ namespace cuvs::neighbors::cagra { -#define RAFT_INST_CAGRA_BUILD(T, IdxT) \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::cagra::index_params& params, \ - raft::device_matrix_view dataset) \ - -> cuvs::neighbors::cagra::index \ - { \ - return cuvs::neighbors::cagra::build(handle, params, dataset); \ - } \ - \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::cagra::index_params& params, \ - raft::host_matrix_view dataset) \ - -> cuvs::neighbors::cagra::index \ - { \ - return cuvs::neighbors::cagra::build(handle, params, dataset); \ +#define RAFT_INST_CAGRA_BUILD(T, IdxT) \ + void build_knn_graph(raft::resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view knn_graph, \ + cuvs::neighbors::cagra::graph_build_params::ivf_pq_params params) \ + { \ + cuvs::neighbors::cagra::build_knn_graph(handle, dataset, knn_graph, params); \ + } \ + \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset) \ + -> cuvs::neighbors::cagra::index \ + { \ + return cuvs::neighbors::cagra::build(handle, params, dataset); \ + } \ + \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset) \ + -> cuvs::neighbors::cagra::index \ + { \ + return cuvs::neighbors::cagra::build(handle, params, dataset); \ } RAFT_INST_CAGRA_BUILD(uint8_t, uint32_t); diff --git a/cpp/src/neighbors/detail/cagra/add_nodes.cuh b/cpp/src/neighbors/detail/cagra/add_nodes.cuh index 755b37c119..9d70f7848c 100644 --- a/cpp/src/neighbors/detail/cagra/add_nodes.cuh +++ b/cpp/src/neighbors/detail/cagra/add_nodes.cuh @@ -357,6 +357,11 @@ void extend_core( std::optional> new_dataset_buffer_view, std::optional> new_graph_buffer_view) { + RAFT_EXPECTS(!index.dataset_fd().has_value(), + "Cannot extend a disk-backed CAGRA index. Convert it with " + "cuvs::neighbors::hnsw::from_cagra() and load it into memory via " + "cuvs::neighbors::hnsw::deserialize() before calling extend()."); + if (dynamic_cast*>(&index.data()) != nullptr && !new_dataset_buffer_view.has_value()) { RAFT_LOG_WARN( diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 08ba6bf207..5f7389493a 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -7,7 +7,6 @@ #include "../../../core/nvtx.hpp" #include "../../vpq_dataset.cuh" #include "graph_core.cuh" -#include #include #include @@ -17,12 +16,16 @@ #include #include #include +#include +#include #include +#include #include -#include - #include +#include +#include +#include // TODO: This shouldn't be calling spatial/knn APIs #include "../ann_utils.cuh" @@ -31,13 +34,1362 @@ #include #include +#include +#include #include +#include #include #include +#include namespace cuvs::neighbors::cagra::detail { +template +void check_graph_degree(size_t& intermediate_degree, size_t& graph_degree, size_t dataset_size) +{ + if (intermediate_degree >= static_cast(dataset_size)) { + RAFT_LOG_WARN( + "Intermediate graph degree cannot be larger than dataset size, reducing it to %lu", + dataset_size); + intermediate_degree = dataset_size - 1; + } + if (intermediate_degree < graph_degree) { + RAFT_LOG_WARN( + "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " + "graph_degree.", + graph_degree, + intermediate_degree); + graph_degree = intermediate_degree; + } +} + +// ACE: Get partition labels for partitioned approach +// TODO(julianmi): Use all neighbors APIs. +template +void ace_get_partition_labels( + raft::resources const& res, + raft::host_matrix_view dataset, + raft::host_matrix_view partition_labels, + raft::host_matrix_view partition_histogram, + size_t min_partition_size, + double sampling_rate = 0.01) +{ + size_t dataset_size = dataset.extent(0); + size_t dataset_dim = dataset.extent(1); + size_t labels_size = partition_labels.extent(0); + size_t labels_dim = partition_labels.extent(1); + RAFT_EXPECTS(dataset_size == labels_size, "Dataset size must match partition labels extent"); + size_t n_partitions = partition_histogram.extent(0); + RAFT_EXPECTS(labels_dim == 2, "Labels must have 2 columns"); + RAFT_EXPECTS(partition_histogram.extent(1) == 2, "Partition histogram must have 2 columns"); + cudaStream_t stream = raft::resource::get_cuda_stream(res); + + // Sampling vectors from dataset. Uses float conversion on host instead of + // raft::matrix::sample_rows to minimize GPU memory usage. + // TODO(julianmi): Switch to sample_rows when https://github.com/rapidsai/cuvs/issues/1461 is + // addressed. + size_t n_samples = dataset_size * sampling_rate; + const size_t min_samples = 100 * n_partitions; + n_samples = std::max(n_samples, min_samples); + n_samples = std::min(n_samples, dataset_size); + RAFT_LOG_DEBUG("ACE: n_samples: %lu", n_samples); + + auto sample_db = raft::make_host_matrix(n_samples, dataset_dim); +#pragma omp parallel for + for (size_t i = 0; i < n_samples; i++) { + size_t j = i * dataset_size / n_samples; + for (size_t k = 0; k < dataset_dim; k++) { + sample_db(i, k) = static_cast(dataset(j, k)); + } + } + auto sample_db_dev = raft::make_device_matrix(res, n_samples, dataset_dim); + raft::update_device( + sample_db_dev.data_handle(), sample_db.data_handle(), sample_db.size(), stream); + + cuvs::cluster::kmeans::balanced_params kmeans_params; + auto centroids_dev = raft::make_device_matrix(res, n_partitions, dataset_dim); + cuvs::cluster::kmeans::fit(res, kmeans_params, sample_db_dev.view(), centroids_dev.view()); + + // Compute distances between dataset and centroid vectors + // Uses float conversion on host instead of batch_load_iterator to minimize GPU memory usage. + const size_t chunk_size = 32 * 1024; + auto _sub_dataset = raft::make_host_matrix(chunk_size, dataset_dim); + auto _sub_distances = raft::make_host_matrix(chunk_size, n_partitions); + auto _sub_dataset_dev = raft::make_device_matrix(res, chunk_size, dataset_dim); + auto _sub_distances_dev = raft::make_device_matrix(res, chunk_size, n_partitions); + size_t report_interval = dataset_size / 10; + report_interval = (report_interval / chunk_size) * chunk_size; + report_interval = std::max(report_interval, chunk_size); + + for (size_t i_base = 0; i_base < dataset_size; i_base += chunk_size) { + const size_t sub_dataset_size = std::min(chunk_size, dataset_size - i_base); + if (i_base % report_interval == 0) { + RAFT_LOG_INFO("ACE: Processing chunk %lu / %lu (%.1f%%)", + i_base, + dataset_size, + static_cast(100 * i_base) / dataset_size); + } + + auto sub_dataset = raft::make_host_matrix_view( + _sub_dataset.data_handle(), sub_dataset_size, dataset_dim); +#pragma omp parallel for + for (size_t i_sub = 0; i_sub < sub_dataset_size; i_sub++) { + size_t i = i_base + i_sub; + for (size_t k = 0; k < dataset_dim; k++) { + sub_dataset(i_sub, k) = static_cast(dataset(i, k)); + } + } + auto sub_dataset_dev = raft::make_device_matrix_view( + _sub_dataset_dev.data_handle(), sub_dataset_size, dataset_dim); + raft::update_device( + _sub_dataset_dev.data_handle(), sub_dataset.data_handle(), sub_dataset.size(), stream); + + auto sub_distances = raft::make_host_matrix_view( + _sub_distances.data_handle(), sub_dataset_size, n_partitions); + auto sub_distances_dev = raft::make_device_matrix_view( + _sub_distances_dev.data_handle(), sub_dataset_size, n_partitions); + + cuvs::distance::pairwise_distance(res, + sub_dataset_dev, + centroids_dev.view(), + sub_distances_dev, + cuvs::distance::DistanceType::L2Expanded); + + raft::update_host( + sub_distances.data_handle(), sub_distances_dev.data_handle(), sub_distances.size(), stream); + raft::resource::sync_stream(res, stream); + + // Find two closest partitions to each dataset vector +#pragma omp parallel for + for (size_t i_sub = 0; i_sub < sub_dataset_size; i_sub++) { + size_t core_label = 0; + size_t augmented_label = 1; + if (sub_distances(i_sub, 0) > sub_distances(i_sub, 1)) { + core_label = 1; + augmented_label = 0; + } + for (size_t c = 2; c < n_partitions; c++) { + if (sub_distances(i_sub, c) < sub_distances(i_sub, core_label)) { + augmented_label = core_label; + core_label = c; + } else if (sub_distances(i_sub, c) < sub_distances(i_sub, augmented_label)) { + augmented_label = c; + } + } + size_t i = i_base + i_sub; + partition_labels(i, 0) = core_label; + partition_labels(i, 1) = augmented_label; + +#pragma omp atomic update + partition_histogram(core_label, 0) += 1; +#pragma omp atomic update + partition_histogram(augmented_label, 1) += 1; + } + } +} + +// ACE: Check partition sizes for stable KNN graph construction +template +void ace_check_partition_sizes( + size_t dataset_size, + size_t n_partitions, + raft::host_matrix_view partition_labels, + raft::host_matrix_view partition_histogram, + size_t min_partition_size) +{ + // Collect partition histogram statistics + size_t total_core_vectors = 0; + size_t total_augmented_vectors = 0; + size_t min_core_vectors = dataset_size; + size_t max_core_vectors = 0; + size_t min_augmented_vectors = dataset_size; + size_t max_augmented_vectors = 0; + size_t min_total_vectors = dataset_size; + size_t max_total_vectors = 0; + + for (size_t c = 0; c < n_partitions; c++) { + size_t core_count = partition_histogram(c, 0); + size_t augmented_count = partition_histogram(c, 1); + size_t total_count = core_count + augmented_count; + + if (total_count > 0) { + total_core_vectors += core_count; + total_augmented_vectors += augmented_count; + + min_core_vectors = std::min(min_core_vectors, core_count); + max_core_vectors = std::max(max_core_vectors, core_count); + min_augmented_vectors = std::min(min_augmented_vectors, augmented_count); + max_augmented_vectors = std::max(max_augmented_vectors, augmented_count); + min_total_vectors = std::min(min_total_vectors, total_count); + max_total_vectors = std::max(max_total_vectors, total_count); + } + } + + double avg_core_vectors = static_cast(total_core_vectors) / n_partitions; + double avg_augmented_vectors = static_cast(total_augmented_vectors) / n_partitions; + double avg_total_vectors = 2.0 * static_cast(dataset_size) / n_partitions; + double expected_avg_vectors = 2.0 * static_cast(dataset_size) / n_partitions; + + RAFT_LOG_INFO("ACE: Core vectors - Total: %lu, Avg: %.1f, Min: %lu, Max: %lu", + total_core_vectors, + avg_core_vectors, + min_core_vectors, + max_core_vectors); + RAFT_LOG_INFO("ACE: Augmented vectors - Total: %lu, Avg: %.1f, Min: %lu, Max: %lu", + total_augmented_vectors, + avg_augmented_vectors, + min_augmented_vectors, + max_augmented_vectors); + RAFT_LOG_INFO("ACE: Total per partition - Total: %lu, Avg: %.1f, Min: %lu, Max: %lu", + total_core_vectors + total_augmented_vectors, + avg_total_vectors, + min_total_vectors, + max_total_vectors); + + // Check for partition imbalance and issue warnings + size_t very_small_threshold = min_partition_size; + size_t very_large_threshold = static_cast(5.0 * expected_avg_vectors); + + for (size_t c = 0; c < n_partitions; c++) { + size_t total_count = partition_histogram(c, 0) + partition_histogram(c, 1); + + if (total_count > 0 && total_count < very_small_threshold) { + RAFT_LOG_WARN( + "ACE: Partition %lu is very small (%lu vectors, expected ~%.1f). This may affect graph " + "quality.", + c, + total_count, + expected_avg_vectors); + } else if (total_count > very_large_threshold) { + RAFT_LOG_WARN( + "ACE: Partition %lu is very large (%lu vectors, expected ~%.1f, threshold: %lu). This may " + "indicate imbalance and can lead to memory issues in restricted environments.", + c, + total_count, + expected_avg_vectors, + very_large_threshold); + } + } +} + +// ACE: Create forward/backward mappings between original and reordered vector IDs +// The in-memory path can be parallelized but the disk path requires ordering. +template +void ace_create_forward_and_backward_lists( + size_t dataset_size, + size_t n_partitions, + raft::host_matrix_view partition_labels, + raft::host_matrix_view partition_histogram, + raft::host_vector_view core_forward_mapping, + raft::host_vector_view core_backward_mapping, + raft::host_vector_view augmented_backward_mapping, + raft::host_vector_view core_partition_offsets, + raft::host_vector_view augmented_partition_offsets) +{ + core_partition_offsets(0) = 0; + augmented_partition_offsets(0) = 0; + for (size_t c = 1; c < n_partitions; c++) { + core_partition_offsets(c) = core_partition_offsets(c - 1) + partition_histogram(c - 1, 0); + augmented_partition_offsets(c) = + augmented_partition_offsets(c - 1) + partition_histogram(c - 1, 1); + } + + if (static_cast(core_forward_mapping.extent(0)) == 0) { + // Memory path: both backward mappings + RAFT_EXPECTS(static_cast(core_backward_mapping.extent(0)) == dataset_size, + "core_backward_mapping must be of size dataset_size"); + RAFT_EXPECTS(static_cast(augmented_backward_mapping.extent(0)) == dataset_size, + "augmented_backward_mapping must be of size dataset_size"); +#pragma omp parallel for + for (size_t i = 0; i < dataset_size; i++) { + size_t core_partition_id = partition_labels(i, 0); + size_t core_id; +#pragma omp atomic capture + core_id = core_partition_offsets(core_partition_id)++; + RAFT_EXPECTS(core_id < dataset_size, "Vector ID must be smaller than dataset_size"); + core_backward_mapping(core_id) = i; + + size_t augmented_partition_id = partition_labels(i, 1); + size_t augmented_id; +#pragma omp atomic capture + augmented_id = augmented_partition_offsets(augmented_partition_id)++; + RAFT_EXPECTS(augmented_id < dataset_size, "Vector ID must be smaller than dataset_size"); + augmented_backward_mapping(augmented_id) = i; + } + } else { + // Disk path: all three mappings + RAFT_EXPECTS(static_cast(core_forward_mapping.extent(0)) == dataset_size, + "core_forward_mapping must be of size dataset_size"); + RAFT_EXPECTS(static_cast(core_backward_mapping.extent(0)) == dataset_size, + "core_backward_mapping must be of size dataset_size"); + RAFT_EXPECTS(static_cast(augmented_backward_mapping.extent(0)) == dataset_size, + "augmented_backward_mapping must be of size dataset_size"); + for (size_t i = 0; i < dataset_size; i++) { + size_t core_partition_id = partition_labels(i, 0); + size_t core_id; + core_id = core_partition_offsets(core_partition_id)++; + RAFT_EXPECTS(core_id < dataset_size, "Vector ID must be smaller than dataset_size"); + core_backward_mapping(core_id) = i; + core_forward_mapping(i) = core_id; + + size_t augmented_partition_id = partition_labels(i, 1); + size_t augmented_id; + augmented_id = augmented_partition_offsets(augmented_partition_id)++; + RAFT_EXPECTS(augmented_id < dataset_size, "Vector ID must be smaller than dataset_size"); + augmented_backward_mapping(augmented_id) = i; + } + } + + // Restore idxptr arrays + for (size_t c = n_partitions; c > 0; c--) { + core_partition_offsets(c) = core_partition_offsets(c - 1); + augmented_partition_offsets(c) = augmented_partition_offsets(c - 1); + } + core_partition_offsets(0) = 0; + augmented_partition_offsets(0) = 0; +} + +// ACE: Gather partition dataset +template +void ace_gather_partition_dataset( + size_t core_sub_dataset_size, + size_t augmented_sub_dataset_size, + size_t dataset_dim, + size_t partition_id, + raft::host_matrix_view dataset, + raft::host_vector_view core_backward_mapping, + raft::host_vector_view augmented_backward_mapping, + raft::host_vector_view core_partition_offsets, + raft::host_vector_view augmented_partition_offsets, + raft::host_matrix_view sub_dataset) +{ + const size_t vector_size_bytes = dataset_dim * sizeof(T); + + // Copy core partition vectors +#pragma omp parallel for + for (size_t j = 0; j < core_sub_dataset_size; j++) { + size_t i = core_backward_mapping(j + core_partition_offsets(partition_id)); + memcpy(&sub_dataset(j, 0), &dataset(i, 0), vector_size_bytes); + } + + // Copy augmented partition vectors (2nd closest partition) +#pragma omp parallel for + for (size_t j = 0; j < augmented_sub_dataset_size; j++) { + size_t i = augmented_backward_mapping(j + augmented_partition_offsets(partition_id)); + memcpy(&sub_dataset(j + core_sub_dataset_size, 0), &dataset(i, 0), vector_size_bytes); + } +} + +// ACE: Adjust IDs from core and augmented partitions to global reordered IDs +template +void ace_adjust_sub_graph_ids( + size_t core_sub_dataset_size, + size_t augmented_sub_dataset_size, + size_t graph_degree, + size_t partition_id, + raft::host_matrix_view sub_search_graph, + raft::host_matrix_view search_graph, + raft::host_vector_view core_partition_offsets, + raft::host_vector_view augmented_partition_offsets, + raft::host_vector_view core_backward_mapping, + raft::host_vector_view augmented_backward_mapping) +{ +#pragma omp parallel for + for (size_t i = 0; i < core_sub_dataset_size; i++) { + // Map row index from local → reordered → original + size_t i_reordered = i + core_partition_offsets(partition_id); + size_t i_original = core_backward_mapping(i_reordered); + + for (size_t k = 0; k < graph_degree; k++) { + size_t j = sub_search_graph(i, k); + size_t j_original; + + if (j < core_sub_dataset_size) { + // core partition neighbor: local → core reordered → original + size_t j_reordered = j + core_partition_offsets(partition_id); + j_original = core_backward_mapping(j_reordered); + } else { + // Augmented partition neighbor: local → augmented reordered → original + size_t j_augmented = j - core_sub_dataset_size; + j_original = + augmented_backward_mapping(j_augmented + augmented_partition_offsets(partition_id)); + } + search_graph(i_original, k) = j_original; + } + } +} + +// ACE: Adjust ids in sub search graph in place for disk version +template +void ace_adjust_sub_graph_ids_disk( + size_t core_sub_dataset_size, + size_t augmented_sub_dataset_size, + size_t graph_degree, + size_t partition_id, + raft::host_matrix_view sub_search_graph, + raft::host_vector_view core_partition_offsets, + raft::host_vector_view augmented_partition_offsets, + raft::host_vector_view augmented_backward_mapping, + raft::host_vector_view core_forward_mapping) +{ +#pragma omp parallel for + for (size_t i = 0; i < core_sub_dataset_size; i++) { + for (size_t k = 0; k < graph_degree; k++) { + size_t j = sub_search_graph(i, k); + if (j < core_sub_dataset_size) { + // core partition neighbor: local → core reordered + sub_search_graph(i, k) = j + core_partition_offsets(partition_id); + } else { + // Augmented partition neighbor: local → augmented reordered→ original → core reordered + size_t j_augmented = j - core_sub_dataset_size; + size_t j_original = + augmented_backward_mapping(j_augmented + augmented_partition_offsets(partition_id)); + sub_search_graph(i, k) = core_forward_mapping(j_original); + } + } + } +} + +// ACE: Reorder dataset based on partition assignments and store to disk +// Writes two files: reordered_dataset.npy (core partitions) and augmented_dataset.npy (secondary +// partitions). Uses buffered writes optimized for NVMe storage. +template +void ace_reorder_and_store_dataset( + raft::resources const& res, + const std::string& build_dir, + raft::host_matrix_view dataset, + raft::host_matrix_view partition_labels, + raft::host_matrix_view partition_histogram, + raft::host_vector_view core_backward_mapping, + raft::host_vector_view core_partition_offsets, + raft::host_vector_view augmented_partition_offsets, + cuvs::util::file_descriptor& reordered_fd, + cuvs::util::file_descriptor& augmented_fd, + cuvs::util::file_descriptor& mapping_fd, + size_t reordered_header_size, + size_t augmented_header_size, + size_t mapping_header_size) +{ + auto start = std::chrono::high_resolution_clock::now(); + + size_t dataset_size = dataset.extent(0); + size_t dataset_dim = dataset.extent(1); + size_t n_partitions = partition_histogram.extent(0); + + RAFT_LOG_DEBUG( + "ACE: Reordering and storing dataset to disk (%lu vectors, %lu dimensions, %lu partitions)", + dataset_size, + dataset_dim, + n_partitions); + + // Calculate total sizes for pre-allocation + size_t total_core_vectors = 0; + size_t total_augmented_vectors = 0; + size_t max_core_vectors = 0; + size_t max_augmented_vectors = 0; + for (size_t p = 0; p < n_partitions; p++) { + total_core_vectors += partition_histogram(p, 0); + total_augmented_vectors += partition_histogram(p, 1); + max_core_vectors = std::max(max_core_vectors, partition_histogram(p, 0)); + max_augmented_vectors = std::max(max_augmented_vectors, partition_histogram(p, 1)); + } + RAFT_EXPECTS(total_core_vectors == dataset_size, + "Total core vectors must be equal to dataset size"); + RAFT_EXPECTS(total_augmented_vectors == dataset_size, + "Total augmented vectors must be equal to dataset size"); + + // Pre-allocate file space for better performance + const size_t vector_size = dataset_dim * sizeof(T); + size_t reordered_file_size = total_core_vectors * vector_size; + size_t augmented_file_size = total_augmented_vectors * vector_size; + + RAFT_LOG_DEBUG("ACE: Reordered dataset: %lu core vectors (%.2f GiB)", + total_core_vectors, + reordered_file_size / (1024.0 * 1024.0 * 1024.0)); + RAFT_LOG_DEBUG("ACE: Augmented dataset: %lu secondary vectors (%.2f GiB)", + total_augmented_vectors, + augmented_file_size / (1024.0 * 1024.0 * 1024.0)); + + // Calculate partition start offsets for reordered and augmented datasets + auto core_partition_starts = raft::make_host_vector(n_partitions + 1); + memset(core_partition_starts.data_handle(), 0, (n_partitions + 1) * sizeof(size_t)); + auto augmented_partition_starts = raft::make_host_vector(n_partitions + 1); + memset(augmented_partition_starts.data_handle(), 0, (n_partitions + 1) * sizeof(size_t)); + auto core_partition_current = raft::make_host_vector(n_partitions); + memset(core_partition_current.data_handle(), 0, n_partitions * sizeof(size_t)); + auto augmented_partition_current = raft::make_host_vector(n_partitions); + memset(augmented_partition_current.data_handle(), 0, n_partitions * sizeof(size_t)); + + for (size_t p = 0; p < n_partitions; p++) { + core_partition_starts(p + 1) = core_partition_starts(p) + partition_histogram(p, 0); + augmented_partition_starts(p + 1) = augmented_partition_starts(p) + partition_histogram(p, 1); + } + + const size_t free_memory = cuvs::util::get_free_host_memory(); + // Conservatively allocate 50% of free memory per partition. Accounts for internal buffers and + // overhead. + // TODO: Adjust overhead if needed. + const size_t memory_per_partition = 0.5 * free_memory / (n_partitions * 2); + size_t disk_write_size = raft::bound_by_power_of_two(memory_per_partition); + // 64MB should be enough to saturate typical NVMe SSDs. + disk_write_size = std::min(disk_write_size, 64 * 1024 * 1024); + size_t vectors_per_buffer = std::max(64, disk_write_size / vector_size); + + RAFT_LOG_DEBUG("ACE: Reorder buffers: %lu vectors per buffer (%.2f MiB)", + vectors_per_buffer, + vectors_per_buffer * vector_size / (1024.0 * 1024.0)); + + std::vector> core_buffers; + std::vector> augmented_buffers; + auto core_buffer_counts = raft::make_host_vector(n_partitions); + auto augmented_buffer_counts = raft::make_host_vector(n_partitions); + + core_buffers.reserve(n_partitions); + augmented_buffers.reserve(n_partitions); + + for (size_t p = 0; p < n_partitions; p++) { + core_buffers.emplace_back(raft::make_host_matrix(vectors_per_buffer, dataset_dim)); + augmented_buffers.emplace_back( + raft::make_host_matrix(vectors_per_buffer, dataset_dim)); + core_buffer_counts(p) = 0; + augmented_buffer_counts(p) = 0; + } + auto flush_core_buffer = [&](size_t partition_id) { + const size_t count = core_buffer_counts(partition_id); + if (count > 0) { + const size_t bytes_to_write = count * vector_size; + const size_t file_offset = + (core_partition_starts(partition_id) + core_partition_current(partition_id)) * vector_size + + reordered_header_size; + + cuvs::util::write_large_file( + reordered_fd, core_buffers[partition_id].data_handle(), bytes_to_write, file_offset); + + core_partition_current(partition_id) += count; + core_buffer_counts(partition_id) = 0; + } + }; + + auto flush_augmented_buffer = [&](size_t partition_id) { + const size_t count = augmented_buffer_counts(partition_id); + if (count > 0) { + const size_t bytes_to_write = count * vector_size; + const size_t file_offset = + (augmented_partition_starts(partition_id) + augmented_partition_current(partition_id)) * + vector_size + + augmented_header_size; + + cuvs::util::write_large_file( + augmented_fd, augmented_buffers[partition_id].data_handle(), bytes_to_write, file_offset); + + augmented_partition_current(partition_id) += count; + augmented_buffer_counts(partition_id) = 0; + } + }; + + size_t vectors_processed = 0; + const size_t log_interval = std::max(dataset_size / 10, size_t(1)); + for (size_t i = 0; i < dataset_size; i++) { + size_t core_partition = partition_labels(i, 0); + size_t secondary_partition = partition_labels(i, 1); + + // Add vector to core partition buffer + size_t core_buffer_row = core_buffer_counts(core_partition); + memcpy( + &core_buffers[core_partition](core_buffer_row, 0), &dataset(i, 0), dataset_dim * sizeof(T)); + core_buffer_counts(core_partition)++; + + // Flush core buffer if full + if (core_buffer_counts(core_partition) >= vectors_per_buffer) { + flush_core_buffer(core_partition); + } + + // Add vector to augmented partition buffer + size_t augmented_buffer_row = augmented_buffer_counts(secondary_partition); + memcpy(&augmented_buffers[secondary_partition](augmented_buffer_row, 0), + &dataset(i, 0), + dataset_dim * sizeof(T)); + augmented_buffer_counts(secondary_partition)++; + + // Flush augmented buffer if full + if (augmented_buffer_counts(secondary_partition) >= vectors_per_buffer) { + flush_augmented_buffer(secondary_partition); + } + + vectors_processed++; + if (vectors_processed % log_interval == 0) { + RAFT_LOG_INFO("ACE: Processed %lu/%lu vectors (%.1f%%)", + vectors_processed, + dataset_size, + 100.0 * vectors_processed / dataset_size); + } + } + + // Flush all remaining buffers + RAFT_LOG_DEBUG("ACE: Flushing remaining buffers..."); +#pragma omp parallel sections + { +#pragma omp section + { + for (size_t p = 0; p < n_partitions; p++) { + flush_core_buffer(p); + } + } +#pragma omp section + { + for (size_t p = 0; p < n_partitions; p++) { + flush_augmented_buffer(p); + } + } + } + + const size_t mapping_file_size = dataset_size * sizeof(IdxT); + cuvs::util::write_large_file( + mapping_fd, core_backward_mapping.data_handle(), mapping_file_size, mapping_header_size); + + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed_ms = std::chrono::duration_cast(end - start).count(); + + // Calculate total bytes written + size_t total_bytes_written = reordered_file_size + augmented_file_size + mapping_file_size; + double throughput_mb_s = + elapsed_ms > 0 ? (total_bytes_written / (1024.0 * 1024.0)) / (elapsed_ms / 1000.0) : 0.0; + + RAFT_LOG_INFO( + "ACE: Dataset (%.2f GiB reordered, %.2f GiB augmented, %.2f GiB mapping) reordering completed " + "in %ld ms (%.1f MiB/s)", + reordered_file_size / (1024.0 * 1024.0 * 1024.0), + augmented_file_size / (1024.0 * 1024.0 * 1024.0), + mapping_file_size / (1024.0 * 1024.0 * 1024.0), + elapsed_ms, + throughput_mb_s); +} + +// ACE: Load partition dataset and augmented dataset from disk +template +void ace_load_partition_dataset_from_disk( + raft::resources const& res, + const std::string& build_dir, + size_t partition_id, + size_t dataset_dim, + raft::host_matrix_view partition_histogram, + raft::host_vector_view core_partition_offsets, + raft::host_vector_view augmented_partition_offsets, + raft::host_matrix_view sub_dataset) +{ + size_t n_partitions = partition_histogram.extent(0); + + RAFT_LOG_DEBUG("ACE: Loading partition %lu dataset from disk", partition_id); + + size_t core_size = partition_histogram(partition_id, 0); + size_t augmented_size = partition_histogram(partition_id, 1); + size_t total_partition_size = core_size + augmented_size; + + RAFT_LOG_DEBUG("ACE: Partition %lu: %lu core + %lu augmented = %lu total vectors", + partition_id, + core_size, + augmented_size, + total_partition_size); + + RAFT_EXPECTS(static_cast(sub_dataset.extent(0)) == total_partition_size, + "sub_dataset rows (%lu) must match total partition size (%lu)", + sub_dataset.extent(0), + total_partition_size); + RAFT_EXPECTS(static_cast(sub_dataset.extent(1)) == dataset_dim, + "sub_dataset columns (%lu) must match dataset dimensions (%lu)", + sub_dataset.extent(1), + dataset_dim); + + const size_t vector_size = dataset_dim * sizeof(T); + + const std::string reordered_dataset_path = build_dir + "/reordered_dataset.npy"; + const std::string augmented_dataset_path = build_dir + "/augmented_dataset.npy"; + + if (!std::filesystem::exists(reordered_dataset_path)) { + RAFT_FAIL("ACE: Required file does not exist: %s", reordered_dataset_path.c_str()); + } + if (!std::filesystem::exists(augmented_dataset_path)) { + RAFT_FAIL("ACE: Required file does not exist: %s", augmented_dataset_path.c_str()); + } + + size_t core_header_size = 0; + size_t augmented_header_size = 0; + size_t core_file_offset = 0; + size_t augmented_file_offset = 0; + { + std::ifstream is(reordered_dataset_path, std::ios::in | std::ios::binary); + if (!is) { RAFT_FAIL("Cannot open file %s", reordered_dataset_path.c_str()); } + auto start_pos = is.tellg(); + raft::detail::numpy_serializer::read_header(is); + core_header_size = static_cast(is.tellg() - start_pos); + } + { + std::ifstream is(augmented_dataset_path, std::ios::in | std::ios::binary); + if (!is) { RAFT_FAIL("Cannot open file %s", augmented_dataset_path.c_str()); } + auto start_pos = is.tellg(); + raft::detail::numpy_serializer::read_header(is); + augmented_header_size = static_cast(is.tellg() - start_pos); + } + + for (size_t p = 0; p < partition_id; p++) { + core_file_offset += partition_histogram(p, 0); + augmented_file_offset += partition_histogram(p, 1); + } + + core_file_offset *= vector_size; + augmented_file_offset *= vector_size; + + core_file_offset += core_header_size; + augmented_file_offset += augmented_header_size; + + RAFT_LOG_DEBUG("ACE: Core file offset: %lu bytes, Augmented file offset: %lu bytes", + core_file_offset, + augmented_file_offset); + + // Read core and augmented data in parallel + std::exception_ptr core_exception = nullptr; + std::exception_ptr augmented_exception = nullptr; + +#pragma omp parallel sections + { +#pragma omp section + { + try { + if (core_size > 0) { + RAFT_LOG_DEBUG( + "ACE: Reading %lu core vectors from offset %lu", core_size, core_file_offset); + cuvs::util::file_descriptor reordered_fd(reordered_dataset_path, O_RDONLY); + const size_t core_bytes = core_size * vector_size; + cuvs::util::read_large_file( + reordered_fd, sub_dataset.data_handle(), core_bytes, core_file_offset); + } + } catch (...) { + core_exception = std::current_exception(); + } + } +#pragma omp section + { + try { + if (augmented_size > 0) { + RAFT_LOG_DEBUG("ACE: Reading %lu augmented vectors from offset %lu", + augmented_size, + augmented_file_offset); + cuvs::util::file_descriptor augmented_fd(augmented_dataset_path, O_RDONLY); + const size_t augmented_bytes = augmented_size * vector_size; + T* augmented_dest = sub_dataset.data_handle() + (core_size * dataset_dim); + cuvs::util::read_large_file( + augmented_fd, augmented_dest, augmented_bytes, augmented_file_offset); + } + } catch (...) { + augmented_exception = std::current_exception(); + } + } + } + + // Check for exceptions from parallel sections + if (core_exception) { std::rethrow_exception(core_exception); } + if (augmented_exception) { std::rethrow_exception(augmented_exception); } +} + +// Build CAGRA index using ACE (Augmented Core Extraction) partitioning +// ACE enables building indices for datasets too large to fit in GPU memory by: +// 1. Partitioning the dataset using balanced k-means in core (non-overlapping) and augmented +// (second-closest) partitions +// 2. Building sub-indices for each partition independently +// 3. Concatenating sub-graphs (of core partitions) into a final unified index +// Supports both in-memory and disk-based modes depending on available host memory. +// In disk mode, the graph is stored in build_dir and dataset is reordered on disk. +// The returned index is not usable for search. Use the created files for search instead. +template +index build_ace(raft::resources const& res, + const index_params& params, + raft::host_matrix_view dataset) +{ + // Extract ACE parameters from graph_build_params + RAFT_EXPECTS( + std::holds_alternative(params.graph_build_params), + "ACE build requires graph_build_params to be set to ace_params"); + + auto ace_params = std::get(params.graph_build_params); + size_t npartitions = ace_params.npartitions; + size_t ef_construction = ace_params.ef_construction; + std::string build_dir = ace_params.build_dir; + bool use_disk = ace_params.use_disk; + + common::nvtx::range function_scope( + "cagra::build_ace(%zu, %zu, %zu)", + params.intermediate_graph_degree, + params.graph_degree, + npartitions); + + size_t dataset_size = dataset.extent(0); + size_t dataset_dim = dataset.extent(1); + + RAFT_EXPECTS(dataset_size > 0, "ACE: Dataset must not be empty"); + if (dataset_size < 1000) { + RAFT_LOG_WARN("ACE: Very small dataset size (%zu), consider using regular CAGRA build instead.", + dataset_size); + } + RAFT_EXPECTS(dataset_dim > 0, "ACE: Dataset dimension must be greater than 0"); + RAFT_EXPECTS(params.intermediate_graph_degree > 0, + "ACE: Intermediate graph degree must be greater than 0"); + RAFT_EXPECTS(params.graph_degree > 0, "ACE: Graph degree must be greater than 0"); + + size_t n_partitions = npartitions; + RAFT_EXPECTS(n_partitions > 0, "ACE: npartitions must be greater than 0"); + + size_t min_required_per_partition = 1000; + if (n_partitions > dataset_size / min_required_per_partition) { + n_partitions = dataset_size / min_required_per_partition; + if (n_partitions < 2) { + RAFT_LOG_WARN( + "ACE: Reduced number of partitions to the minimum of 2 to avoid tiny partitions. Consider " + "using regular CAGRA build instead."); + n_partitions = 2; + } else { + RAFT_LOG_WARN("ACE: Reduced number of partitions to %zu to avoid tiny partitions", + n_partitions); + } + } + + auto total_start = std::chrono::high_resolution_clock::now(); + RAFT_LOG_INFO("ACE: Starting partitioned CAGRA build with %zu partitions", n_partitions); + + size_t intermediate_degree = params.intermediate_graph_degree; + size_t graph_degree = params.graph_degree; + + // Track whether to clean up build directory on failure + bool cleanup_on_failure = false; + + try { + check_graph_degree(intermediate_degree, graph_degree, dataset_size); + + size_t available_memory = cuvs::util::get_free_host_memory(); + + // Optimistic memory model: focus on largest arrays, assumes all partitions are of equal size + // For memory path: + // - Partition labes (core + augmented): 2 * dataset_size * sizeof(IdxT) + // - Backward ID mapping arrays (core + augmented): 2 * dataset_size * sizeof(IdxT) + // - Per-partition dataset (2x for imbalanced partitions): 4 * (dataset_size / n_partitions) * + // dataset_dim * sizeof(T) + // - Per-partition graph during build: (dataset_size / n_partitions) * (intermediate + final) + // * sizeof(IdxT) + // - Final assembled graph: dataset_size * graph_degree * sizeof(IdxT) + size_t ace_partition_labels_size = 2 * dataset_size * sizeof(IdxT); + size_t ace_id_mapping_size = 2 * dataset_size * sizeof(IdxT); + size_t ace_sub_dataset_size = 4 * (dataset_size / n_partitions) * dataset_dim * sizeof(T); + size_t ace_sub_graph_size = + (dataset_size / n_partitions) * (intermediate_degree + graph_degree) * sizeof(IdxT); + size_t cagra_graph_size = dataset_size * graph_degree * sizeof(IdxT); + size_t total_size = ace_partition_labels_size + ace_id_mapping_size + ace_sub_dataset_size + + ace_sub_graph_size + cagra_graph_size; + RAFT_LOG_INFO("ACE: Estimated host memory required: %.2f GiB, available: %.2f GiB", + total_size / (1024.0 * 1024.0 * 1024.0), + available_memory / (1024.0 * 1024.0 * 1024.0)); + // TODO: Adjust overhead factor if needed + bool host_memory_limited = static_cast(0.8 * available_memory) < total_size; + + // GPU is mostly limited by the index size (update_graph() in the end of this routine). + // Check if GPU has enough memory for the final graph or use disk mode instead. + // TODO: Extend model or use managed memory if running out of GPU memory. + auto available_gpu_memory = rmm::available_device_memory().second; + bool gpu_memory_limited = static_cast(0.8 * available_gpu_memory) < cagra_graph_size; + RAFT_LOG_INFO("ACE: Estimated GPU memory required: %.2f GiB, available: %.2f GiB", + cagra_graph_size / (1024.0 * 1024.0 * 1024.0), + available_gpu_memory / (1024.0 * 1024.0 * 1024.0)); + + bool use_disk_mode = use_disk || host_memory_limited || gpu_memory_limited; + if (use_disk_mode) { + bool valid_build_dir = !build_dir.empty(); + valid_build_dir &= build_dir.length() <= 255; + valid_build_dir &= build_dir.find('\0') == std::string::npos; + valid_build_dir &= build_dir.find("//") == std::string::npos; + if (!valid_build_dir) { + RAFT_LOG_WARN("ACE: Invalid build_dir path, resetting to default: /tmp/ace_build"); + build_dir = "/tmp/ace_build"; + } + if (mkdir(build_dir.c_str(), 0755) != 0 && errno != EEXIST) { + RAFT_EXPECTS(false, "Failed to create ACE build directory: %s", build_dir.c_str()); + } + } + + if (host_memory_limited && gpu_memory_limited) { + RAFT_LOG_INFO( + "ACE: Graph does not fit in host and GPU memory. Using disk-mode with temporary storage %s", + build_dir.c_str()); + } else if (host_memory_limited) { + RAFT_LOG_INFO( + "ACE: Graph does not fit in host memory. Using disk-mode with temporary storage %s", + build_dir.c_str()); + } else if (gpu_memory_limited) { + RAFT_LOG_INFO( + "ACE: Graph does not fit in GPU memory. Using disk-mode with temporary storage %s", + build_dir.c_str()); + } else if (use_disk) { + RAFT_LOG_INFO( + "ACE: Graph fits in host and GPU memory but disk mode is forced. Using disk-mode with " + "temporary storage %s", + build_dir.c_str()); + } else { + RAFT_LOG_INFO("ACE: Graph fits in host and GPU memory. Using in-memory mode."); + } + + // Preallocate space for files for better performance and fail early if not enough space. + cuvs::util::file_descriptor reordered_fd; + cuvs::util::file_descriptor augmented_fd; + cuvs::util::file_descriptor mapping_fd; + cuvs::util::file_descriptor graph_fd; + size_t reordered_header_size = 0; + size_t augmented_header_size = 0; + size_t mapping_header_size = 0; + size_t graph_header_size = 0; + + if (use_disk_mode) { + if (mkdir(build_dir.c_str(), 0755) != 0 && errno != EEXIST) { + RAFT_EXPECTS(false, "Failed to create ACE build directory: %s", build_dir.c_str()); + } + // Mark for cleanup if we fail after creating the directory + cleanup_on_failure = true; + + // Helper lambda to write numpy header to file descriptor + auto write_numpy_header = [](int fd, + const std::vector& shape, + const raft::detail::numpy_serializer::dtype_t& dtype) { + std::stringstream ss; + + const bool fortran_order = false; + const raft::detail::numpy_serializer::header_t header = {dtype, fortran_order, shape}; + + raft::detail::numpy_serializer::write_header(ss, header); + + std::string header_str = ss.str(); + ssize_t written = write(fd, header_str.data(), header_str.size()); + if (written < 0 || static_cast(written) != header_str.size()) { + RAFT_FAIL("Failed to write numpy header to file descriptor"); + } + return header_str.size(); + }; + + // Create and allocate dataset file + reordered_fd = cuvs::util::file_descriptor( + build_dir + "/reordered_dataset.npy", O_CREAT | O_RDWR | O_TRUNC, 0644); + { + std::stringstream ss; + const auto dtype = raft::detail::numpy_serializer::get_numpy_dtype(); + const bool fortran_order = false; + const raft::detail::numpy_serializer::header_t header = { + dtype, fortran_order, {dataset_size, dataset_dim}}; + raft::detail::numpy_serializer::write_header(ss, header); + reordered_header_size = ss.str().size(); + } + if (posix_fallocate(reordered_fd.get(), + 0, + reordered_header_size + dataset_size * dataset_dim * sizeof(T)) != 0) { + RAFT_FAIL("Failed to pre-allocate space for reordered dataset file"); + } + { + auto dtype_for_dataset = raft::detail::numpy_serializer::get_numpy_dtype(); + RAFT_LOG_DEBUG("Writing reordered_dataset.npy header: shape=[%zu,%zu], dtype=%c", + dataset_size, + dataset_dim, + dtype_for_dataset.kind); + if (lseek(reordered_fd.get(), 0, SEEK_SET) == -1) { + RAFT_FAIL("Failed to seek to beginning of reordered dataset file"); + } + write_numpy_header(reordered_fd.get(), {dataset_size, dataset_dim}, dtype_for_dataset); + } + + // Create and allocate augmented dataset file + augmented_fd = cuvs::util::file_descriptor( + build_dir + "/augmented_dataset.npy", O_CREAT | O_RDWR | O_TRUNC, 0644); + { + std::stringstream ss; + const auto dtype = raft::detail::numpy_serializer::get_numpy_dtype(); + const bool fortran_order = false; + const raft::detail::numpy_serializer::header_t header = { + dtype, fortran_order, {dataset_size, dataset_dim}}; + raft::detail::numpy_serializer::write_header(ss, header); + augmented_header_size = ss.str().size(); + } + if (posix_fallocate(augmented_fd.get(), + 0, + augmented_header_size + dataset_size * dataset_dim * sizeof(T)) != 0) { + RAFT_FAIL("Failed to pre-allocate space for augmented dataset file"); + } + // Seek to beginning before writing header + if (lseek(augmented_fd.get(), 0, SEEK_SET) == -1) { + RAFT_FAIL("Failed to seek to beginning of augmented dataset file"); + } + write_numpy_header(augmented_fd.get(), + {dataset_size, dataset_dim}, + raft::detail::numpy_serializer::get_numpy_dtype()); + + // Create and allocate mapping file + mapping_fd = cuvs::util::file_descriptor( + build_dir + "/dataset_mapping.npy", O_CREAT | O_RDWR | O_TRUNC, 0644); + { + std::stringstream ss; + const auto dtype = raft::detail::numpy_serializer::get_numpy_dtype(); + const bool fortran_order = false; + const raft::detail::numpy_serializer::header_t header = { + dtype, fortran_order, {dataset_size}}; + raft::detail::numpy_serializer::write_header(ss, header); + mapping_header_size = ss.str().size(); + } + if (posix_fallocate(mapping_fd.get(), 0, mapping_header_size + dataset_size * sizeof(IdxT)) != + 0) { + RAFT_FAIL("Failed to pre-allocate space for dataset mapping file"); + } + { + auto dtype_for_mapping = raft::detail::numpy_serializer::get_numpy_dtype(); + RAFT_LOG_DEBUG("Writing dataset_mapping.npy header: shape=[%zu], dtype=%c", + dataset_size, + dtype_for_mapping.kind); + if (lseek(mapping_fd.get(), 0, SEEK_SET) == -1) { + RAFT_FAIL("Failed to seek to beginning of mapping file"); + } + write_numpy_header(mapping_fd.get(), {dataset_size}, dtype_for_mapping); + } + + // Create and allocate graph file + graph_fd = cuvs::util::file_descriptor( + build_dir + "/cagra_graph.npy", O_CREAT | O_RDWR | O_TRUNC, 0644); + { + std::stringstream ss; + const auto dtype = raft::detail::numpy_serializer::get_numpy_dtype(); + const bool fortran_order = false; + const raft::detail::numpy_serializer::header_t header = { + dtype, fortran_order, {dataset_size, graph_degree}}; + raft::detail::numpy_serializer::write_header(ss, header); + graph_header_size = ss.str().size(); + } + if (posix_fallocate(graph_fd.get(), 0, graph_header_size + cagra_graph_size) != 0) { + RAFT_FAIL("Failed to pre-allocate space for graph file"); + } + { + auto dtype_for_graph = raft::detail::numpy_serializer::get_numpy_dtype(); + RAFT_LOG_DEBUG("Writing cagra_graph.npy header: shape=[%zu,%zu], dtype=%c", + dataset_size, + graph_degree, + dtype_for_graph.kind); + if (lseek(graph_fd.get(), 0, SEEK_SET) == -1) { + RAFT_FAIL("Failed to seek to beginning of graph file"); + } + write_numpy_header(graph_fd.get(), {dataset_size, graph_degree}, dtype_for_graph); + } + + RAFT_LOG_DEBUG( + "ACE: Wrote numpy headers (reordered: %zu, augmented: %zu, mapping: %zu, graph: %zu bytes)", + reordered_header_size, + augmented_header_size, + mapping_header_size, + graph_header_size); + } + + auto partition_start = std::chrono::high_resolution_clock::now(); + auto partition_labels = raft::make_host_matrix(dataset_size, 2); + auto partition_histogram = raft::make_host_matrix(n_partitions, 2); + for (size_t c = 0; c < n_partitions; c++) { + partition_histogram(c, 0) = 0; + partition_histogram(c, 1) = 0; + } + + // Determine minimum partition size for stable KNN graph construction + size_t min_partition_size = std::max(1000ULL, dataset_size / n_partitions * 0.1); + + ace_get_partition_labels( + res, dataset, partition_labels.view(), partition_histogram.view(), min_partition_size); + + ace_check_partition_sizes(dataset_size, + n_partitions, + partition_labels.view(), + partition_histogram.view(), + min_partition_size); + + auto partition_end = std::chrono::high_resolution_clock::now(); + auto partition_elapsed = + std::chrono::duration_cast(partition_end - partition_start) + .count(); + RAFT_LOG_INFO( + "ACE: Partition labeling completed in %ld ms (min_partition_size: " + "%lu)", + partition_elapsed, + min_partition_size); + + // Create vector lists for each partition + auto vectorlist_start = std::chrono::high_resolution_clock::now(); + auto core_forward_mapping = use_disk_mode ? raft::make_host_vector(dataset_size) + : raft::make_host_vector(0); + auto core_backward_mapping = raft::make_host_vector(dataset_size); + auto augmented_backward_mapping = raft::make_host_vector(dataset_size); + auto core_partition_offsets = raft::make_host_vector(n_partitions + 1); + auto augmented_partition_offsets = raft::make_host_vector(n_partitions + 1); + + ace_create_forward_and_backward_lists(dataset_size, + n_partitions, + partition_labels.view(), + partition_histogram.view(), + core_forward_mapping.view(), + core_backward_mapping.view(), + augmented_backward_mapping.view(), + core_partition_offsets.view(), + augmented_partition_offsets.view()); + + auto vectorlist_end = std::chrono::high_resolution_clock::now(); + auto vectorlist_elapsed = + std::chrono::duration_cast(vectorlist_end - vectorlist_start) + .count(); + RAFT_LOG_INFO("ACE: Vector list creation completed in %ld ms", vectorlist_elapsed); + + // Reorder the dataset based on partitions and store to disk. Uses write buffers to improve + // performance. + if (use_disk_mode) { + ace_reorder_and_store_dataset(res, + build_dir, + dataset, + partition_labels.view(), + partition_histogram.view(), + core_backward_mapping.view(), + core_partition_offsets.view(), + augmented_partition_offsets.view(), + reordered_fd, + augmented_fd, + mapping_fd, + reordered_header_size, + augmented_header_size, + mapping_header_size); + // core_backward_mapping is not needed anymore. + core_backward_mapping = raft::make_host_vector(0); + } + + // Placeholder search graph for in-memory version + auto search_graph = use_disk_mode + ? raft::make_host_matrix(0, 0) + : raft::make_host_matrix(dataset_size, graph_degree); + + // Process each partition + auto partition_processing_start = std::chrono::high_resolution_clock::now(); + for (size_t partition_id = 0; partition_id < n_partitions; partition_id++) { + RAFT_LOG_DEBUG("ACE: Processing partition %lu/%lu", partition_id + 1, n_partitions); + auto start = std::chrono::high_resolution_clock::now(); + + // Extract vectors for this partition + size_t core_sub_dataset_size = partition_histogram(partition_id, 0); + size_t augmented_sub_dataset_size = partition_histogram(partition_id, 1); + size_t sub_dataset_size = core_sub_dataset_size + augmented_sub_dataset_size; + + if (sub_dataset_size == 0) { + RAFT_LOG_WARN("ACE: Skipping empty partition %lu", partition_id); + continue; + } + RAFT_LOG_DEBUG("ACE: Sub-dataset size: %lu (%lu + %lu)", + sub_dataset_size, + core_sub_dataset_size, + augmented_sub_dataset_size); + + auto sub_dataset = raft::make_host_matrix(sub_dataset_size, dataset_dim); + + if (use_disk_mode) { + // Load partition dataset from disk files + ace_load_partition_dataset_from_disk(res, + build_dir, + partition_id, + dataset_dim, + partition_histogram.view(), + core_partition_offsets.view(), + augmented_partition_offsets.view(), + sub_dataset.view()); + } else { + // Gather partition dataset from memory + ace_gather_partition_dataset(core_sub_dataset_size, + augmented_sub_dataset_size, + dataset_dim, + partition_id, + dataset, + core_backward_mapping.view(), + augmented_backward_mapping.view(), + core_partition_offsets.view(), + augmented_partition_offsets.view(), + sub_dataset.view()); + } + auto read_end = std::chrono::high_resolution_clock::now(); + auto read_elapsed = + std::chrono::duration_cast(read_end - start).count(); + + // Create index for this partition + cuvs::neighbors::cagra::index_params sub_index_params; + sub_index_params = cuvs::neighbors::cagra::index_params::from_hnsw_params( + raft::make_extents(sub_dataset_size, dataset_dim), + graph_degree / 2, + ef_construction, + cuvs::neighbors::cagra::hnsw_heuristic_type::SAME_GRAPH_FOOTPRINT, + params.metric); + sub_index_params.attach_dataset_on_build = false; + sub_index_params.guarantee_connectivity = params.guarantee_connectivity; + + auto sub_index = cuvs::neighbors::cagra::build( + res, sub_index_params, raft::make_const_mdspan(sub_dataset.view())); + + auto optimize_end = std::chrono::high_resolution_clock::now(); + auto optimize_elapsed = + std::chrono::duration_cast(optimize_end - read_end).count(); + + // Copy graph edges for core members of this partition + auto sub_search_graph = + raft::make_host_matrix(core_sub_dataset_size, graph_degree); + cudaStream_t stream = raft::resource::get_cuda_stream(res); + raft::update_host(sub_search_graph.data_handle(), + sub_index.graph().data_handle(), + sub_search_graph.size(), + stream); + raft::resource::sync_stream(res, stream); + + if (use_disk_mode) { + // Adjust IDs in sub_search_graph in place for disk storage + ace_adjust_sub_graph_ids_disk(core_sub_dataset_size, + augmented_sub_dataset_size, + graph_degree, + partition_id, + sub_search_graph.view(), + core_partition_offsets.view(), + augmented_partition_offsets.view(), + augmented_backward_mapping.view(), + core_forward_mapping.view()); + } else { + // Adjust IDs in sub_search_graph and save to search_graph + ace_adjust_sub_graph_ids(core_sub_dataset_size, + augmented_sub_dataset_size, + graph_degree, + partition_id, + sub_search_graph.view(), + search_graph.view(), + core_partition_offsets.view(), + augmented_partition_offsets.view(), + core_backward_mapping.view(), + augmented_backward_mapping.view()); + } + + auto adjust_end = std::chrono::high_resolution_clock::now(); + auto adjust_elapsed = + std::chrono::duration_cast(adjust_end - optimize_end).count(); + + if (use_disk_mode) { + const size_t graph_offset = + static_cast(core_partition_offsets(partition_id)) * graph_degree * sizeof(IdxT) + + graph_header_size; + const size_t graph_bytes = core_sub_dataset_size * graph_degree * sizeof(IdxT); + cuvs::util::write_large_file( + graph_fd, sub_search_graph.data_handle(), graph_bytes, graph_offset); + } + + auto end = std::chrono::high_resolution_clock::now(); + auto write_elapsed = + std::chrono::duration_cast(end - adjust_end).count(); + auto elapsed_ms = std::chrono::duration_cast(end - start).count(); + double read_throughput = read_elapsed > 0 ? sub_dataset_size * dataset_dim * sizeof(T) / + (1024.0 * 1024.0) / (read_elapsed / 1000.0) + : 0.0; + double write_throughput = write_elapsed > 0 + ? core_sub_dataset_size * dataset_dim * sizeof(T) / + (1024.0 * 1024.0) / (write_elapsed / 1000.0) + : 0.0; + RAFT_LOG_INFO( + "ACE: Partition %4lu (%8lu + %8lu) completed in %6ld ms: read %6ld ms (%7.1f MiB/s), " + "optimize %6ld ms, adjust %6ld ms, write %6ld ms (%7.1f MiB/s)", + partition_id, + core_sub_dataset_size, + augmented_sub_dataset_size, + elapsed_ms, + read_elapsed, + read_throughput, + optimize_elapsed, + adjust_elapsed, + write_elapsed, + write_throughput); + } + + auto partition_processing_end = std::chrono::high_resolution_clock::now(); + auto partition_processing_elapsed = std::chrono::duration_cast( + partition_processing_end - partition_processing_start) + .count(); + RAFT_LOG_INFO("ACE: All partition processing completed in %ld ms (%zu partitions)", + partition_processing_elapsed, + n_partitions); + + // Clean up augmented dataset file to save disk space (no longer needed after partitions + // processed) + if (use_disk_mode) { + const std::string augmented_dataset_path = build_dir + "/augmented_dataset.npy"; + if (std::filesystem::exists(augmented_dataset_path)) { + std::filesystem::remove(augmented_dataset_path); + RAFT_LOG_INFO("ACE: Removed augmented dataset file to save disk space"); + } + } + + auto index_creation_start = std::chrono::high_resolution_clock::now(); + index idx(res, params.metric); + // Only add graph and dataset if not using disk storage. The returned index is empty if using + // disk storage. Use the files written to disk for search. + if (!use_disk_mode) { + idx.update_graph(res, raft::make_const_mdspan(search_graph.view())); + + if (params.attach_dataset_on_build) { + try { + idx.update_dataset(res, dataset); + } catch (std::bad_alloc& e) { + RAFT_LOG_WARN( + "Insufficient GPU memory to attach dataset to ACE index. Only the graph will be " + "stored."); + } catch (raft::logic_error& e) { + RAFT_LOG_WARN( + "Insufficient GPU memory to attach dataset to ACE index. Only the graph will be " + "stored."); + } + } + } else { + idx.update_dataset(res, std::move(reordered_fd)); + idx.update_graph(res, std::move(graph_fd)); + idx.update_mapping(res, std::move(mapping_fd)); + + RAFT_LOG_INFO( + "ACE: Set disk storage at %s (dataset shape [%zu, %zu], graph shape [%zu, %zu])", + build_dir.c_str(), + idx.size(), + idx.dim(), + idx.size(), + idx.graph_degree()); + } + + auto index_creation_end = std::chrono::high_resolution_clock::now(); + auto index_creation_elapsed = std::chrono::duration_cast( + index_creation_end - index_creation_start) + .count(); + RAFT_LOG_INFO("ACE: Final index creation completed in %ld ms", index_creation_elapsed); + + auto total_end = std::chrono::high_resolution_clock::now(); + auto total_elapsed = + std::chrono::duration_cast(total_end - total_start).count(); + RAFT_LOG_INFO("ACE: Partitioned CAGRA build completed in %ld ms total", total_elapsed); + + return idx; + } catch (const std::exception& e) { + // Clean up build directory on failure if we created it + RAFT_LOG_ERROR("ACE: Build failed with exception: %s", e.what()); + if (cleanup_on_failure && !build_dir.empty()) { + RAFT_LOG_INFO("ACE: Cleaning up build directory: %s", build_dir.c_str()); + try { + std::filesystem::remove_all(build_dir); + RAFT_LOG_INFO("ACE: Successfully removed build directory"); + } catch (const std::exception& cleanup_error) { + RAFT_LOG_WARN("ACE: Failed to clean up build directory: %s", cleanup_error.what()); + } + } + // Re-throw the original exception + throw; + } +} + template void write_to_graph(raft::host_matrix_view knn_graph, raft::host_matrix_view neighbors_host_view, @@ -536,26 +1888,23 @@ auto iterative_build_graph( // Determine graph degree and number of search results while increasing // graph size. - auto small_graph_degree = std::max(graph_degree / 2, std::min(graph_degree, (uint64_t)32)); - auto small_topk = topk * small_graph_degree / graph_degree; - RAFT_LOG_DEBUG("# graph_degree = %lu", (uint64_t)graph_degree); + auto small_graph_degree = std::max(graph_degree / 2, std::min(graph_degree, (uint64_t)24)); RAFT_LOG_DEBUG("# small_graph_degree = %lu", (uint64_t)small_graph_degree); + RAFT_LOG_DEBUG("# graph_degree = %lu", (uint64_t)graph_degree); RAFT_LOG_DEBUG("# topk = %lu", (uint64_t)topk); - RAFT_LOG_DEBUG("# small_topk = %lu", (uint64_t)small_topk); // Create an initial graph. The initial graph created here is not suitable for // searching, but connectivity is guaranteed. - auto offset = raft::make_host_vector(small_graph_degree); - const double base = sqrt((double)2.0); + auto offset = raft::make_host_vector(small_graph_degree); for (uint64_t j = 0; j < small_graph_degree; j++) { if (j == 0) { offset(j) = 1; } else { offset(j) = offset(j - 1) + 1; } - IdxT ofst = initial_graph_size * pow(base, (double)j - small_graph_degree - 1); + IdxT ofst = pow((double)(initial_graph_size - 1) / 2, (double)(j + 1) / small_graph_degree); if (offset(j) < ofst) { offset(j) = ofst; } - RAFT_LOG_DEBUG("# offset(%lu) = %lu\n", (uint64_t)j, (uint64_t)offset(j)); + RAFT_LOG_DEBUG("# offset(%lu) = %lu", (uint64_t)j, (uint64_t)offset(j)); } cagra_graph = raft::make_host_matrix(initial_graph_size, small_graph_degree); for (uint64_t i = 0; i < initial_graph_size; i++) { @@ -572,22 +1921,34 @@ auto iterative_build_graph( IdxT* neighbors_ptr = (IdxT*)neighbors_list.data(); memset(neighbors_ptr, 0, byte_size); + bool flag_last = false; auto curr_graph_size = initial_graph_size; while (true) { - RAFT_LOG_DEBUG("# graph_size = %lu (%.3lf)", - (uint64_t)curr_graph_size, - (double)curr_graph_size / final_graph_size); - - auto curr_query_size = std::min(2 * curr_graph_size, final_graph_size); - auto curr_topk = small_topk; - auto curr_itopk_size = small_topk * 3 / 2; - auto curr_graph_degree = small_graph_degree; - if (curr_query_size == final_graph_size) { - curr_topk = topk; - curr_itopk_size = topk * 2; - curr_graph_degree = graph_degree; + auto start = std::chrono::high_resolution_clock::now(); + auto curr_query_size = std::min(2 * curr_graph_size, final_graph_size); + + auto next_graph_degree = small_graph_degree; + if (curr_graph_size == final_graph_size) { next_graph_degree = graph_degree; } + + // The search count (topk) is set to the next graph degree + 1, because + // pruning is not used except in the last iteration. + // (*) The appropriate setting for itopk_size requires careful consideration. + auto curr_topk = next_graph_degree + 1; + auto curr_itopk_size = next_graph_degree + 32; + if (flag_last) { + curr_topk = topk; + curr_itopk_size = curr_topk + 32; } + RAFT_LOG_INFO( + "# graph_size = %lu (%.3lf), graph_degree = %lu, query_size = %lu, itopk = %lu, topk = %lu", + (uint64_t)cagra_graph.extent(0), + (double)cagra_graph.extent(0) / final_graph_size, + (uint64_t)cagra_graph.extent(1), + (uint64_t)curr_query_size, + (uint64_t)curr_itopk_size, + (uint64_t)curr_topk); + cuvs::neighbors::cagra::search_params search_params; search_params.algo = cuvs::neighbors::cagra::search_algo::AUTO; search_params.max_queries = max_chunk_size; @@ -640,13 +2001,19 @@ auto iterative_build_graph( } // Optimize graph - bool flag_last = (curr_graph_size == final_graph_size); - curr_graph_size = curr_query_size; - cagra_graph = raft::make_host_matrix(0, 0); // delete existing grahp - cagra_graph = raft::make_host_matrix(curr_graph_size, curr_graph_degree); + auto next_graph_size = curr_query_size; + cagra_graph = raft::make_host_matrix(0, 0); // delete existing grahp + cagra_graph = raft::make_host_matrix(next_graph_size, next_graph_degree); optimize( res, neighbors_view, cagra_graph.view(), flag_last ? params.guarantee_connectivity : 0); + + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed_ms = std::chrono::duration_cast(end - start).count(); + RAFT_LOG_INFO("# elapsed time: %.3lf sec", (double)elapsed_ms / 1000); + if (flag_last) { break; } + flag_last = (curr_graph_size == final_graph_size); + curr_graph_size = next_graph_size; } return cagra_graph; @@ -670,20 +2037,7 @@ index build( : "device", intermediate_degree, graph_degree); - if (intermediate_degree >= static_cast(dataset.extent(0))) { - RAFT_LOG_WARN( - "Intermediate graph degree cannot be larger than dataset size, reducing it to %lu", - dataset.extent(0)); - intermediate_degree = dataset.extent(0) - 1; - } - if (intermediate_degree < graph_degree) { - RAFT_LOG_WARN( - "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " - "graph_degree.", - graph_degree, - intermediate_degree); - graph_degree = intermediate_degree; - } + check_graph_degree(intermediate_degree, graph_degree, dataset.extent(0)); // Set default value in case knn_build_params is not defined. auto knn_build_params = params.graph_build_params; diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 26e0aafd2d..45328377be 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -142,6 +142,11 @@ void search_main(raft::resources const& res, raft::device_matrix_view distances, CagraSampleFilterT sample_filter = CagraSampleFilterT()) { + RAFT_EXPECTS(!index.dataset_fd().has_value(), + "Cannot search a CAGRA index that is stored on disk. " + "Use cuvs::neighbors::hnsw::from_cagra() to convert the index and " + "cuvs::neighbors::hnsw::deserialize() to load it into memory before searching."); + // n_rows has the same type as the dataset index (the array extents type) using ds_idx_type = decltype(index.data().n_rows()); using graph_idx_type = uint32_t; diff --git a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh index 20984e3e45..866415b1e4 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh @@ -44,6 +44,11 @@ void serialize(raft::resources const& res, { raft::common::nvtx::range fun_scope("cagra::serialize"); + RAFT_EXPECTS(!index_.dataset_fd().has_value(), + "Cannot serialize a disk-backed CAGRA index. Convert it with " + "cuvs::neighbors::hnsw::from_cagra() and load it into memory via " + "cuvs::neighbors::hnsw::deserialize() before serialization."); + RAFT_LOG_DEBUG( "Saving CAGRA index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); @@ -80,6 +85,10 @@ void serialize(raft::resources const& res, const index& index_, bool include_dataset) { + RAFT_EXPECTS(!index_.dataset_fd().has_value(), + "Cannot serialize a disk-backed CAGRA index. Convert it with " + "cuvs::neighbors::hnsw::from_cagra() and load it into memory via " + "cuvs::neighbors::hnsw::deserialize() before serialization."); std::ofstream of(filename, std::ios::out | std::ios::binary); if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index a746cf473a..f8091c9e51 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -169,16 +169,17 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g uint64_t* const num_retain = stats; uint64_t* const num_full = stats + 1; - const uint64_t nid = blockIdx.x + (batch_size * batch_id); - if (nid >= graph_size) { return; } + const uint64_t iA = blockIdx.x + (batch_size * batch_id); + if (iA >= graph_size) { return; } for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { smem_num_detour[k] = 0; + if (knn_graph[k + ((uint64_t)graph_degree * iA)] == iA) { + // Lower the priority of self-edge + smem_num_detour[k] = graph_degree; + } } __syncthreads(); - const uint64_t iA = nid; - if (iA >= graph_size) { return; } - // count number of detours (A->D->B) for (uint32_t kAD = 0; kAD < graph_degree - 1; kAD++) { const uint64_t iD = knn_graph[kAD + (graph_degree * iA)]; @@ -1410,7 +1411,7 @@ void optimize( "overflows occur during the norm computation between the dataset vectors."); const double time_prune_end = cur_time(); - RAFT_LOG_DEBUG("# Pruning time: %.1lf sec", time_prune_end - time_prune_start); + RAFT_LOG_DEBUG("# Pruning time: %.1lf ms", (time_prune_end - time_prune_start) * 1000.0); } auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); @@ -1480,7 +1481,8 @@ void optimize( raft::resource::get_cuda_stream(res)); const double time_make_end = cur_time(); - RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf sec", time_make_end - time_make_start); + RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", + (time_make_end - time_make_start) * 1000.0); } { @@ -1567,7 +1569,8 @@ void optimize( "many MST optimization edges."); const double time_replace_end = cur_time(); - RAFT_LOG_DEBUG("# Replacing edges time: %.1lf sec", time_replace_end - time_replace_start); + RAFT_LOG_DEBUG("# Replacing edges time: %.1lf ms", + (time_replace_end - time_replace_start) * 1000.0); /* stats */ uint64_t num_replaced_edges = 0; diff --git a/cpp/src/neighbors/detail/hnsw.hpp b/cpp/src/neighbors/detail/hnsw.hpp index 2ce9e0dda8..186216a4da 100644 --- a/cpp/src/neighbors/detail/hnsw.hpp +++ b/cpp/src/neighbors/detail/hnsw.hpp @@ -9,18 +9,25 @@ #include "../../core/omp_wrapper.hpp" #include +#include #include +#include +#include #include #include #include #include +#include #include +#include #include #include +#include #include +#include namespace cuvs::neighbors::hnsw::detail { @@ -107,9 +114,32 @@ struct index_impl : index { return space_.get(); } + /** + @brief Set file descriptor for disk-backed index + */ + void set_file_descriptor(cuvs::util::file_descriptor&& fd) { hnsw_fd_.emplace(std::move(fd)); } + + /** + @brief Get file descriptor + */ + auto file_descriptor() const -> const std::optional& + { + return hnsw_fd_; + } + + /** + @brief Get file path for disk-backed index + */ + std::string file_path() const override + { + if (hnsw_fd_.has_value() && hnsw_fd_->is_valid()) { return hnsw_fd_->get_path(); } + return ""; + } + private: std::unique_ptr::type>> appr_alg_; std::unique_ptr::type>> space_; + std::optional hnsw_fd_; }; template @@ -179,7 +209,7 @@ std::enable_if_t>> fro auto appr_algo = std::make_unique::type>>( hnsw_index->get_space(), host_dataset_view.extent(0), - cagra_index.graph().extent(1) / 2, + (cagra_index.graph().extent(1) + 1) / 2, params.ef_construction); appr_algo->base_layer_init = false; // tell hnswlib to build upper layers only [[maybe_unused]] auto num_threads = @@ -257,12 +287,499 @@ void all_neighbors_graph(raft::resources const& res, raft::host_matrix_view neighbors, cuvs::distance::DistanceType metric) { - nn_descent::index_params nn_params; - nn_params.graph_degree = neighbors.extent(1); - nn_params.intermediate_graph_degree = neighbors.extent(1) * 2; - nn_params.metric = metric; - nn_params.return_distances = false; - auto nn_index = nn_descent::build(res, nn_params, dataset, neighbors); + // FIXME: choose better heuristic + bool use_nn_decent = neighbors.size() < 1e7; + if (use_nn_decent) { + nn_descent::index_params nn_params; + nn_params.graph_degree = neighbors.extent(1); + nn_params.intermediate_graph_degree = neighbors.extent(1) * 2; + nn_params.metric = metric; + nn_params.return_distances = false; + auto nn_index = nn_descent::build(res, nn_params, dataset, neighbors); + } else { + // TODO: choose parameters to minimize memory consumption + cagra::graph_build_params::ivf_pq_params ivfpq_params(dataset.extents(), metric); + cagra::build_knn_graph(res, dataset, neighbors, ivfpq_params); + } +} + +template +void serialize_to_hnswlib_from_disk(raft::resources const& res, + std::ostream& os_raw, + const cuvs::neighbors::hnsw::index_params& params, + const cuvs::neighbors::cagra::index& index_) +{ + raft::common::nvtx::range fun_scope("cagra::serialize"); + + auto start_time = std::chrono::system_clock::now(); + + cuvs::util::buffered_ofstream os(&os_raw, 1 << 20 /*1MB*/); + + RAFT_EXPECTS(index_.dataset_fd().has_value() && index_.graph_fd().has_value(), + "Function only implements serialization from disk."); + RAFT_EXPECTS(params.hierarchy != HnswHierarchy::CPU, + "Disk2disk serialization not supported for CPU hierarchy."); + + auto n_rows = index_.size(); + auto dim = index_.dim(); + auto graph_degree_int = static_cast(index_.graph_degree()); + RAFT_LOG_INFO("Saving CAGRA index to hnswlib format, size %zu, dim %zu, graph_degree %zu", + static_cast(n_rows), + static_cast(dim), + static_cast(graph_degree_int)); + + // Get file descriptors from index + const auto& graph_fd_opt = index_.graph_fd(); + const auto& dataset_fd_opt = index_.dataset_fd(); + const auto& mapping_fd_opt = index_.mapping_fd(); + + RAFT_EXPECTS(graph_fd_opt.has_value() && graph_fd_opt->is_valid(), + "Graph file descriptor is not available"); + RAFT_EXPECTS(dataset_fd_opt.has_value() && dataset_fd_opt->is_valid(), + "Dataset file descriptor is not available"); + RAFT_EXPECTS(mapping_fd_opt.has_value() && mapping_fd_opt->is_valid(), + "Mapping file descriptor is not available"); + + // Get file paths from file descriptors + std::string graph_path = graph_fd_opt->get_path(); + std::string dataset_path = dataset_fd_opt->get_path(); + std::string mapping_path = mapping_fd_opt->get_path(); + + RAFT_EXPECTS(!graph_path.empty(), "Unable to get path from graph file descriptor"); + RAFT_EXPECTS(!dataset_path.empty(), "Unable to get path from dataset file descriptor"); + RAFT_EXPECTS(!mapping_path.empty(), "Unable to get path from mapping file descriptor"); + + int graph_fd = graph_fd_opt->get(); + int dataset_fd = dataset_fd_opt->get(); + int label_fd = mapping_fd_opt->get(); + + // Read headers from files to get dimensions + size_t graph_header_size = 0; + size_t graph_n_rows = 0; + size_t graph_n_cols = 0; + { + std::ifstream graph_stream(graph_path, std::ios::binary); + RAFT_EXPECTS(graph_stream.good(), "Failed to open graph file: %s", graph_path.c_str()); + + auto header = raft::detail::numpy_serializer::read_header(graph_stream); + graph_header_size = static_cast(graph_stream.tellg()); + RAFT_EXPECTS( + header.shape.size() == 2, "Graph file should be 2D, got %zu dimensions", header.shape.size()); + + graph_n_rows = header.shape[0]; + graph_n_cols = header.shape[1]; + RAFT_LOG_DEBUG("Graph file: %zu x %zu, header size: %zu bytes", + graph_n_rows, + graph_n_cols, + graph_header_size); + } + + size_t dataset_header_size = 0; + size_t dataset_n_rows = 0; + size_t dataset_n_cols = 0; + { + std::ifstream dataset_stream(dataset_path, std::ios::binary); + RAFT_EXPECTS(dataset_stream.good(), "Failed to open dataset file: %s", dataset_path.c_str()); + + auto header = raft::detail::numpy_serializer::read_header(dataset_stream); + dataset_header_size = static_cast(dataset_stream.tellg()); + RAFT_EXPECTS(header.shape.size() == 2, + "Dataset file should be 2D, got %zu dimensions", + header.shape.size()); + + dataset_n_rows = header.shape[0]; + dataset_n_cols = header.shape[1]; + RAFT_LOG_DEBUG("Dataset file: %zu x %zu, header size: %zu bytes", + dataset_n_rows, + dataset_n_cols, + dataset_header_size); + } + + size_t label_header_size = 0; + size_t label_n_elements = 0; + { + std::ifstream mapping_stream(mapping_path, std::ios::binary); + RAFT_EXPECTS(mapping_stream.good(), "Failed to open mapping file: %s", mapping_path.c_str()); + + auto header = raft::detail::numpy_serializer::read_header(mapping_stream); + label_header_size = static_cast(mapping_stream.tellg()); + RAFT_EXPECTS(header.shape.size() == 1, + "Mapping file should be 1D, got %zu dimensions", + header.shape.size()); + + label_n_elements = header.shape[0]; + RAFT_LOG_DEBUG( + "Mapping file: %zu elements, header size: %zu bytes", label_n_elements, label_header_size); + } + + // Verify consistency + RAFT_EXPECTS(graph_n_rows == static_cast(n_rows), + "Graph rows (%zu) != index size (%zu)", + graph_n_rows, + static_cast(n_rows)); + RAFT_EXPECTS(dataset_n_rows == static_cast(n_rows), + "Dataset rows (%zu) != index size (%zu)", + dataset_n_rows, + static_cast(n_rows)); + RAFT_EXPECTS(label_n_elements == static_cast(n_rows), + "Label elements (%zu) != index size (%zu)", + label_n_elements, + static_cast(n_rows)); + RAFT_EXPECTS(graph_n_cols == static_cast(graph_degree_int), + "Graph cols (%zu) != graph degree (%d)", + graph_n_cols, + graph_degree_int); + RAFT_EXPECTS(dataset_n_cols == static_cast(dim), + "Dataset cols (%zu) != dimensions (%zu)", + dataset_n_cols, + static_cast(dim)); + + const size_t row_size_bytes = + graph_degree_int * sizeof(IdxT) + dim * sizeof(T) + sizeof(uint32_t); + const size_t target_batch_bytes = 64 * 1024 * 1024; + const size_t batch_size = std::max(1, target_batch_bytes / row_size_bytes); + + RAFT_LOG_DEBUG("Using batch size %zu rows (~%.2f MiB/batch)", + batch_size, + (batch_size * row_size_bytes) / (1024.0 * 1024.0)); + + // Allocate buffers for batched reading + auto graph_buffer = raft::make_host_matrix(batch_size, graph_degree_int); + auto dataset_buffer = raft::make_host_matrix(batch_size, dim); + auto label_buffer = raft::make_host_vector(batch_size); + + RAFT_LOG_DEBUG("Allocated buffers: graph[%ld,%d], dataset[%ld,%ld], labels[%ld]", + graph_buffer.extent(0), + graph_degree_int, + dataset_buffer.extent(0), + dataset_buffer.extent(1), + label_buffer.extent(0)); + + // initialize dummy HNSW index to retrieve constants + auto hnsw_index = std::make_unique>(dim, index_.metric(), params.hierarchy); + + int odd_graph_degree = graph_degree_int % 2; + auto appr_algo = std::make_unique::type>>( + hnsw_index->get_space(), 1, (graph_degree_int + 1) / 2, params.ef_construction); + + bool create_hierarchy = params.hierarchy != HnswHierarchy::NONE; + + // create hierarchy order + // sort the points by levels + // roll dice & build histogram + std::vector hist; + std::vector order(n_rows); + std::vector order_bw(n_rows); + std::vector levels(n_rows); + std::vector offsets; + + if (create_hierarchy) { + RAFT_LOG_INFO("Sort points by levels"); + for (int64_t i = 0; i < n_rows; i++) { + auto pt_level = appr_algo->getRandomLevel(appr_algo->mult_); + while (pt_level >= static_cast(hist.size())) + hist.push_back(0); + hist[pt_level]++; + levels[i] = pt_level; + } + + // accumulate + offsets.resize(hist.size() + 1, 0); + for (size_t i = 0; i < hist.size() - 1; i++) { + offsets[i + 1] = offsets[i] + hist[i]; + RAFT_LOG_INFO("Level %zu : %zu", i + 1, size_t(n_rows) - offsets[i + 1]); + } + + // fw/bw indices + for (int64_t i = 0; i < n_rows; i++) { + auto pt_level = levels[i]; + order_bw[i] = offsets[pt_level]; + order[offsets[pt_level]++] = i; + } + } + + // set last point of the highest level as the entry point + appr_algo->enterpoint_node_ = create_hierarchy ? order.back() : n_rows / 2; + appr_algo->maxlevel_ = create_hierarchy ? hist.size() - 1 : 1; + + // write header information + RAFT_LOG_DEBUG("Writing HNSW header: offsetLevel0=%zu, n_rows=%zu, size_data_per_element=%zu", + appr_algo->offsetLevel0_, + static_cast(n_rows), + appr_algo->size_data_per_element_); + RAFT_LOG_DEBUG(" maxlevel=%d, enterpoint=%d, maxM=%zu, maxM0=%zu, M=%zu", + appr_algo->maxlevel_, + appr_algo->enterpoint_node_, + appr_algo->maxM_, + appr_algo->maxM0_, + appr_algo->M_); + + // offset_level_0 + os.write(reinterpret_cast(&appr_algo->offsetLevel0_), sizeof(std::size_t)); + // 8 max_element - override with n_rows + size_t num_elements = (size_t)n_rows; + os.write(reinterpret_cast(&num_elements), sizeof(std::size_t)); + // 16 curr_element_count - override with n_rows + os.write(reinterpret_cast(&num_elements), sizeof(std::size_t)); + // 24 size_data_per_element + os.write(reinterpret_cast(&appr_algo->size_data_per_element_), sizeof(std::size_t)); + // 32 label_offset + os.write(reinterpret_cast(&appr_algo->label_offset_), sizeof(std::size_t)); + // 40 offset_data + os.write(reinterpret_cast(&appr_algo->offsetData_), sizeof(std::size_t)); + // 48 maxlevel + os.write(reinterpret_cast(&appr_algo->maxlevel_), sizeof(int)); + // 52 enterpoint_node + os.write(reinterpret_cast(&appr_algo->enterpoint_node_), sizeof(int)); + // 56 maxM + os.write(reinterpret_cast(&appr_algo->maxM_), sizeof(std::size_t)); + // 64 maxM0 + os.write(reinterpret_cast(&appr_algo->maxM0_), sizeof(std::size_t)); + // 72 M + os.write(reinterpret_cast(&appr_algo->M_), sizeof(std::size_t)); + // 80 mult + os.write(reinterpret_cast(&appr_algo->mult_), sizeof(double)); + // 88 ef_construction + os.write(reinterpret_cast(&appr_algo->ef_construction_), sizeof(std::size_t)); + + // host queries + auto host_query_set = + raft::make_host_matrix(create_hierarchy ? n_rows - hist[0] : 0, dim); + + int64_t d_report_offset = n_rows / 10; // Report progress in 10% steps. + int64_t next_report_offset = d_report_offset; + auto start_clock = std::chrono::system_clock::now(); + + RAFT_EXPECTS(appr_algo->size_data_per_element_ == + dim * sizeof(T) + appr_algo->maxM0_ * sizeof(IdxT) + sizeof(int) + sizeof(size_t), + "Size data per element mismatch"); + + RAFT_LOG_INFO("Writing base level"); + size_t bytes_written = 0; + float GiB = 1 << 30; + IdxT zero = 0; + RAFT_EXPECTS(appr_algo->size_data_per_element_ == + dim * sizeof(T) + appr_algo->maxM0_ * sizeof(IdxT) + sizeof(int) + sizeof(size_t), + "Size data per element mismatch"); + + // Helper lambda for parallel reading of batches + auto read_batch = [&](int64_t start_row, int64_t rows_to_read) { + const size_t graph_bytes = rows_to_read * graph_degree_int * sizeof(IdxT); + const size_t dataset_bytes = rows_to_read * dim * sizeof(T); + const size_t label_bytes = rows_to_read * sizeof(uint32_t); + + const off_t graph_offset = graph_header_size + start_row * graph_degree_int * sizeof(IdxT); + const off_t dataset_offset = dataset_header_size + start_row * dim * sizeof(T); + const off_t label_offset = label_header_size + start_row * sizeof(uint32_t); + + RAFT_LOG_DEBUG("Reading batch: row=%ld, rows=%ld", start_row, rows_to_read); + RAFT_LOG_DEBUG( + " graph: offset=%zu, bytes=%zu", static_cast(graph_offset), graph_bytes); + RAFT_LOG_DEBUG( + " dataset: offset=%zu, bytes=%zu", static_cast(dataset_offset), dataset_bytes); + RAFT_LOG_DEBUG( + " label: offset=%zu, bytes=%zu", static_cast(label_offset), label_bytes); + +#pragma omp parallel sections num_threads(3) + { +#pragma omp section + { + ssize_t bytes_read = pread(graph_fd, graph_buffer.data_handle(), graph_bytes, graph_offset); + RAFT_EXPECTS(bytes_read == static_cast(graph_bytes), + "Failed to read graph data: expected %zu, got %zd", + graph_bytes, + bytes_read); + } +#pragma omp section + { + ssize_t bytes_read = + pread(dataset_fd, dataset_buffer.data_handle(), dataset_bytes, dataset_offset); + RAFT_EXPECTS(bytes_read == static_cast(dataset_bytes), + "Failed to read dataset data: expected %zu, got %zd", + dataset_bytes, + bytes_read); + } +#pragma omp section + { + ssize_t bytes_read = pread(label_fd, label_buffer.data_handle(), label_bytes, label_offset); + RAFT_EXPECTS(bytes_read == static_cast(label_bytes), + "Failed to read label data: expected %zu, got %zd", + label_bytes, + bytes_read); + } + } + + // Log first few values from first batch for debugging + if (start_row == 0 && rows_to_read > 0) { + RAFT_LOG_DEBUG("First graph row: [%u, %u, %u, ...]", + static_cast(graph_buffer(0, 0)), + graph_degree_int > 1 ? static_cast(graph_buffer(0, 1)) : 0, + graph_degree_int > 2 ? static_cast(graph_buffer(0, 2)) : 0); + RAFT_LOG_DEBUG("First dataset row: [%f, %f, %f, ...]", + static_cast(dataset_buffer(0, 0)), + dim > 1 ? static_cast(dataset_buffer(0, 1)) : 0.0f, + dim > 2 ? static_cast(dataset_buffer(0, 2)) : 0.0f); + RAFT_LOG_DEBUG("First labels: [%u, %u, %u, ...]", + static_cast(label_buffer(0)), + rows_to_read > 1 ? static_cast(label_buffer(1)) : 0, + rows_to_read > 2 ? static_cast(label_buffer(2)) : 0); + } + }; + + for (int64_t batch_start = 0; batch_start < n_rows; batch_start += batch_size) { + const int64_t current_batch_size = std::min(batch_size, n_rows - batch_start); + + RAFT_LOG_DEBUG("Reading batch: start=%ld, size=%ld (batch_size=%zu)", + batch_start, + current_batch_size, + batch_size); + read_batch(batch_start, current_batch_size); + + for (int64_t batch_idx = 0; batch_idx < current_batch_size; batch_idx++) { + const int64_t i = batch_start + batch_idx; + + os.write(reinterpret_cast(&graph_degree_int), sizeof(int)); + + const IdxT* graph_row = &graph_buffer(batch_idx, 0); + os.write(reinterpret_cast(graph_row), sizeof(IdxT) * graph_degree_int); + + if (odd_graph_degree) { + RAFT_EXPECTS(odd_graph_degree == static_cast(appr_algo->maxM0_) - graph_degree_int, + "Odd graph degree mismatch"); + os.write(reinterpret_cast(&zero), sizeof(IdxT)); + } + + const T* data_row = &dataset_buffer(batch_idx, 0); + os.write(reinterpret_cast(data_row), sizeof(T) * dim); + + if (create_hierarchy && levels[i] > 0) { + // position in query: order_bw[i]-hist[0] + std::copy(data_row, + data_row + dim, + reinterpret_cast(&host_query_set(order_bw[i] - hist[0], 0))); + } + + // assign original label + auto label = static_cast(label_buffer(batch_idx)); + os.write(reinterpret_cast(&label), sizeof(std::size_t)); + + bytes_written += appr_algo->size_data_per_element_; + + const auto end_clock = std::chrono::system_clock::now(); + // if (!os.good()) { RAFT_FAIL("Error writing HNSW file, row %zu", i); } + if (i > next_report_offset) { + next_report_offset += d_report_offset; + const auto time = + std::chrono::duration_cast(end_clock - start_clock).count() * + 1e-6; + float throughput = bytes_written / GiB / time; + float rows_throughput = i / time; + float ETA = (n_rows - i) / rows_throughput; + RAFT_LOG_INFO( + "# Writing rows %12lu / %12lu (%3.2f %%), %3.2f GiB/sec, ETA %d:%3.1f, written %3.2f " + "GiB\r", + i, + n_rows, + i / static_cast(n_rows) * 100, + throughput, + int(ETA / 60), + std::fmod(ETA, 60.0f), + bytes_written / GiB); + } + } + } + + RAFT_LOG_DEBUG("Completed writing %ld base level rows", n_rows); + + // trigger knn builds for all levels + std::vector> host_neighbors; + if (create_hierarchy) { + for (size_t pt_level = 1; pt_level < hist.size(); pt_level++) { + auto num_pts = n_rows - offsets[pt_level - 1]; + auto neighbor_size = num_pts > appr_algo->M_ ? appr_algo->M_ : num_pts - 1; + host_neighbors.emplace_back(raft::make_host_matrix(num_pts, neighbor_size)); + } + for (size_t pt_level = 1; pt_level < hist.size(); pt_level++) { + RAFT_LOG_INFO("Compute hierarchy neighbors level %zu", pt_level); + auto removed_rows = offsets[pt_level - 1] - offsets[0]; + raft::host_matrix_view sub_query_view( + host_query_set.data_handle() + removed_rows * dim, + host_query_set.extent(0) - removed_rows, + dim); + auto neighbor_view = host_neighbors[pt_level - 1].view(); + all_neighbors_graph( + res, raft::make_const_mdspan(sub_query_view), neighbor_view, index_.metric()); + } + } + + if (create_hierarchy) { + RAFT_LOG_INFO("Assemble hierarchy linklists"); + next_report_offset = d_report_offset; + } + bytes_written = 0; + start_clock = std::chrono::system_clock::now(); + + for (int64_t i = 0; i < n_rows; i++) { + size_t cur_level = create_hierarchy ? levels[i] : 0; + unsigned int linkListSize = + create_hierarchy && cur_level > 0 ? appr_algo->size_links_per_element_ * cur_level : 0; + os.write(reinterpret_cast(&linkListSize), sizeof(int)); + bytes_written += sizeof(int); + if (linkListSize) { + for (size_t pt_level = 1; pt_level <= cur_level; pt_level++) { + auto neighbor_view = host_neighbors[pt_level - 1].view(); + auto my_row = order_bw[i] - offsets[pt_level - 1]; + + IdxT* neighbors = &neighbor_view(my_row, 0); + unsigned int extent = neighbor_view.extent(1); + os.write(reinterpret_cast(&extent), sizeof(int)); + for (unsigned int j = 0; j < extent; j++) { + const IdxT converted = order[neighbors[j] + offsets[pt_level - 1]]; + os.write(reinterpret_cast(&converted), sizeof(IdxT)); + } + auto remainder = appr_algo->M_ - neighbor_view.extent(1); + for (size_t j = 0; j < remainder; j++) { + os.write(reinterpret_cast(&zero), sizeof(IdxT)); + } + bytes_written += (neighbor_view.extent(1) + remainder) * sizeof(IdxT) + sizeof(int); + RAFT_EXPECTS(appr_algo->size_links_per_element_ == + (neighbor_view.extent(1) + remainder) * sizeof(IdxT) + sizeof(int), + "Size links per element mismatch"); + } + } + + const auto end_clock = std::chrono::system_clock::now(); + if (i > next_report_offset) { + next_report_offset += d_report_offset; + const auto time = + std::chrono::duration_cast(end_clock - start_clock).count() * + 1e-6; + float throughput = bytes_written / GiB / time; + float rows_throughput = i / time; + float ETA = (n_rows - i) / rows_throughput; + RAFT_LOG_INFO( + "# Writing rows %12lu / %12lu (%3.2f %%), %3.2f GiB/sec, ETA %d:%3.1f, written %3.2f GiB\r", + i, + n_rows, + i / static_cast(n_rows) * 100, + throughput, + int(ETA / 60), + std::fmod(ETA, 60.0f), + bytes_written / GiB); + } + } + + // Flush buffered output and check data was written + os.flush(); + os_raw.flush(); + auto final_pos = os_raw.tellp(); + RAFT_LOG_DEBUG("HNSW file size: %ld bytes", static_cast(final_pos)); + if (!os_raw.good()) { RAFT_LOG_WARN("Output stream is not in good state after serialization"); } + + auto end_time = std::chrono::system_clock::now(); + auto elapsed_time = + std::chrono::duration_cast(end_time - start_time).count(); + RAFT_LOG_INFO("HNSW serialization from disk complete in %ld ms", elapsed_time); } template @@ -315,7 +832,10 @@ std::enable_if_t>> fro // initialize HNSW index auto hnsw_index = std::make_unique>(dim, cagra_index.metric(), hierarchy); auto appr_algo = std::make_unique::type>>( - hnsw_index->get_space(), n_rows, cagra_index.graph().extent(1) / 2, params.ef_construction); + hnsw_index->get_space(), + n_rows, + (cagra_index.graph().extent(1) + 1) / 2, + params.ef_construction); appr_algo->cur_element_count = n_rows; // Initialize linked lists @@ -514,6 +1034,45 @@ std::unique_ptr> from_cagra( const cuvs::neighbors::cagra::index& cagra_index, std::optional> dataset) { + // special treatment for index on disk + if (cagra_index.dataset_fd().has_value() && cagra_index.graph_fd().has_value()) { + // Get directory from graph file descriptor + const auto& graph_fd = cagra_index.graph_fd(); + RAFT_EXPECTS(graph_fd.has_value() && graph_fd->is_valid(), + "Graph file descriptor is not available for disk-backed index"); + + std::string graph_path = graph_fd->get_path(); + RAFT_EXPECTS(!graph_path.empty(), "Unable to get path from graph file descriptor"); + + std::string index_directory = std::filesystem::path(graph_path).parent_path().string(); + RAFT_EXPECTS( + std::filesystem::exists(index_directory) && std::filesystem::is_directory(index_directory), + "Directory '%s' does not exist", + index_directory.c_str()); + std::string index_filename = + (std::filesystem::path(index_directory) / "hnsw_index.bin").string(); + + std::ofstream of(index_filename, std::ios::out | std::ios::binary); + + RAFT_EXPECTS(of, "Cannot open file %s", index_filename.c_str()); + + serialize_to_hnswlib_from_disk(res, of, params, cagra_index); + + of.close(); + RAFT_EXPECTS(of, "Error writing output %s", index_filename.c_str()); + + // Create an empty HNSW index that holds the file descriptor + auto hnsw_index = + std::make_unique>(cagra_index.dim(), cagra_index.metric(), params.hierarchy); + + // Open file descriptor for the HNSW index file and transfer ownership to the index + hnsw_index->set_file_descriptor(cuvs::util::file_descriptor(index_filename, O_RDONLY)); + + RAFT_LOG_INFO("HNSW index written to disk at: %s", index_filename.c_str()); + + return hnsw_index; + } + if (params.hierarchy == HnswHierarchy::NONE) { return from_cagra(res, params, cagra_index, dataset); } else if (params.hierarchy == HnswHierarchy::CPU) { @@ -531,6 +1090,10 @@ void extend(raft::resources const& res, raft::host_matrix_view additional_dataset, index& idx) { + auto* idx_impl = dynamic_cast*>(&idx); + RAFT_EXPECTS(!idx_impl || !idx_impl->file_descriptor().has_value(), + "Cannot extend an HNSW index that is stored on disk. " + "The index must be deserialized into memory first using hnsw::deserialize()."); auto* hnswlib_index = reinterpret_cast::type>*>( const_cast(idx.get_index())); auto current_element_count = hnswlib_index->getCurrentElementCount(); @@ -572,6 +1135,11 @@ void search(raft::resources const& res, raft::host_matrix_view neighbors, raft::host_matrix_view distances) { + auto* idx_impl = dynamic_cast*>(&idx); + RAFT_EXPECTS(!idx_impl || !idx_impl->file_descriptor().has_value(), + "Cannot search an HNSW index that is stored on disk. " + "The index must be deserialized into memory first using hnsw::deserialize()."); + RAFT_EXPECTS(queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), "Number of rows in output neighbors and distances matrices must equal the number of " "queries."); @@ -611,6 +1179,10 @@ void search(raft::resources const& res, template void serialize(raft::resources const& res, const std::string& filename, const index& idx) { + auto* idx_impl = dynamic_cast*>(&idx); + RAFT_EXPECTS(!idx_impl || !idx_impl->file_descriptor().has_value(), + "Cannot serialize an HNSW index that is stored on disk. " + "The index must be deserialized into memory first using hnsw::deserialize()."); auto* hnswlib_index = reinterpret_cast::type>*>( const_cast(idx.get_index())); hnswlib_index->saveIndex(filename); diff --git a/cpp/src/util/file_io.cpp b/cpp/src/util/file_io.cpp new file mode 100644 index 0000000000..d924527e72 --- /dev/null +++ b/cpp/src/util/file_io.cpp @@ -0,0 +1,81 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include +#include + +#include +#include + +namespace cuvs::util { + +void read_large_file(const file_descriptor& fd, + void* dest_ptr, + const size_t total_bytes, + const uint64_t file_offset) +{ + RAFT_EXPECTS(total_bytes > 0, "Total bytes must be greater than 0"); + RAFT_EXPECTS(dest_ptr != nullptr, "Destination pointer must not be nullptr"); + RAFT_EXPECTS(fd.is_valid(), "File descriptor must be valid"); + + const size_t read_chunk_size = std::min(1024 * 1024 * 1024, SSIZE_MAX); + size_t bytes_remaining = total_bytes; + size_t offset = 0; + + while (bytes_remaining > 0) { + const size_t chunk_size = std::min(read_chunk_size, bytes_remaining); + const uint64_t file_pos = file_offset + offset; + const ssize_t bytes_read = + pread(fd.get(), reinterpret_cast(dest_ptr) + offset, chunk_size, file_pos); + + RAFT_EXPECTS( + bytes_read != -1, "Failed to read from file at offset %lu: %s", file_pos, strerror(errno)); + RAFT_EXPECTS(bytes_read == static_cast(chunk_size), + "Incomplete read from file. Expected %zu bytes, got %zd at offset %lu", + chunk_size, + bytes_read, + file_pos); + + bytes_remaining -= chunk_size; + offset += chunk_size; + } +} + +void write_large_file(const file_descriptor& fd, + const void* data_ptr, + const size_t total_bytes, + const uint64_t file_offset) +{ + RAFT_EXPECTS(total_bytes > 0, "Total bytes must be greater than 0"); + RAFT_EXPECTS(data_ptr != nullptr, "Data pointer must not be nullptr"); + RAFT_EXPECTS(fd.is_valid(), "File descriptor must be valid"); + + const size_t write_chunk_size = std::min(1024 * 1024 * 1024, SSIZE_MAX); + size_t bytes_remaining = total_bytes; + size_t offset = 0; + + while (bytes_remaining > 0) { + const size_t chunk_size = std::min(write_chunk_size, bytes_remaining); + const uint64_t file_pos = file_offset + offset; + const ssize_t chunk_written = + pwrite(fd.get(), reinterpret_cast(data_ptr) + offset, chunk_size, file_pos); + + RAFT_EXPECTS( + chunk_written != -1, "Failed to write to file at offset %lu: %s", file_pos, strerror(errno)); + RAFT_EXPECTS(chunk_written == static_cast(chunk_size), + "Incomplete write to file. Expected %zu bytes, wrote %zd at offset %lu", + chunk_size, + chunk_written, + file_pos); + + bytes_remaining -= chunk_size; + offset += chunk_size; + } +} + +} // namespace cuvs::util diff --git a/cpp/src/util/host_memory.cpp b/cpp/src/util/host_memory.cpp new file mode 100644 index 0000000000..23e5ff4258 --- /dev/null +++ b/cpp/src/util/host_memory.cpp @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include + +namespace cuvs::util { + +size_t get_free_host_memory() +{ + size_t available_memory = 0; + std::ifstream meminfo("/proc/meminfo"); + std::string line; + while (std::getline(meminfo, line)) { + if (line.find("MemAvailable:") != std::string::npos) { + available_memory = std::stoi(line.substr(line.find(":") + 1)); + } + } + available_memory *= 1024; + meminfo.close(); + RAFT_EXPECTS(available_memory > 0, "Failed to get available memory from /proc/meminfo"); + return available_memory; +} + +} // namespace cuvs::util diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index e6e391103c..094e6c3db9 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -194,6 +194,34 @@ ConfigureTest( PERCENT 100 ) +ConfigureTest( + NAME NEIGHBORS_ANN_CAGRA_ACE_FLOAT_UINT32_TEST + PATH neighbors/ann_cagra_ace/test_float_uint32_t.cu + GPUS 1 + PERCENT 100 +) + +ConfigureTest( + NAME NEIGHBORS_ANN_CAGRA_ACE_HALF_UINT32_TEST + PATH neighbors/ann_cagra_ace/test_half_uint32_t.cu + GPUS 1 + PERCENT 100 +) + +ConfigureTest( + NAME NEIGHBORS_ANN_CAGRA_ACE_INT8_UINT32_TEST + PATH neighbors/ann_cagra_ace/test_int8_t_uint32_t.cu + GPUS 1 + PERCENT 100 +) + +ConfigureTest( + NAME NEIGHBORS_ANN_CAGRA_ACE_UINT8_UINT32_TEST + PATH neighbors/ann_cagra_ace/test_uint8_t_uint32_t.cu + GPUS 1 + PERCENT 100 +) + ConfigureTest( NAME NEIGHBORS_ANN_NN_DESCENT_TEST PATH neighbors/ann_nn_descent/test_float_uint32_t.cu diff --git a/cpp/tests/neighbors/ann_cagra_ace.cuh b/cpp/tests/neighbors/ann_cagra_ace.cuh new file mode 100644 index 0000000000..4c6d96050b --- /dev/null +++ b/cpp/tests/neighbors/ann_cagra_ace.cuh @@ -0,0 +1,270 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include "ann_cagra.cuh" + +#include + +#include +#include + +namespace cuvs::neighbors::cagra { + +struct AnnCagraAceInputs { + int n_queries; + int n_rows; + int dim; + int k; + int npartitions; + int ef_construction; + bool use_disk; + cuvs::distance::DistanceType metric; + double min_recall; +}; + +inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraAceInputs& p) +{ + os << "{n_queries=" << p.n_queries << ", dataset shape=" << p.n_rows << "x" << p.dim + << ", k=" << p.k << ", npartitions=" << p.npartitions + << ", ef_construction=" << p.ef_construction + << ", use_disk=" << (p.use_disk ? "true" : "false") << ", metric="; + switch (p.metric) { + case cuvs::distance::DistanceType::L2Expanded: os << "L2"; break; + case cuvs::distance::DistanceType::InnerProduct: os << "InnerProduct"; break; + default: os << "Unknown"; break; + } + os << ", min_recall=" << p.min_recall << "}"; + return os; +} + +template +class AnnCagraAceTest : public ::testing::TestWithParam { + public: + AnnCagraAceTest() + : stream_(raft::resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam::GetParam()), + database_dev(0, stream_), + search_queries(0, stream_) + { + } + + protected: + void testAce() + { + size_t queries_size = ps.n_queries * ps.k; + std::vector indices_ace(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_ace(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + + cuvs::neighbors::naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database_dev.data(), + ps.n_queries, + ps.n_rows, + ps.dim, + ps.k, + ps.metric); + raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + raft::update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + } + + // Create temporary directory for ACE build + std::string temp_dir = std::string("/tmp/cuvs_ace_test_") + std::to_string(std::time(nullptr)) + + "_" + std::to_string(reinterpret_cast(this)); + std::filesystem::create_directories(temp_dir); + + { + auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host.data_handle(), database_dev.data(), ps.n_rows * ps.dim, stream_); + raft::resource::sync_stream(handle_); + + cagra::index_params index_params; + index_params.metric = ps.metric; + index_params.intermediate_graph_degree = 128; + index_params.graph_degree = 64; + auto ace_params = graph_build_params::ace_params(); + ace_params.npartitions = ps.npartitions; + ace_params.ef_construction = ps.ef_construction; + ace_params.build_dir = temp_dir; + ace_params.use_disk = ps.use_disk; + index_params.graph_build_params = ace_params; + + auto index = + cagra::build(handle_, index_params, raft::make_const_mdspan(database_host.view())); + + ASSERT_EQ(index.size(), ps.n_rows); + + if (ps.use_disk) { + // Verify disk-based ACE index using HNSW index from disk + EXPECT_TRUE(index.dataset_fd().has_value() && index.graph_fd().has_value()); + + // Verify file directory from graph file descriptor + const auto& graph_fd = index.graph_fd(); + EXPECT_TRUE(graph_fd.has_value() && graph_fd->is_valid()); + std::string graph_path = graph_fd->get_path(); + std::string file_dir = std::filesystem::path(graph_path).parent_path().string(); + EXPECT_EQ(file_dir, temp_dir); + + EXPECT_TRUE(std::filesystem::exists(temp_dir + "/cagra_graph.npy")); + EXPECT_GE(std::filesystem::file_size(temp_dir + "/cagra_graph.npy"), + ps.n_rows * index_params.graph_degree * sizeof(IdxT)); + + EXPECT_TRUE(std::filesystem::exists(temp_dir + "/reordered_dataset.npy")); + EXPECT_GE(std::filesystem::file_size(temp_dir + "/reordered_dataset.npy"), + ps.n_rows * ps.dim * sizeof(DataT)); + + EXPECT_TRUE(std::filesystem::exists(temp_dir + "/dataset_mapping.npy")); + EXPECT_GE(std::filesystem::file_size(temp_dir + "/dataset_mapping.npy"), + ps.n_rows * sizeof(IdxT)); + + hnsw::index_params hnsw_params; + hnsw_params.hierarchy = hnsw::HnswHierarchy::GPU; + + auto hnsw_index = hnsw::from_cagra(handle_, hnsw_params, index); + ASSERT_NE(hnsw_index, nullptr); + + std::string hnsw_index_path = temp_dir + "/hnsw_index.bin"; + EXPECT_TRUE(std::filesystem::exists(hnsw_index_path)); + // For GPU hierarchy, HNSW index includes multi-layer structure + // The size should be at least the base layer size + auto hnsw_file_size = std::filesystem::file_size(hnsw_index_path); + EXPECT_GE(hnsw_file_size, ps.n_rows * index_params.graph_degree * sizeof(IdxT)); + + hnsw::index* hnsw_index_raw = nullptr; + hnsw::deserialize( + handle_, hnsw_params, hnsw_index_path, ps.dim, ps.metric, &hnsw_index_raw); + ASSERT_NE(hnsw_index_raw, nullptr); + + std::unique_ptr> hnsw_index_deserialized(hnsw_index_raw); + EXPECT_EQ(hnsw_index_deserialized->dim(), ps.dim); + EXPECT_EQ(hnsw_index_deserialized->metric(), ps.metric); + + auto queries_host = raft::make_host_matrix(ps.n_queries, ps.dim); + raft::copy( + queries_host.data_handle(), search_queries.data(), ps.n_queries * ps.dim, stream_); + raft::resource::sync_stream(handle_); + + auto indices_hnsw_host = raft::make_host_matrix(ps.n_queries, ps.k); + auto distances_hnsw_host = raft::make_host_matrix(ps.n_queries, ps.k); + + hnsw::search_params search_params; + search_params.ef = std::max(ps.ef_construction, ps.k * 2); + search_params.num_threads = 1; + + hnsw::search(handle_, + search_params, + *hnsw_index_deserialized, + queries_host.view(), + indices_hnsw_host.view(), + distances_hnsw_host.view()); + + for (size_t i = 0; i < queries_size; i++) { + indices_ace[i] = static_cast(indices_hnsw_host.data_handle()[i]); + distances_ace[i] = distances_hnsw_host.data_handle()[i]; + } + + EXPECT_TRUE(eval_neighbours(indices_naive, + indices_ace, + distances_naive, + distances_ace, + ps.n_queries, + ps.k, + 0.003, + ps.min_recall)) + << "Disk-based ACE index loaded via HNSW failed recall check"; + } else { + // For in-memory ACE, we can search directly + EXPECT_FALSE(index.dataset_fd().has_value() || index.graph_fd().has_value()); + ASSERT_GT(index.graph().size(), 0); + EXPECT_EQ(index.graph_degree(), 64); + + rmm::device_uvector distances_dev(queries_size, stream_); + rmm::device_uvector indices_dev(queries_size, stream_); + + auto queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.n_queries, ps.dim); + auto indices_view = + raft::make_device_matrix_view(indices_dev.data(), ps.n_queries, ps.k); + auto distances_view = raft::make_device_matrix_view( + distances_dev.data(), ps.n_queries, ps.k); + + cagra::search_params search_params; + search_params.itopk_size = 64; + + cagra::search(handle_, search_params, index, queries_view, indices_view, distances_view); + + raft::update_host(distances_ace.data(), distances_dev.data(), queries_size, stream_); + raft::update_host(indices_ace.data(), indices_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + + EXPECT_TRUE(eval_neighbours(indices_naive, + indices_ace, + distances_naive, + distances_ace, + ps.n_queries, + ps.k, + 0.003, + ps.min_recall)) + << "In-memory ACE index failed recall check"; + } + } + + // Clean up temporary directory + std::filesystem::remove_all(temp_dir); + } + + void SetUp() override + { + database_dev.resize(((size_t)ps.n_rows) * ps.dim, stream_); + search_queries.resize(ps.n_queries * ps.dim, stream_); + raft::random::RngState r(1234ULL); + InitDataset(handle_, database_dev.data(), ps.n_rows, ps.dim, ps.metric, r); + InitDataset(handle_, search_queries.data(), ps.n_queries, ps.dim, ps.metric, r); + raft::resource::sync_stream(handle_); + } + + void TearDown() override + { + raft::resource::sync_stream(handle_); + database_dev.resize(0, stream_); + search_queries.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnCagraAceInputs ps; + rmm::device_uvector database_dev; + rmm::device_uvector search_queries; +}; + +inline std::vector generate_ace_inputs() +{ + return raft::util::itertools::product( + {10}, // n_queries + {5000}, // n_rows + {64, 128}, // dim + {10}, // k + {2, 4}, // npartitions + {100}, // ef_construction + {false, true}, // use_disk (test both modes) + {cuvs::distance::DistanceType::L2Expanded, + cuvs::distance::DistanceType::InnerProduct}, // metric + {0.9} // min_recall + ); +} + +const std::vector ace_inputs = generate_ace_inputs(); + +} // namespace cuvs::neighbors::cagra diff --git a/cpp/tests/neighbors/ann_cagra_ace/test_float_uint32_t.cu b/cpp/tests/neighbors/ann_cagra_ace/test_float_uint32_t.cu new file mode 100644 index 0000000000..de96a40339 --- /dev/null +++ b/cpp/tests/neighbors/ann_cagra_ace/test_float_uint32_t.cu @@ -0,0 +1,17 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include "../ann_cagra_ace.cuh" + +namespace cuvs::neighbors::cagra { + +typedef AnnCagraAceTest AnnCagraAceTestF_U32; +TEST_P(AnnCagraAceTestF_U32, AnnCagraAce) { this->testAce(); } + +INSTANTIATE_TEST_CASE_P(AnnCagraAceTest, AnnCagraAceTestF_U32, ::testing::ValuesIn(ace_inputs)); + +} // namespace cuvs::neighbors::cagra diff --git a/cpp/tests/neighbors/ann_cagra_ace/test_half_uint32_t.cu b/cpp/tests/neighbors/ann_cagra_ace/test_half_uint32_t.cu new file mode 100644 index 0000000000..a1a6ec1397 --- /dev/null +++ b/cpp/tests/neighbors/ann_cagra_ace/test_half_uint32_t.cu @@ -0,0 +1,17 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include "../ann_cagra_ace.cuh" + +namespace cuvs::neighbors::cagra { + +typedef AnnCagraAceTest AnnCagraAceTestF16_U32; +TEST_P(AnnCagraAceTestF16_U32, AnnCagraAce) { this->testAce(); } + +INSTANTIATE_TEST_CASE_P(AnnCagraAceTest, AnnCagraAceTestF16_U32, ::testing::ValuesIn(ace_inputs)); + +} // namespace cuvs::neighbors::cagra diff --git a/cpp/tests/neighbors/ann_cagra_ace/test_int8_t_uint32_t.cu b/cpp/tests/neighbors/ann_cagra_ace/test_int8_t_uint32_t.cu new file mode 100644 index 0000000000..3973b72cd6 --- /dev/null +++ b/cpp/tests/neighbors/ann_cagra_ace/test_int8_t_uint32_t.cu @@ -0,0 +1,17 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include "../ann_cagra_ace.cuh" + +namespace cuvs::neighbors::cagra { + +typedef AnnCagraAceTest AnnCagraAceTestI8_U32; +TEST_P(AnnCagraAceTestI8_U32, AnnCagraAce) { this->testAce(); } + +INSTANTIATE_TEST_CASE_P(AnnCagraAceTest, AnnCagraAceTestI8_U32, ::testing::ValuesIn(ace_inputs)); + +} // namespace cuvs::neighbors::cagra diff --git a/cpp/tests/neighbors/ann_cagra_ace/test_uint8_t_uint32_t.cu b/cpp/tests/neighbors/ann_cagra_ace/test_uint8_t_uint32_t.cu new file mode 100644 index 0000000000..5ca6f038df --- /dev/null +++ b/cpp/tests/neighbors/ann_cagra_ace/test_uint8_t_uint32_t.cu @@ -0,0 +1,17 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include "../ann_cagra_ace.cuh" + +namespace cuvs::neighbors::cagra { + +typedef AnnCagraAceTest AnnCagraAceTestU8_U32; +TEST_P(AnnCagraAceTestU8_U32, AnnCagraAce) { this->testAce(); } + +INSTANTIATE_TEST_CASE_P(AnnCagraAceTest, AnnCagraAceTestU8_U32, ::testing::ValuesIn(ace_inputs)); + +} // namespace cuvs::neighbors::cagra diff --git a/docs/source/cuvs_bench/param_tuning.rst b/docs/source/cuvs_bench/param_tuning.rst index a0de6daf09..12e804dd56 100644 --- a/docs/source/cuvs_bench/param_tuning.rst +++ b/docs/source/cuvs_bench/param_tuning.rst @@ -201,7 +201,7 @@ CAGRA uses a graph-based index, which creates an intermediate, approximate kNN g * - `graph_build_algo` - `build` - `N` - - [`IVF_PQ`, `NN_DESCENT`] + - [`IVF_PQ`, `NN_DESCENT`, `ACE`] - `IVF_PQ` - Algorithm to use for building the initial kNN graph, from which CAGRA will optimize into the navigable CAGRA graph @@ -212,6 +212,34 @@ CAGRA uses a graph-based index, which creates an intermediate, approximate kNN g - `mmap` - Where should the dataset reside? + * - `npartitions` + - `build` + - N + - Positive integer >0 + - 1 + - The number of partitions to use for the ACE build. Small values might improve recall but potentially degrade performance and increase memory usage. Partitions should not be too small to prevent issues in KNN graph construction. 100k - 5M vectors per partition is recommended depending on the available host and GPU memory. The partition size is on average 2 * (n_rows / npartitions) * dim * sizeof(T). 2 is because of the core and augmented vectors. Please account for imbalance in the partition sizes (up to 3x in our tests). + + * - `build_dir` + - `build` + - N + - String + - "/tmp/ace_build" + - The directory to use for the ACE build. Must be specified when using ACE build. This should be the fastest disk in the system and hold enough space for twice the dataset, final graph, and label mapping. + + * - `ef_construction` + - `build` + - Y + - Positive integer >0 + - 120 + - Controls index time and accuracy when using ACE build. Bigger values increase the index quality. At some point, increasing this will no longer improve the quality. + + * - `use_disk` + - `build` + - N + - Boolean + - `false` + - Whether to use disk-based storage for ACE build. When true, forces ACE to use disk-based storage even if the graph fits in host and GPU memory. When false, ACE will use in-memory storage if the graph fits in host and GPU memory and disk-based storage otherwise. + * - `query_memory_type` - `search` - N diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index bd87df5782..619583a83e 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -30,6 +30,7 @@ include(../cmake/thirdparty/get_cuvs.cmake) # -------------- compile tasks ----------------- # add_executable(BRUTE_FORCE_EXAMPLE src/brute_force_bitmap.cu) +add_executable(CAGRA_HNSW_ACE_EXAMPLE src/cagra_hnsw_ace_example.cu) add_executable(CAGRA_EXAMPLE src/cagra_example.cu) add_executable(CAGRA_PERSISTENT_EXAMPLE src/cagra_persistent_example.cu) add_executable(DYNAMIC_BATCHING_EXAMPLE src/dynamic_batching_example.cu) @@ -41,6 +42,7 @@ add_executable(SCANN_EXAMPLE src/scann_example.cu) # `$` is a generator expression that ensures that targets are # installed in a conda environment, if one exists target_link_libraries(BRUTE_FORCE_EXAMPLE PRIVATE cuvs::cuvs $) +target_link_libraries(CAGRA_HNSW_ACE_EXAMPLE PRIVATE cuvs::cuvs $) target_link_libraries(CAGRA_EXAMPLE PRIVATE cuvs::cuvs $) target_link_libraries( CAGRA_PERSISTENT_EXAMPLE PRIVATE cuvs::cuvs $ Threads::Threads diff --git a/examples/cpp/src/cagra_hnsw_ace_example.cu b/examples/cpp/src/cagra_hnsw_ace_example.cu new file mode 100644 index 0000000000..b2474eeab9 --- /dev/null +++ b/examples/cpp/src/cagra_hnsw_ace_example.cu @@ -0,0 +1,182 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include "common.cuh" + +void cagra_build_search_ace(raft::device_resources const& dev_resources, + raft::device_matrix_view dataset, + raft::device_matrix_view queries) +{ + using namespace cuvs::neighbors; + + int64_t topk = 12; + int64_t n_queries = queries.extent(0); + + // create output arrays + auto neighbors = raft::make_device_matrix(dev_resources, n_queries, topk); + auto distances = raft::make_device_matrix(dev_resources, n_queries, topk); + + // CAGRA index parameters + cagra::index_params index_params; + index_params.intermediate_graph_degree = 128; + index_params.graph_degree = 64; + + // ACE index parameters + auto ace_params = cagra::graph_build_params::ace_params(); + // Set the number of partitions. Small values might improve recall but potentially degrade + // performance and increase memory usage. Partitions should not be too small to prevent issues in + // KNN graph construction. 100k - 5M vectors per partition is recommended depending on the + // available host and GPU memory. The partition size is on average 2 * (n_rows / npartitions) * + // dim * sizeof(T). 2 is because of the core and augmented vectors. Please account for imbalance + // in the partition sizes (up to 3x in our tests). + ace_params.npartitions = 4; + // Set the index quality for the ACE build. Bigger values increase the index quality. At some + // point, increasing this will no longer improve the quality. + ace_params.ef_construction = 120; + // Set the directory to store the ACE build artifacts. This should be the fastest disk in the + // system and hold enough space for twice the dataset, final graph, and label mapping. + ace_params.build_dir = "/tmp/ace_build"; + // Set whether to use disk-based storage for ACE build. When true, enables disk-based operations + // for memory-efficient graph construction. If not set, the index will be built in memory if the + // graph fits in host and GPU memory, and on disk otherwise. + ace_params.use_disk = true; + index_params.graph_build_params = ace_params; + + // ACE requires the dataset to be on the host + auto dataset_host = raft::make_host_matrix(dataset.extent(0), dataset.extent(1)); + raft::copy(dataset_host.data_handle(), + dataset.data_handle(), + dataset.extent(0) * dataset.extent(1), + raft::resource::get_cuda_stream(dev_resources)); + raft::resource::sync_stream(dev_resources); + auto dataset_host_view = raft::make_host_matrix_view( + dataset_host.data_handle(), dataset_host.extent(0), dataset_host.extent(1)); + + std::cout << "Building CAGRA index (search graph)" << std::endl; + auto index = cagra::build(dev_resources, index_params, dataset_host_view); + // In-memory build of ACE provides the index in memory, so we can search it directly using + // cagra::search + + // On-disk build of ACE stores the reordered dataset, the dataset mapping, and the graph on disk. + // The index is not directly usable for CAGRA search. Convert to HNSW for search operations. + + // Convert CAGRA index to HNSW + // For disk-based indices: serializes CAGRA to HNSW format on disk, returns an index with file + // descriptor For in-memory indices: creates HNSW index in memory + std::cout << "Converting CAGRA index to HNSW" << std::endl; + hnsw::index_params hnsw_params; + auto hnsw_index = hnsw::from_cagra(dev_resources, hnsw_params, index); + + // HNSW search requires host matrices + auto queries_host = raft::make_host_matrix(n_queries, queries.extent(1)); + raft::copy(queries_host.data_handle(), + queries.data_handle(), + n_queries * queries.extent(1), + raft::resource::get_cuda_stream(dev_resources)); + raft::resource::sync_stream(dev_resources); + + // HNSW search outputs uint64_t indices + auto indices_hnsw_host = raft::make_host_matrix(n_queries, topk); + auto distances_hnsw_host = raft::make_host_matrix(n_queries, topk); + + hnsw::search_params hnsw_search_params; + hnsw_search_params.ef = std::max(200, static_cast(topk) * 2); + hnsw_search_params.num_threads = 1; + + // If the HNSW index is in memory, search directly + // std::cout << "HNSW index in memory. Searching..." << std::endl; + // hnsw::search(dev_resources, + // hnsw_search_params, + // *hnsw_index, + // queries_host.view(), + // indices_hnsw_host.view(), + // distances_hnsw_host.view()); + + // If the HNSW index is stored on disk, deserialize it for searching + std::cout << "HNSW index is stored on disk." << std::endl; + + // For disk-based indices, the HNSW index file path can be obtained via file_path() + std::string hnsw_index_path = hnsw_index->file_path(); + std::cout << "HNSW index file location: " << hnsw_index_path << std::endl; + std::cout << "Deserializing HNSW index from disk for search." << std::endl; + + hnsw::index* hnsw_index_raw = nullptr; + hnsw::deserialize( + dev_resources, hnsw_params, hnsw_index_path, index.dim(), index.metric(), &hnsw_index_raw); + + std::unique_ptr> hnsw_index_deserialized(hnsw_index_raw); + + std::cout << "Searching HNSW index." << std::endl; + hnsw::search(dev_resources, + hnsw_search_params, + *hnsw_index_deserialized, + queries_host.view(), + indices_hnsw_host.view(), + distances_hnsw_host.view()); + + // Convert HNSW uint64_t indices back to uint32_t for printing + auto neighbors_host = raft::make_host_matrix(n_queries, topk); + for (int64_t i = 0; i < n_queries; i++) { + for (int64_t j = 0; j < topk; j++) { + neighbors_host(i, j) = static_cast(indices_hnsw_host(i, j)); + } + } + + // Copy results to device + raft::copy(neighbors.data_handle(), + neighbors_host.data_handle(), + n_queries * topk, + raft::resource::get_cuda_stream(dev_resources)); + raft::copy(distances.data_handle(), + distances_hnsw_host.data_handle(), + n_queries * topk, + raft::resource::get_cuda_stream(dev_resources)); + raft::resource::sync_stream(dev_resources); + + print_results(dev_resources, neighbors.view(), distances.view()); +} + +int main() +{ + raft::device_resources dev_resources; + + // Set pool memory resource with 1 GiB initial pool size. All allocations use the same pool. + rmm::mr::pool_memory_resource pool_mr( + rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull); + rmm::mr::set_current_device_resource(&pool_mr); + + // Alternatively, one could define a pool allocator for temporary arrays (used within RAFT + // algorithms). In that case only the internal arrays would use the pool, any other allocation + // uses the default RMM memory resource. Here is how to change the workspace memory resource to + // a pool with 2 GiB upper limit. + // raft::resource::set_workspace_to_pool_resource(dev_resources, 2 * 1024 * 1024 * 1024ull); + + // Create input arrays. + int64_t n_samples = 10000; + int64_t n_dim = 90; + int64_t n_queries = 10; + auto dataset = raft::make_device_matrix(dev_resources, n_samples, n_dim); + auto queries = raft::make_device_matrix(dev_resources, n_queries, n_dim); + generate_dataset(dev_resources, dataset.view(), queries.view()); + + // ACE build and search example. + cagra_build_search_ace(dev_resources, + raft::make_const_mdspan(dataset.view()), + raft::make_const_mdspan(queries.view())); +} diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndexParams.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndexParams.java index 2e47ac27e2..e185ed9f26 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndexParams.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndexParams.java @@ -24,6 +24,7 @@ public class CagraIndexParams { private final long nnDescentNiter; private final int numWriterThreads; private final CuVSIvfPqParams cuVSIvfPqParams; + private final CuVSAceParams cuVSAceParams; private final CagraCompressionParams cagraCompressionParams; /** @@ -41,7 +42,12 @@ public enum CagraGraphBuildAlgo { /** * Experimental, use NN-Descent to build all-neighbors knn graph */ - NN_DESCENT(2); + NN_DESCENT(2), + /** + * Experimental, use ACE (Augmented Core Extraction) to build graph for large datasets. + * 4 to be consistent with the other interfaces. + */ + ACE(4); /** * The value for the enum choice. @@ -329,6 +335,7 @@ private CagraIndexParams( int writerThreads, CuvsDistanceType cuvsDistanceType, CuVSIvfPqParams cuVSIvfPqParams, + CuVSAceParams cuVSAceParams, CagraCompressionParams cagraCompressionParams) { this.intermediateGraphDegree = intermediateGraphDegree; this.graphDegree = graphDegree; @@ -337,6 +344,7 @@ private CagraIndexParams( this.numWriterThreads = writerThreads; this.cuvsDistanceType = cuvsDistanceType; this.cuVSIvfPqParams = cuVSIvfPqParams; + this.cuVSAceParams = cuVSAceParams; this.cagraCompressionParams = cagraCompressionParams; } @@ -405,6 +413,13 @@ public CuVSIvfPqParams getCuVSIvfPqParams() { return cuVSIvfPqParams; } + /** + * Gets the ACE parameters. + */ + public CuVSAceParams getCuVSAceParams() { + return cuVSAceParams; + } + /** * Gets the CAGRA build algorithm. */ @@ -435,6 +450,8 @@ public String toString() { + numWriterThreads + ", cuVSIvfPqParams=" + cuVSIvfPqParams + + ", cuVSAceParams=" + + cuVSAceParams + ", cagraCompressionParams=" + cagraCompressionParams + "]"; @@ -452,6 +469,7 @@ public static class Builder { private long nnDescentNumIterations = 20; private int numWriterThreads = 2; private CuVSIvfPqParams cuVSIvfPqParams = new CuVSIvfPqParams.Builder().build(); + private CuVSAceParams cuVSAceParams = new CuVSAceParams.Builder().build(); private CagraCompressionParams cagraCompressionParams; public Builder() {} @@ -535,6 +553,17 @@ public Builder withCuVSIvfPqParams(CuVSIvfPqParams cuVSIvfPqParams) { return this; } + /** + * Sets the ACE index parameters. + * + * @param cuVSAceParams the ACE index parameters + * @return an instance of Builder + */ + public Builder withCuVSAceParams(CuVSAceParams cuVSAceParams) { + this.cuVSAceParams = cuVSAceParams; + return this; + } + /** * Registers an instance of configured {@link CagraCompressionParams} with this * Builder. @@ -561,6 +590,7 @@ public CagraIndexParams build() { numWriterThreads, cuvsDistanceType, cuVSIvfPqParams, + cuVSAceParams, cagraCompressionParams); } } diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSAceParams.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSAceParams.java new file mode 100644 index 0000000000..54c25814bb --- /dev/null +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSAceParams.java @@ -0,0 +1,184 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +package com.nvidia.cuvs; + +/** + * Parameters for ACE (Augmented Core Extraction) graph build algorithm. + * ACE enables building indices for datasets too large to fit in GPU memory by: + * 1. Partitioning the dataset in core (closest) and augmented (second-closest) + * partitions using balanced k-means. + * 2. Building sub-indices for each partition independently + * 3. Concatenating sub-graphs into a final unified index + * + * @since 25.12 + */ +public class CuVSAceParams { + + /** + * Number of partitions for ACE (Augmented Core Extraction) partitioned build. + * + * Small values might improve recall but potentially degrade performance and increase memory usage. + * Partitions should not be too small to prevent issues in KNN graph construction. 100k - 5M + * vectors per partition is recommended depending on the available host and GPU memory. The + * partition size is on average {@code 2 * (n_rows / npartitions) * dim * sizeof(T)}—the factor 2 + * accounts for core and augmented vectors. Please account for imbalance in the partition sizes + * (up to 3x in our tests). + */ + private final long npartitions; + + /** + * The index quality for the ACE build. + * + * Bigger values increase the index quality. At some point, increasing this will no longer improve + * the quality. + */ + private final long efConstruction; + + /** + * Directory to store ACE build artifacts (e.g., KNN graph, optimized graph). + * + * Used when {@link #isUseDisk()} is true or when the graph does not fit in host and GPU memory. + * This should be the fastest disk in the system and hold enough space for twice the dataset, final + * graph, and label mapping. + */ + private final String buildDir; + + /** + * Whether to use disk-based storage for ACE builds. + * + * When true, enables disk-based operations for memory-efficient graph construction. + */ + private final boolean useDisk; + + private CuVSAceParams( + long npartitions, long efConstruction, String buildDir, boolean useDisk) { + this.npartitions = npartitions; + this.efConstruction = efConstruction; + this.buildDir = buildDir; + this.useDisk = useDisk; + } + + /** + * Gets the number of partitions. + * + * @return the number of partitions + */ + public long getNpartitions() { + return npartitions; + } + + /** + * Gets the {@code ef_construction} parameter. + * + * @return the {@code ef_construction} parameter + */ + public long getEfConstruction() { + return efConstruction; + } + + /** + * Gets the build directory path. + * + * @return the build directory path + */ + public String getBuildDir() { + return buildDir; + } + + /** + * Gets whether disk-based mode is enabled. + * + * @return true if disk-based mode is enabled + */ + public boolean isUseDisk() { + return useDisk; + } + + @Override + public String toString() { + return "CuVSAceParams [npartitions=" + + npartitions + + ", efConstruction=" + + efConstruction + + ", buildDir=" + + buildDir + + ", useDisk=" + + useDisk + + "]"; + } + + /** + * Builder configures and creates an instance of {@link CuVSAceParams}. + */ + public static class Builder { + + /** Number of partitions to split the dataset into */ + private long npartitions = 1; + + /** ef_construction parameter for HNSW used in ACE */ + private long efConstruction = 120; + + /** Directory to store intermediate build files */ + private String buildDir = "/tmp/ace_build"; + + /** Whether to use disk-based mode for very large datasets */ + private boolean useDisk = false; + + public Builder() {} + + /** + * Sets the number of partitions. + * + * @param npartitions the number of partitions + * @return an instance of Builder + */ + public Builder withNpartitions(long npartitions) { + this.npartitions = npartitions; + return this; + } + + /** + * Sets the ef_construction parameter. + * + * @param efConstruction the ef_construction parameter + * @return an instance of Builder + */ + public Builder withEfConstruction(long efConstruction) { + this.efConstruction = efConstruction; + return this; + } + + /** + * Sets the build directory path. + * + * @param buildDir the build directory path + * @return an instance of Builder + */ + public Builder withBuildDir(String buildDir) { + this.buildDir = buildDir; + return this; + } + + /** + * Sets whether to use disk-based mode. + * + * @param useDisk whether to use disk-based mode + * @return an instance of Builder + */ + public Builder withUseDisk(boolean useDisk) { + this.useDisk = useDisk; + return this; + } + + /** + * Builds an instance of {@link CuVSAceParams}. + * + * @return an instance of {@link CuVSAceParams} + */ + public CuVSAceParams build() { + return new CuVSAceParams(npartitions, efConstruction, buildDir, useDisk); + } + } +} diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswIndex.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswIndex.java index 6837c50505..c09111fcc3 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswIndex.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswIndex.java @@ -43,6 +43,21 @@ static HnswIndex.Builder newBuilder(CuVSResources cuvsResources) { return CuVSProvider.provider().newHnswIndexBuilder(cuvsResources); } + /** + * Creates an HNSW index from an existing CAGRA index. + * + * @param hnswParams Parameters for the HNSW index + * @param cagraIndex The CAGRA index to convert from + * @return A new HNSW index + * @throws Throwable if an error occurs during conversion + */ + static HnswIndex fromCagra(HnswIndexParams hnswParams, CagraIndex cagraIndex) + throws Throwable { + Objects.requireNonNull(hnswParams); + Objects.requireNonNull(cagraIndex); + return CuVSProvider.provider().hnswIndexFromCagra(hnswParams, cagraIndex); + } + /** * Builder helps configure and create an instance of {@link HnswIndex}. */ diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/CuVSProvider.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/CuVSProvider.java index 107da0bd8e..5ff87e5c64 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/CuVSProvider.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/CuVSProvider.java @@ -113,6 +113,17 @@ CagraIndex.Builder newCagraIndexBuilder(CuVSResources cuVSResources) HnswIndex.Builder newHnswIndexBuilder(CuVSResources cuVSResources) throws UnsupportedOperationException; + /** + * Creates an HNSW index from an existing CAGRA index. + * + * @param hnswParams Parameters for the HNSW index + * @param cagraIndex The CAGRA index to convert from + * @return A new HNSW index + * @throws Throwable if an error occurs during conversion + */ + HnswIndex hnswIndexFromCagra(HnswIndexParams hnswParams, CagraIndex cagraIndex) + throws Throwable; + /** Creates a new TieredIndex Builder. */ TieredIndex.Builder newTieredIndexBuilder(CuVSResources cuVSResources) throws UnsupportedOperationException; diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/UnsupportedProvider.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/UnsupportedProvider.java index 0b229009dd..d0f244d9ce 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/UnsupportedProvider.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/UnsupportedProvider.java @@ -40,6 +40,12 @@ public HnswIndex.Builder newHnswIndexBuilder(CuVSResources cuVSResources) { throw new UnsupportedOperationException(reasons); } + @Override + public HnswIndex hnswIndexFromCagra(HnswIndexParams hnswParams, CagraIndex cagraIndex) + throws Throwable { + throw new UnsupportedOperationException(reasons); + } + @Override public TieredIndex.Builder newTieredIndexBuilder(CuVSResources cuVSResources) { throw new UnsupportedOperationException(reasons); diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java index efc278c1f5..41e46f78fd 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java @@ -505,6 +505,20 @@ public CuVSResources getCuVSResources() { return resources; } + /** + * Gets the CAGRA index reference (for internal use). + * Package-private to allow access from HnswIndexImpl. + * + * @return the memory segment representing the CAGRA index + */ + MemorySegment getCagraIndexReference() { + return cagraIndexReference.getMemorySegment(); + } + + CuVSMatrix getDatasetForConversion() { + return cagraIndexReference.dataset; + } + /** * Allocates the native CagraIndexParams data structures and fills the configured index parameters in. */ @@ -610,6 +624,25 @@ private static void populateNativeIndexParams( cuvsIvfPqParamsMemorySegment, params.getCuVSIvfPqParams().getRefinementRate()); cuvsCagraIndexParams.graph_build_params(indexPtr, cuvsIvfPqParamsMemorySegment); + } else if (params.getCagraGraphBuildAlgo().equals(CagraGraphBuildAlgo.ACE)) { + var aceParams = createAceParams(); + // Note: Do NOT add aceParams to handles list. + // The cuvsCagraIndexParamsDestroy will handle freeing the ACE params + // when graph_build_algo is ACE, just like it does for IVF-PQ params. + MemorySegment cuvsAceParamsMemorySegment = aceParams.handle(); + CuVSAceParams cuVSAceParams = params.getCuVSAceParams(); + + cuvsAceParams.npartitions(cuvsAceParamsMemorySegment, cuVSAceParams.getNpartitions()); + cuvsAceParams.ef_construction(cuvsAceParamsMemorySegment, cuVSAceParams.getEfConstruction()); + cuvsAceParams.use_disk(cuvsAceParamsMemorySegment, cuVSAceParams.isUseDisk()); + + String buildDir = cuVSAceParams.getBuildDir(); + if (buildDir != null && !buildDir.isEmpty()) { + MemorySegment buildDirSegment = Util.duplicateNativeString(buildDir); + cuvsAceParams.build_dir(cuvsAceParamsMemorySegment, buildDirSegment); + } + + cuvsCagraIndexParams.graph_build_params(indexPtr, cuvsAceParamsMemorySegment); } } diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSParamsHelper.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSParamsHelper.java index b45b90c7ec..9cfc7e2f49 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSParamsHelper.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSParamsHelper.java @@ -122,6 +122,25 @@ public void close() { } } + public static CloseableHandle createAceParams() { + try (var localArena = Arena.ofConfined()) { + var paramsPtrPtr = localArena.allocate(cuvsAceParams_t); + checkCuVSError(cuvsAceParamsCreate(paramsPtrPtr), "cuvsAceParamsCreate"); + var paramsPtr = paramsPtrPtr.get(cuvsAceParams_t, 0L); + return new CloseableHandle() { + @Override + public MemorySegment handle() { + return paramsPtr; + } + + @Override + public void close() { + checkCuVSError(cuvsAceParamsDestroy(paramsPtr), "cuvsAceParamsDestroy"); + } + }; + } + } + static CloseableHandle createHnswIndexParams() { try (var localArena = Arena.ofConfined()) { var paramsPtrPtr = localArena.allocate(cuvsHnswIndexParams_t); diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/HnswIndexImpl.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/HnswIndexImpl.java index 876efce7ef..90ac4e7357 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/HnswIndexImpl.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/HnswIndexImpl.java @@ -12,6 +12,8 @@ import static com.nvidia.cuvs.internal.common.Util.prepareTensor; import static com.nvidia.cuvs.internal.panama.headers_h.*; +import com.nvidia.cuvs.CagraIndex; +import com.nvidia.cuvs.CuVSMatrix; import com.nvidia.cuvs.CuVSResources; import com.nvidia.cuvs.HnswIndex; import com.nvidia.cuvs.HnswIndexParams; @@ -59,6 +61,20 @@ private HnswIndexImpl( this.hnswIndexReference = deserialize(inputStream); } + /** + * Constructor for creating index from an existing IndexReference + * + * @param indexReference the index reference + * @param resources an instance of {@link CuVSResources} + * @param hnswIndexParams an instance of {@link HnswIndexParams} + */ + private HnswIndexImpl( + IndexReference indexReference, CuVSResources resources, HnswIndexParams hnswIndexParams) { + this.hnswIndexParams = hnswIndexParams; + this.resources = resources; + this.hnswIndexReference = indexReference; + } + /** * Invokes the native destroy_hnsw_index to de-allocate the HNSW index */ @@ -222,6 +238,105 @@ public static HnswIndex.Builder newBuilder(CuVSResources cuvsResources) { return new HnswIndexImpl.Builder(Objects.requireNonNull(cuvsResources)); } + /** + * Creates an HNSW index from an existing CAGRA index. + * + * @param hnswParams Parameters for the HNSW index + * @param cagraIndex The CAGRA index to convert from + * @return A new HNSW index for in-memory indices, or null for disk-based indices + * @throws Throwable if an error occurs during conversion + */ + public static HnswIndex fromCagra(HnswIndexParams hnswParams, CagraIndex cagraIndex) + throws Throwable { + Objects.requireNonNull(hnswParams); + Objects.requireNonNull(cagraIndex); + + // Get the CAGRA index implementation to access internals + if (!(cagraIndex instanceof CagraIndexImpl)) { + throw new IllegalArgumentException("Invalid CagraIndex implementation"); + } + CagraIndexImpl cagraImpl = (CagraIndexImpl) cagraIndex; + CuVSResources resources = cagraImpl.getCuVSResources(); + + // Create HNSW index + MemorySegment hnswIndex = createHnswIndexHandle(); + + initializeIndexDType(hnswIndex, cagraImpl.getDatasetForConversion()); + + try (var localArena = Arena.ofConfined(); + var hnswParamsHandle = createHnswIndexParams()) { + MemorySegment hnswParamsMemorySegment = hnswParamsHandle.handle(); + + // Set HNSW params + cuvsHnswIndexParams.hierarchy( + hnswParamsMemorySegment, + hnswParams.getHierarchy().value); + cuvsHnswIndexParams.ef_construction( + hnswParamsMemorySegment, + hnswParams.getEfConstruction()); + cuvsHnswIndexParams.num_threads( + hnswParamsMemorySegment, + hnswParams.getNumThreads()); + + try (var resourcesAccessor = resources.access()) { + var cuvsRes = resourcesAccessor.handle(); + + // Call cuvsHnswFromCagra + int returnValue = + cuvsHnswFromCagra( + cuvsRes, + hnswParamsMemorySegment, + cagraImpl.getCagraIndexReference(), + hnswIndex); + checkCuVSError(returnValue, "cuvsHnswFromCagra"); + + returnValue = cuvsStreamSync(cuvsRes); + checkCuVSError(returnValue, "cuvsStreamSync"); + } + } + return new HnswIndexImpl(new IndexReference(hnswIndex), resources, hnswParams); + } + + /** + * Creates a new HNSW index handle. + */ + private static MemorySegment createHnswIndexHandle() { + try (var localArena = Arena.ofConfined()) { + MemorySegment indexPtrPtr = localArena.allocate(cuvsHnswIndex_t); + var returnValue = cuvsHnswIndexCreate(indexPtrPtr); + checkCuVSError(returnValue, "cuvsHnswIndexCreate"); + return indexPtrPtr.get(cuvsHnswIndex_t, 0); + } + } + + private static void initializeIndexDType(MemorySegment hnswIndex, CuVSMatrix dataset) { + int bits = 32; + int code = kDLFloat(); + + if (dataset instanceof CuVSMatrixInternal matrixInternal) { + bits = matrixInternal.bits(); + code = matrixInternal.code(); + } else if (dataset != null) { + bits = bitsFromDataType(dataset.dataType()); + code = CuVSMatrixInternal.code(dataset.dataType()); + } + + try (var localArena = Arena.ofConfined()) { + MemorySegment dtype = DLDataType.allocate(localArena); + DLDataType.bits(dtype, (byte) bits); + DLDataType.code(dtype, (byte) code); + DLDataType.lanes(dtype, (byte) 1); + cuvsHnswIndex.dtype(hnswIndex, dtype); + } + } + + private static int bitsFromDataType(CuVSMatrix.DataType dataType) { + return switch (dataType) { + case BYTE -> 8; + default -> 32; + }; + } + /** * Builder helps configure and create an instance of {@link HnswIndex}. */ diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java index f84eed5dbe..1117c08f23 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java @@ -54,6 +54,11 @@ private Util() {} LINKER.downcallHandle( cudaMemcpyAsync$address(), cudaMemcpyAsync$descriptor(), Linker.Option.critical(true)); + private static final MethodHandle strdup$mh = + LINKER.downcallHandle( + SYMBOL_LOOKUP.find("strdup").orElseThrow(UnsatisfiedLinkError::new), + FunctionDescriptor.of(C_POINTER, C_POINTER)); + private static final MethodHandle cudaGetDeviceProperties$mh = LINKER.downcallHandle( SYMBOL_LOOKUP @@ -215,6 +220,19 @@ public static MemorySegment buildMemorySegment(Arena arena, String str) { return stringMemorySegment; } + /** + * Allocates a native (C-owned) copy of a string using strdup(). The returned memory must be freed + * by the native side (e.g. cuVS APIs) via free(). + */ + public static MemorySegment duplicateNativeString(String str) { + try (var arena = Arena.ofConfined()) { + MemorySegment src = buildMemorySegment(arena, str); + return (MemorySegment) strdup$mh.invokeExact(src); + } catch (Throwable t) { + throw new RuntimeException("Failed to duplicate native string", t); + } + } + /** * A utility method for building a {@link MemorySegment} for a 1D long array. * diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/JDKProvider.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/JDKProvider.java index facbc36a51..c639c48460 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/JDKProvider.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/JDKProvider.java @@ -156,6 +156,12 @@ public HnswIndex.Builder newHnswIndexBuilder(CuVSResources cuVSResources) { return HnswIndexImpl.newBuilder(Objects.requireNonNull(cuVSResources)); } + @Override + public HnswIndex hnswIndexFromCagra(HnswIndexParams hnswParams, CagraIndex cagraIndex) + throws Throwable { + return HnswIndexImpl.fromCagra(hnswParams, cagraIndex); + } + @Override public TieredIndex.Builder newTieredIndexBuilder(CuVSResources cuVSResources) { return TieredIndexImpl.newBuilder(Objects.requireNonNull(cuVSResources)); diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraAceBuildAndSearchIT.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraAceBuildAndSearchIT.java new file mode 100644 index 0000000000..997fa560a7 --- /dev/null +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraAceBuildAndSearchIT.java @@ -0,0 +1,243 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +package com.nvidia.cuvs; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.assumeTrue; +import static org.junit.Assert.*; + +import com.carrotsearch.randomizedtesting.RandomizedRunner; +import com.nvidia.cuvs.CagraIndexParams.CagraGraphBuildAlgo; +import com.nvidia.cuvs.CagraIndexParams.CuvsDistanceType; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Integration tests for CAGRA index using ACE (Augmented Core Extraction) build algorithm. + * ACE enables building indices for datasets too large to fit in GPU memory by partitioning + * the dataset and building sub-indices. + * + * @since 25.12 + */ +@RunWith(RandomizedRunner.class) +public class CagraAceBuildAndSearchIT extends CuVSTestCase { + + private static final Logger log = LoggerFactory.getLogger(CagraAceBuildAndSearchIT.class); + + @Before + public void setup() { + assumeTrue("not supported on " + System.getProperty("os.name"), isLinuxAmd64()); + initializeRandom(); + log.trace("Random context initialized for test."); + } + + private static List> getExpectedResults() { + return Arrays.asList( + Map.of(3, 0.038782578f, 2, 0.3590463f, 0, 0.83774555f), + Map.of(0, 0.12472608f, 2, 0.21700792f, 1, 0.31918612f), + Map.of(3, 0.047766715f, 2, 0.20332818f, 0, 0.48305473f), + Map.of(1, 0.15224178f, 0, 0.59063464f, 3, 0.5986642f)); + } + + private static float[][] createSampleQueries() { + return new float[][] { + {0.48216683f, 0.0428398f}, + {0.5084142f, 0.6545497f}, + {0.51260436f, 0.2643005f}, + {0.05198065f, 0.5789965f} + }; + } + + private static float[][] createSampleData() { + return new float[][] { + {0.74021935f, 0.9209938f}, + {0.03902049f, 0.9689629f}, + {0.92514056f, 0.4463501f}, + {0.6673192f, 0.10993068f} + }; + } + + /** + * Test ACE build with in-memory mode (use_disk=false). + * This tests the basic ACE functionality with small datasets that fit in memory. + */ + @Test + public void testAceInMemoryBuild() throws Throwable { + float[][] dataset = createSampleData(); + float[][] queries = createSampleQueries(); + List> expectedResults = getExpectedResults(); + + try (CuVSResources resources = CheckedCuVSResources.create()) { + // Configure ACE parameters for in-memory mode + CuVSAceParams aceParams = + new CuVSAceParams.Builder() + .withNpartitions(2) + .withEfConstruction(120) + .withUseDisk(false) + .build(); + + // Configure index parameters with ACE build algorithm + CagraIndexParams indexParams = + new CagraIndexParams.Builder() + .withCagraGraphBuildAlgo(CagraGraphBuildAlgo.ACE) + .withGraphDegree(64) + .withIntermediateGraphDegree(128) + .withNumWriterThreads(2) + .withMetric(CuvsDistanceType.L2Expanded) + .withCuVSAceParams(aceParams) + .build(); + + // Build the index with ACE + try (CagraIndex index = + CagraIndex.newBuilder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build()) { + + // Verify index was built + assertNotNull("Index should not be null", index); + log.debug("ACE index built successfully in memory"); + + // Perform search + CagraSearchParams searchParams = new CagraSearchParams.Builder().build(); + + try (var queryVectors = CuVSMatrix.ofArray(queries)) { + CagraQuery cuvsQuery = + new CagraQuery.Builder(resources) + .withTopK(3) + .withSearchParams(searchParams) + .withQueryVectors(queryVectors) + .build(); + + SearchResults results = index.search(cuvsQuery); + log.debug("Search results: " + results.getResults().toString()); + + // Verify search results + checkResults(expectedResults, results.getResults()); + } + } + } + } + + /** + * Test ACE build with disk-based mode (use_disk=true). + * This tests ACE's ability to handle large datasets that don't fit in GPU memory. + */ + @Test + public void testAceDiskBasedBuild() throws Throwable { + float[][] dataset = createSampleData(); + float[][] queries = createSampleQueries(); + List> expectedResults = getExpectedResults(); + + try (CuVSResources resources = CheckedCuVSResources.create()) { + // Configure ACE parameters for disk-based mode + Path buildDir = Path.of("/tmp/java_ace_test"); + CuVSAceParams aceParams = + new CuVSAceParams.Builder() + .withNpartitions(2) + .withEfConstruction(120) + .withUseDisk(true) + .withBuildDir(buildDir.toString()) + .build(); + + // Configure index parameters with ACE build algorithm + CagraIndexParams indexParams = + new CagraIndexParams.Builder() + .withCagraGraphBuildAlgo(CagraGraphBuildAlgo.ACE) + .withGraphDegree(64) + .withIntermediateGraphDegree(128) + .withNumWriterThreads(32) + .withMetric(CuvsDistanceType.L2Expanded) + .withCuVSAceParams(aceParams) + .build(); + + // Build the index with ACE in disk mode + try (CagraIndex index = + CagraIndex.newBuilder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build()) { + + // Verify index was built + assertNotNull("Index should not be null", index); + log.debug("ACE index built successfully with disk mode"); + + // Verify ACE created the expected output files in the build directory + assertTrue( + "CAGRA graph file should exist", + Files.exists(buildDir.resolve("cagra_graph.npy"))); + assertTrue( + "Reordered dataset file should exist", + Files.exists(buildDir.resolve("reordered_dataset.npy"))); + assertTrue( + "Dataset mapping file should exist", + Files.exists(buildDir.resolve("dataset_mapping.npy"))); + + log.debug("ACE disk output files verified"); + + // Convert CAGRA index to HNSW using fromCagra + // This automatically handles disk-based indices + HnswIndexParams hnswIndexParams = + new HnswIndexParams.Builder().withVectorDimension(2).build(); + + try (var hnswIndexSerialized = HnswIndex.fromCagra(hnswIndexParams, index)) { + var hnswIndexSerializedPath = buildDir.resolve("hnsw_index.bin"); + assertTrue("HNSW index should exist", Files.exists(hnswIndexSerializedPath)); + log.debug("HNSW index created from disk-based ACE CAGRA index"); + + // Load the serialized index from disk + try (var inputStreamHNSW = Files.newInputStream(hnswIndexSerializedPath)) { + var hnswIndex = + HnswIndex.newBuilder(resources) + .from(inputStreamHNSW) + .withIndexParams(hnswIndexParams) + .build(); + + HnswSearchParams hnswSearchParams = new HnswSearchParams.Builder().build(); + HnswQuery hnswQuery = + new HnswQuery.Builder(resources) + .withTopK(3) + .withSearchParams(hnswSearchParams) + .withQueryVectors(queries) + .build(); + + SearchResults results = hnswIndex.search(hnswQuery); + log.debug("HNSW search results: " + results.getResults().toString()); + + checkResults(expectedResults, results.getResults()); + log.debug("HNSW search verification passed"); + + hnswIndex.close(); + } + } + + // Clean up the default build directory + deleteRecursively(buildDir); + } + } + } + + /** + * Helper method to recursively delete a directory and its contents. + */ + private void deleteRecursively(Path path) { + try { + if (Files.isDirectory(path)) { + Files.list(path).forEach(this::deleteRecursively); + } + Files.deleteIfExists(path); + } catch (Exception e) { + log.warn("Failed to delete {}: {}", path, e.getMessage()); + } + } +} diff --git a/python/cuvs/cuvs/neighbors/cagra/__init__.py b/python/cuvs/cuvs/neighbors/cagra/__init__.py index 7e59e62ed0..ec70305d72 100644 --- a/python/cuvs/cuvs/neighbors/cagra/__init__.py +++ b/python/cuvs/cuvs/neighbors/cagra/__init__.py @@ -3,6 +3,7 @@ from .cagra import ( + AceParams, CompressionParams, ExtendParams, Index, @@ -17,6 +18,7 @@ ) __all__ = [ + "AceParams", "CompressionParams", "ExtendParams", "Index", diff --git a/python/cuvs/cuvs/neighbors/cagra/cagra.pxd b/python/cuvs/cuvs/neighbors/cagra/cagra.pxd index 9a8f167c27..44c1fa7aee 100644 --- a/python/cuvs/cuvs/neighbors/cagra/cagra.pxd +++ b/python/cuvs/cuvs/neighbors/cagra/cagra.pxd @@ -40,6 +40,7 @@ cdef extern from "cuvs/neighbors/cagra.h" nogil: IVF_PQ NN_DESCENT ITERATIVE_CAGRA_SEARCH + ACE ctypedef struct cuvsCagraCompressionParams: uint32_t pq_bits @@ -57,6 +58,13 @@ cdef extern from "cuvs/neighbors/cagra.h" nogil: float refinement_rate ctypedef cuvsIvfPqParams* cuvsIvfPqParams_t + ctypedef struct cuvsAceParams: + size_t npartitions + size_t ef_construction + const char* build_dir + bool use_disk + ctypedef cuvsAceParams* cuvsAceParams_t + ctypedef struct cuvsCagraIndexParams: cuvsDistanceType metric size_t intermediate_graph_degree @@ -64,7 +72,7 @@ cdef extern from "cuvs/neighbors/cagra.h" nogil: cuvsCagraGraphBuildAlgo build_algo size_t nn_descent_niter cuvsCagraCompressionParams_t compression - cuvsIvfPqParams_t graph_build_params + void* graph_build_params ctypedef cuvsCagraIndexParams* cuvsCagraIndexParams_t @@ -111,6 +119,10 @@ cdef extern from "cuvs/neighbors/cagra.h" nogil: cuvsError_t cuvsCagraCompressionParamsDestroy( cuvsCagraCompressionParams_t index) + cuvsError_t cuvsAceParamsCreate(cuvsAceParams_t* params) + + cuvsError_t cuvsAceParamsDestroy(cuvsAceParams_t params) + cuvsError_t cuvsCagraIndexParamsCreate(cuvsCagraIndexParams_t* params) cuvsError_t cuvsCagraIndexParamsDestroy(cuvsCagraIndexParams_t index) @@ -193,6 +205,7 @@ cdef class IndexParams: cdef public object compression cdef public object ivf_pq_build_params cdef public object ivf_pq_search_params + cdef public object ace_params cdef class SearchParams: cdef cuvsCagraSearchParams * params diff --git a/python/cuvs/cuvs/neighbors/cagra/cagra.pyx b/python/cuvs/cuvs/neighbors/cagra/cagra.pyx index bb59aa55dc..c7ce834f5a 100644 --- a/python/cuvs/cuvs/neighbors/cagra/cagra.pyx +++ b/python/cuvs/cuvs/neighbors/cagra/cagra.pyx @@ -17,7 +17,9 @@ from libcpp cimport bool, cast from libcpp.string cimport string from cuvs.common cimport cydlpack + from cuvs.common.device_tensor_view import DeviceTensorView + from cuvs.distance_type cimport cuvsDistanceType from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray @@ -37,6 +39,8 @@ from libc.stdint cimport ( uint64_t, uintptr_t, ) +from libc.stdlib cimport free, malloc +from libc.string cimport strdup from cuvs.common.exceptions import check_cuvs from cuvs.neighbors import ivf_pq @@ -119,6 +123,94 @@ cdef class CompressionParams: def get_handle(self): return self.params + +cdef class AceParams: + """ + Parameters for ACE (Augmented Core Extraction) graph building algorithm. + + ACE enables building indices for datasets too large to fit in GPU memory by + partitioning the dataset using balanced k-means and building sub-indices + for each partition independently. + + Parameters + ---------- + npartitions : int, default = 1 + Number of partitions for ACE partitioned build. Small values might + improve recall but potentially degrade performance and increase memory + usage. Partitions should not be too small to prevent issues in KNN + graph construction. 100k - 5M vectors per partition is recommended + depending on the available host and GPU memory. The partition size is + on average 2 * (n_rows / npartitions) * dim * sizeof(T). 2 is because + of the core and augmented vectors. Please account for imbalance in the + partition sizes (up to 3x in our tests). + ef_construction : int, default = 120 + The index quality for the ACE build. Bigger values increase the index + quality. At some point, increasing this will no longer improve the + quality. + build_dir : str, default = "/tmp/ace_build" + Directory to store ACE build artifacts (e.g., KNN graph, optimized + graph). Used when `use_disk` is true or when the graph does not fit + in host and GPU memory. This should be the fastest disk in the system + and hold enough space for twice the dataset, final graph, and label + mapping. + use_disk : bool, default = False + Whether to use disk-based storage for ACE build. When true, enables + disk-based operations for memory-efficient graph construction. + """ + cdef cuvsAceParams* params + cdef bytes _build_dir_bytes # Keep Python bytes alive for property access + + def __cinit__(self): + check_cuvs(cuvsAceParamsCreate(&self.params)) + self._build_dir_bytes = b"" + + def __dealloc__(self): + if self.params != NULL: + check_cuvs(cuvsAceParamsDestroy(self.params)) + + def __init__(self, *, + npartitions=1, + ef_construction=120, + build_dir="/tmp/ace_build", + use_disk=False): + self.params.npartitions = npartitions + self.params.ef_construction = ef_construction + self.params.use_disk = use_disk + + # Need to replace the default build_dir allocated by + # cuvsAceParamsCreate + # First free the old C string, then allocate new one + if self.params.build_dir != NULL: + free(self.params.build_dir) + + # Store Python bytes for property access + self._build_dir_bytes = build_dir.encode('utf-8') + # Allocate C memory and copy the string (strdup-like behavior) + self.params.build_dir = strdup(self._build_dir_bytes) + + @property + def npartitions(self): + return self.params.npartitions + + @property + def ef_construction(self): + return self.params.ef_construction + + @property + def build_dir(self): + if self._build_dir_bytes: + return self._build_dir_bytes.decode('utf-8') + else: + return "" + + @property + def use_disk(self): + return self.params.use_disk + + def get_handle(self): + return self.params + + cdef class IndexParams: """ Parameters to build index for CAGRA nearest neighbor search @@ -141,7 +233,7 @@ cdef class IndexParams: graph_degree : int, default = 64 build_algo: str, default = "ivf_pq" string denoting the graph building algorithm to use. Valid values for - algo: ["ivf_pq", "nn_descent", "iterative_cagra_search"], where + algo: ["ivf_pq", "nn_descent", "iterative_cagra_search", "ace"], where - ivf_pq will use the IVF-PQ algorithm for building the knn graph - nn_descent (experimental) will use the NN-Descent algorithm for @@ -149,6 +241,8 @@ cdef class IndexParams: faster than ivf_pq. - iterative_cagra_search will iteratively build the knn graph using CAGRA's search() and optimize() + - ace will use ACE (Augmented Core Extraction) for building indices + for datasets too large to fit in GPU memory compression: CompressionParams, optional If compression is desired should be a CompressionParams object. If None @@ -159,6 +253,9 @@ cdef class IndexParams: ivf_pq_search_params: cuvs.neighbors.ivf_pq.SearchParams, optional Parameters for IVF-PQ search. If provided, it will be used for searching the graph. + ace_params: AceParams, optional + Parameters for ACE algorithm. If provided, it will be used for + building the graph with ACE partitioning. refinement_rate: float, default = 1.0 """ @@ -168,6 +265,7 @@ cdef class IndexParams: self.compression = None self.ivf_pq_build_params = None self.ivf_pq_search_params = None + self.ace_params = None def __dealloc__(self): if self.params != NULL: @@ -182,7 +280,11 @@ cdef class IndexParams: compression=None, ivf_pq_build_params: ivf_pq.IndexParams = None, ivf_pq_search_params: ivf_pq.SearchParams = None, + ace_params: AceParams = None, refinement_rate: float = 1.0): + # Declare cdef variables at the top of the function + cdef cuvsIvfPqParams_t ivf_pq_params_ptr + cdef cuvsAceParams_t new_ace_params self.params.metric = DISTANCE_TYPES[metric] self.params.intermediate_graph_degree = intermediate_graph_degree @@ -194,6 +296,8 @@ cdef class IndexParams: elif build_algo == "iterative_cagra_search": self.params.build_algo = \ cuvsCagraGraphBuildAlgo.ITERATIVE_CAGRA_SEARCH + elif build_algo == "ace": + self.params.build_algo = cuvsCagraGraphBuildAlgo.ACE else: raise ValueError(f"Unknown build_algo '{build_algo}'") @@ -202,19 +306,56 @@ cdef class IndexParams: self.compression = compression self.params.compression = \ compression.get_handle() - if ivf_pq_build_params is not None: - if ivf_pq_build_params.metric != self.metric: - raise ValueError("Metric mismatch with IVF-PQ build params") - self.ivf_pq_build_params = ivf_pq_build_params - self.params.graph_build_params.ivf_pq_build_params = \ - \ - ivf_pq_build_params.get_handle() - if ivf_pq_search_params is not None: - self.ivf_pq_search_params = ivf_pq_search_params - self.params.graph_build_params.ivf_pq_search_params = \ - \ - ivf_pq_search_params.get_handle() - self.params.graph_build_params.refinement_rate = refinement_rate + + # Handle graph build params based on build algorithm + if build_algo == "ace": + if ace_params is None: + ace_params = AceParams() + self.ace_params = ace_params + + # Create a new C-allocated cuvsAceParams that the C API will own + # We cannot pass the Python object's pointer directly because + # cuvsCagraIndexParamsDestroy will try to delete it + check_cuvs(cuvsAceParamsCreate(&new_ace_params)) + + # Copy values from Python object to new C struct + new_ace_params.npartitions = ace_params.params.npartitions + new_ace_params.ef_construction = ace_params.params.ef_construction + new_ace_params.use_disk = ace_params.params.use_disk + + # Copy the build_dir string + if new_ace_params.build_dir != NULL: + free(new_ace_params.build_dir) + new_ace_params.build_dir = strdup(ace_params.params.build_dir) + + # Pass the new C struct to the index params + self.params.graph_build_params = new_ace_params + else: + # For IVF-PQ algorithm, handle ivf_pq params + # Cast the void* back to cuvsIvfPqParams_t + ivf_pq_params_ptr = ( + self.params.graph_build_params + ) + + if ivf_pq_build_params is not None: + if ivf_pq_build_params.metric != self.metric: + raise ValueError( + "Metric mismatch with IVF-PQ build params" + ) + self.ivf_pq_build_params = ivf_pq_build_params + ivf_pq_params_ptr.ivf_pq_build_params = ( + + ivf_pq_build_params.get_handle() + ) + + if ivf_pq_search_params is not None: + self.ivf_pq_search_params = ivf_pq_search_params + ivf_pq_params_ptr.ivf_pq_search_params = ( + + ivf_pq_search_params.get_handle() + ) + + ivf_pq_params_ptr.refinement_rate = refinement_rate def get_handle(self): return self.params @@ -241,7 +382,15 @@ cdef class IndexParams: @property def refinement_rate(self): - return self.params.graph_build_params.refinement_rate + # refinement_rate only applies to IVF-PQ builds + if self.params.build_algo == cuvsCagraGraphBuildAlgo.IVF_PQ: + return ( + (self.params.graph_build_params) + .refinement_rate + ) + else: + # For ACE and other algorithms, refinement_rate doesn't apply + return 1.0 cdef class Index: @@ -327,6 +476,10 @@ def build(IndexParams index_params, dataset, resources=None): It is required that both the dataset and the optimized graph fit the GPU memory. + Note: When using ACE (Augmented Core Extraction) build algorithm, the + dataset must be in host memory (CPU). The ACE algorithm is designed for + datasets too large to fit in GPU memory. + The following distance metrics are supported: - L2 - InnerProduct @@ -337,6 +490,8 @@ def build(IndexParams index_params, dataset, resources=None): index_params : IndexParams object dataset : CUDA array interface compliant matrix shape (n_samples, dim) Supported dtype [float, half, int8, uint8] + **Note:** For ACE build algorithm, the dataset MUST be in host memory. + Use NumPy arrays or call .get() on CuPy arrays before passing. {resources_docstring} Returns @@ -361,8 +516,28 @@ def build(IndexParams index_params, dataset, resources=None): ... k) >>> distances = cp.asarray(distances) >>> neighbors = cp.asarray(neighbors) + + >>> # ACE example with host data + >>> import numpy as np + >>> dataset_host = np.random.random_sample( + ... (n_samples, n_features) + ... ).astype(np.float32) + >>> ace_params = cagra.AceParams( + ... npartitions=4, use_disk=True, build_dir="/tmp/ace" + ... ) + >>> build_params = cagra.IndexParams( + ... metric="sqeuclidean", + ... build_algo="ace", + ... ace_params=ace_params + ... ) + >>> idx = cagra.build(build_params, dataset_host) """ + # Check if ACE build is requested + is_ace_build = ( + index_params.params.build_algo == cuvsCagraGraphBuildAlgo.ACE + ) + # todo(dgd): we can make the check of dtype a parameter of wrap_array # in RAFT to make this a single call dataset_ai = wrap_array(dataset) @@ -371,6 +546,16 @@ def build(IndexParams index_params, dataset, resources=None): np.dtype('byte'), np.dtype('ubyte')]) + # For ACE, verify dataset is on host + if is_ace_build: + # Check if data is on device (has __cuda_array_interface__) + if hasattr(dataset, '__cuda_array_interface__'): + raise ValueError( + "ACE build requires dataset to be in host memory. " + "Please use NumPy arrays or transfer CuPy arrays to host with " + "dataset.get() before calling build()." + ) + cdef Index idx = Index() cdef cydlpack.DLManagedTensor* dataset_dlpack = \ cydlpack.dlpack_c(dataset_ai) diff --git a/python/cuvs/cuvs/neighbors/hnsw/hnsw.pxd b/python/cuvs/cuvs/neighbors/hnsw/hnsw.pxd index 2db0902e18..399fc06d47 100644 --- a/python/cuvs/cuvs/neighbors/hnsw/hnsw.pxd +++ b/python/cuvs/cuvs/neighbors/hnsw/hnsw.pxd @@ -1,10 +1,11 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # # cython: language_level=3 from libc.stdint cimport int32_t, uintptr_t +from libcpp cimport bool from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t from cuvs.common.cydlpack cimport DLDataType, DLManagedTensor diff --git a/python/cuvs/cuvs/tests/test_cagra_ace.py b/python/cuvs/cuvs/tests/test_cagra_ace.py new file mode 100644 index 0000000000..d2d9ccf6a4 --- /dev/null +++ b/python/cuvs/cuvs/tests/test_cagra_ace.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +# + +import os +import tempfile + +import cupy as cp +import numpy as np +import pytest +from pylibraft.common import device_ndarray +from sklearn.neighbors import NearestNeighbors +from sklearn.preprocessing import normalize + +from cuvs.neighbors import cagra, hnsw +from cuvs.tests.ann_utils import calc_recall, generate_data + + +def run_cagra_ace_build_search_test( + n_rows=5000, + n_cols=64, + n_queries=10, + k=10, + dtype=np.float32, + metric="sqeuclidean", + intermediate_graph_degree=128, + graph_degree=64, + npartitions=2, + ef_construction=100, + use_disk=False, + hierarchy="none", +): + dataset = generate_data((n_rows, n_cols), dtype) + queries = generate_data((n_queries, n_cols), dtype) + if metric == "inner_product": + dataset = normalize(dataset, norm="l2", axis=1) + queries = normalize(queries, norm="l2", axis=1) + if dtype in [np.int8, np.uint8]: + # Quantize the normalized data to the int8/uint8 range + dtype_max = np.iinfo(dtype).max + dataset = (dataset * dtype_max).astype(dtype) + queries = (queries * dtype_max).astype(dtype) + + # Create a temporary directory for ACE build + with tempfile.TemporaryDirectory() as temp_dir: + # Set up ACE parameters + ace_params = cagra.AceParams( + npartitions=npartitions, + ef_construction=ef_construction, + build_dir=temp_dir, + use_disk=use_disk, + ) + + # Build parameters + build_params = cagra.IndexParams( + metric=metric, + intermediate_graph_degree=intermediate_graph_degree, + graph_degree=graph_degree, + build_algo="ace", + ace_params=ace_params, + ) + + # Build the index with ACE (uses host memory) + index = cagra.build(build_params, dataset) + + assert index.trained + + # For disk-based mode, we can't search directly + # (would need HNSW conversion which is tested separately) + if not use_disk: + # For in-memory mode, we can search directly + # But queries need to be on device + search_params = cagra.SearchParams(itopk_size=64) + + # Transfer queries to device for search + queries_device = device_ndarray(queries) + + out_dist, out_idx = cagra.search( + search_params, index, queries_device, k + ) + + # Convert results back to host + out_idx_host = out_idx.copy_to_host() + + # Calculate reference values with sklearn + skl_metric = { + "sqeuclidean": "sqeuclidean", + "inner_product": "cosine", + "euclidean": "euclidean", + }[metric] + nn_skl = NearestNeighbors( + n_neighbors=k, algorithm="brute", metric=skl_metric + ) + nn_skl.fit(dataset) + skl_idx = nn_skl.kneighbors(queries, return_distance=False) + + recall = calc_recall(out_idx_host, skl_idx) + assert recall > 0.7 + + # test that we can get the cagra graph from the index + graph = index.graph + assert graph.shape == (n_rows, graph_degree) + + # make sure we can convert the graph to cupy, and access it + cp_graph = cp.array(graph) + assert cp_graph.shape == (n_rows, graph_degree) + else: + # For disk-based mode, verify that expected files were created + assert os.path.exists(os.path.join(temp_dir, "cagra_graph.npy")) + assert os.path.exists( + os.path.join(temp_dir, "reordered_dataset.npy") + ) + assert os.path.exists( + os.path.join(temp_dir, "dataset_mapping.npy") + ) + + # Test HNSW conversion from disk-based ACE index + hnsw_params = hnsw.IndexParams(hierarchy=hierarchy) + hnsw_index_serialized = hnsw.from_cagra(hnsw_params, index) + assert hnsw_index_serialized is not None + assert os.path.exists(os.path.join(temp_dir, "hnsw_index.bin")) + + # Deserialize the HNSW index from disk for search + hnsw_index = hnsw.load( + hnsw_params, + os.path.join(temp_dir, "hnsw_index.bin"), + n_cols, + dtype, + ) + + search_params = hnsw.SearchParams(ef=200, num_threads=1) + out_dist, out_idx = hnsw.search( + search_params, hnsw_index, queries, k + ) + + # Calculate reference values with sklearn + skl_metric = { + "sqeuclidean": "sqeuclidean", + "inner_product": "cosine", + "euclidean": "euclidean", + }[metric] + nn_skl = NearestNeighbors( + n_neighbors=k, algorithm="brute", metric=skl_metric + ) + nn_skl.fit(dataset) + skl_dist, skl_idx = nn_skl.kneighbors( + queries, return_distance=True + ) + + recall = calc_recall(out_idx, skl_idx) + assert recall > 0.7 + + +@pytest.mark.parametrize("dim", [64, 128]) +@pytest.mark.parametrize("dtype", [np.float32, np.float16, np.int8, np.uint8]) +@pytest.mark.parametrize("metric", ["sqeuclidean", "inner_product"]) +@pytest.mark.parametrize("npartitions", [2, 4]) +@pytest.mark.parametrize("ef_construction", [100, 200]) +@pytest.mark.parametrize("use_disk", [False, True]) +@pytest.mark.parametrize("hierarchy", ["none", "gpu"]) +def test_cagra_ace_dtypes_and_metrics( + dim, dtype, metric, npartitions, ef_construction, use_disk, hierarchy +): + """Test ACE with different data types and metrics.""" + run_cagra_ace_build_search_test( + n_cols=dim, + dtype=dtype, + metric=metric, + npartitions=npartitions, + ef_construction=ef_construction, + use_disk=use_disk, + hierarchy=hierarchy, + ) From 42358721bb4c8c456213f4d27e055becca6eadc1 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Thu, 13 Nov 2025 16:14:43 -0600 Subject: [PATCH 08/32] Update RMM includes from `` to `` (#1538) This updates RMM memory resource includes to use the header path `` instead of ``. xref: https://github.com/rapidsai/rmm/issues/2141 Authors: - Bradley Dice (https://github.com/bdice) Approvers: - Divye Gala (https://github.com/divyegala) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuvs/pull/1538 --- c/src/core/c_api.cpp | 10 +++++----- cpp/bench/ann/src/common/cuda_huge_page_resource.hpp | 2 +- cpp/bench/ann/src/cuvs/cuvs_ann_bench_utils.h | 8 ++++---- cpp/bench/ann/src/cuvs/cuvs_benchmark.cu | 4 ++-- cpp/bench/ann/src/cuvs/cuvs_cagra_diskann.cu | 2 +- cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib.cu | 2 +- cpp/bench/ann/src/cuvs/cuvs_vamana.cu | 2 +- cpp/internal/cuvs_internal/neighbors/naive_knn.cuh | 4 ++-- cpp/src/cluster/detail/kmeans_balanced.cuh | 2 +- .../detail/cagra/search_single_cta_kernel-inl.cuh | 2 +- cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh | 2 +- cpp/src/neighbors/ivf_pq/ivf_pq_fp_8bit.cuh | 4 ++-- cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh | 2 +- cpp/tests/neighbors/ann_ivf_pq.cuh | 4 ++-- cpp/tests/neighbors/ann_scann.cuh | 2 +- cpp/tests/neighbors/ann_utils.cuh | 4 ++-- cpp/tests/neighbors/dynamic_batching.cuh | 4 ++-- cpp/tests/neighbors/naive_knn.cuh | 2 +- examples/cpp/src/brute_force_bitmap.cu | 2 +- examples/cpp/src/cagra_example.cu | 6 +++--- examples/cpp/src/cagra_persistent_example.cu | 4 ++-- examples/cpp/src/dynamic_batching_example.cu | 4 ++-- examples/cpp/src/ivf_flat_example.cu | 4 ++-- examples/cpp/src/ivf_pq_example.cu | 4 ++-- examples/cpp/src/scann_example.cu | 4 ++-- examples/cpp/src/vamana_example.cu | 4 ++-- 26 files changed, 47 insertions(+), 47 deletions(-) diff --git a/c/src/core/c_api.cpp b/c/src/core/c_api.cpp index afa59e88b7..7c0f0c7ebf 100644 --- a/c/src/core/c_api.cpp +++ b/c/src/core/c_api.cpp @@ -14,11 +14,11 @@ #include #include #include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include #include "../core/exceptions.hpp" diff --git a/cpp/bench/ann/src/common/cuda_huge_page_resource.hpp b/cpp/bench/ann/src/common/cuda_huge_page_resource.hpp index d06448fed9..8039187bde 100644 --- a/cpp/bench/ann/src/common/cuda_huge_page_resource.hpp +++ b/cpp/bench/ann/src/common/cuda_huge_page_resource.hpp @@ -8,7 +8,7 @@ #include #include -#include +#include #include diff --git a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_utils.h b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_utils.h index 53cbef1488..83cb7303c8 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_utils.h +++ b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_utils.h @@ -19,10 +19,10 @@ #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include diff --git a/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu b/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu index 37fc1bc56e..aebac654c2 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu @@ -1,12 +1,12 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #include "../common/ann_types.hpp" #include "cuvs_ann_bench_param_parser.h" -#include +#include #include #include diff --git a/cpp/bench/ann/src/cuvs/cuvs_cagra_diskann.cu b/cpp/bench/ann/src/cuvs/cuvs_cagra_diskann.cu index 5e047044f2..1521333c5e 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cagra_diskann.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_cagra_diskann.cu @@ -8,7 +8,7 @@ #include "cuvs_cagra_diskann_wrapper.h" #include -#include +#include #include namespace cuvs::bench { diff --git a/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib.cu b/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib.cu index 113e79fa15..26028b6d98 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib.cu @@ -8,7 +8,7 @@ #include "cuvs_cagra_hnswlib_wrapper.h" #include -#include +#include #include namespace cuvs::bench { diff --git a/cpp/bench/ann/src/cuvs/cuvs_vamana.cu b/cpp/bench/ann/src/cuvs/cuvs_vamana.cu index b9bb62b437..185095d5b4 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_vamana.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_vamana.cu @@ -7,7 +7,7 @@ #include "cuvs_vamana_wrapper.h" #include -#include +#include #include namespace cuvs::bench { diff --git a/cpp/internal/cuvs_internal/neighbors/naive_knn.cuh b/cpp/internal/cuvs_internal/neighbors/naive_knn.cuh index 5b4cb45618..6c7577065b 100644 --- a/cpp/internal/cuvs_internal/neighbors/naive_knn.cuh +++ b/cpp/internal/cuvs_internal/neighbors/naive_knn.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -13,7 +13,7 @@ #include #include #include -#include +#include namespace cuvs::neighbors { diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 51c631e709..2637e81517 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -35,7 +35,7 @@ #include #include -#include +#include #include #include diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 939589c942..96e0c419f2 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -33,7 +33,7 @@ #include #include -#include +#include #include #include diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index 912c32f0a2..50b8d117be 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -49,7 +49,7 @@ #include #include -#include +#include #include #include diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_fp_8bit.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_fp_8bit.cuh index 7f38342461..47b48b2555 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_fp_8bit.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_fp_8bit.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -17,7 +17,7 @@ #include #include -#include +#include #include #include diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh index bf0ec44a4b..7a25006aaa 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh @@ -40,7 +40,7 @@ #include #include -#include +#include #include #include diff --git a/cpp/tests/neighbors/ann_ivf_pq.cuh b/cpp/tests/neighbors/ann_ivf_pq.cuh index c919211b53..4660c5d0d3 100644 --- a/cpp/tests/neighbors/ann_ivf_pq.cuh +++ b/cpp/tests/neighbors/ann_ivf_pq.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -15,7 +15,7 @@ #include #include #include -#include +#include #include namespace cuvs::neighbors::ivf_pq { diff --git a/cpp/tests/neighbors/ann_scann.cuh b/cpp/tests/neighbors/ann_scann.cuh index 621bb7664b..d8bc27fed7 100644 --- a/cpp/tests/neighbors/ann_scann.cuh +++ b/cpp/tests/neighbors/ann_scann.cuh @@ -13,7 +13,7 @@ #include #include #include -#include +#include #include namespace cuvs::neighbors::experimental::scann { diff --git a/cpp/tests/neighbors/ann_utils.cuh b/cpp/tests/neighbors/ann_utils.cuh index ee9e86bcd3..8a908c0187 100644 --- a/cpp/tests/neighbors/ann_utils.cuh +++ b/cpp/tests/neighbors/ann_utils.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -16,7 +16,7 @@ #include #include -#include +#include #include "naive_knn.cuh" diff --git a/cpp/tests/neighbors/dynamic_batching.cuh b/cpp/tests/neighbors/dynamic_batching.cuh index d98c803bd1..9f54325b7a 100644 --- a/cpp/tests/neighbors/dynamic_batching.cuh +++ b/cpp/tests/neighbors/dynamic_batching.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -14,7 +14,7 @@ #include #include -#include +#include #include #include diff --git a/cpp/tests/neighbors/naive_knn.cuh b/cpp/tests/neighbors/naive_knn.cuh index b24e598166..0b549cfdb1 100644 --- a/cpp/tests/neighbors/naive_knn.cuh +++ b/cpp/tests/neighbors/naive_knn.cuh @@ -13,7 +13,7 @@ #include #include #include -#include +#include namespace cuvs::neighbors { diff --git a/examples/cpp/src/brute_force_bitmap.cu b/examples/cpp/src/brute_force_bitmap.cu index b67e19d3bd..73f59b1348 100644 --- a/examples/cpp/src/brute_force_bitmap.cu +++ b/examples/cpp/src/brute_force_bitmap.cu @@ -11,7 +11,7 @@ #include #include -#include +#include void load_dataset(const raft::device_resources& res, float* data_ptr, int n_vectors, int dim) { diff --git a/examples/cpp/src/cagra_example.cu b/examples/cpp/src/cagra_example.cu index 65cb66d9f8..7d5b6d867b 100644 --- a/examples/cpp/src/cagra_example.cu +++ b/examples/cpp/src/cagra_example.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -10,8 +10,8 @@ #include -#include -#include +#include +#include #include "common.cuh" diff --git a/examples/cpp/src/cagra_persistent_example.cu b/examples/cpp/src/cagra_persistent_example.cu index 310ce51ef5..1639ceb0db 100644 --- a/examples/cpp/src/cagra_persistent_example.cu +++ b/examples/cpp/src/cagra_persistent_example.cu @@ -9,8 +9,8 @@ #include #include #include -#include -#include +#include +#include #include #include diff --git a/examples/cpp/src/dynamic_batching_example.cu b/examples/cpp/src/dynamic_batching_example.cu index 66ed85179b..ff5e36959f 100644 --- a/examples/cpp/src/dynamic_batching_example.cu +++ b/examples/cpp/src/dynamic_batching_example.cu @@ -14,8 +14,8 @@ #include #include #include -#include -#include +#include +#include #include #include diff --git a/examples/cpp/src/ivf_flat_example.cu b/examples/cpp/src/ivf_flat_example.cu index 2097e817b5..404cd86e89 100644 --- a/examples/cpp/src/ivf_flat_example.cu +++ b/examples/cpp/src/ivf_flat_example.cu @@ -11,8 +11,8 @@ #include #include -#include -#include +#include +#include #include #include diff --git a/examples/cpp/src/ivf_pq_example.cu b/examples/cpp/src/ivf_pq_example.cu index 4fd1e6603e..9283fe373d 100644 --- a/examples/cpp/src/ivf_pq_example.cu +++ b/examples/cpp/src/ivf_pq_example.cu @@ -10,8 +10,8 @@ #include #include -#include -#include +#include +#include #include diff --git a/examples/cpp/src/scann_example.cu b/examples/cpp/src/scann_example.cu index d1f420ba82..dae80e8372 100644 --- a/examples/cpp/src/scann_example.cu +++ b/examples/cpp/src/scann_example.cu @@ -12,8 +12,8 @@ #include -#include -#include +#include +#include #include "common.cuh" diff --git a/examples/cpp/src/vamana_example.cu b/examples/cpp/src/vamana_example.cu index 3adce03d68..63d48bc197 100644 --- a/examples/cpp/src/vamana_example.cu +++ b/examples/cpp/src/vamana_example.cu @@ -11,8 +11,8 @@ #include -#include -#include +#include +#include #include "common.cuh" From 6fe96ce64ab485f840eed5e1ae5b6914564cfd84 Mon Sep 17 00:00:00 2001 From: Robert Maynard Date: Thu, 13 Nov 2025 20:25:41 -0500 Subject: [PATCH 09/32] Extend CI to build and test x86 libcuvs_c tarballs (#1524) Adds new `rocky8-clib-standalone-build` and `rocky8-clib-tests` PR jobs that validate that the C api binaries can be built and run all C tests correctly. Also adds a new nightly build job that produces the C api binaries. Authors: - Robert Maynard (https://github.com/robertmaynard) - Ben Frederickson (https://github.com/benfred) Approvers: - Jake Awe (https://github.com/AyodeAwe) - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cuvs/pull/1524 --- .github/workflows/build.yaml | 27 +++++++++- .github/workflows/pr.yaml | 102 +++++++++++++++++++++++++++++------ c/CMakeLists.txt | 9 ++++ c/tests/CMakeLists.txt | 5 +- ci/build_java.sh | 4 +- ci/build_standalone_c.sh | 90 +++++++++++++++++++++++++++++++ ci/test_standalone_c.sh | 37 +++++++++++++ cpp/CMakeLists.txt | 6 +-- 8 files changed, 256 insertions(+), 24 deletions(-) create mode 100755 ci/build_standalone_c.sh create mode 100755 ci/test_standalone_c.sh diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 29128311a7..e76c010f26 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -41,6 +41,29 @@ jobs: date: ${{ inputs.date }} script: ci/build_cpp.sh sha: ${{ inputs.sha }} + + rocky8-clib-standalone-build: + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main + strategy: + fail-fast: false + matrix: + cuda_version: + - &latest_cuda12 '12.9.1' + - &latest_cuda13 '13.0.2' + with: + build_type: ${{ inputs.build_type || 'branch' }} + branch: ${{ inputs.branch }} + arch: "amd64" + date: ${{ inputs.date }} + container_image: "rapidsai/ci-wheel:25.12-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" + node_type: "cpu16" + name: "${{ matrix.cuda_version }}, amd64, rockylinux8" + # requires_license_builder: false + script: "ci/build_standalone_c.sh" + artifact-name: "libcuvs_c_${{ matrix.cuda_version }}.tar.gz" + file_to_upload: "libcuvs_c.tar.gz" + sha: ${{ inputs.sha }} rust-build: needs: cpp-build secrets: inherit @@ -51,8 +74,8 @@ jobs: fail-fast: false matrix: cuda_version: - - &latest_cuda12 '12.9.1' - - &latest_cuda13 '13.0.2' + - *latest_cuda12 + - *latest_cuda13 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index a69bb19362..de56414fc9 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -17,10 +17,12 @@ jobs: - conda-cpp-checks - conda-python-build - conda-python-tests - - conda-java-tests - - docs-build + - rocky8-clib-standalone-build + - rocky8-clib-tests + - conda-java-build-and-tests - rust-build - go-build + - docs-build - wheel-build-libcuvs - wheel-build-cuvs - wheel-tests-cuvs @@ -103,6 +105,30 @@ jobs: - '!rust/**' - '!go/**' - '!thirdparty/LICENSES/**' + test_rust: + - '**' + - '!.devcontainer/**' + - '!.pre-commit-config.yaml' + - '!README.md' + - '!docs/**' + - '!img/**' + - '!notebooks/**' + - '!python/**' + - '!java/**' + - '!go/**' + - '!thirdparty/LICENSES/**' + test_go: + - '**' + - '!.devcontainer/**' + - '!.pre-commit-config.yaml' + - '!README.md' + - '!docs/**' + - '!img/**' + - '!notebooks/**' + - '!python/**' + - '!java/**' + - '!rust/**' + - '!thirdparty/LICENSES/**' checks: needs: telemetry-setup secrets: inherit @@ -148,41 +174,72 @@ jobs: with: build_type: pull-request script: ci/test_python.sh - conda-java-tests: - needs: [conda-cpp-build, changed-files] + rocky8-clib-standalone-build: + needs: [checks] secrets: inherit uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main - if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_java - # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. - # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. strategy: fail-fast: false matrix: cuda_version: - &latest_cuda12 '12.9.1' - &latest_cuda13 '13.0.2' + with: + build_type: pull-request + arch: "amd64" + date: ${{ inputs.date }}_c + container_image: "rapidsai/ci-wheel:25.12-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" + node_type: "cpu16" + # requires_license_builder: false + script: "ci/build_standalone_c.sh --build-tests" + artifact-name: "libcuvs_c_${{ matrix.cuda_version }}.tar.gz" + file_to_upload: "libcuvs_c.tar.gz" + sha: ${{ inputs.sha }} + rocky8-clib-tests: + needs: [rocky8-clib-standalone-build, changed-files] + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main + if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_cpp + strategy: + fail-fast: false + matrix: + cuda_version: + - *latest_cuda12 + - *latest_cuda13 with: build_type: pull-request node_type: "gpu-l4-latest-1" arch: "amd64" - container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" - script: "ci/test_java.sh" - artifact-name: "cuvs-java-cuda${{ matrix.cuda_version }}" - file_to_upload: "java/cuvs-java/target/" - docs-build: - needs: conda-python-build + date: ${{ inputs.date }}_c + container_image: "rapidsai/ci-wheel:25.12-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" + script: "ci/test_standalone_c.sh" + sha: ${{ inputs.sha }} + conda-java-build-and-tests: + needs: [conda-cpp-build, changed-files] secrets: inherit uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main + if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_java || fromJSON(needs.changed-files.outputs.changed_file_groups).test_cpp + # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. + # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. + strategy: + fail-fast: false + matrix: + cuda_version: + - *latest_cuda12 + - *latest_cuda13 with: build_type: pull-request node_type: "gpu-l4-latest-1" arch: "amd64" - container_image: "rapidsai/ci-conda:25.12-latest" - script: "ci/build_docs.sh" + container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + script: "ci/test_java.sh" + artifact-name: "cuvs-java-cuda${{ matrix.cuda_version }}" + file_to_upload: "java/cuvs-java/target/" rust-build: - needs: conda-cpp-build + needs: [conda-cpp-build, changed-files] secrets: inherit uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main + if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_rust || fromJSON(needs.changed-files.outputs.changed_file_groups).test_cpp # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. strategy: @@ -198,9 +255,10 @@ jobs: container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" script: "ci/build_rust.sh" go-build: - needs: conda-cpp-build + needs: [conda-cpp-build, changed-files] secrets: inherit uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main + if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_go || fromJSON(needs.changed-files.outputs.changed_file_groups).test_cpp # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. strategy: @@ -215,6 +273,16 @@ jobs: arch: "amd64" container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" script: "ci/build_go.sh" + docs-build: + needs: conda-python-build + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main + with: + build_type: pull-request + node_type: "gpu-l4-latest-1" + arch: "amd64" + container_image: "rapidsai/ci-conda:25.12-latest" + script: "ci/build_docs.sh" wheel-build-libcuvs: needs: checks secrets: inherit diff --git a/c/CMakeLists.txt b/c/CMakeLists.txt index 1608a5598d..5c66cad9fe 100644 --- a/c/CMakeLists.txt +++ b/c/CMakeLists.txt @@ -64,6 +64,15 @@ if(BUILD_CAGRA_HNSWLIB) include(../cpp/cmake/thirdparty/get_hnswlib.cmake) endif() +if(BUILD_MG_ALGOS AND CUVSC_STATIC_CUVS_LIBRARY) + rapids_find_generate_module( + NCCL + HEADER_NAMES nccl.h + LIBRARY_NAMES nccl + ) + find_package(NCCL REQUIRED) +endif() + # ################################################################################################## # * cuvs_c ------------------------------------------------------------------------------- add_library( diff --git a/c/tests/CMakeLists.txt b/c/tests/CMakeLists.txt index 80152da986..0218f8d4b8 100644 --- a/c/tests/CMakeLists.txt +++ b/c/tests/CMakeLists.txt @@ -15,6 +15,9 @@ if(PROJECT_IS_TOP_LEVEL) rapids_test_init() endif() +rapids_cmake_install_lib_dir(lib_dir) +include(GNUInstallDirs) + include(${rapids-cmake-dir}/cpm/gtest.cmake) rapids_cpm_gtest(BUILD_STATIC) @@ -46,7 +49,7 @@ function(ConfigureTest) set_target_properties( ${TEST_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "$" - INSTALL_RPATH "\$ORIGIN/../../../lib" + INSTALL_RPATH "\$ORIGIN/../../../${lib_dir}" ) target_include_directories( diff --git a/ci/build_java.sh b/ci/build_java.sh index 9ad38a22d0..d5352910f8 100755 --- a/ci/build_java.sh +++ b/ci/build_java.sh @@ -12,7 +12,9 @@ if [[ "${1:-}" == "--run-java-tests" ]]; then EXTRA_BUILD_ARGS+=("--run-java-tests") fi -. /opt/conda/etc/profile.d/conda.sh +if [ -e "/opt/conda/etc/profile.d/conda.sh" ]; then + . /opt/conda/etc/profile.d/conda.sh +fi rapids-logger "Downloading artifacts from previous jobs" CPP_CHANNEL=$(rapids-download-conda-from-github cpp) diff --git a/ci/build_standalone_c.sh b/ci/build_standalone_c.sh new file mode 100755 index 0000000000..88043b10ad --- /dev/null +++ b/ci/build_standalone_c.sh @@ -0,0 +1,90 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 + +set -euo pipefail + +TOOLSET_VERSION=14 +CMAKE_VERSION=3.31.8 +CMAKE_ARCH=x86_64 + +BUILD_C_LIB_TESTS="OFF" +if [[ "${1:-}" == "--build-tests" ]]; then + BUILD_C_LIB_TESTS="ON" +fi + +dnf install -y \ + patch \ + tar \ + make + +# Fetch and install CMake. +if [ ! -e "/usr/local/bin/cmake" ]; then + pushd /usr/local + wget --quiet https://github.com/Kitware/CMake/releases/download/v"${CMAKE_VERSION}"/cmake-"${CMAKE_VERSION}"-linux-"${CMAKE_ARCH}".tar.gz + tar zxf cmake-"${CMAKE_VERSION}"-linux-"${CMAKE_ARCH}".tar.gz + rm cmake-"${CMAKE_VERSION}"-linux-"${CMAKE_ARCH}".tar.gz + ln -s /usr/local/cmake-"${CMAKE_VERSION}"-linux-"${CMAKE_ARCH}"/bin/cmake /usr/local/bin/cmake + popd +fi + +source rapids-configure-sccache + +source rapids-date-string + +rapids-print-env + +rapids-logger "Begin cpp build" + +sccache --zero-stats + + +RAPIDS_PACKAGE_VERSION=$(rapids-generate-version) +export RAPIDS_PACKAGE_VERSION + +RAPIDS_ARTIFACTS_DIR=${RAPIDS_ARTIFACTS_DIR:-"${PWD}/artifacts"} +mkdir -p "${RAPIDS_ARTIFACTS_DIR}" +export RAPIDS_ARTIFACTS_DIR + +scl enable gcc-toolset-${TOOLSET_VERSION} -- \ + cmake -S cpp -B cpp/build/ \ + -DCMAKE_CUDA_HOST_COMPILER=/opt/rh/gcc-toolset-${TOOLSET_VERSION}/root/usr/bin/gcc \ + -DCMAKE_CUDA_ARCHITECTURES=RAPIDS \ + -DBUILD_SHARED_LIBS=OFF \ + -DCUTLASS_ENABLE_TESTS=OFF \ + -DDISABLE_OPENMP=OFF \ + -DBUILD_TESTS=OFF \ + -DBUILD_SHARED_LIBS=ON \ + -DCUVS_STATIC_RAPIDS_LIBRARIES=ON +cmake --build cpp/build "-j${PARALLEL_LEVEL}" + +rapids-logger "Begin c build" + +scl enable gcc-toolset-${TOOLSET_VERSION} -- \ + cmake -S c -B c/build \ + -DCMAKE_CUDA_HOST_COMPILER=/opt/rh/gcc-toolset-${TOOLSET_VERSION}/root/usr/bin/gcc \ + -DCUVSC_STATIC_CUVS_LIBRARY=ON \ + -DCMAKE_PREFIX_PATH="$PWD/cpp/build/" \ + -DBUILD_TESTS=${BUILD_C_LIB_TESTS} +cmake --build c/build "-j${PARALLEL_LEVEL}" + +rapids-logger "Begin c install" +cmake --install c/build --prefix c/build/install + +# need to install the tests +if [ "${BUILD_C_LIB_TESTS}" != "OFF" ]; then + cmake --install c/build --prefix c/build/install --component testing +fi + + +rapids-logger "Begin gathering licenses" +cp LICENSE c/build/install/ +if [ -e "./tool/extract_licenses_via_spdx.py" ]; then + python ./tool/extract_licenses_via_spdx.py "." --with-licenses >> c/build/install/LICENSE +fi + +rapids-logger "Begin c tarball creation" +tar czf libcuvs_c.tar.gz -C c/build/install/ . +ls -lh libcuvs_c.tar.gz + +sccache --show-adv-stats diff --git a/ci/test_standalone_c.sh b/ci/test_standalone_c.sh new file mode 100755 index 0000000000..123f14a061 --- /dev/null +++ b/ci/test_standalone_c.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 + +set -euo pipefail + +CMAKE_VERSION=4.1.2 +CMAKE_ARCH=x86_64 + +# Fetch and install CMake. +if [ ! -e "/usr/local/bin/cmake" ]; then + pushd /usr/local + wget --quiet https://github.com/Kitware/CMake/releases/download/v"${CMAKE_VERSION}"/cmake-"${CMAKE_VERSION}"-linux-"${CMAKE_ARCH}".tar.gz + tar zxf cmake-"${CMAKE_VERSION}"-linux-"${CMAKE_ARCH}".tar.gz + rm cmake-"${CMAKE_VERSION}"-linux-"${CMAKE_ARCH}".tar.gz + ln -s /usr/local/cmake-"${CMAKE_VERSION}"-linux-"${CMAKE_ARCH}"/bin/cmake /usr/local/bin/cmake + popd +fi + +# Download the standalone C library artifact +payload_name="libcuvs_c_${RAPIDS_CUDA_VERSION}.tar.gz" +pkg_name="libcuvs_c.tar.gz" +rapids-logger "Download ${payload_name} artifacts from previous jobs" +DOWNLOAD_LOCATION=$(rapids-download-from-github "${payload_name}") + +# Extract the artifact to a staging directory +INSTALL_PREFIX="${PWD}/libcuvs_c_install" +mkdir -p "${INSTALL_PREFIX}" +ls -l "${DOWNLOAD_LOCATION}" +tar -xf "${DOWNLOAD_LOCATION}/${pkg_name}" -C "${INSTALL_PREFIX}" + +rapids-logger "Run C API tests" +ls -l "${INSTALL_PREFIX}" +cd "$INSTALL_PREFIX"/bin/gtests/libcuvs +ctest -j8 --output-on-failure + +rapids-logger "C API tests completed successfully" diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 31f7227ae4..c5c3cc86ac 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -690,10 +690,10 @@ SECTIONS PUBLIC rmm::rmm raft::raft ${CUVS_CTK_MATH_DEPENDENCIES} - $> - $> - $<$:CUDA::nvtx3> + $ # needs to be public for DT_NEEDED + $> # header only PRIVATE nvidia::cutlass::cutlass $ + $<$:CUDA::nvtx3> ) endif() From c1ea376b4a63ea21887eb82dcc8f319bfb73dcac Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Sat, 15 Nov 2025 15:22:11 -0500 Subject: [PATCH 10/32] Use ruff-check, ruff-format instead of black, flake8 (#1500) Issue: https://github.com/rapidsai/build-planning/issues/130 Ops-Bot-Merge-Barrier: true Authors: - Kyle Edwards (https://github.com/KyleFromNVIDIA) Approvers: - Gil Forsyth (https://github.com/gforsyth) - Bradley Dice (https://github.com/bdice) - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/cuvs/pull/1500 --- .flake8 | 25 -- .pre-commit-config.yaml | 24 +- cpp/scripts/analyze_nvcc_log.py | 92 +++-- cpp/scripts/gitutils.py | 103 +++--- .../select_k/algorithm_selection.ipynb | 13 +- .../select_k/generate_heuristic.ipynb | 87 +++-- .../heuristics/select_k/generate_plots.ipynb | 31 +- cpp/scripts/include_checker.py | 29 +- cpp/scripts/run-clang-compile.py | 182 ++++++---- cpp/scripts/run-clang-tidy.py | 226 ++++++++----- .../pairwise_matrix/dispatch_00_generate.py | 103 +++--- .../ball_cover/registers_00_generate.py | 14 +- .../cagra/compute_distance_00_generate.py | 67 ++-- .../cagra/search_multi_cta_00_generate.py | 6 +- .../cagra/search_single_cta_00_generate.py | 6 +- cpp/src/neighbors/iface/generate_iface.py | 42 +-- .../neighbors/ivf_flat/generate_ivf_flat.py | 2 +- .../ivf_pq/detail/generate_ivf_pq.py | 2 +- .../generate_ivf_pq_compute_similarity.py | 68 +++- cpp/src/neighbors/mg/generate_mg.py | 38 ++- docs/source/conf.py | 14 +- docs/source/sphinxext/github_link.py | 3 - .../VectorSearch_QuestionRetrieval.ipynb | 67 ++-- ...ectorSearch_QuestionRetrieval_Milvus.ipynb | 256 +++++++++----- notebooks/cuvs_hpo_example.ipynb | 156 +++++---- notebooks/ivf_flat_example.ipynb | 153 +++++---- notebooks/tutorial_ivf_pq.ipynb | 316 ++++++++++++------ notebooks/utils.py | 3 +- pyproject.toml | 41 ++- python/cuvs/cuvs/tests/test_cagra.py | 1 - python/cuvs/cuvs/tests/test_mg_cagra.py | 6 +- python/cuvs/cuvs/tests/test_mg_ivf_flat.py | 6 +- python/cuvs/cuvs/tests/test_mg_ivf_pq.py | 6 +- python/cuvs/cuvs/tests/test_refine.py | 3 +- python/cuvs_bench/cuvs_bench/plot/__main__.py | 1 - .../cuvs_bench/cuvs_bench/run/data_export.py | 11 +- .../cuvs_bench/cuvs_bench/tests/test_cli.py | 36 +- 37 files changed, 1395 insertions(+), 844 deletions(-) delete mode 100644 .flake8 diff --git a/.flake8 b/.flake8 deleted file mode 100644 index f9bda5354f..0000000000 --- a/.flake8 +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION. -# SPDX-License-Identifier: Apache-2.0 - -[flake8] -filename = *.py, *.pyx, *.pxd, *.pxi -exclude = __init__.py, *.egg, build, docs, .git -force-check = True -ignore = - # line break before binary operator - W503, - # whitespace before : - E203 -per-file-ignores = - # Rules ignored only in Cython: - # E211: whitespace before '(' (used in multi-line imports) - # E225: Missing whitespace around operators (breaks cython casting syntax like ) - # E226: Missing whitespace around arithmetic operators (breaks cython pointer syntax like int*) - # E227: Missing whitespace around bitwise or shift operator (Can also break casting syntax) - # E275: Missing whitespace after keyword (Doesn't work with Cython except?) - # E402: invalid syntax (works for Python, not Cython) - # E999: invalid syntax (works for Python, not Cython) - # W504: line break after binary operator (breaks lines that end with a pointer) - *.pyx: E211, E225, E226, E227, E275, E402, E999, W504 - *.pxd: E211, E225, E226, E227, E275, E402, E999, W504 - *.pxi: E211, E225, E226, E227, E275, E402, E999, W504 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b0c510d532..d5b622c061 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,23 +17,13 @@ repos: # project can specify its own first/third-party packages. args: ["--config-root=python/", "--resolve-all-configs"] files: python/.* - types_or: [python, cython, pyi] - - repo: https://github.com/psf/black - rev: 22.3.0 + types: [cython] + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.3 hooks: - - id: black - files: python/.* - # Explicitly specify the pyproject.toml at the repo root, not per-project. - args: ["--config", "pyproject.toml"] - - repo: https://github.com/PyCQA/flake8 - rev: 7.1.1 - hooks: - - id: flake8 - args: ["--config=.flake8"] - files: python/.*$ - types: [file] - types_or: [python, cython] - additional_dependencies: ["flake8-force"] + - id: ruff-check + args: [--fix] + - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy rev: 'v0.971' hooks: @@ -110,7 +100,7 @@ repos: ^CHANGELOG[.]md$| ^cpp/cmake/patches/cutlass/build-export[.]patch$ - repo: https://github.com/rapidsai/pre-commit-hooks - rev: v1.2.0 + rev: v1.2.1 hooks: - id: verify-copyright name: verify-copyright-cuvs diff --git a/cpp/scripts/analyze_nvcc_log.py b/cpp/scripts/analyze_nvcc_log.py index 823c4f8e3e..936b5b1751 100755 --- a/cpp/scripts/analyze_nvcc_log.py +++ b/cpp/scripts/analyze_nvcc_log.py @@ -1,15 +1,15 @@ #!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import sys import pandas as pd -import numpy as np import matplotlib.pyplot as plt import seaborn as sns from pathlib import Path from matplotlib import colors + def main(input_path): input_path = Path(input_path) print("-- loading data") @@ -22,40 +22,56 @@ def main(input_path): df["file"] = df["source file name"] df["phase"] = df["phase name"].str.strip() - dfp = (df - # Remove nvcc driver entries. They don't contain a source file name - .query("phase!='nvcc (driver)'") - # Make a pivot table containing files as row, phase (preprocessing, - # cicc, etc.) as column and the total times as table entries. NOTE: - # if compiled for multiple archs, the archs will be summed. - .pivot_table(index="file", values="seconds", columns="phase", aggfunc='sum')) + dfp = ( + df + # Remove nvcc driver entries. They don't contain a source file name + .query("phase!='nvcc (driver)'") + # Make a pivot table containing files as row, phase (preprocessing, + # cicc, etc.) as column and the total times as table entries. NOTE: + # if compiled for multiple archs, the archs will be summed. + .pivot_table( + index="file", values="seconds", columns="phase", aggfunc="sum" + ) + ) dfp_sum = dfp.sum(axis="columns") df_fraction = dfp.divide(dfp_sum, axis="index") df_fraction["total time"] = dfp_sum - df_fraction = df_fraction.melt(ignore_index=False, id_vars="total time", var_name="phase", value_name="fraction") + df_fraction = df_fraction.melt( + ignore_index=False, + id_vars="total time", + var_name="phase", + value_name="fraction", + ) dfp["total time"] = dfp_sum - df_absolute = dfp.melt(ignore_index=False, id_vars="total time", var_name="phase", value_name="seconds") + df_absolute = dfp.melt( + ignore_index=False, + id_vars="total time", + var_name="phase", + value_name="seconds", + ) # host: light red to dark red (preprocessing, cudafe, gcc (compiling)) # device: ligt green to dark green (preprocessing, cicc, ptxas) palette = { "gcc (preprocessing 4)": colors.hsv_to_rgb((0, 1, 1)), - 'cudafe++': colors.hsv_to_rgb((0, 1, .75)), - 'gcc (compiling)': colors.hsv_to_rgb((0, 1, .4)), - "gcc (preprocessing 1)": colors.hsv_to_rgb((.33, 1, 1)), - 'cicc': colors.hsv_to_rgb((.33, 1, 0.75)), - 'ptxas': colors.hsv_to_rgb((.33, 1, 0.4)), - 'fatbinary': "grey", + "cudafe++": colors.hsv_to_rgb((0, 1, 0.75)), + "gcc (compiling)": colors.hsv_to_rgb((0, 1, 0.4)), + "gcc (preprocessing 1)": colors.hsv_to_rgb((0.33, 1, 1)), + "cicc": colors.hsv_to_rgb((0.33, 1, 0.75)), + "ptxas": colors.hsv_to_rgb((0.33, 1, 0.4)), + "fatbinary": "grey", } print("-- Ten longest translation units:") - colwidth = pd.get_option('display.max_colwidth') - 1 + colwidth = pd.get_option("display.max_colwidth") - 1 dfp = dfp.reset_index() dfp["file"] = dfp["file"].apply(lambda s: s[-colwidth:]) - print(dfp.sort_values("total time", ascending=False).reset_index().loc[:10]) + print( + dfp.sort_values("total time", ascending=False).reset_index().loc[:10] + ) print("-- Plotting absolute compile times") abs_out_path = f"{input_path}.absolute.compile_times.png" @@ -64,43 +80,57 @@ def main(input_path): y="file", hue="phase", hue_order=reversed( - ["gcc (preprocessing 4)", 'cudafe++', 'gcc (compiling)', - "gcc (preprocessing 1)", 'cicc', 'ptxas', - 'fatbinary', - ]), + [ + "gcc (preprocessing 4)", + "cudafe++", + "gcc (compiling)", + "gcc (preprocessing 1)", + "cicc", + "ptxas", + "fatbinary", + ] + ), palette=palette, weights="seconds", multiple="stack", kind="hist", height=20, ) - plt.xlabel("seconds"); + plt.xlabel("seconds") plt.savefig(abs_out_path) print(f"-- Wrote absolute compile time plot to {abs_out_path}") print("-- Plotting relative compile times") rel_out_path = f"{input_path}.relative.compile_times.png" sns.displot( - df_fraction.sort_values('total time').reset_index(), + df_fraction.sort_values("total time").reset_index(), y="file", hue="phase", - hue_order=reversed(["gcc (preprocessing 4)", 'cudafe++', 'gcc (compiling)', - "gcc (preprocessing 1)", 'cicc', 'ptxas', - 'fatbinary', - ]), + hue_order=reversed( + [ + "gcc (preprocessing 4)", + "cudafe++", + "gcc (compiling)", + "gcc (preprocessing 1)", + "cicc", + "ptxas", + "fatbinary", + ] + ), palette=palette, weights="fraction", multiple="stack", kind="hist", height=15, ) - plt.xlabel("fraction"); + plt.xlabel("fraction") plt.savefig(rel_out_path) print(f"-- Wrote relative compile time plot to {rel_out_path}") + if __name__ == "__main__": if len(sys.argv) != 2: - printf("""NVCC log analyzer + print("""NVCC log analyzer Analyzes nvcc logs and outputs a figure with highest ranking translation units. diff --git a/cpp/scripts/gitutils.py b/cpp/scripts/gitutils.py index 99fe5de676..7b1cef1f74 100644 --- a/cpp/scripts/gitutils.py +++ b/cpp/scripts/gitutils.py @@ -55,17 +55,20 @@ def repo_version_major_minor(): full_repo_version = repo_version() - match = re.match(r"^v?(?P[0-9]+)(?:\.(?P[0-9]+))?", - full_repo_version) - - if (match is None): - print(" [DEBUG] Could not determine repo major minor version. " - f"Full repo version: {full_repo_version}.") + match = re.match( + r"^v?(?P[0-9]+)(?:\.(?P[0-9]+))?", full_repo_version + ) + + if match is None: + print( + " [DEBUG] Could not determine repo major minor version. " + f"Full repo version: {full_repo_version}." + ) return None out_version = match.group("major") - if (match.group("minor")): + if match.group("minor"): out_version += "." + match.group("minor") return out_version @@ -91,44 +94,50 @@ def determine_merge_commit(current_branch="HEAD"): try: # Try to determine the target branch from the most recent tag - head_branch = __git("describe", - "--all", - "--tags", - "--match='branch-*'", - "--abbrev=0") + head_branch = __git( + "describe", "--all", "--tags", "--match='branch-*'", "--abbrev=0" + ) except subprocess.CalledProcessError: - print(" [DEBUG] Could not determine target branch from most recent " - "tag. Falling back to 'branch-{major}.{minor}.") + print( + " [DEBUG] Could not determine target branch from most recent " + "tag. Falling back to 'branch-{major}.{minor}." + ) head_branch = None - if (head_branch is not None): + if head_branch is not None: # Convert from head to branch name head_branch = __git("name-rev", "--name-only", head_branch) else: # Try and guess the target branch as "branch-." version = repo_version_major_minor() - if (version is None): + if version is None: return None head_branch = "branch-{}".format(version) try: # Now get the remote tracking branch - remote_branch = __git("rev-parse", - "--abbrev-ref", - "--symbolic-full-name", - head_branch + "@{upstream}") + remote_branch = __git( + "rev-parse", + "--abbrev-ref", + "--symbolic-full-name", + head_branch + "@{upstream}", + ) except subprocess.CalledProcessError: - print(" [DEBUG] Could not remote tracking reference for " - f"branch {head_branch}.") + print( + " [DEBUG] Could not remote tracking reference for " + f"branch {head_branch}." + ) remote_branch = None - if (remote_branch is None): + if remote_branch is None: return None - print(f" [DEBUG] Determined TARGET_BRANCH as: '{remote_branch}'. " - "Finding common ancestor.") + print( + f" [DEBUG] Determined TARGET_BRANCH as: '{remote_branch}'. " + "Finding common ancestor." + ) common_commit = __git("merge-base", remote_branch, current_branch) @@ -166,9 +175,9 @@ def changedFilesBetween(baseName, branchName, commitHash): # checkout latest commit from branch __git("checkout", "-fq", commitHash) - files = __gitdiff("--name-only", - "--ignore-submodules", - f"{baseName}..{branchName}") + files = __gitdiff( + "--name-only", "--ignore-submodules", f"{baseName}..{branchName}" + ) # restore the original branch __git("checkout", "--force", current) @@ -180,13 +189,15 @@ def changesInFileBetween(file, b1, b2, filter=None): current = branch() __git("checkout", "--quiet", b1) __git("checkout", "--quiet", b2) - diffs = __gitdiff("--ignore-submodules", - "-w", - "--minimal", - "-U0", - "%s...%s" % (b1, b2), - "--", - file) + diffs = __gitdiff( + "--ignore-submodules", + "-w", + "--minimal", + "-U0", + "%s...%s" % (b1, b2), + "--", + file, + ) __git("checkout", "--quiet", current) lines = [] for line in diffs.splitlines(): @@ -215,25 +226,29 @@ def modifiedFiles(pathFilter=None): currentBranch = branch() print( f" [DEBUG] TARGET_BRANCH={targetBranch}, COMMIT_HASH={commitHash}, " - f"currentBranch={currentBranch}") + f"currentBranch={currentBranch}" + ) if targetBranch and commitHash and (currentBranch == "current-pr-branch"): print(" [DEBUG] Assuming a CI environment.") allFiles = changedFilesBetween(targetBranch, currentBranch, commitHash) else: - print(" [DEBUG] Did not detect CI environment. " - "Determining TARGET_BRANCH locally.") + print( + " [DEBUG] Did not detect CI environment. " + "Determining TARGET_BRANCH locally." + ) common_commit = determine_merge_commit(currentBranch) - if (common_commit is not None): - + if common_commit is not None: # Now get the diff. Use --staged to get both diff between # common_commit..HEAD and any locally staged files - allFiles = __gitdiff("--name-only", - "--ignore-submodules", - "--staged", - f"{common_commit}").splitlines() + allFiles = __gitdiff( + "--name-only", + "--ignore-submodules", + "--staged", + f"{common_commit}", + ).splitlines() else: # Fallback to just uncommitted files allFiles = uncommittedFiles() diff --git a/cpp/scripts/heuristics/select_k/algorithm_selection.ipynb b/cpp/scripts/heuristics/select_k/algorithm_selection.ipynb index c56281ef58..a0be1de932 100644 --- a/cpp/scripts/heuristics/select_k/algorithm_selection.ipynb +++ b/cpp/scripts/heuristics/select_k/algorithm_selection.ipynb @@ -247,12 +247,13 @@ "source": [ "from collections import Counter\n", "\n", + "\n", "def rank_algos(df, use_relative_speedup=False):\n", " _, y, weights = get_dataset(df)\n", " times = Counter()\n", " for algo, speedup in zip(y, weights):\n", " times[algo] += speedup if use_relative_speedup else 1\n", - " return sorted(times.items(), key=lambda x:-x[-1])" + " return sorted(times.items(), key=lambda x: -x[-1])" ] }, { @@ -343,7 +344,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "id": "bc0a10ea-652b-4822-8587-514c8f0348c3", "metadata": { "tags": [] @@ -382,11 +383,11 @@ "# well over diverse inputs.\n", "#\n", "# note: the lowest performing algorithm here might actually be pretty good, but\n", - "# just not provide much benefit over another similar algorithm. \n", - "# As an example, kWarpDistributed is an excellent selection algorithm, but in testing \n", - "# kWarpDistributedShm is slightly faster than it in situations where it does well, \n", + "# just not provide much benefit over another similar algorithm.\n", + "# As an example, kWarpDistributed is an excellent selection algorithm, but in testing\n", + "# kWarpDistributedShm is slightly faster than it in situations where it does well,\n", "# meaning that it gets removed early on in this loop\n", - "current = df[df.use_memory_pool == True]\n", + "current = df[df.use_memory_pool == True] # noqa: E712\n", "algos = set(df.algo)\n", "\n", "# we're arbitrarily getting this down to 3 selection algorithms\n", diff --git a/cpp/scripts/heuristics/select_k/generate_heuristic.ipynb b/cpp/scripts/heuristics/select_k/generate_heuristic.ipynb index 50bc12556a..941567b826 100644 --- a/cpp/scripts/heuristics/select_k/generate_heuristic.ipynb +++ b/cpp/scripts/heuristics/select_k/generate_heuristic.ipynb @@ -30,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "56765f40-96ce-46c6-bce8-ab782cd72b6e", "metadata": { "tags": [] @@ -245,12 +245,21 @@ "# load up the timings from the MATRIX_BENCH script into a pandas dataframe\n", "df = load_dataframe(\"select_k_times.json\")\n", "\n", - "# we're limiting down to 3 different select_k methods - chosen by \n", + "# we're limiting down to 3 different select_k methods - chosen by\n", "# the 'algorithm_selection.ipynb' script here\n", - "df = df[df.algo.isin(['kWarpImmediate', 'kRadix11bitsExtraPass', 'kRadix11bits', 'kWarpDistributedShm'])]\n", + "df = df[\n", + " df.algo.isin(\n", + " [\n", + " \"kWarpImmediate\",\n", + " \"kRadix11bitsExtraPass\",\n", + " \"kRadix11bits\",\n", + " \"kWarpDistributedShm\",\n", + " ]\n", + " )\n", + "]\n", "\n", "# we're also assuming we have a memory pool for now\n", - "df = df[(df.use_memory_pool == True)]\n", + "df = df[(df.use_memory_pool == True)] # noqa: E712\n", "# df = df[(df.index_type == 'int64_t') & (df.key_type == 'float')]\n", "\n", "df" @@ -278,7 +287,9 @@ "source": [ "# break down into a train/set set\n", "X, y, weights = get_dataset(df)\n", - "train_test_sets = sklearn.model_selection.train_test_split(X, y, weights, test_size=0.15, random_state=1)\n", + "train_test_sets = sklearn.model_selection.train_test_split(\n", + " X, y, weights, test_size=0.15, random_state=1\n", + ")\n", "X_train, X_test, y_train, y_test, weights_train, weights_test = train_test_sets\n", "X_train.shape, X_test.shape" ] @@ -307,7 +318,7 @@ ], "source": [ "model = sklearn.tree.DecisionTreeClassifier(max_depth=4, max_leaf_nodes=8)\n", - "model.fit(X_train, y_train) #, weights_train)" + "model.fit(X_train, y_train) # , weights_train)" ] }, { @@ -389,8 +400,15 @@ ], "source": [ "import matplotlib.pyplot as plt\n", - "plt.figure(figsize=(12,12))\n", - "viz = sklearn.tree.plot_tree(model, fontsize=8, class_names=list(model.classes_), feature_names=[\"k\", \"rows\", \"cols\", \"use_memory_pool\"], impurity=True)" + "\n", + "plt.figure(figsize=(12, 12))\n", + "viz = sklearn.tree.plot_tree(\n", + " model,\n", + " fontsize=8,\n", + " class_names=list(model.classes_),\n", + " feature_names=[\"k\", \"rows\", \"cols\", \"use_memory_pool\"],\n", + " impurity=True,\n", + ")" ] }, { @@ -441,33 +459,36 @@ " classes = model.classes_\n", " tree = model.tree_\n", " feature_names = [\"k\", \"rows\", \"cols\", \"use_memory_pool\"]\n", - " \n", + "\n", " def _get_label(nodeid):\n", - " \"\"\" returns the most frequent class name for the node \"\"\"\n", + " \"\"\"returns the most frequent class name for the node\"\"\"\n", " return classes[np.argsort(tree.value[nodeid, 0])[-1]]\n", - " \n", + "\n", " def _is_leaf_node(nodeid):\n", - " \"\"\" returns whether or not the node is a leaf node in the tree\"\"\"\n", + " \"\"\"returns whether or not the node is a leaf node in the tree\"\"\"\n", " # negative values here indicate we're a leaf\n", " if tree.feature[nodeid] < 0:\n", " return True\n", - " \n", + "\n", " # some nodes have both branches with the same label, combine those\n", - " left, right = tree.children_left[nodeid], tree.children_right[nodeid] \n", - " if (_is_leaf_node(left) and \n", - " _is_leaf_node(right) and \n", - " _get_label(left) == _get_label(right)):\n", + " left, right = tree.children_left[nodeid], tree.children_right[nodeid]\n", + " if (\n", + " _is_leaf_node(left)\n", + " and _is_leaf_node(right)\n", + " and _get_label(left) == _get_label(right)\n", + " ):\n", " return True\n", - " \n", + "\n", " return False\n", - " \n", + "\n", " code = []\n", + "\n", " def _convert_node(nodeid, indent):\n", " if _is_leaf_node(nodeid):\n", " # we're a leaf node, just output the label of the most frequent algorithm\n", " class_name = _get_label(nodeid)\n", " code.append(\" \" * indent + f\"return Algo::{class_name};\")\n", - " else: \n", + " else:\n", " feature = feature_names[tree.feature[nodeid]]\n", " threshold = int(np.floor(tree.threshold[nodeid]))\n", " code.append(\" \" * indent + f\"if ({feature} > {threshold}) \" + \"{\")\n", @@ -475,13 +496,16 @@ " code.append(\" \" * indent + \"} else {\")\n", " _convert_node(tree.children_left[nodeid], indent + 2)\n", " code.append(\" \" * indent + \"}\")\n", - " \n", - " code.append(\"inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k)\")\n", + "\n", + " code.append(\n", + " \"inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k)\"\n", + " )\n", " code.append(\"{\")\n", " _convert_node(0, indent=2)\n", " code.append(\"}\")\n", " return \"\\n\".join(code)\n", "\n", + "\n", "code = convert_model_to_code(model)\n", "print(code)" ] @@ -506,14 +530,27 @@ "source": [ "# also update the source code in raft/matrix/detail/select_k.cuh\n", "import pathlib\n", - "select_k_path = pathlib.Path.cwd() / \"..\" / \"..\" / \"..\" / \"include\" / \"raft\" / \"matrix\" / \"detail\" / \"select_k-inl.cuh\"\n", + "\n", + "select_k_path = (\n", + " pathlib.Path.cwd()\n", + " / \"..\"\n", + " / \"..\"\n", + " / \"..\"\n", + " / \"include\"\n", + " / \"raft\"\n", + " / \"matrix\"\n", + " / \"detail\"\n", + " / \"select_k-inl.cuh\"\n", + ")\n", "source_lines = open(select_k_path.resolve()).read().split(\"\\n\")\n", "\n", "# figure out the location of the code snippet in the file, and splice it in\n", "code_lines = code.split(\"\\n\")\n", "first_line = source_lines.index(code_lines[0])\n", - "last_line = source_lines.index(code_lines[-1], first_line)\n", - "new_source = source_lines[:first_line] + code_lines + source_lines[last_line+1:]\n", + "last_line = source_lines.index(code_lines[-1], first_line)\n", + "new_source = (\n", + " source_lines[:first_line] + code_lines + source_lines[last_line + 1 :]\n", + ")\n", "\n", "open(select_k_path.resolve(), \"w\").write(\"\\n\".join(new_source))" ] diff --git a/cpp/scripts/heuristics/select_k/generate_plots.ipynb b/cpp/scripts/heuristics/select_k/generate_plots.ipynb index ffdad58b1c..4e0d048a56 100644 --- a/cpp/scripts/heuristics/select_k/generate_plots.ipynb +++ b/cpp/scripts/heuristics/select_k/generate_plots.ipynb @@ -15,12 +15,13 @@ "import numpy as np\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", + "\n", "sns.set_theme()" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "f91d6f1d-e198-46c8-9ac6-955995f058d1", "metadata": { "tags": [] @@ -233,9 +234,10 @@ ], "source": [ "from select_k_dataset import load_dataframe, get_dataset\n", + "\n", "df = load_dataframe(\"select_k_times.json\")\n", - "df = df[(df.use_memory_pool == True)]\n", - "df = df[(df.index_type == 'int64_t') & (df.key_type == 'float')]\n", + "df = df[(df.use_memory_pool == True)] # noqa: E712\n", + "df = df[(df.index_type == \"int64_t\") & (df.key_type == \"float\")]\n", "df" ] }, @@ -253,24 +255,33 @@ " for algo in sorted(set(df.algo)):\n", " current = df[(df.algo == algo) & (df.time < np.inf)]\n", " ax.plot(current[x_axis], current[\"time\"], label=algo)\n", - " ax.set_xscale('log', base=2)\n", - " ax.set_yscale('log', base=2)\n", + " ax.set_xscale(\"log\", base=2)\n", + " ax.set_yscale(\"log\", base=2)\n", " ax.set_xlabel(x_axis)\n", " ax.set_ylabel(\"time(s)\")\n", " ax.set_title(title)\n", " fig.set_dpi(200)\n", - " ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=4)\n", - "# fig.legend()\n", + " ax.legend(loc=\"upper center\", bbox_to_anchor=(0.5, -0.15), ncol=4)\n", + " # fig.legend()\n", " plt.show()\n", "\n", + "\n", "def generate_k_plot(df, col, row):\n", - " return generate_plot(df[(df.col == col) & (df.row == row)], \"k\", f\"#cols={col}, #rows={row}\")\n", + " return generate_plot(\n", + " df[(df.col == col) & (df.row == row)], \"k\", f\"#cols={col}, #rows={row}\"\n", + " )\n", + "\n", "\n", "def generate_col_plot(df, row, k):\n", - " return generate_plot(df[(df.row == row) & (df.k == k)], \"col\", f\"#rows={row}, k={k}\")\n", + " return generate_plot(\n", + " df[(df.row == row) & (df.k == k)], \"col\", f\"#rows={row}, k={k}\"\n", + " )\n", + "\n", "\n", "def generate_row_plot(df, col, k):\n", - " return generate_plot(df[(df.col == col) & (df.k == k)], \"row\", f\"#cols={col}, k={k}\")" + " return generate_plot(\n", + " df[(df.col == col) & (df.k == k)], \"row\", f\"#cols={col}, k={k}\"\n", + " )" ] }, { diff --git a/cpp/scripts/include_checker.py b/cpp/scripts/include_checker.py index efbebdb765..5c45f63f0d 100644 --- a/cpp/scripts/include_checker.py +++ b/cpp/scripts/include_checker.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2020-2023, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # @@ -6,7 +6,6 @@ import sys import re import os -import subprocess import argparse @@ -15,14 +14,20 @@ exclusion_regex = re.compile(r".*thirdparty.*") + def parse_args(): argparser = argparse.ArgumentParser( - "Checks for a consistent '#include' syntax") - argparser.add_argument("--regex", type=str, - default=r"[.](cu|cuh|h|hpp|hxx|cpp)$", - help="Regex string to filter in sources") - argparser.add_argument("dirs", type=str, nargs="*", - help="List of dirs where to find sources") + "Checks for a consistent '#include' syntax" + ) + argparser.add_argument( + "--regex", + type=str, + default=r"[.](cu|cuh|h|hpp|hxx|cpp)$", + help="Regex string to filter in sources", + ) + argparser.add_argument( + "dirs", type=str, nargs="*", help="List of dirs where to find sources" + ) args = argparser.parse_args() args.regex_compiled = re.compile(args.regex) return args @@ -33,7 +38,9 @@ def list_all_source_file(file_regex, srcdirs): for srcdir in srcdirs: for root, dirs, files in os.walk(srcdir): for f in files: - if not re.search(exclusion_regex, root) and re.search(file_regex, f): + if not re.search(exclusion_regex, root) and re.search( + file_regex, f + ): src = os.path.join(root, f) all_files.append(src) return all_files @@ -51,10 +58,10 @@ def check_includes_in(src): inc_file = val[1:-1] # strip out " or < full_path = os.path.join(dir, inc_file) line_num = line_number + 1 - if val[0] == "\"" and not os.path.exists(full_path): + if val[0] == '"' and not os.path.exists(full_path): errs.append("Line:%d use #include <...>" % line_num) elif val[0] == "<" and os.path.exists(full_path): - errs.append("Line:%d use #include \"...\"" % line_num) + errs.append('Line:%d use #include "..."' % line_num) return errs diff --git a/cpp/scripts/run-clang-compile.py b/cpp/scripts/run-clang-compile.py index 30ff6fac98..d1eef26627 100644 --- a/cpp/scripts/run-clang-compile.py +++ b/cpp/scripts/run-clang-compile.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2020-2023, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # @@ -18,7 +18,8 @@ CMAKE_COMPILER_REGEX = re.compile( - r"^\s*CMAKE_CXX_COMPILER:FILEPATH=(.+)\s*$", re.MULTILINE) + r"^\s*CMAKE_CXX_COMPILER:FILEPATH=(.+)\s*$", re.MULTILINE +) CLANG_COMPILER = "clang++" GPU_ARCH_REGEX = re.compile(r"sm_(\d+)") SPACES = re.compile(r"\s+") @@ -26,28 +27,43 @@ XPTXAS_FLAG = re.compile(r"-((Xptxas)|(-ptxas-options))=?") # any options that may have equal signs in nvcc but not in clang # add those options here if you find any -OPTIONS_NO_EQUAL_SIGN = ['-isystem'] +OPTIONS_NO_EQUAL_SIGN = ["-isystem"] SEPARATOR = "-" * 8 END_SEPARATOR = "*" * 64 def parse_args(): - argparser = argparse.ArgumentParser("Runs clang++ on a project instead of nvcc") + argparser = argparse.ArgumentParser( + "Runs clang++ on a project instead of nvcc" + ) argparser.add_argument( - "-cdb", type=str, default="compile_commands.json", - help="Path to cmake-generated compilation database") + "-cdb", + type=str, + default="compile_commands.json", + help="Path to cmake-generated compilation database", + ) argparser.add_argument( - "-ignore", type=str, default=None, - help="Regex used to ignore files from checking") + "-ignore", + type=str, + default=None, + help="Regex used to ignore files from checking", + ) argparser.add_argument( - "-select", type=str, default=None, - help="Regex used to select files for checking") + "-select", + type=str, + default=None, + help="Regex used to select files for checking", + ) argparser.add_argument( - "-j", type=int, default=-1, help="Number of parallel jobs to launch.") + "-j", type=int, default=-1, help="Number of parallel jobs to launch." + ) argparser.add_argument( - "-build_dir", type=str, default=None, + "-build_dir", + type=str, + default=None, help="Directory from which compile commands should be called. " - "By default, directory of compile_commands.json file.") + "By default, directory of compile_commands.json file.", + ) args = argparser.parse_args() if args.j <= 0: args.j = mp.cpu_count() @@ -92,11 +108,14 @@ def get_gpu_archs(command): # clang only accepts a single architecture, so first determine the lowest archs = [] for loc in range(len(command)): - if (command[loc] != "-gencode" and command[loc] != "--generate-code" - and not command[loc].startswith("--generate-code=")): + if ( + command[loc] != "-gencode" + and command[loc] != "--generate-code" + and not command[loc].startswith("--generate-code=") + ): continue if command[loc].startswith("--generate-code="): - arch_flag = command[loc][len("--generate-code="):] + arch_flag = command[loc][len("--generate-code=") :] else: arch_flag = command[loc + 1] match = GPU_ARCH_REGEX.search(arch_flag) @@ -106,8 +125,9 @@ def get_gpu_archs(command): def get_index(arr, item_options): - return set(i for i, s in enumerate(arr) for item in item_options - if s == item) + return set( + i for i, s in enumerate(arr) for item in item_options if s == item + ) def remove_items(arr, item_options): @@ -120,8 +140,12 @@ def remove_items_plus_one(arr, item_options): if i < len(arr) - 1: del arr[i + 1] del arr[i] - idx = set(i for i, s in enumerate(arr) for item in item_options - if s.startswith(item + "=")) + idx = set( + i + for i, s in enumerate(arr) + for item in item_options + if s.startswith(item + "=") + ) for i in sorted(idx, reverse=True): del arr[i] @@ -131,7 +155,7 @@ def add_cuda_path(command, nvcc): if not nvcc_path: raise Exception("Command %s has invalid compiler %s" % (command, nvcc)) cuda_root = os.path.dirname(os.path.dirname(nvcc_path)) - command.append('--cuda-path=%s' % cuda_root) + command.append("--cuda-path=%s" % cuda_root) def get_clang_args(cmd, build_dir): @@ -152,57 +176,63 @@ def get_clang_args(cmd, build_dir): # provide proper cuda path to clang add_cuda_path(command, cc_orig) # remove all kinds of nvcc flags clang doesn't know about - remove_items_plus_one(command, [ - "--generate-code", - "-gencode", - "--x", - "-x", - "--compiler-bindir", - "-ccbin", - "--diag_suppress", - "-diag-suppress", - "--default-stream", - "-default-stream", - ]) - remove_items(command, [ - "-extended-lambda", - "--extended-lambda", - "-expt-extended-lambda", - "--expt-extended-lambda", - "-expt-relaxed-constexpr", - "--expt-relaxed-constexpr", - "--device-debug", - "-G", - "--generate-line-info", - "-lineinfo", - ]) + remove_items_plus_one( + command, + [ + "--generate-code", + "-gencode", + "--x", + "-x", + "--compiler-bindir", + "-ccbin", + "--diag_suppress", + "-diag-suppress", + "--default-stream", + "-default-stream", + ], + ) + remove_items( + command, + [ + "-extended-lambda", + "--extended-lambda", + "-expt-extended-lambda", + "--expt-extended-lambda", + "-expt-relaxed-constexpr", + "--expt-relaxed-constexpr", + "--device-debug", + "-G", + "--generate-line-info", + "-lineinfo", + ], + ) # "-x cuda" is the right usage in clang command.extend(["-x", "cuda"]) # we remove -Xcompiler flags: here we basically have to hope for the # best that clang++ will accept any flags which nvcc passed to gcc for i, c in reversed(list(enumerate(command))): - new_c = XCOMPILER_FLAG.sub('', c) + new_c = XCOMPILER_FLAG.sub("", c) if new_c == c: continue - command[i:i + 1] = new_c.split(',') + command[i : i + 1] = new_c.split(",") # we also change -Xptxas to -Xcuda-ptxas, always adding space here for i, c in reversed(list(enumerate(command))): if XPTXAS_FLAG.search(c): if not c.endswith("=") and i < len(command) - 1: del command[i + 1] - command[i] = '-Xcuda-ptxas' - command.insert(i + 1, XPTXAS_FLAG.sub('', c)) + command[i] = "-Xcuda-ptxas" + command.insert(i + 1, XPTXAS_FLAG.sub("", c)) # several options like isystem don't expect `=` for opt in OPTIONS_NO_EQUAL_SIGN: - opt_eq = opt + '=' + opt_eq = opt + "=" # make sure that we iterate from back to front here for insert for i, c in reversed(list(enumerate(command))): if not c.startswith(opt_eq): continue - x = c.split('=') + x = c.split("=") # we only care about the first `=` command[i] = x[0] - command.insert(i + 1, '='.join(x[1:])) + command.insert(i + 1, "=".join(x[1:])) # use extensible whole program, to avoid ptx resolution/linking command.extend(["-Xcuda-ptxas", "-ewp"]) # for libcudacxx, we need to allow variadic functions @@ -210,13 +240,17 @@ def get_clang_args(cmd, build_dir): # add some additional CUDA intrinsics cuda_intrinsics_file = os.path.join( os.path.dirname(os.path.realpath(__file__)), - "__clang_cuda_additional_intrinsics.h") + "__clang_cuda_additional_intrinsics.h", + ) command.extend(["-include", cuda_intrinsics_file]) # somehow this option gets onto the commandline, it is unrecognized by clang - remove_items(command, [ - "--forward-unknown-to-host-compiler", - "-forward-unknown-to-host-compiler" - ]) + remove_items( + command, + [ + "--forward-unknown-to-host-compiler", + "-forward-unknown-to-host-compiler", + ], + ) # do not treat warnings as errors here ! for i, x in reversed(list(enumerate(command))): if x.startswith("-Werror"): @@ -228,8 +262,14 @@ def get_clang_args(cmd, build_dir): def run_clang_command(clang_cmd, cwd): cmd = " ".join(clang_cmd) - result = subprocess.run(cmd, check=False, shell=True, cwd=cwd, - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + result = subprocess.run( + cmd, + check=False, + shell=True, + cwd=cwd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) result.stdout = result.stdout.decode("utf-8").strip() out = "CMD: " + cmd + "\n" out += "CWD: " + cwd + "\n" @@ -281,11 +321,15 @@ def run_sequential(args, all_files): results = [] for cmd in all_files: # skip files that we don't want to look at - if args.ignore_compiled is not None and \ - re.search(args.ignore_compiled, cmd["file"]) is not None: + if ( + args.ignore_compiled is not None + and re.search(args.ignore_compiled, cmd["file"]) is not None + ): continue - if args.select_compiled is not None and \ - re.search(args.select_compiled, cmd["file"]) is None: + if ( + args.select_compiled is not None + and re.search(args.select_compiled, cmd["file"]) is None + ): continue results.append(run_clang(cmd, args)) return all(results) @@ -305,11 +349,15 @@ def run_parallel(args, all_files): results = [] for cmd in all_files: # skip files that we don't want to look at - if args.ignore_compiled is not None and \ - re.search(args.ignore_compiled, cmd["file"]) is not None: + if ( + args.ignore_compiled is not None + and re.search(args.ignore_compiled, cmd["file"]) is not None + ): continue - if args.select_compiled is not None and \ - re.search(args.select_compiled, cmd["file"]) is None: + if ( + args.select_compiled is not None + and re.search(args.select_compiled, cmd["file"]) is None + ): continue results.append(pool.apply_async(run_clang, args=(cmd, args))) results_final = [r.get() for r in results] diff --git a/cpp/scripts/run-clang-tidy.py b/cpp/scripts/run-clang-tidy.py index 8382668ec9..2c051fd9f7 100644 --- a/cpp/scripts/run-clang-tidy.py +++ b/cpp/scripts/run-clang-tidy.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2020-2023, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # @@ -20,7 +20,8 @@ EXPECTED_VERSIONS = ("20.1.4",) VERSION_REGEX = re.compile(r"clang version ([0-9.]+)") CMAKE_COMPILER_REGEX = re.compile( - r"^\s*CMAKE_CXX_COMPILER:FILEPATH=(.+)\s*$", re.MULTILINE) + r"^\s*CMAKE_CXX_COMPILER:FILEPATH=(.+)\s*$", re.MULTILINE +) CLANG_COMPILER = "clang++" GPU_ARCH_REGEX = re.compile(r"sm_(\d+)") SPACES = re.compile(r"\s+") @@ -28,7 +29,7 @@ XPTXAS_FLAG = re.compile(r"-((Xptxas)|(-ptxas-options))=?") # any options that may have equal signs in nvcc but not in clang # add those options here if you find any -OPTIONS_NO_EQUAL_SIGN = ['-isystem'] +OPTIONS_NO_EQUAL_SIGN = ["-isystem"] SEPARATOR = "-" * 8 END_SEPARATOR = "*" * 64 @@ -36,28 +37,48 @@ def parse_args(): argparser = argparse.ArgumentParser("Runs clang-tidy on a project") argparser.add_argument( - "-cdb", type=str, default="compile_commands.json", - help="Path to cmake-generated compilation database") + "-cdb", + type=str, + default="compile_commands.json", + help="Path to cmake-generated compilation database", + ) argparser.add_argument( - "-exe", type=str, default="clang-tidy", help="Path to clang-tidy exe") + "-exe", type=str, default="clang-tidy", help="Path to clang-tidy exe" + ) argparser.add_argument( - "-ignore", type=str, default=None, - help="Regex used to ignore files from checking") + "-ignore", + type=str, + default=None, + help="Regex used to ignore files from checking", + ) argparser.add_argument( - "-select", type=str, default=None, - help="Regex used to select files for checking") + "-select", + type=str, + default=None, + help="Regex used to select files for checking", + ) argparser.add_argument( - "-j", type=int, default=-1, help="Number of parallel jobs to launch.") + "-j", type=int, default=-1, help="Number of parallel jobs to launch." + ) argparser.add_argument( - "-root", type=str, default=None, - help="Repo root path to filter headers correctly, CWD by default.") + "-root", + type=str, + default=None, + help="Repo root path to filter headers correctly, CWD by default.", + ) argparser.add_argument( - "-thrust_dir", type=str, default=None, - help="Pass the directory to a THRUST git repo recent enough for clang.") + "-thrust_dir", + type=str, + default=None, + help="Pass the directory to a THRUST git repo recent enough for clang.", + ) argparser.add_argument( - "-build_dir", type=str, default=None, + "-build_dir", + type=str, + default=None, help="Directory from which compile commands should be called. " - "By default, directory of compile_commands.json file.") + "By default, directory of compile_commands.json file.", + ) args = argparser.parse_args() if args.j <= 0: args.j = mp.cpu_count() @@ -71,8 +92,10 @@ def parse_args(): raise Exception("Failed to figure out clang compiler version!") version = version.group(1) if version not in EXPECTED_VERSIONS: - raise Exception("clang compiler version must be in %s found '%s'" % - (EXPECTED_VERSIONS, version)) + raise Exception( + "clang compiler version must be in %s found '%s'" + % (EXPECTED_VERSIONS, version) + ) if not os.path.exists(args.cdb): raise Exception("Compilation database '%s' missing" % args.cdb) # we assume that this script is run from repo root @@ -82,7 +105,8 @@ def parse_args(): # we need to have a recent enough cub version for clang to compile if args.thrust_dir is None: args.thrust_dir = os.path.join( - os.path.dirname(args.cdb), "thrust_1.15", "src", "thrust_1.15") + os.path.dirname(args.cdb), "thrust_1.15", "src", "thrust_1.15" + ) if args.build_dir is None: args.build_dir = os.path.dirname(args.cdb) if not os.path.isdir(args.thrust_dir): @@ -120,11 +144,14 @@ def get_gpu_archs(command): # clang only accepts a single architecture, so first determine the lowest archs = [] for loc in range(len(command)): - if (command[loc] != "-gencode" and command[loc] != "--generate-code" - and not command[loc].startswith("--generate-code=")): + if ( + command[loc] != "-gencode" + and command[loc] != "--generate-code" + and not command[loc].startswith("--generate-code=") + ): continue if command[loc].startswith("--generate-code="): - arch_flag = command[loc][len("--generate-code="):] + arch_flag = command[loc][len("--generate-code=") :] else: arch_flag = command[loc + 1] match = GPU_ARCH_REGEX.search(arch_flag) @@ -134,8 +161,9 @@ def get_gpu_archs(command): def get_index(arr, item_options): - return set(i for i, s in enumerate(arr) for item in item_options - if s == item) + return set( + i for i, s in enumerate(arr) for item in item_options if s == item + ) def remove_items(arr, item_options): @@ -148,8 +176,12 @@ def remove_items_plus_one(arr, item_options): if i < len(arr) - 1: del arr[i + 1] del arr[i] - idx = set(i for i, s in enumerate(arr) for item in item_options - if s.startswith(item + "=")) + idx = set( + i + for i, s in enumerate(arr) + for item in item_options + if s.startswith(item + "=") + ) for i in sorted(idx, reverse=True): del arr[i] @@ -159,7 +191,7 @@ def add_cuda_path(command, nvcc): if not nvcc_path: raise Exception("Command %s has invalid compiler %s" % (command, nvcc)) cuda_root = os.path.dirname(os.path.dirname(nvcc_path)) - command.append('--cuda-path=%s' % cuda_root) + command.append("--cuda-path=%s" % cuda_root) def get_tidy_args(cmd, args): @@ -183,57 +215,63 @@ def get_tidy_args(cmd, args): # provide proper cuda path to clang add_cuda_path(command, cc_orig) # remove all kinds of nvcc flags clang doesn't know about - remove_items_plus_one(command, [ - "--generate-code", - "-gencode", - "--x", - "-x", - "--compiler-bindir", - "-ccbin", - "--diag_suppress", - "-diag-suppress", - "--default-stream", - "-default-stream", - ]) - remove_items(command, [ - "-extended-lambda", - "--extended-lambda", - "-expt-extended-lambda", - "--expt-extended-lambda", - "-expt-relaxed-constexpr", - "--expt-relaxed-constexpr", - "--device-debug", - "-G", - "--generate-line-info", - "-lineinfo", - ]) + remove_items_plus_one( + command, + [ + "--generate-code", + "-gencode", + "--x", + "-x", + "--compiler-bindir", + "-ccbin", + "--diag_suppress", + "-diag-suppress", + "--default-stream", + "-default-stream", + ], + ) + remove_items( + command, + [ + "-extended-lambda", + "--extended-lambda", + "-expt-extended-lambda", + "--expt-extended-lambda", + "-expt-relaxed-constexpr", + "--expt-relaxed-constexpr", + "--device-debug", + "-G", + "--generate-line-info", + "-lineinfo", + ], + ) # "-x cuda" is the right usage in clang command.extend(["-x", "cuda"]) # we remove -Xcompiler flags: here we basically have to hope for the # best that clang++ will accept any flags which nvcc passed to gcc for i, c in reversed(list(enumerate(command))): - new_c = XCOMPILER_FLAG.sub('', c) + new_c = XCOMPILER_FLAG.sub("", c) if new_c == c: continue - command[i:i + 1] = new_c.split(',') + command[i : i + 1] = new_c.split(",") # we also change -Xptxas to -Xcuda-ptxas, always adding space here for i, c in reversed(list(enumerate(command))): if XPTXAS_FLAG.search(c): if not c.endswith("=") and i < len(command) - 1: del command[i + 1] - command[i] = '-Xcuda-ptxas' - command.insert(i + 1, XPTXAS_FLAG.sub('', c)) + command[i] = "-Xcuda-ptxas" + command.insert(i + 1, XPTXAS_FLAG.sub("", c)) # several options like isystem don't expect `=` for opt in OPTIONS_NO_EQUAL_SIGN: - opt_eq = opt + '=' + opt_eq = opt + "=" # make sure that we iterate from back to front here for insert for i, c in reversed(list(enumerate(command))): if not c.startswith(opt_eq): continue - x = c.split('=') + x = c.split("=") # we only care about the first `=` command[i] = x[0] - command.insert(i + 1, '='.join(x[1:])) + command.insert(i + 1, "=".join(x[1:])) # use extensible whole program, to avoid ptx resolution/linking command.extend(["-Xcuda-ptxas", "-ewp"]) # for libcudacxx, we need to allow variadic functions @@ -241,13 +279,17 @@ def get_tidy_args(cmd, args): # add some additional CUDA intrinsics cuda_intrinsics_file = os.path.join( os.path.dirname(os.path.realpath(__file__)), - "__clang_cuda_additional_intrinsics.h") + "__clang_cuda_additional_intrinsics.h", + ) command.extend(["-include", cuda_intrinsics_file]) # somehow this option gets onto the commandline, it is unrecognized by tidy - remove_items(command, [ - "--forward-unknown-to-host-compiler", - "-forward-unknown-to-host-compiler" - ]) + remove_items( + command, + [ + "--forward-unknown-to-host-compiler", + "-forward-unknown-to-host-compiler", + ], + ) # do not treat warnings as errors here ! for i, x in reversed(list(enumerate(command))): if x.startswith("-Werror"): @@ -271,8 +313,14 @@ def check_output_for_errors(output): def run_clang_tidy_command(tidy_cmd, cwd): cmd = " ".join(tidy_cmd) - result = subprocess.run(cmd, check=False, shell=True, cwd=cwd, - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + result = subprocess.run( + cmd, + check=False, + shell=True, + cwd=cwd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) result.stdout = result.stdout.decode("utf-8").strip() out = "CMD: " + cmd + "\n" out += "EXIT-CODE: %d\n" % result.returncode @@ -300,7 +348,8 @@ def __exit__(self, _, __, ___): def print_result(passed, stdout, file, errors): if any(errors): raise Exception( - "File %s: got %d errors:\n%s" % (file, len(errors), stdout)) + "File %s: got %d errors:\n%s" % (file, len(errors), stdout) + ) status_str = "PASSED" if passed else "FAILED" print("%s File:%s %s %s" % (SEPARATOR, file, status_str, SEPARATOR)) if not passed and stdout: @@ -354,11 +403,15 @@ def run_sequential(args, all_files): # actual tidy checker for cmd in all_files: # skip files that we don't want to look at - if args.ignore_compiled is not None and \ - re.search(args.ignore_compiled, cmd["file"]) is not None: + if ( + args.ignore_compiled is not None + and re.search(args.ignore_compiled, cmd["file"]) is not None + ): continue - if args.select_compiled is not None and \ - re.search(args.select_compiled, cmd["file"]) is None: + if ( + args.select_compiled is not None + and re.search(args.select_compiled, cmd["file"]) is None + ): continue results.append(run_clang_tidy(cmd, args)) return parse_results(results) @@ -379,11 +432,15 @@ def run_parallel(args, all_files): # actual tidy checker for cmd in all_files: # skip files that we don't want to look at - if args.ignore_compiled is not None and \ - re.search(args.ignore_compiled, cmd["file"]) is not None: + if ( + args.ignore_compiled is not None + and re.search(args.ignore_compiled, cmd["file"]) is not None + ): continue - if args.select_compiled is not None and \ - re.search(args.select_compiled, cmd["file"]) is None: + if ( + args.select_compiled is not None + and re.search(args.select_compiled, cmd["file"]) is None + ): continue results.append(pool.apply_async(run_clang_tidy, args=(cmd, args))) results_final = [r.get() for r in results] @@ -409,22 +466,29 @@ def main(): # first get a list of all checks that were run ret = subprocess.check_output(args.exe + " --list-checks", shell=True) ret = ret.decode("utf-8") - checks = [line.strip() for line in ret.splitlines() - if line.startswith(' ' * 4)] + checks = [ + line.strip() + for line in ret.splitlines() + if line.startswith(" " * 4) + ] max_check_len = max(len(c) for c in checks) check_counts = dict() content = os.linesep.join(lines) for check in checks: check_counts[check] = content.count(check) sorted_counts = sorted( - check_counts.items(), key=lambda x: x[1], reverse=True) - print("Failed {} check(s) in total. Counts as per below:".format( - sum(1 for _, count in sorted_counts if count > 0))) + check_counts.items(), key=lambda x: x[1], reverse=True + ) + print( + "Failed {} check(s) in total. Counts as per below:".format( + sum(1 for _, count in sorted_counts if count > 0) + ) + ) for check, count in sorted_counts: if count <= 0: break n_space = max_check_len - len(check) + 4 - print("{}:{}{}".format(check, ' ' * n_space, count)) + print("{}:{}{}".format(check, " " * n_space, count)) raise Exception("clang-tidy failed! Refer to the errors above.") diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py index da66e37996..0cfa0c2c2a 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py @@ -69,128 +69,141 @@ dict( path_prefix="canberra", OpT="cuvs::distance::detail::ops::canberra_distance_op", - archs = [60], + archs=[60], ), dict( path_prefix="correlation", OpT="cuvs::distance::detail::ops::correlation_distance_op", - archs = [60], + archs=[60], ), dict( path_prefix="cosine", OpT="cuvs::distance::detail::ops::cosine_distance_op", - archs = [60, 80], + archs=[60, 80], ), dict( path_prefix="hamming_unexpanded", OpT="cuvs::distance::detail::ops::hamming_distance_op", - archs = [60], + archs=[60], ), dict( path_prefix="hellinger_expanded", OpT="cuvs::distance::detail::ops::hellinger_distance_op", - archs = [60], + archs=[60], ), # inner product is handled by cublas. dict( path_prefix="jensen_shannon", OpT="cuvs::distance::detail::ops::jensen_shannon_distance_op", - archs = [60], + archs=[60], ), dict( path_prefix="kl_divergence", OpT="cuvs::distance::detail::ops::kl_divergence_op", - archs = [60], + archs=[60], ), dict( path_prefix="l1", OpT="cuvs::distance::detail::ops::l1_distance_op", - archs = [60], + archs=[60], ), dict( path_prefix="l2_expanded", OpT="cuvs::distance::detail::ops::l2_exp_distance_op", - archs = [60, 80], + archs=[60, 80], ), dict( path_prefix="l2_unexpanded", OpT="cuvs::distance::detail::ops::l2_unexp_distance_op", - archs = [60], + archs=[60], ), dict( path_prefix="l_inf", OpT="cuvs::distance::detail::ops::l_inf_distance_op", - archs = [60], + archs=[60], ), dict( path_prefix="lp_unexpanded", OpT="cuvs::distance::detail::ops::lp_unexp_distance_op", - archs = [60], + archs=[60], ), dict( path_prefix="russel_rao", OpT="cuvs::distance::detail::ops::russel_rao_distance_op", - archs = [60], - ), + archs=[60], + ), ] + def arch_headers(archs): - include_headers ="\n".join([ - f"#include \"dispatch_sm{arch}.cuh\"" - for arch in archs - ]) + include_headers = "\n".join( + [f'#include "dispatch_sm{arch}.cuh"' for arch in archs] + ) return include_headers - for op in op_instances: for dt in data_type_instances: - DataT, AccT, OutT, IdxT = (dt[k] for k in ["DataT", "AccT", "OutT", "IdxT"]); + DataT, AccT, OutT, IdxT = ( + dt[k] for k in ["DataT", "AccT", "OutT", "IdxT"] + ) path = f"dispatch_{op['path_prefix']}_{DataT}_{AccT}_{OutT}_{IdxT}.cu" with open(path, "w") as f: f.write(header) f.write(arch_headers(op["archs"])) f.write(macro) - OpT = op['OpT'] + OpT = op["OpT"] FinOpT = "raft::identity_op" - f.write(f"\ninstantiate_raft_distance_detail_pairwise_matrix_dispatch({OpT}, {DataT}, {AccT}, {OutT}, {FinOpT}, {IdxT});\n") - f.write("\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n") + f.write( + f"\ninstantiate_raft_distance_detail_pairwise_matrix_dispatch({OpT}, {DataT}, {AccT}, {OutT}, {FinOpT}, {IdxT});\n" + ) + f.write( + "\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n" + ) print(f"src/distance/detail/pairwise_matrix/{path}") # Dispatch kernels for with the RBF fin op. with open("dispatch_rbf.cu", "w") as f: - OpT="cuvs::distance::detail::ops::l2_unexp_distance_op" - archs = [60] + OpT = "cuvs::distance::detail::ops::l2_unexp_distance_op" + archs = [60] - f.write(header) - f.write("#include \"../kernels/rbf_fin_op.cuh\" // rbf_fin_op\n") - f.write(arch_headers(archs)) - f.write(macro) + f.write(header) + f.write('#include "../kernels/rbf_fin_op.cuh" // rbf_fin_op\n') + f.write(arch_headers(archs)) + f.write(macro) - for dt in data_type_instances: - DataT, AccT, OutT, IdxT = (dt[k] for k in ["DataT", "AccT", "OutT", "IdxT"]); - IdxT = "int64_t" # overwrite IdxT - - FinOpT = f"cuvs::distance::kernels::detail::rbf_fin_op<{DataT}>" - f.write(f"\ninstantiate_raft_distance_detail_pairwise_matrix_dispatch({OpT}, {DataT}, {AccT}, {OutT}, {FinOpT}, {IdxT});\n") + for dt in data_type_instances: + DataT, AccT, OutT, IdxT = ( + dt[k] for k in ["DataT", "AccT", "OutT", "IdxT"] + ) + IdxT = "int64_t" # overwrite IdxT - f.write("\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n") + FinOpT = f"cuvs::distance::kernels::detail::rbf_fin_op<{DataT}>" + f.write( + f"\ninstantiate_raft_distance_detail_pairwise_matrix_dispatch({OpT}, {DataT}, {AccT}, {OutT}, {FinOpT}, {IdxT});\n" + ) + f.write( + "\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n" + ) - print("src/distance/detail/pairwise_matrix/dispatch_rbf.cu") + print("src/distance/detail/pairwise_matrix/dispatch_rbf.cu") # L2 with int64_t indices for kmeans code int64_t_op_instances = [ dict( path_prefix="l2_expanded", OpT="cuvs::distance::detail::ops::l2_exp_distance_op", - archs = [60, 80], - )] + archs=[60, 80], + ) +] for op in int64_t_op_instances: for dt in data_type_instances: - DataT, AccT, OutT, IdxT = (dt[k] for k in ["DataT", "AccT", "OutT", "IdxT"]); + DataT, AccT, OutT, IdxT = ( + dt[k] for k in ["DataT", "AccT", "OutT", "IdxT"] + ) IdxT = "int64_t" path = f"dispatch_{op['path_prefix']}_{DataT}_{AccT}_{OutT}_{IdxT}.cu" @@ -199,8 +212,12 @@ def arch_headers(archs): f.write(arch_headers(op["archs"])) f.write(macro) - OpT = op['OpT'] + OpT = op["OpT"] FinOpT = "raft::identity_op" - f.write(f"\ninstantiate_raft_distance_detail_pairwise_matrix_dispatch({OpT}, {DataT}, {AccT}, {OutT}, {FinOpT}, {IdxT});\n") - f.write("\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n") + f.write( + f"\ninstantiate_raft_distance_detail_pairwise_matrix_dispatch({OpT}, {DataT}, {AccT}, {OutT}, {FinOpT}, {IdxT});\n" + ) + f.write( + "\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n" + ) print(f"src/distance/detail/pairwise_matrix/{path}") diff --git a/cpp/src/neighbors/ball_cover/detail/ball_cover/registers_00_generate.py b/cpp/src/neighbors/ball_cover/detail/ball_cover/registers_00_generate.py index 7068014e9c..f5ef49f67f 100644 --- a/cpp/src/neighbors/ball_cover/detail/ball_cover/registers_00_generate.py +++ b/cpp/src/neighbors/ball_cover/detail/ball_cover/registers_00_generate.py @@ -93,38 +93,38 @@ """ -euclideanSq="cuvs::neighbors::ball_cover::detail::EuclideanSqFunc" +euclideanSq = "cuvs::neighbors::ball_cover::detail::EuclideanSqFunc" types = dict( int64_float=("std::int64_t", "float"), ) -path = f"registers_pass_one.cu" +path = "registers_pass_one.cu" with open(path, "w") as f: f.write(header) f.write(macro_pass_one) for type_path, (int_t, data_t) in types.items(): - f.write(f"instantiate_cuvs_neighbors_detail_rbc_low_dim_pass_one(\n") + f.write("instantiate_cuvs_neighbors_detail_rbc_low_dim_pass_one(\n") f.write(f" {int_t}, {data_t});\n") f.write("#undef instantiate_cuvs_neighbors_detail_rbc_low_dim_pass_one\n") print(f"src/neighbors/ball_cover/detail/ball_cover/{path}") -path = f"registers_pass_two.cu" +path = "registers_pass_two.cu" with open(path, "w") as f: f.write(header) f.write(macro_pass_two) for type_path, (int_t, data_t) in types.items(): - f.write(f"instantiate_cuvs_neighbors_detail_rbc_low_dim_pass_two(\n") + f.write("instantiate_cuvs_neighbors_detail_rbc_low_dim_pass_two(\n") f.write(f" {int_t}, {data_t});\n") f.write("#undef instantiate_cuvs_neighbors_detail_rbc_low_dim_pass_two\n") print(f"src/neighbors/ball_cover/detail/ball_cover/{path}") -path="registers_eps_pass_euclidean.cu" +path = "registers_eps_pass_euclidean.cu" with open(path, "w") as f: f.write(header) f.write(macro_pass_eps) for type_path, (int_t, data_t) in types.items(): - f.write(f"instantiate_cuvs_neighbors_detail_rbc_eps_pass(\n") + f.write("instantiate_cuvs_neighbors_detail_rbc_eps_pass(\n") f.write(f" {int_t}, {data_t}, {euclideanSq});\n") f.write("#undef instantiate_cuvs_neighbors_detail_rbc_eps_pass\n") print(f"src/neighbors/ball_cover/detail/ball_cover/{path}") diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py b/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py index c0b5d572fb..fde2081c12 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py @@ -19,19 +19,19 @@ * */ -{includes} +{{includes}} -namespace cuvs::neighbors::cagra::detail {{ +namespace cuvs::neighbors::cagra::detail {{{{ using namespace cuvs::distance; -{content} +{{content}} -}} // namespace cuvs::neighbors::cagra::detail +}}}} // namespace cuvs::neighbors::cagra::detail """ mxdim_team = [(128, 8), (256, 16), (512, 32)] -#mxdim_team = [(64, 8), (128, 16), (256, 32)] -#mxdim_team = [(32, 8), (64, 16), (128, 32)] +# mxdim_team = [(64, 8), (128, 16), (256, 32)] +# mxdim_team = [(32, 8), (64, 16), (128, 32)] pq_bits = [8] pq_lens = [2, 4] @@ -48,27 +48,24 @@ uint8_uint32=("uint8_t", "uint32_t", "float"), ) -metric_prefix = 'DistanceType::' +metric_prefix = "DistanceType::" specs = [] descs = [] cmake_list = [] - - # Cleanup first for f in glob.glob("compute_distance_standard_*.cu"): - os.remove(f) + os.remove(f) for f in glob.glob("compute_distance_vpq_*.cu"): - os.remove(f) + os.remove(f) # Generate new files for type_path, (data_t, idx_t, distance_t) in search_types.items(): - for (mxdim, team) in mxdim_team: + for mxdim, team in mxdim_team: # CAGRA - for metric in ['L2Expanded', 'InnerProduct', 'CosineExpanded']: - + for metric in ["L2Expanded", "InnerProduct", "CosineExpanded"]: path = f"compute_distance_standard_{metric}_{type_path}_dim{mxdim}_t{team}.cu" includes = '#include "compute_distance_standard-impl.cuh"' params = f"{metric_prefix}{metric}, {team}, {mxdim}, {data_t}, {idx_t}, {distance_t}" @@ -83,7 +80,7 @@ for code_book_t in code_book_types: for pq_len in pq_lens: for pq_bit in pq_bits: - for metric in ['L2Expanded']: + for metric in ["L2Expanded"]: path = f"compute_distance_vpq_{metric}_{type_path}_dim{mxdim}_t{team}_{pq_bit}pq_{pq_len}subd_{code_book_t}.cu" includes = '#include "compute_distance_vpq-impl.cuh"' params = f"{metric_prefix}{metric}, {team}, {mxdim}, {pq_bit}, {pq_len}, {code_book_t}, {data_t}, {idx_t}, {distance_t}" @@ -91,18 +88,26 @@ content = f"""template struct {spec};""" specs.append(spec) with open(path, "w") as f: - f.write(template.format(includes=includes, content=content)) - cmake_list.append(f" src/neighbors/detail/cagra/{path}") + f.write( + template.format( + includes=includes, content=content + ) + ) + cmake_list.append( + f" src/neighbors/detail/cagra/{path}" + ) # CAGRA (Binary Hamming distance) -for (mxdim, team) in mxdim_team: - metric = 'BitwiseHamming' - type_path = 'u8_uint32' - idx_t = 'uint32_t' - distance_t = 'float' - data_t = 'uint8_t' - - path = f"compute_distance_standard_{metric}_{type_path}_dim{mxdim}_t{team}.cu" +for mxdim, team in mxdim_team: + metric = "BitwiseHamming" + type_path = "u8_uint32" + idx_t = "uint32_t" + distance_t = "float" + data_t = "uint8_t" + + path = ( + f"compute_distance_standard_{metric}_{type_path}_dim{mxdim}_t{team}.cu" + ) includes = '#include "compute_distance_standard-impl.cuh"' params = f"{metric_prefix}{metric}, {team}, {mxdim}, {data_t}, {idx_t}, {distance_t}" spec = f"standard_descriptor_spec<{params}>" @@ -113,14 +118,14 @@ cmake_list.append(f" src/neighbors/detail/cagra/{path}") with open("compute_distance-ext.cuh", "w") as f: - includes = ''' + includes = """ #pragma once #include "compute_distance_standard.hpp" #include "compute_distance_vpq.hpp" -''' +""" newline = "\n" - contents = f''' + contents = f""" {newline.join(map(lambda s: "extern template struct " + s + ";", specs))} extern template struct @@ -142,16 +147,16 @@ }} return init(params, dataset, metric, dataset_norms); }} -''' +""" f.write(template.format(includes=includes, content=contents)) with open("compute_distance.cu", "w") as f: includes = '#include "compute_distance-ext.cuh"' newline = "\n" - contents = f''' + contents = f""" template struct instance_selector<{("," + newline + " ").join(specs)}>; -''' +""" f.write(template.format(includes=includes, content=contents)) cmake_list.sort() diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py index f5e0287321..342a61afd6 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py @@ -22,7 +22,7 @@ #define COMMA , -namespace cuvs::neighbors::cagra::detail::multi_cta_search { +namespace cuvs::neighbors::cagra::detail::multi_cta_search {{ """ trailer = """ @@ -48,10 +48,10 @@ with open(path, "w") as f: f.write(header) f.write( - f"instantiate_kernel_selection(\n {data_t}, {idx_t}, {distance_t}, cuvs::neighbors::filtering::none_sample_filter);\n" + f"instantiate_kernel_selection(\n {data_t}, {idx_t}, {distance_t}, cuvs::neighbors::filtering::none_sample_filter);\n" ) f.write( - f"instantiate_kernel_selection(\n {data_t}, {idx_t}, {distance_t}, CagraSampleFilterWithQueryIdOffset>);\n" + f"instantiate_kernel_selection(\n {data_t}, {idx_t}, {distance_t}, CagraSampleFilterWithQueryIdOffset>);\n" ) f.write(trailer) # For pasting into CMakeLists.txt diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py index 78e03d9b4f..0e98c6e41c 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py @@ -22,7 +22,7 @@ #define COMMA , -namespace cuvs::neighbors::cagra::detail::single_cta_search { +namespace cuvs::neighbors::cagra::detail::single_cta_search {{ """ trailer = """ @@ -51,10 +51,10 @@ with open(path, "w") as f: f.write(header) f.write( - f"instantiate_kernel_selection(\n {data_t}, {idx_t}, {distance_t}, cuvs::neighbors::filtering::none_sample_filter);\n" + f"instantiate_kernel_selection(\n {data_t}, {idx_t}, {distance_t}, cuvs::neighbors::filtering::none_sample_filter);\n" ) f.write( - f"instantiate_kernel_selection(\n {data_t}, {idx_t}, {distance_t}, CagraSampleFilterWithQueryIdOffset>);\n" + f"instantiate_kernel_selection(\n {data_t}, {idx_t}, {distance_t}, CagraSampleFilterWithQueryIdOffset>);\n" ) f.write(trailer) diff --git a/cpp/src/neighbors/iface/generate_iface.py b/cpp/src/neighbors/iface/generate_iface.py index 385914e4c6..ef9e56a5b2 100644 --- a/cpp/src/neighbors/iface/generate_iface.py +++ b/cpp/src/neighbors/iface/generate_iface.py @@ -207,24 +207,24 @@ const std::string& filename); """ -flat_macros = dict ( - flat = dict( +flat_macros = dict( + flat=dict( include=include_macro, definition=flat_macro, name="CUVS_INST_MG_FLAT", ) ) -pq_macros = dict ( - pq = dict( +pq_macros = dict( + pq=dict( include=include_macro, definition=pq_macro, name="CUVS_INST_MG_PQ", ) ) -cagra_macros = dict ( - cagra = dict( +cagra_macros = dict( + cagra=dict( include=include_macro, definition=cagra_macro, name="CUVS_INST_MG_CAGRA", @@ -252,17 +252,21 @@ uint8_t_uint32_t=("uint8_t", "uint32_t"), ) -for macros, types in [(flat_macros, flat_types), (pq_macros, pq_types), (cagra_macros, cagra_types)]: - for type_path, (T, IdxT) in types.items(): - for macro_path, macro in macros.items(): - path = f"iface_{macro_path}_{type_path}.cu" - with open(path, "w") as f: - f.write(header) - f.write(macro['include']) - f.write(namespace_macro) - f.write(macro["definition"]) - f.write(f"{macro['name']}({T}, {IdxT});\n\n") - f.write(f"#undef {macro['name']}\n") - f.write(footer) +for macros, types in [ + (flat_macros, flat_types), + (pq_macros, pq_types), + (cagra_macros, cagra_types), +]: + for type_path, (T, IdxT) in types.items(): + for macro_path, macro in macros.items(): + path = f"iface_{macro_path}_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(macro["include"]) + f.write(namespace_macro) + f.write(macro["definition"]) + f.write(f"{macro['name']}({T}, {IdxT});\n\n") + f.write(f"#undef {macro['name']}\n") + f.write(footer) - print(f"src/neighbors/iface/{path}") + print(f"src/neighbors/iface/{path}") diff --git a/cpp/src/neighbors/ivf_flat/generate_ivf_flat.py b/cpp/src/neighbors/ivf_flat/generate_ivf_flat.py index b55a945c27..a44c306ca3 100644 --- a/cpp/src/neighbors/ivf_flat/generate_ivf_flat.py +++ b/cpp/src/neighbors/ivf_flat/generate_ivf_flat.py @@ -161,7 +161,7 @@ path = f"ivf_flat_{macro_path}_{type_path}.cu" with open(path, "w") as f: f.write(header) - f.write(macro['include']) + f.write(macro["include"]) f.write(namespace_macro) f.write(macro["definition"]) f.write(f"{macro['name']}({T}, {IdxT});\n\n") diff --git a/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq.py b/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq.py index 9feba888ba..da9b78992e 100644 --- a/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq.py +++ b/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq.py @@ -77,7 +77,7 @@ path = f"ivf_pq_{macro_path}_{type_path}.cu" with open(path, "w") as f: f.write(header) - f.write(macro['include']) + f.write(macro["include"]) f.write(namespace_macro) f.write(macro["definition"]) f.write(f"{macro['name']}({T}, {IdxT});\n\n") diff --git a/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq_compute_similarity.py b/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq_compute_similarity.py index 5c7543e973..2d9f619a62 100644 --- a/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq_compute_similarity.py +++ b/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq_compute_similarity.py @@ -65,26 +65,62 @@ #define COMMA , """ -none_filter_int64 = "cuvs::neighbors::filtering::ivf_to_sample_filter" \ - "" -bitset_filter64 = "cuvs::neighbors::filtering::ivf_to_sample_filter" \ - ">" +none_filter_int64 = ( + "cuvs::neighbors::filtering::ivf_to_sample_filter" + "" +) +bitset_filter64 = ( + "cuvs::neighbors::filtering::ivf_to_sample_filter" + ">" +) types = dict( - half_fp8_false=("half", "cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", none_filter_int64), - half_fp8_true=("half", "cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", none_filter_int64), + half_fp8_false=( + "half", + "cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", + none_filter_int64, + ), + half_fp8_true=( + "half", + "cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", + none_filter_int64, + ), half_half=("half", "half", none_filter_int64), float_half=("float", "half", none_filter_int64), - float_float= ("float", "float", none_filter_int64), - float_fp8_false=("float", "cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", none_filter_int64), - float_fp8_true=("float", "cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", none_filter_int64), - half_fp8_false_bitset64=("half", "cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", bitset_filter64), - half_fp8_true_bitset64=("half", "cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", bitset_filter64), + float_float=("float", "float", none_filter_int64), + float_fp8_false=( + "float", + "cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", + none_filter_int64, + ), + float_fp8_true=( + "float", + "cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", + none_filter_int64, + ), + half_fp8_false_bitset64=( + "half", + "cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", + bitset_filter64, + ), + half_fp8_true_bitset64=( + "half", + "cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", + bitset_filter64, + ), half_half_bitset64=("half", "half", bitset_filter64), float_half_bitset64=("float", "half", bitset_filter64), - float_float_bitset64= ("float", "float", bitset_filter64), - float_fp8_false_bitset64=("float", "cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", bitset_filter64), - float_fp8_true_bitset64=("float", "cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", bitset_filter64) + float_float_bitset64=("float", "float", bitset_filter64), + float_fp8_false_bitset64=( + "float", + "cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", + bitset_filter64, + ), + float_fp8_true_bitset64=( + "float", + "cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", + bitset_filter64, + ), ) for path_key, (OutT, LutT, FilterT) in types.items(): @@ -92,5 +128,7 @@ with open(path, "w") as f: f.write(header) f.write(declaration_macro) - f.write(f"instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, {FilterT});\n") + f.write( + f"instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, {FilterT});\n" + ) print(f"src/neighbors/ivf_pq/{path}") diff --git a/cpp/src/neighbors/mg/generate_mg.py b/cpp/src/neighbors/mg/generate_mg.py index 14dcf19e6d..c85a1aeaf3 100644 --- a/cpp/src/neighbors/mg/generate_mg.py +++ b/cpp/src/neighbors/mg/generate_mg.py @@ -230,24 +230,24 @@ } // namespace cuvs::neighbors::cagra """ -flat_macros = dict ( - flat = dict( +flat_macros = dict( + flat=dict( include=include_macro, definition=flat_macro, name="CUVS_INST_MG_FLAT", ) ) -pq_macros = dict ( - pq = dict( +pq_macros = dict( + pq=dict( include=include_macro, definition=pq_macro, name="CUVS_INST_MG_PQ", ) ) -cagra_macros = dict ( - cagra = dict( +cagra_macros = dict( + cagra=dict( include=include_macro, definition=cagra_macro, name="CUVS_INST_MG_CAGRA", @@ -275,15 +275,19 @@ uint8_t_uint32_t=("uint8_t", "uint32_t"), ) -for macros, types in [(flat_macros, flat_types), (pq_macros, pq_types), (cagra_macros, cagra_types)]: - for type_path, (T, IdxT) in types.items(): - for macro_path, macro in macros.items(): - path = f"mg_{macro_path}_{type_path}.cu" - with open(path, "w") as f: - f.write(header) - f.write(macro['include']) - f.write(macro["definition"]) - f.write(f"{macro['name']}({T}, {IdxT});\n\n") - f.write(f"#undef {macro['name']}\n") +for macros, types in [ + (flat_macros, flat_types), + (pq_macros, pq_types), + (cagra_macros, cagra_types), +]: + for type_path, (T, IdxT) in types.items(): + for macro_path, macro in macros.items(): + path = f"mg_{macro_path}_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(macro["include"]) + f.write(macro["definition"]) + f.write(f"{macro['name']}({T}, {IdxT});\n\n") + f.write(f"#undef {macro['name']}\n") - print(f"src/neighbors/mg/{path}") + print(f"src/neighbors/mg/{path}") diff --git a/docs/source/conf.py b/docs/source/conf.py index a2b407f8d8..5b8c38b0fb 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -37,7 +37,7 @@ "breathe", "recommonmark", "sphinx_markdown_tables", - "sphinx_copybutton" + "sphinx_copybutton", ] breathe_default_project = "cuvs" @@ -74,7 +74,9 @@ # The short X.Y version. version = f"{CUVS_VERSION.major:02}.{CUVS_VERSION.minor:02}" # The full version, including alpha/beta/rc tags. -release = f"{CUVS_VERSION.major:02}.{CUVS_VERSION.minor:02}.{CUVS_VERSION.micro:02}" +release = ( + f"{CUVS_VERSION.major:02}.{CUVS_VERSION.minor:02}.{CUVS_VERSION.micro:02}" +) # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -150,7 +152,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, "cuvs.tex", "cuVS Documentation", "NVIDIA Corporation", "manual"), + ( + master_doc, + "cuvs.tex", + "cuVS Documentation", + "NVIDIA Corporation", + "manual", + ), ] # -- Options for manual page output --------------------------------------- diff --git a/docs/source/sphinxext/github_link.py b/docs/source/sphinxext/github_link.py index 512782af58..1ee5f610b5 100644 --- a/docs/source/sphinxext/github_link.py +++ b/docs/source/sphinxext/github_link.py @@ -10,7 +10,6 @@ import re import subprocess import sys -from functools import partial from operator import attrgetter orig = inspect.isfunction @@ -18,7 +17,6 @@ # See https://opendreamkit.org/2017/06/09/CythonSphinx/ def isfunction(obj): - orig_val = orig(obj) new_val = hasattr(type(obj), "__code__") @@ -125,7 +123,6 @@ def _linkcode_resolve(domain, info, package, url_fmt, revision): try: lineno = inspect.getsourcelines(obj)[1] except Exception: - # Can happen if its a cyfunction. See if it has `__code__` if hasattr(obj, "__code__"): lineno = obj.__code__.co_firstlineno diff --git a/notebooks/VectorSearch_QuestionRetrieval.ipynb b/notebooks/VectorSearch_QuestionRetrieval.ipynb index 1115a5920d..d93f1e8fd3 100644 --- a/notebooks/VectorSearch_QuestionRetrieval.ipynb +++ b/notebooks/VectorSearch_QuestionRetrieval.ipynb @@ -56,10 +56,13 @@ "import torch\n", "import pylibraft\n", "from cuvs.neighbors import ivf_flat, ivf_pq\n", - "pylibraft.config.set_output_as(lambda device_ndarray: device_ndarray.copy_to_host())\n", + "\n", + "pylibraft.config.set_output_as(\n", + " lambda device_ndarray: device_ndarray.copy_to_host()\n", + ")\n", "\n", "if not torch.cuda.is_available():\n", - " print(\"Warning: No GPU found. Please add GPU to your notebook\")" + " print(\"Warning: No GPU found. Please add GPU to your notebook\")" ] }, { @@ -70,41 +73,51 @@ "outputs": [], "source": [ "# We use the Bi-Encoder to encode all passages, so that we can use it with semantic search\n", - "model_name = 'nq-distilbert-base-v1'\n", + "model_name = \"nq-distilbert-base-v1\"\n", "bi_encoder = SentenceTransformer(model_name)\n", "\n", "# As dataset, we use Simple English Wikipedia. Compared to the full English wikipedia, it has only\n", "# about 170k articles. We split these articles into paragraphs and encode them with the bi-encoder\n", "\n", - "wikipedia_filepath = 'data/simplewiki-2020-11-01.jsonl.gz'\n", + "wikipedia_filepath = \"data/simplewiki-2020-11-01.jsonl.gz\"\n", "\n", "if not os.path.exists(wikipedia_filepath):\n", - " util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz', wikipedia_filepath)\n", + " util.http_get(\n", + " \"http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz\",\n", + " wikipedia_filepath,\n", + " )\n", "\n", "passages = []\n", - "with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn:\n", + "with gzip.open(wikipedia_filepath, \"rt\", encoding=\"utf8\") as fIn:\n", " for line in fIn:\n", " data = json.loads(line.strip())\n", - " for paragraph in data['paragraphs']:\n", + " for paragraph in data[\"paragraphs\"]:\n", " # We encode the passages as [title, text]\n", - " passages.append([data['title'], paragraph])\n", + " passages.append([data[\"title\"], paragraph])\n", "\n", "# If you like, you can also limit the number of passages you want to use\n", "print(\"Passages:\", len(passages))\n", "\n", "# To speed things up, pre-computed embeddings are downloaded.\n", "# The provided file encoded the passages with the model 'nq-distilbert-base-v1'\n", - "if model_name == 'nq-distilbert-base-v1':\n", - " embeddings_filepath = 'simplewiki-2020-11-01-nq-distilbert-base-v1.pt'\n", + "if model_name == \"nq-distilbert-base-v1\":\n", + " embeddings_filepath = \"simplewiki-2020-11-01-nq-distilbert-base-v1.pt\"\n", " if not os.path.exists(embeddings_filepath):\n", - " util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01-nq-distilbert-base-v1.pt', embeddings_filepath)\n", + " util.http_get(\n", + " \"http://sbert.net/datasets/simplewiki-2020-11-01-nq-distilbert-base-v1.pt\",\n", + " embeddings_filepath,\n", + " )\n", "\n", " corpus_embeddings = torch.load(embeddings_filepath)\n", - " corpus_embeddings = corpus_embeddings.float() # Convert embedding file to float\n", + " corpus_embeddings = (\n", + " corpus_embeddings.float()\n", + " ) # Convert embedding file to float\n", " if torch.cuda.is_available():\n", - " corpus_embeddings = corpus_embeddings.to('cuda')\n", + " corpus_embeddings = corpus_embeddings.to(\"cuda\")\n", "else: # Here, we compute the corpus_embeddings from scratch (which can take a while depending on the GPU)\n", - " corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)" + " corpus_embeddings = bi_encoder.encode(\n", + " passages, convert_to_tensor=True, show_progress_bar=True\n", + " )" ] }, { @@ -131,11 +144,14 @@ "pq_index = ivf_pq.build(params, corpus_embeddings)\n", "search_params = ivf_pq.SearchParams()\n", "\n", - "def search_cuvs_pq(query, top_k = 5):\n", + "\n", + "def search_cuvs_pq(query, top_k=5):\n", " # Encode the query using the bi-encoder and find potentially relevant passages\n", " question_embedding = bi_encoder.encode(query, convert_to_tensor=True)\n", "\n", - " hits = ivf_pq.search(search_params, pq_index, question_embedding[None], top_k)\n", + " hits = ivf_pq.search(\n", + " search_params, pq_index, question_embedding[None], top_k\n", + " )\n", "\n", " # Output of top-k hits\n", " print(\"Input question:\", query)\n", @@ -199,7 +215,7 @@ "outputs": [], "source": [ "%%time\n", - "search_cuvs_pq(query = \"What is creating tides?\")" + "search_cuvs_pq(query=\"What is creating tides?\")" ] }, { @@ -214,12 +230,15 @@ "flat_index = ivf_flat.build(params, corpus_embeddings)\n", "search_params = ivf_flat.SearchParams()\n", "\n", - "def search_cuvs_flat(query, top_k = 5):\n", + "\n", + "def search_cuvs_flat(query, top_k=5):\n", " # Encode the query using the bi-encoder and find potentially relevant passages\n", " question_embedding = bi_encoder.encode(query, convert_to_tensor=True)\n", - " \n", + "\n", " start_time = time.time()\n", - " hits = ivf_flat.search(search_params, flat_index, question_embedding[None], top_k)\n", + " hits = ivf_flat.search(\n", + " search_params, flat_index, question_embedding[None], top_k\n", + " )\n", " end_time = time.time()\n", "\n", " # Output of top-k hits\n", @@ -259,7 +278,7 @@ "outputs": [], "source": [ "%%time\n", - "search_cuvs_flat(query = \"What is creating tides?\")" + "search_cuvs_flat(query=\"What is creating tides?\")" ] }, { @@ -304,11 +323,13 @@ "metadata": {}, "outputs": [], "source": [ - "def search_cuvs_cagra(query, top_k = 5):\n", + "def search_cuvs_cagra(query, top_k=5):\n", " # Encode the query using the bi-encoder and find potentially relevant passages\n", " question_embedding = bi_encoder.encode(query, convert_to_tensor=True)\n", "\n", - " hits = cagra.search(search_params, cagra_index, question_embedding[None], top_k)\n", + " hits = cagra.search(\n", + " search_params, cagra_index, question_embedding[None], top_k\n", + " )\n", "\n", " # Output of top-k hits\n", " print(\"Input question:\", query)\n", diff --git a/notebooks/VectorSearch_QuestionRetrieval_Milvus.ipynb b/notebooks/VectorSearch_QuestionRetrieval_Milvus.ipynb index 09a6cca43b..647c9f83b7 100644 --- a/notebooks/VectorSearch_QuestionRetrieval_Milvus.ipynb +++ b/notebooks/VectorSearch_QuestionRetrieval_Milvus.ipynb @@ -84,13 +84,14 @@ "from typing import List\n", "\n", "\n", - "from pymilvus import (\n", - " connections, utility\n", - ")\n", - "from pymilvus.bulk_writer import LocalBulkWriter, BulkFileType # pip install pymilvus[bulk_writer]\n", + "from pymilvus import connections, utility\n", + "from pymilvus.bulk_writer import (\n", + " LocalBulkWriter,\n", + " BulkFileType,\n", + ") # pip install pymilvus[bulk_writer]\n", "\n", "if not torch.cuda.is_available():\n", - " print(\"Warning: No GPU found. Please add GPU to your notebook\")" + " print(\"Warning: No GPU found. Please add GPU to your notebook\")" ] }, { @@ -118,32 +119,45 @@ "DIM = 768\n", "MILVUS_PORT = 30004\n", "MILVUS_HOST = f\"http://localhost:{MILVUS_PORT}\"\n", - "ID_FIELD=\"id\"\n", - "EMBEDDING_FIELD=\"embedding\"\n", + "ID_FIELD = \"id\"\n", + "EMBEDDING_FIELD = \"embedding\"\n", "\n", "collection_name = \"simple_wiki\"\n", "\n", + "\n", "def get_milvus_client():\n", " return pymilvus.MilvusClient(uri=MILVUS_HOST)\n", "\n", + "\n", "client = get_milvus_client()\n", "\n", "fields = [\n", - " pymilvus.FieldSchema(name=ID_FIELD, dtype=pymilvus.DataType.INT64, is_primary=True),\n", - " pymilvus.FieldSchema(name=EMBEDDING_FIELD, dtype=pymilvus.DataType.FLOAT_VECTOR, dim=DIM)\n", + " pymilvus.FieldSchema(\n", + " name=ID_FIELD, dtype=pymilvus.DataType.INT64, is_primary=True\n", + " ),\n", + " pymilvus.FieldSchema(\n", + " name=EMBEDDING_FIELD, dtype=pymilvus.DataType.FLOAT_VECTOR, dim=DIM\n", + " ),\n", "]\n", "\n", "schema = pymilvus.CollectionSchema(fields)\n", "schema.verify()\n", "\n", "if collection_name in client.list_collections():\n", - " print(f\"Collection '{collection_name}' already exists. Deleting collection...\")\n", + " print(\n", + " f\"Collection '{collection_name}' already exists. Deleting collection...\"\n", + " )\n", " client.drop_collection(collection_name)\n", "\n", - "client.create_collection(collection_name, schema=schema, dimension=DIM, vector_field_name=EMBEDDING_FIELD)\n", + "client.create_collection(\n", + " collection_name,\n", + " schema=schema,\n", + " dimension=DIM,\n", + " vector_field_name=EMBEDDING_FIELD,\n", + ")\n", "collection = pymilvus.Collection(name=collection_name, using=client._using)\n", "collection.release()\n", - "collection.drop_index()\n" + "collection.drop_index()" ] }, { @@ -169,40 +183,50 @@ "outputs": [], "source": [ "# We use the Bi-Encoder to encode all passages, so that we can use it with semantic search\n", - "model_name = 'nq-distilbert-base-v1'\n", + "model_name = \"nq-distilbert-base-v1\"\n", "bi_encoder = SentenceTransformer(model_name)\n", "\n", "# As dataset, we use Simple English Wikipedia. Compared to the full English wikipedia, it has only\n", "# about 170k articles. We split these articles into paragraphs and encode them with the bi-encoder\n", "\n", - "wikipedia_filepath = 'data/simplewiki-2020-11-01.jsonl.gz'\n", + "wikipedia_filepath = \"data/simplewiki-2020-11-01.jsonl.gz\"\n", "\n", "if not os.path.exists(wikipedia_filepath):\n", - " util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz', wikipedia_filepath)\n", + " util.http_get(\n", + " \"http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz\",\n", + " wikipedia_filepath,\n", + " )\n", "\n", "passages = []\n", - "with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn:\n", + "with gzip.open(wikipedia_filepath, \"rt\", encoding=\"utf8\") as fIn:\n", " for line in fIn:\n", " data = json.loads(line.strip())\n", - " for paragraph in data['paragraphs']:\n", + " for paragraph in data[\"paragraphs\"]:\n", " # We encode the passages as [title, text]\n", - " passages.append([data['title'], paragraph])\n", + " passages.append([data[\"title\"], paragraph])\n", "\n", "# If you like, you can also limit the number of passages you want to use\n", "print(\"Passages:\", len(passages))\n", "\n", "# To speed things up, pre-computed embeddings are downloaded.\n", "# The provided file encoded the passages with the model 'nq-distilbert-base-v1'\n", - "if model_name == 'nq-distilbert-base-v1':\n", - " embeddings_filepath = 'simplewiki-2020-11-01-nq-distilbert-base-v1.pt'\n", + "if model_name == \"nq-distilbert-base-v1\":\n", + " embeddings_filepath = \"simplewiki-2020-11-01-nq-distilbert-base-v1.pt\"\n", " if not os.path.exists(embeddings_filepath):\n", - " util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01-nq-distilbert-base-v1.pt', embeddings_filepath)\n", + " util.http_get(\n", + " \"http://sbert.net/datasets/simplewiki-2020-11-01-nq-distilbert-base-v1.pt\",\n", + " embeddings_filepath,\n", + " )\n", "\n", - " corpus_embeddings = torch.load(embeddings_filepath, map_location='cpu', weights_only=True).float() # Convert embedding file to float\n", - " #if torch.cuda.is_available():\n", + " corpus_embeddings = torch.load(\n", + " embeddings_filepath, map_location=\"cpu\", weights_only=True\n", + " ).float() # Convert embedding file to float\n", + " # if torch.cuda.is_available():\n", " # corpus_embeddings = corpus_embeddings.to('cuda')\n", "else: # Here, we compute the corpus_embeddings from scratch (which can take a while depending on the GPU)\n", - " corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True).to('cpu')" + " corpus_embeddings = bi_encoder.encode(\n", + " passages, convert_to_tensor=True, show_progress_bar=True\n", + " ).to(\"cpu\")" ] }, { @@ -236,73 +260,109 @@ "MINIO_SECRET_KEY = \"minioadmin\"\n", "MINIO_ACCESS_KEY = \"minioadmin\"\n", "\n", - "def upload_to_minio(file_paths: List[List[str]], remote_paths: List[List[str]], bucket_name=\"milvus-bucket\"):\n", - " minio_client = Minio(endpoint=MINIO_URL, access_key=MINIO_ACCESS_KEY, secret_key=MINIO_SECRET_KEY, secure=False)\n", + "\n", + "def upload_to_minio(\n", + " file_paths: List[List[str]],\n", + " remote_paths: List[List[str]],\n", + " bucket_name=\"milvus-bucket\",\n", + "):\n", + " minio_client = Minio(\n", + " endpoint=MINIO_URL,\n", + " access_key=MINIO_ACCESS_KEY,\n", + " secret_key=MINIO_SECRET_KEY,\n", + " secure=False,\n", + " )\n", " if not minio_client.bucket_exists(bucket_name):\n", " minio_client.make_bucket(bucket_name)\n", "\n", " for local_batch, remote_batch in zip(file_paths, remote_paths):\n", " for local_file, remote_file in zip(local_batch, remote_batch):\n", - " minio_client.fput_object(bucket_name, \n", - " object_name=remote_file,\n", - " file_path=local_file,\n", - " part_size=512 * 1024 * 1024,\n", - " num_parallel_uploads=5)\n", - " \n", - " \n", - "def ingest_data_bulk(collection_name, vectors, schema: pymilvus.CollectionSchema, log_times=True, bulk_writer_type=\"milvus\", debug=False):\n", + " minio_client.fput_object(\n", + " bucket_name,\n", + " object_name=remote_file,\n", + " file_path=local_file,\n", + " part_size=512 * 1024 * 1024,\n", + " num_parallel_uploads=5,\n", + " )\n", + "\n", + "\n", + "def ingest_data_bulk(\n", + " collection_name,\n", + " vectors,\n", + " schema: pymilvus.CollectionSchema,\n", + " log_times=True,\n", + " bulk_writer_type=\"milvus\",\n", + " debug=False,\n", + "):\n", " print(f\"- Ingesting {len(vectors) // 1000}k vectors, Bulk\")\n", " tic = time.perf_counter()\n", - " collection = pymilvus.Collection(collection_name, using=get_milvus_client()._using)\n", + " collection = pymilvus.Collection(\n", + " collection_name, using=get_milvus_client()._using\n", + " )\n", " remote_path = None\n", "\n", - " if bulk_writer_type == 'milvus':\n", + " if bulk_writer_type == \"milvus\":\n", " # # Prepare source data for faster ingestion\n", " writer = LocalBulkWriter(\n", " schema=schema,\n", - " local_path='bulk_data',\n", - " segment_size=512 * 1024 * 1024, # Default value\n", - " file_type=BulkFileType.NPY\n", + " local_path=\"bulk_data\",\n", + " segment_size=512 * 1024 * 1024, # Default value\n", + " file_type=BulkFileType.NPY,\n", " )\n", " for id, vec in enumerate(vectors):\n", " writer.append_row({ID_FIELD: id, EMBEDDING_FIELD: vec})\n", "\n", " if debug:\n", " print(writer.batch_files)\n", + "\n", " def callback(file_list):\n", " if debug:\n", - " print(f\" - Commit successful\")\n", + " print(\" - Commit successful\")\n", " print(file_list)\n", + "\n", " writer.commit(call_back=callback)\n", " files_to_upload = writer.batch_files\n", - " elif bulk_writer_type == 'dask':\n", + " elif bulk_writer_type == \"dask\":\n", " # Prepare source data for faster ingestion\n", " if not os.path.isdir(\"bulk_data\"):\n", " os.mkdir(\"bulk_data\")\n", "\n", " from dask.distributed import Client, LocalCluster\n", + "\n", " cluster = LocalCluster(n_workers=1, threads_per_worker=1)\n", " client = Client(cluster)\n", "\n", " chunk_size = 100000\n", - " da_vectors = da.from_array(vectors, chunks=(chunk_size, vectors.shape[1]))\n", + " da_vectors = da.from_array(\n", + " vectors, chunks=(chunk_size, vectors.shape[1])\n", + " )\n", " da_ids = da.arange(len(vectors), chunks=(chunk_size,))\n", " da.to_npy_stack(\"bulk_data/da_embedding/\", da_vectors)\n", " da.to_npy_stack(\"bulk_data/da_id/\", da_ids)\n", " files_to_upload = []\n", " remote_path = []\n", " for chunk_nb in range(math.ceil(len(vectors) / chunk_size)):\n", - " files_to_upload.append([f\"bulk_data/da_embedding/{chunk_nb}.npy\", f\"bulk_data/da_id/{chunk_nb}.npy\"])\n", - " remote_path.append([f\"bulk_data/da_{chunk_nb}/embedding.npy\", f\"bulk_data/da__{chunk_nb}/id.npy\"])\n", + " files_to_upload.append(\n", + " [\n", + " f\"bulk_data/da_embedding/{chunk_nb}.npy\",\n", + " f\"bulk_data/da_id/{chunk_nb}.npy\",\n", + " ]\n", + " )\n", + " remote_path.append(\n", + " [\n", + " f\"bulk_data/da_{chunk_nb}/embedding.npy\",\n", + " f\"bulk_data/da__{chunk_nb}/id.npy\",\n", + " ]\n", + " )\n", "\n", - " elif bulk_writer_type == 'numpy':\n", + " elif bulk_writer_type == \"numpy\":\n", " # Directly save NPY files\n", " np.save(\"bulk_data/embedding.npy\", vectors)\n", " np.save(\"bulk_data/id.npy\", np.arange(len(vectors)))\n", " files_to_upload = [[\"bulk_data/embedding.npy\", \"bulk_data/id.npy\"]]\n", " else:\n", " raise ValueError(\"Invalid bulk writer type\")\n", - " \n", + "\n", " toc = time.perf_counter()\n", " if log_times:\n", " print(f\" - File save time: {toc - tic:.2f} seconds\")\n", @@ -310,17 +370,29 @@ " if remote_path is None:\n", " remote_path = files_to_upload\n", " upload_to_minio(files_to_upload, remote_path)\n", - " \n", - " job_ids = [utility.do_bulk_insert(collection_name, batch, using=get_milvus_client()._using) for batch in remote_path]\n", + "\n", + " job_ids = [\n", + " utility.do_bulk_insert(\n", + " collection_name, batch, using=get_milvus_client()._using\n", + " )\n", + " for batch in remote_path\n", + " ]\n", "\n", " while True:\n", - " tasks = [utility.get_bulk_insert_state(job_id, using=get_milvus_client()._using) for job_id in job_ids]\n", + " tasks = [\n", + " utility.get_bulk_insert_state(\n", + " job_id, using=get_milvus_client()._using\n", + " )\n", + " for job_id in job_ids\n", + " ]\n", " success = all(task.state_name == \"Completed\" for task in tasks)\n", " failure = any(task.state_name == \"Failed\" for task in tasks)\n", " for i in range(len(tasks)):\n", " task = tasks[i]\n", " if debug:\n", - " print(f\" - Task {i}/{len(tasks)} state: {task.state_name}, Progress percent: {task.infos['progress_percent']}, Imported row count: {task.row_count}\")\n", + " print(\n", + " f\" - Task {i}/{len(tasks)} state: {task.state_name}, Progress percent: {task.infos['progress_percent']}, Imported row count: {task.row_count}\"\n", + " )\n", " if task.state_name == \"Failed\":\n", " print(task)\n", " if success or failure:\n", @@ -334,9 +406,18 @@ " toc = time.perf_counter()\n", " if log_times:\n", " datasize = vectors.nbytes / 1024 / 1024\n", - " print(f\"- Ingestion time: {toc - tic:.2f} seconds. ({(datasize / (toc-tic)):.2f}MB/s)\")\n", + " print(\n", + " f\"- Ingestion time: {toc - tic:.2f} seconds. ({(datasize / (toc - tic)):.2f}MB/s)\"\n", + " )\n", + "\n", "\n", - "ingest_data_bulk(collection_name, np.array(corpus_embeddings), schema, bulk_writer_type='dask', log_times=True)" + "ingest_data_bulk(\n", + " collection_name,\n", + " np.array(corpus_embeddings),\n", + " schema,\n", + " bulk_writer_type=\"dask\",\n", + " log_times=True,\n", + ")" ] }, { @@ -358,8 +439,11 @@ "index_params = dict(\n", " index_type=\"GPU_IVF_PQ\",\n", " metric_type=\"L2\",\n", - " params={\"nlist\": 150, # Number of clusters\n", - " \"m\": 96}) # Product Quantization dimension\n", + " params={\n", + " \"nlist\": 150, # Number of clusters\n", + " \"m\": 96,\n", + " },\n", + ") # Product Quantization dimension\n", "\n", "# Drop the index if it exists\n", "if collection.has_index():\n", @@ -389,22 +473,27 @@ "outputs": [], "source": [ "# Search the index\n", - "def search_cuvs_pq(query, top_k = 5, n_probe = 30):\n", + "def search_cuvs_pq(query, top_k=5, n_probe=30):\n", " # Encode the query using the bi-encoder and find potentially relevant passages\n", " question_embedding = bi_encoder.encode(query, convert_to_tensor=True)\n", "\n", " search_params = {\"nprobe\": n_probe}\n", " tic = time.perf_counter()\n", " hits = collection.search(\n", - " data=np.array(question_embedding[None].cpu()), anns_field=EMBEDDING_FIELD, param=search_params, limit=top_k\n", - " )\n", + " data=np.array(question_embedding[None].cpu()),\n", + " anns_field=EMBEDDING_FIELD,\n", + " param=search_params,\n", + " limit=top_k,\n", + " )\n", " toc = time.perf_counter()\n", "\n", " # Output of top-k hits\n", " print(\"Input question:\", query)\n", - " print(\"Results (after {:.3f} ms):\".format((toc - tic)*1000))\n", + " print(\"Results (after {:.3f} ms):\".format((toc - tic) * 1000))\n", " for k in range(top_k):\n", - " print(\"\\t{:.3f}\\t{}\".format(hits[0][k].distance, passages[hits[0][k].id]))" + " print(\n", + " \"\\t{:.3f}\\t{}\".format(hits[0][k].distance, passages[hits[0][k].id])\n", + " )" ] }, { @@ -463,7 +552,7 @@ }, "outputs": [], "source": [ - "search_cuvs_pq(query = \"What is creating tides?\")" + "search_cuvs_pq(query=\"What is creating tides?\")" ] }, { @@ -487,9 +576,8 @@ "\n", "# Create the IVF Flat index\n", "index_params = dict(\n", - " index_type=\"GPU_IVF_FLAT\",\n", - " metric_type=\"L2\",\n", - " params={\"nlist\": 150}) # Number of clusters)\n", + " index_type=\"GPU_IVF_FLAT\", metric_type=\"L2\", params={\"nlist\": 150}\n", + ") # Number of clusters)\n", "tic = time.perf_counter()\n", "collection.create_index(field_name=EMBEDDING_FIELD, index_params=index_params)\n", "collection.load()\n", @@ -511,22 +599,27 @@ }, "outputs": [], "source": [ - "def search_cuvs_flat(query, top_k = 5, n_probe = 30):\n", + "def search_cuvs_flat(query, top_k=5, n_probe=30):\n", " # Encode the query using the bi-encoder and find potentially relevant passages\n", " question_embedding = bi_encoder.encode(query, convert_to_tensor=True)\n", - " \n", + "\n", " search_params = {\"nprobe\": n_probe}\n", " tic = time.perf_counter()\n", " hits = collection.search(\n", - " data=np.array(question_embedding[None].cpu()), anns_field=EMBEDDING_FIELD, param=search_params, limit=top_k\n", - " )\n", + " data=np.array(question_embedding[None].cpu()),\n", + " anns_field=EMBEDDING_FIELD,\n", + " param=search_params,\n", + " limit=top_k,\n", + " )\n", " toc = time.perf_counter()\n", "\n", " # Output of top-k hits\n", " print(\"Input question:\", query)\n", - " print(\"Results (after {:.3f} ms):\".format((toc - tic)*1000))\n", + " print(\"Results (after {:.3f} ms):\".format((toc - tic) * 1000))\n", " for k in range(top_k):\n", - " print(\"\\t{:.3f}\\t{}\".format(hits[0][k].distance, passages[hits[0][k].id]))" + " print(\n", + " \"\\t{:.3f}\\t{}\".format(hits[0][k].distance, passages[hits[0][k].id])\n", + " )" ] }, { @@ -577,7 +670,7 @@ }, "outputs": [], "source": [ - "search_cuvs_flat(query = \"What is creating tides?\")" + "search_cuvs_flat(query=\"What is creating tides?\")" ] }, { @@ -616,7 +709,13 @@ "index_params = dict(\n", " index_type=\"GPU_CAGRA\",\n", " metric_type=\"L2\",\n", - " params={\"graph_degree\": 64, \"intermediate_graph_degree\": 128, \"build_algo\": \"NN_DESCENT\", \"adapt_for_cpu\": True})\n", + " params={\n", + " \"graph_degree\": 64,\n", + " \"intermediate_graph_degree\": 128,\n", + " \"build_algo\": \"NN_DESCENT\",\n", + " \"adapt_for_cpu\": True,\n", + " },\n", + ")\n", "tic = time.perf_counter()\n", "collection.create_index(field_name=EMBEDDING_FIELD, index_params=index_params)\n", "collection.load()\n", @@ -638,22 +737,27 @@ }, "outputs": [], "source": [ - "def search_cuvs_cagra(query, top_k = 5, itopk = 32):\n", + "def search_cuvs_cagra(query, top_k=5, itopk=32):\n", " # Encode the query using the bi-encoder and find potentially relevant passages\n", " question_embedding = bi_encoder.encode(query, convert_to_tensor=True)\n", "\n", " search_params = {\"params\": {\"itopk\": itopk, \"ef\": 35}}\n", " tic = time.perf_counter()\n", " hits = collection.search(\n", - " data=np.array(question_embedding[None].cpu()), anns_field=EMBEDDING_FIELD, param=search_params, limit=top_k\n", - " )\n", + " data=np.array(question_embedding[None].cpu()),\n", + " anns_field=EMBEDDING_FIELD,\n", + " param=search_params,\n", + " limit=top_k,\n", + " )\n", " toc = time.perf_counter()\n", "\n", " # Output of top-k hits\n", " print(\"Input question:\", query)\n", - " print(\"Results (after {:.3f} ms):\".format((toc - tic)*1000))\n", + " print(\"Results (after {:.3f} ms):\".format((toc - tic) * 1000))\n", " for k in range(top_k):\n", - " print(\"\\t{:.3f}\\t{}\".format(hits[0][k].distance, passages[hits[0][k].id]))" + " print(\n", + " \"\\t{:.3f}\\t{}\".format(hits[0][k].distance, passages[hits[0][k].id])\n", + " )" ] }, { diff --git a/notebooks/cuvs_hpo_example.ipynb b/notebooks/cuvs_hpo_example.ipynb index 964110cb76..e333f8fa9f 100644 --- a/notebooks/cuvs_hpo_example.ipynb +++ b/notebooks/cuvs_hpo_example.ipynb @@ -20,7 +20,7 @@ }, "outputs": [], "source": [ - "#Install Required Packages\n", + "# Install Required Packages\n", "%mamba install -c rapidsai-nightly -c conda-forge cuvs optuna -y\n", "%pip install cupy" ] @@ -38,13 +38,12 @@ "import numpy as np\n", "from cuvs.neighbors import ivf_flat\n", "import urllib.request\n", - "import numpy as np\n", "import time\n", "import optuna\n", "from utils import calc_recall\n", "from optuna.visualization import plot_optimization_history\n", "import math\n", - "import os\n" + "import os" ] }, { @@ -65,21 +64,24 @@ "outputs": [], "source": [ "import tarfile\n", + "\n", "home_dir = os.path.expanduser(\"~/\")\n", - "#wiki-all datasets are in tar format\n", + "\n", + "\n", + "# wiki-all datasets are in tar format\n", "def download_files(url, file):\n", " if os.path.exists(home_dir + \"/\" + file):\n", " print(\"tar file is already downloaded\")\n", " else:\n", " urllib.request.urlretrieve(url, home_dir + \"/\" + file)\n", " # Open the .tar file\n", - " with tarfile.open(home_dir + \"/\" + file, 'r') as tar:\n", + " with tarfile.open(home_dir + \"/\" + file, \"r\") as tar:\n", " filename = file.split(\".\")[0]\n", " if os.path.exists(home_dir + \"/\" + filename + \"/\"):\n", " print(\"Files already extracted\")\n", " return home_dir + \"/\" + filename + \"/\"\n", " # Extract all contents into the specified directory\n", - " extract_path=home_dir + \"/\" +file.split(\".\")[0]\n", + " extract_path = home_dir + \"/\" + file.split(\".\")[0]\n", " tar.extractall(extract_path)\n", " return extract_path" ] @@ -102,7 +104,10 @@ } ], "source": [ - "extracted_path=download_files('https://data.rapids.ai/raft/datasets/wiki_all_1M/wiki_all_1M.tar', 'wiki_all_1M.tar')" + "extracted_path = download_files(\n", + " \"https://data.rapids.ai/raft/datasets/wiki_all_1M/wiki_all_1M.tar\",\n", + " \"wiki_all_1M.tar\",\n", + ")" ] }, { @@ -136,8 +141,8 @@ "source": [ "def read_data(file_path, dtype):\n", " with open(file_path, \"rb\") as f:\n", - " rows,cols = np.fromfile(f, count=2, dtype= np.int32)\n", - " d = np.fromfile(f,count=rows*cols,dtype=dtype).reshape(rows, cols)\n", + " rows, cols = np.fromfile(f, count=2, dtype=np.int32)\n", + " d = np.fromfile(f, count=rows * cols, dtype=dtype).reshape(rows, cols)\n", " return cp.asarray(d)" ] }, @@ -150,9 +155,11 @@ }, "outputs": [], "source": [ - "vectors= read_data(extracted_path + \"/base.1M.fbin\",np.float32)\n", - "queries = read_data(extracted_path + \"/queries.fbin\",np.float32)\n", - "gt_neighbors = read_data(extracted_path + \"/groundtruth.1M.neighbors.ibin\",np.int32)" + "vectors = read_data(extracted_path + \"/base.1M.fbin\", np.float32)\n", + "queries = read_data(extracted_path + \"/queries.fbin\", np.float32)\n", + "gt_neighbors = read_data(\n", + " extracted_path + \"/groundtruth.1M.neighbors.ibin\", np.int32\n", + ")" ] }, { @@ -164,7 +171,7 @@ }, "outputs": [], "source": [ - "#Get the dataset size of database vectors\n", + "# Get the dataset size of database vectors\n", "dataset_size = vectors.shape[0]\n", "dim = vectors.shape[1]" ] @@ -190,26 +197,26 @@ "source": [ "def visualization(study_obj):\n", " \"\"\"\n", - " This function creates two Pareto front plots to visualize trade-offs between different \n", - " optimization objectives. The plots help in understanding the balance between competing \n", + " This function creates two Pareto front plots to visualize trade-offs between different\n", + " optimization objectives. The plots help in understanding the balance between competing\n", " objectives in the optimization process.\n", "\n", " Args:\n", " study_obj (optuna.Study): The Optuna study object containing the optimization results.\n", "\n", " The function produces the following plots:\n", - " 1. **Figure 1**: A Pareto front plot showing the trade-off between `build_time_in_secs` \n", - " and `recall`. It visualizes how the optimization process balances the build time \n", + " 1. **Figure 1**: A Pareto front plot showing the trade-off between `build_time_in_secs`\n", + " and `recall`. It visualizes how the optimization process balances the build time\n", " and recall score.\n", - " 2. **Figure 2**: A Pareto front plot showing the trade-off between `latency_in_ms` \n", + " 2. **Figure 2**: A Pareto front plot showing the trade-off between `latency_in_ms`\n", " and `recall`. This plot illustrates the relationship between latency and recall score.\n", - " \n", + "\n", " \"\"\"\n", - " \n", + "\n", " fig1 = optuna.visualization.plot_pareto_front(\n", - " study_obj,\n", - " targets=lambda t: (t.values[0], t.values[2]),\n", - " target_names=[\"build_time_in_secs\", \"recall\"],\n", + " study_obj,\n", + " targets=lambda t: (t.values[0], t.values[2]),\n", + " target_names=[\"build_time_in_secs\", \"recall\"],\n", " )\n", " fig1.show()\n", "\n", @@ -234,13 +241,14 @@ " print(f\"\\tnumber: {target_instance.number}\")\n", " print(f\"\\tparams: {target_instance.params}\")\n", " print(f\"\\tvalues: {target_instance.values}\")\n", - " \n", + "\n", + "\n", "def print_best_trial_values(optuna_study):\n", " \"\"\"\n", " Prints details about the trials on the Pareto front of an Optuna study.\n", "\n", - " This function analyzes the best trials from an Optuna study, which are typically \n", - " those with the most favorable trade-offs among multiple objectives. It prints \n", + " This function analyzes the best trials from an Optuna study, which are typically\n", + " those with the most favorable trade-offs among multiple objectives. It prints\n", " information on three specific metrics:\n", "\n", " 1. The number of trials on the Pareto front.\n", @@ -256,20 +264,28 @@ " - `values[0]`: Build time\n", " - `values[1]`: latency\n", " - `values[2]`: Accuracy\n", - " \n", + "\n", " \"\"\"\n", - " print(f\"Number of trials on the Pareto front: {len(optuna_study.best_trials)}\")\n", + " print(\n", + " f\"Number of trials on the Pareto front: {len(optuna_study.best_trials)}\"\n", + " )\n", "\n", - " trial_with_lowest_build_time = min(optuna_study.best_trials, key=lambda t: t.values[0])\n", - " print(f\"Trial with lowest build time in secs: \")\n", + " trial_with_lowest_build_time = min(\n", + " optuna_study.best_trials, key=lambda t: t.values[0]\n", + " )\n", + " print(\"Trial with lowest build time in secs: \")\n", " print_target_instance_summary(trial_with_lowest_build_time)\n", "\n", - " trial_with_lowest_latency = min(optuna_study.best_trials, key=lambda t: t.values[1])\n", - " print(f\"Trial with lowest latency in ms: \")\n", + " trial_with_lowest_latency = min(\n", + " optuna_study.best_trials, key=lambda t: t.values[1]\n", + " )\n", + " print(\"Trial with lowest latency in ms: \")\n", " print_target_instance_summary(trial_with_lowest_latency)\n", - " \n", - " trial_with_highest_accuracy = max(optuna_study.best_trials, key=lambda t: t.values[2])\n", - " print(f\"Trial with highest accuracy: \")\n", + "\n", + " trial_with_highest_accuracy = max(\n", + " optuna_study.best_trials, key=lambda t: t.values[2]\n", + " )\n", + " print(\"Trial with highest accuracy: \")\n", " print_target_instance_summary(trial_with_highest_accuracy)" ] }, @@ -312,9 +328,9 @@ "\n", " \"\"\"\n", " # Suggest an integer for the number of lists\n", - " n_lists = trial.suggest_int(\"n_lists\", 10, dataset_size*0.1)\n", + " n_lists = trial.suggest_int(\"n_lists\", 10, dataset_size * 0.1)\n", " # Suggest an integer for the number of probes\n", - " n_probes = trial.suggest_int(\"n_probes\",n_lists*0.01 , n_lists*0.1)\n", + " n_probes = trial.suggest_int(\"n_probes\", n_lists * 0.01, n_lists * 0.1)\n", " build_params = ivf_flat.IndexParams(\n", " n_lists=n_lists,\n", " )\n", @@ -328,12 +344,16 @@ " start_search_time = time.time()\n", " distances, indices = ivf_flat.search(search_params, index, queries, k=10)\n", " search_time = time.time() - start_search_time\n", - " \n", - " latency_in_ms = (search_time * 1000)/queries.shape[0]\n", - " \n", + "\n", + " latency_in_ms = (search_time * 1000) / queries.shape[0]\n", + "\n", " found_distances, found_indices = cp.asnumpy(distances), cp.asnumpy(indices)\n", " recall = calc_recall(found_indices, gt_neighbors)\n", - " return round(build_time_in_secs,4), round(latency_in_ms,4), round(recall,4)" + " return (\n", + " round(build_time_in_secs, 4),\n", + " round(latency_in_ms, 4),\n", + " round(recall, 4),\n", + " )" ] }, { @@ -363,7 +383,9 @@ } ], "source": [ - "ivf_flat_study = optuna.create_study(directions=['minimize', 'minimize', 'maximize'])\n", + "ivf_flat_study = optuna.create_study(\n", + " directions=[\"minimize\", \"minimize\", \"maximize\"]\n", + ")\n", "ivf_flat_study.optimize(multi_objective_ivf_flat, n_trials=10)" ] }, @@ -1419,7 +1441,7 @@ "yaxis": { "autorange": true, "range": [ - 0.9929718446601943, + 0.9929718446601944, 1.0004281553398058 ], "title": { @@ -2473,7 +2495,7 @@ "yaxis": { "autorange": true, "range": [ - 0.9929718446601943, + 0.9929718446601944, 1.0004281553398058 ], "title": { @@ -2537,7 +2559,7 @@ }, "outputs": [], "source": [ - "from cuvs.neighbors import ivf_pq,refine" + "from cuvs.neighbors import ivf_pq, refine" ] }, { @@ -2555,15 +2577,15 @@ "\n", " \"\"\"\n", " # Suggest values for build parameters\n", - " pq_dim = trial.suggest_int(\"pq_dim\", dim*0.25, dim, step=2)\n", + " pq_dim = trial.suggest_int(\"pq_dim\", dim * 0.25, dim, step=2)\n", " n_lists = 1000\n", "\n", " # Suggest an integer for the number of probes\n", - " n_probes = trial.suggest_int(\"n_probes\",n_lists*0.01 , n_lists*0.1)\n", + " n_probes = trial.suggest_int(\"n_probes\", n_lists * 0.01, n_lists * 0.1)\n", "\n", " build_params = ivf_pq.IndexParams(\n", - " n_lists=n_lists,\n", - " pq_dim=pq_dim,\n", + " n_lists=n_lists,\n", + " pq_dim=pq_dim,\n", " )\n", "\n", " start_build_time = time.time()\n", @@ -2578,12 +2600,16 @@ " distances, indices = ivf_pq.search(search_params, index, queries, k=10)\n", " search_time = time.time() - start_search_time\n", "\n", - " latency_in_ms = (search_time * 1000)/queries.shape[0]\n", + " latency_in_ms = (search_time * 1000) / queries.shape[0]\n", "\n", " found_distances, found_indices = cp.asnumpy(distances), cp.asnumpy(indices)\n", " recall = calc_recall(found_indices, gt_neighbors)\n", "\n", - " return round(build_time_in_secs,4), round(latency_in_ms, 4), round(recall,4)" + " return (\n", + " round(build_time_in_secs, 4),\n", + " round(latency_in_ms, 4),\n", + " round(recall, 4),\n", + " )" ] }, { @@ -2741,7 +2767,9 @@ } ], "source": [ - "ivf_pq_study = optuna.create_study(directions=['minimize', 'minimize', 'maximize'])\n", + "ivf_pq_study = optuna.create_study(\n", + " directions=[\"minimize\", \"minimize\", \"maximize\"]\n", + ")\n", "ivf_pq_study.optimize(multi_objective_ivf_pq, n_trials=10)" ] }, @@ -4907,7 +4935,7 @@ "metadata": {}, "outputs": [], "source": [ - "from cuvs.neighbors import cagra\n" + "from cuvs.neighbors import cagra" ] }, { @@ -4923,13 +4951,15 @@ "\n", " \"\"\"\n", " # Suggest values for build parameters\n", - " intermediate_graph_degree = trial.suggest_int(\"intermediate_graph_degree\", 64, 128, step=2 )\n", + " intermediate_graph_degree = trial.suggest_int(\n", + " \"intermediate_graph_degree\", 64, 128, step=2\n", + " )\n", "\n", " # Suggest an integer for the number of probes\n", " itopk_size = trial.suggest_int(\"itopk_size\", 64, 128, step=2)\n", "\n", " build_params = cagra.IndexParams(\n", - " intermediate_graph_degree=intermediate_graph_degree\n", + " intermediate_graph_degree=intermediate_graph_degree\n", " )\n", "\n", " start_build_time = time.time()\n", @@ -4941,15 +4971,21 @@ "\n", " # perform search and refine to increase recall/accuracy\n", " start_search_time = time.time()\n", - " distances, indices = cagra.search(search_params, cagra_index, queries, k=10)\n", + " distances, indices = cagra.search(\n", + " search_params, cagra_index, queries, k=10\n", + " )\n", " search_time = time.time() - start_search_time\n", "\n", - " latency_in_ms = (search_time * 1000)/queries.shape[0]\n", + " latency_in_ms = (search_time * 1000) / queries.shape[0]\n", "\n", " found_distances, found_indices = cp.asnumpy(distances), cp.asnumpy(indices)\n", " recall = calc_recall(found_indices, gt_neighbors)\n", "\n", - " return round(build_time_in_secs,4), round(latency_in_ms,4), round(recall,4)" + " return (\n", + " round(build_time_in_secs, 4),\n", + " round(latency_in_ms, 4),\n", + " round(recall, 4),\n", + " )" ] }, { @@ -5047,7 +5083,9 @@ } ], "source": [ - "cagra_study = optuna.create_study(directions=['minimize', 'minimize', 'maximize'])\n", + "cagra_study = optuna.create_study(\n", + " directions=[\"minimize\", \"minimize\", \"maximize\"]\n", + ")\n", "cagra_study.optimize(multi_objective_cagra, n_trials=5)" ] }, diff --git a/notebooks/ivf_flat_example.ipynb b/notebooks/ivf_flat_example.ipynb index ce35866833..99f1d33626 100644 --- a/notebooks/ivf_flat_example.ipynb +++ b/notebooks/ivf_flat_example.ipynb @@ -53,9 +53,9 @@ "source": [ "import rmm\n", "from rmm.allocators.cupy import rmm_cupy_allocator\n", + "\n", "mr = rmm.mr.PoolMemoryResource(\n", - " rmm.mr.CudaMemoryResource(),\n", - " initial_pool_size=2**30\n", + " rmm.mr.CudaMemoryResource(), initial_pool_size=2**30\n", ")\n", "rmm.mr.set_current_device_resource(mr)\n", "cp.cuda.set_allocator(rmm_cupy_allocator)" @@ -100,7 +100,10 @@ "outputs": [], "source": [ "WORK_FOLDER = os.path.join(tempfile.gettempdir(), \"cuvs_example\")\n", - "f = load_dataset(\"http://ann-benchmarks.com/sift-128-euclidean.hdf5\", work_folder=WORK_FOLDER)" + "f = load_dataset(\n", + " \"http://ann-benchmarks.com/sift-128-euclidean.hdf5\",\n", + " work_folder=WORK_FOLDER,\n", + ")" ] }, { @@ -110,16 +113,18 @@ "metadata": {}, "outputs": [], "source": [ - "metric = f.attrs['distance']\n", + "metric = f.attrs[\"distance\"]\n", "\n", - "dataset = cp.asarray(f['train'])\n", - "queries = cp.asarray(f['test'])\n", - "gt_neighbors = cp.asarray(f['neighbors'][:])\n", - "gt_distances = cp.asarray(f['distances'][:])\n", + "dataset = cp.asarray(f[\"train\"])\n", + "queries = cp.asarray(f[\"test\"])\n", + "gt_neighbors = cp.asarray(f[\"neighbors\"][:])\n", + "gt_distances = cp.asarray(f[\"distances\"][:])\n", "\n", - "itemsize = dataset.dtype.itemsize \n", + "itemsize = dataset.dtype.itemsize\n", "\n", - "print(f\"Loaded dataset of size {dataset.shape}, {dataset.size*itemsize/(1<<30):4.1f} GiB; metric: '{metric}'.\")\n", + "print(\n", + " f\"Loaded dataset of size {dataset.shape}, {dataset.size * itemsize / (1 << 30):4.1f} GiB; metric: '{metric}'.\"\n", + ")\n", "print(f\"Number of test queries: {queries.shape[0]}\")" ] }, @@ -141,12 +146,12 @@ "source": [ "%%time\n", "build_params = ivf_flat.IndexParams(\n", - " n_lists=1024,\n", - " metric=\"euclidean\",\n", - " kmeans_trainset_fraction=0.1,\n", - " kmeans_n_iters=20,\n", - " add_data_on_build=True\n", - " )\n", + " n_lists=1024,\n", + " metric=\"euclidean\",\n", + " kmeans_trainset_fraction=0.1,\n", + " kmeans_n_iters=20,\n", + " add_data_on_build=True,\n", + ")\n", "\n", "index = ivf_flat.build(build_params, dataset)" ] @@ -211,13 +216,19 @@ "outputs": [], "source": [ "%%time\n", - "n_queries=10000\n", + "n_queries = 10000\n", "# n_probes is the number of clusters we select in the first (coarse) search step. This is the only hyper parameter for search.\n", "search_params = ivf_flat.SearchParams(n_probes=30)\n", "\n", "# Search 10 nearest neighbors.\n", - "distances, indices = ivf_flat.search(search_params, index, cp.asarray(queries[:n_queries,:]), k=10, resources=handle)\n", - " \n", + "distances, indices = ivf_flat.search(\n", + " search_params,\n", + " index,\n", + " cp.asarray(queries[:n_queries, :]),\n", + " k=10,\n", + " resources=handle,\n", + ")\n", + "\n", "# cuVS calls are asynchronous (when handle arg is provided), we need to sync before accessing the results.\n", "handle.sync()\n", "distances, neighbors = cp.asnumpy(distances), cp.asnumpy(indices)" @@ -289,9 +300,9 @@ "metadata": {}, "outputs": [], "source": [ - "n_probes = np.asarray([10, 20, 30, 50, 100, 200, 500, 1024]);\n", - "qps = np.zeros(n_probes.shape);\n", - "recall = np.zeros(n_probes.shape);\n", + "n_probes = np.asarray([10, 20, 30, 50, 100, 200, 500, 1024])\n", + "qps = np.zeros(n_probes.shape)\n", + "recall = np.zeros(n_probes.shape)\n", "\n", "for i in range(len(n_probes)):\n", " print(\"\\nBenchmarking search with n_probes =\", n_probes[i])\n", @@ -305,7 +316,7 @@ " resources=handle,\n", " )\n", " handle.sync()\n", - " \n", + "\n", " recall[i] = calc_recall(cp.asnumpy(neighbors), gt_neighbors)\n", " print(\"recall\", recall[i])\n", "\n", @@ -313,7 +324,11 @@ " avg_time = timings.mean()\n", " std_time = timings.std()\n", " qps[i] = queries.shape[0] / avg_time\n", - " print(\"Average search time: {0:7.3f} +/- {1:7.3} s\".format(avg_time, std_time))\n", + " print(\n", + " \"Average search time: {0:7.3f} +/- {1:7.3} s\".format(\n", + " avg_time, std_time\n", + " )\n", + " )\n", " print(\"Queries per second (QPS): {0:8.0f}\".format(qps[i]))" ] }, @@ -332,28 +347,28 @@ "metadata": {}, "outputs": [], "source": [ - "fig = plt.figure(figsize=(12,3))\n", + "fig = plt.figure(figsize=(12, 3))\n", "ax = fig.add_subplot(131)\n", - "ax.plot(n_probes, recall,'o-')\n", - "#ax.set_xticks(bench_k, bench_k)\n", - "ax.set_xlabel('n_probes')\n", + "ax.plot(n_probes, recall, \"o-\")\n", + "# ax.set_xticks(bench_k, bench_k)\n", + "ax.set_xlabel(\"n_probes\")\n", "ax.grid()\n", - "ax.set_ylabel('recall (@k=10)')\n", + "ax.set_ylabel(\"recall (@k=10)\")\n", "\n", "ax = fig.add_subplot(132)\n", - "ax.plot(n_probes, qps,'o-')\n", - "#ax.set_xticks(bench_k, bench_k)\n", - "ax.set_xlabel('n_probes')\n", + "ax.plot(n_probes, qps, \"o-\")\n", + "# ax.set_xticks(bench_k, bench_k)\n", + "ax.set_xlabel(\"n_probes\")\n", "ax.grid()\n", - "ax.set_ylabel('queries per second');\n", + "ax.set_ylabel(\"queries per second\")\n", "\n", "ax = fig.add_subplot(133)\n", - "ax.plot(recall, qps,'o-')\n", - "#ax.set_xticks(bench_k, bench_k)\n", - "ax.set_xlabel('recall')\n", + "ax.plot(recall, qps, \"o-\")\n", + "# ax.set_xticks(bench_k, bench_k)\n", + "ax.set_xlabel(\"recall\")\n", "ax.grid()\n", - "ax.set_ylabel('queries per second');\n", - "#ax.set_yscale('log')" + "ax.set_ylabel(\"queries per second\");\n", + "# ax.set_yscale('log')" ] }, { @@ -375,12 +390,12 @@ "source": [ "%%time\n", "build_params = ivf_flat.IndexParams(\n", - " n_lists=100,\n", - " metric=\"euclidean\",\n", - " kmeans_trainset_fraction=1,\n", - " kmeans_n_iters=20,\n", - " add_data_on_build=True\n", - " )\n", + " n_lists=100,\n", + " metric=\"euclidean\",\n", + " kmeans_trainset_fraction=1,\n", + " kmeans_n_iters=20,\n", + " add_data_on_build=True,\n", + ")\n", "\n", "index = ivf_flat.build(build_params, dataset, resources=handle)" ] @@ -401,13 +416,19 @@ "outputs": [], "source": [ "%%time\n", - "n_queries=10000\n", + "n_queries = 10000\n", "\n", "search_params = ivf_flat.SearchParams(n_probes=10)\n", "\n", "# Search 10 nearest neighbors.\n", - "distances, indices = ivf_flat.search(search_params, index, cp.asarray(queries[:n_queries,:]), k=10, resources=handle)\n", - " \n", + "distances, indices = ivf_flat.search(\n", + " search_params,\n", + " index,\n", + " cp.asarray(queries[:n_queries, :]),\n", + " k=10,\n", + " resources=handle,\n", + ")\n", + "\n", "handle.sync()\n", "distances, neighbors = cp.asnumpy(distances), cp.asnumpy(indices)" ] @@ -439,12 +460,12 @@ "outputs": [], "source": [ "%%time\n", - "build_params = ivf_flat.IndexParams( \n", - " n_lists=100, \n", - " metric=\"sqeuclidean\", \n", - " kmeans_trainset_fraction=0.1, \n", - " kmeans_n_iters=20 \n", - " ) \n", + "build_params = ivf_flat.IndexParams(\n", + " n_lists=100,\n", + " metric=\"sqeuclidean\",\n", + " kmeans_trainset_fraction=0.1,\n", + " kmeans_n_iters=20,\n", + ")\n", "index = ivf_flat.build(build_params, dataset, resources=handle)" ] }, @@ -465,8 +486,14 @@ "source": [ "search_params = ivf_flat.SearchParams(n_probes=10)\n", "\n", - "distances, indices = ivf_flat.search(search_params, index, cp.asarray(queries[:n_queries,:]), k=10, resources=handle)\n", - " \n", + "distances, indices = ivf_flat.search(\n", + " search_params,\n", + " index,\n", + " cp.asarray(queries[:n_queries, :]),\n", + " k=10,\n", + " resources=handle,\n", + ")\n", + "\n", "handle.sync()\n", "distances, neighbors = cp.asnumpy(distances), cp.asnumpy(indices)\n", "calc_recall(neighbors, gt_neighbors)" @@ -494,16 +521,18 @@ "source": [ "# subsample the dataset\n", "n_train = 10000\n", - "train_set = dataset[cp.random.choice(dataset.shape[0], n_train, replace=False),:]\n", + "train_set = dataset[\n", + " cp.random.choice(dataset.shape[0], n_train, replace=False), :\n", + "]\n", "\n", "# build using training set\n", "build_params = ivf_flat.IndexParams(\n", - " n_lists=1024,\n", - " metric=\"sqeuclidean\",\n", - " kmeans_trainset_fraction=1,\n", - " kmeans_n_iters=20,\n", - " add_data_on_build=False\n", - " )\n", + " n_lists=1024,\n", + " metric=\"sqeuclidean\",\n", + " kmeans_trainset_fraction=1,\n", + " kmeans_n_iters=20,\n", + " add_data_on_build=False,\n", + ")\n", "index = ivf_flat.build(build_params, train_set)\n", "\n", "print(\"Index before adding vectors\", index)\n", diff --git a/notebooks/tutorial_ivf_pq.ipynb b/notebooks/tutorial_ivf_pq.ipynb index 9d59daea23..bd1119c5e5 100644 --- a/notebooks/tutorial_ivf_pq.ipynb +++ b/notebooks/tutorial_ivf_pq.ipynb @@ -56,7 +56,7 @@ " return {\n", " attr: getattr(obj, attr)\n", " for attr in dir(obj)\n", - " if type(getattr(type(obj), attr)).__name__ == 'getset_descriptor'\n", + " if type(getattr(type(obj), attr)).__name__ == \"getset_descriptor\"\n", " }" ] }, @@ -67,10 +67,10 @@ "outputs": [], "source": [ "# We'll need to load store some data in this tutorial\n", - "WORK_FOLDER = os.path.join(tempfile.gettempdir(), 'cuvs_ivf_pq_tutorial')\n", + "WORK_FOLDER = os.path.join(tempfile.gettempdir(), \"cuvs_ivf_pq_tutorial\")\n", "\n", "if not os.path.exists(WORK_FOLDER):\n", - " os.makedirs(WORK_FOLDER)\n", + " os.makedirs(WORK_FOLDER)\n", "print(\"The index and data will be saved in\", WORK_FOLDER)" ] }, @@ -100,8 +100,7 @@ "outputs": [], "source": [ "pool = rmm.mr.PoolMemoryResource(\n", - " rmm.mr.CudaMemoryResource(),\n", - " initial_pool_size=2**30\n", + " rmm.mr.CudaMemoryResource(), initial_pool_size=2**30\n", ")\n", "rmm.mr.set_current_device_resource(pool)\n", "cp.cuda.set_allocator(rmm_cupy_allocator)" @@ -141,12 +140,12 @@ "metadata": {}, "outputs": [], "source": [ - "metric = f.attrs['distance']\n", + "metric = f.attrs[\"distance\"]\n", "\n", - "dataset = cp.array(f['train'])\n", - "queries = cp.array(f['test'])\n", - "gt_neighbors = cp.array(f['neighbors'])\n", - "gt_distances = cp.array(f['distances'])\n", + "dataset = cp.array(f[\"train\"])\n", + "queries = cp.array(f[\"test\"])\n", + "gt_neighbors = cp.array(f[\"neighbors\"])\n", + "gt_distances = cp.array(f[\"distances\"])\n", "\n", "print(f\"Loaded dataset of size {dataset.shape}; metric: '{metric}'.\")\n", "print(f\"Number of test queries: {queries.shape[0]}\")" @@ -229,7 +228,7 @@ "source": [ "%%time\n", "index_filepath = os.path.join(WORK_FOLDER, \"ivf_pq.bin\")\n", - "ivf_pq.save(index_filepath, index) \n", + "ivf_pq.save(index_filepath, index)\n", "loaded_index = ivf_pq.load(index_filepath)\n", "resources.sync()\n", "index" @@ -263,7 +262,9 @@ "outputs": [], "source": [ "%%time\n", - "distances, neighbors = ivf_pq.search(search_params, index, queries, k, resources=resources)\n", + "distances, neighbors = ivf_pq.search(\n", + " search_params, index, queries, k, resources=resources\n", + ")\n", "# Sync the GPU to make sure we've got the timing right\n", "resources.sync()" ] @@ -283,7 +284,9 @@ "outputs": [], "source": [ "recall_first_try = calc_recall(neighbors, gt_neighbors)\n", - "print(f\"Got recall = {recall_first_try} with the default parameters (k = {k}).\")" + "print(\n", + " f\"Got recall = {recall_first_try} with the default parameters (k = {k}).\"\n", + ")" ] }, { @@ -304,8 +307,12 @@ "source": [ "%%time\n", "\n", - "candidates = ivf_pq.search(search_params, index, queries, k * 2, resources=resources)[1]\n", - "distances, neighbors = refine(dataset, queries, candidates, k, resources=resources)\n", + "candidates = ivf_pq.search(\n", + " search_params, index, queries, k * 2, resources=resources\n", + ")[1]\n", + "distances, neighbors = refine(\n", + " dataset, queries, candidates, k, resources=resources\n", + ")\n", "resources.sync()" ] }, @@ -354,13 +361,13 @@ " bench_avg[i] = (queries.shape[0] * r.loops / np.array(r.all_runs)).mean()\n", " bench_std[i] = (queries.shape[0] * r.loops / np.array(r.all_runs)).std()\n", "\n", - "fig, ax = plt.subplots(1, 1, figsize=plt.figaspect(1/2))\n", + "fig, ax = plt.subplots(1, 1, figsize=plt.figaspect(1 / 2))\n", "ax.errorbar(bench_k, bench_avg, bench_std)\n", - "ax.set_xscale('log')\n", + "ax.set_xscale(\"log\")\n", "ax.set_xticks(bench_k, bench_k)\n", - "ax.set_xlabel('k')\n", + "ax.set_xlabel(\"k\")\n", "ax.grid()\n", - "ax.set_ylabel('QPS');" + "ax.set_ylabel(\"QPS\");" ] }, { @@ -390,8 +397,10 @@ " sp = ivf_pq.SearchParams(n_probes=n_probes)\n", " r = %timeit -o ivf_pq.search(sp, index, queries, k, resources=resources); resources.sync()\n", " bench_qps[i] = (queries.shape[0] * r.loops / np.array(r.all_runs)).mean()\n", - " bench_recall[i] = calc_recall(ivf_pq.search(sp, index, queries, k, resources=resources)[1], gt_neighbors)\n", - " " + " bench_recall[i] = calc_recall(\n", + " ivf_pq.search(sp, index, queries, k, resources=resources)[1],\n", + " gt_neighbors,\n", + " )" ] }, { @@ -412,27 +421,27 @@ "metadata": {}, "outputs": [], "source": [ - "fig, ax = plt.subplots(1, 3, figsize=plt.figaspect(1/4))\n", + "fig, ax = plt.subplots(1, 3, figsize=plt.figaspect(1 / 4))\n", "\n", "ax[0].plot(bench_probes, bench_recall)\n", - "ax[0].set_xscale('log')\n", + "ax[0].set_xscale(\"log\")\n", "ax[0].set_xticks(bench_probes, bench_probes)\n", - "ax[0].set_xlabel('n_probes')\n", - "ax[0].set_ylabel('recall')\n", + "ax[0].set_xlabel(\"n_probes\")\n", + "ax[0].set_ylabel(\"recall\")\n", "ax[0].grid()\n", "\n", "ax[1].plot(bench_probes, bench_qps)\n", - "ax[1].set_xscale('log')\n", + "ax[1].set_xscale(\"log\")\n", "ax[1].set_xticks(bench_probes, bench_probes)\n", - "ax[1].set_xlabel('n_probes')\n", - "ax[1].set_ylabel('QPS')\n", - "ax[1].set_yscale('log')\n", + "ax[1].set_xlabel(\"n_probes\")\n", + "ax[1].set_ylabel(\"QPS\")\n", + "ax[1].set_yscale(\"log\")\n", "ax[1].grid()\n", "\n", "ax[2].plot(bench_recall, bench_qps)\n", - "ax[2].set_xlabel('recall')\n", - "ax[2].set_ylabel('QPS')\n", - "ax[2].set_yscale('log')\n", + "ax[2].set_xlabel(\"recall\")\n", + "ax[2].set_ylabel(\"QPS\")\n", + "ax[2].set_yscale(\"log\")\n", "ax[2].grid();" ] }, @@ -484,18 +493,39 @@ "bench_recall_s1 = np.zeros((5,), dtype=np.float32)\n", "k = 10\n", "n_probes = 256\n", - "search_params_32_32 = ivf_pq.SearchParams(n_probes=n_probes, internal_distance_dtype=np.float32, lut_dtype=np.float32)\n", - "search_params_32_16 = ivf_pq.SearchParams(n_probes=n_probes, internal_distance_dtype=np.float32, lut_dtype=np.float16)\n", - "search_params_32_08 = ivf_pq.SearchParams(n_probes=n_probes, internal_distance_dtype=np.float32, lut_dtype=np.uint8)\n", - "search_params_16_16 = ivf_pq.SearchParams(n_probes=n_probes, internal_distance_dtype=np.float16, lut_dtype=np.float16)\n", - "search_params_16_08 = ivf_pq.SearchParams(n_probes=n_probes, internal_distance_dtype=np.float16, lut_dtype=np.uint8)\n", - "search_ps = [search_params_32_32, search_params_32_16, search_params_32_08, search_params_16_16, search_params_16_08]\n", - "bench_names = ['32/32', '32/16', '32/8', '16/16', '16/8']\n", + "search_params_32_32 = ivf_pq.SearchParams(\n", + " n_probes=n_probes, internal_distance_dtype=np.float32, lut_dtype=np.float32\n", + ")\n", + "search_params_32_16 = ivf_pq.SearchParams(\n", + " n_probes=n_probes, internal_distance_dtype=np.float32, lut_dtype=np.float16\n", + ")\n", + "search_params_32_08 = ivf_pq.SearchParams(\n", + " n_probes=n_probes, internal_distance_dtype=np.float32, lut_dtype=np.uint8\n", + ")\n", + "search_params_16_16 = ivf_pq.SearchParams(\n", + " n_probes=n_probes, internal_distance_dtype=np.float16, lut_dtype=np.float16\n", + ")\n", + "search_params_16_08 = ivf_pq.SearchParams(\n", + " n_probes=n_probes, internal_distance_dtype=np.float16, lut_dtype=np.uint8\n", + ")\n", + "search_ps = [\n", + " search_params_32_32,\n", + " search_params_32_16,\n", + " search_params_32_08,\n", + " search_params_16_16,\n", + " search_params_16_08,\n", + "]\n", + "bench_names = [\"32/32\", \"32/16\", \"32/8\", \"16/16\", \"16/8\"]\n", "\n", "for i, sp in enumerate(search_ps):\n", " r = %timeit -o ivf_pq.search(sp, index, queries, k, resources=resources); resources.sync()\n", - " bench_qps_s1[i] = (queries.shape[0] * r.loops / np.array(r.all_runs)).mean()\n", - " bench_recall_s1[i] = calc_recall(ivf_pq.search(sp, index, queries, k, resources=resources)[1], gt_neighbors)" + " bench_qps_s1[i] = (\n", + " queries.shape[0] * r.loops / np.array(r.all_runs)\n", + " ).mean()\n", + " bench_recall_s1[i] = calc_recall(\n", + " ivf_pq.search(sp, index, queries, k, resources=resources)[1],\n", + " gt_neighbors,\n", + " )" ] }, { @@ -504,27 +534,35 @@ "metadata": {}, "outputs": [], "source": [ - "fig, ax = plt.subplots(1, 1, figsize=plt.figaspect(1/2))\n", + "fig, ax = plt.subplots(1, 1, figsize=plt.figaspect(1 / 2))\n", "fig.suptitle(\n", - " f'Effects of search parameters on QPS/recall trade-off ({DATASET_NAME})\\n' + \\\n", - " f'k = {k}, n_probes = {n_probes}, pq_dim = {pq_dim}')\n", - "ax.plot(bench_recall_s1, bench_qps_s1, 'o')\n", - "ax.set_xlabel('recall')\n", - "ax.set_ylabel('QPS')\n", + " f\"Effects of search parameters on QPS/recall trade-off ({DATASET_NAME})\\n\"\n", + " + f\"k = {k}, n_probes = {n_probes}, pq_dim = {pq_dim}\"\n", + ")\n", + "ax.plot(bench_recall_s1, bench_qps_s1, \"o\")\n", + "ax.set_xlabel(\"recall\")\n", + "ax.set_ylabel(\"QPS\")\n", "ax.grid()\n", "annotations = []\n", "for i, label in enumerate(bench_names):\n", - " annotations.append(ax.text(\n", - " bench_recall_s1[i], bench_qps_s1[i],\n", - " f\" {label} \",\n", - " ha='center', va='center'))\n", + " annotations.append(\n", + " ax.text(\n", + " bench_recall_s1[i],\n", + " bench_qps_s1[i],\n", + " f\" {label} \",\n", + " ha=\"center\",\n", + " va=\"center\",\n", + " )\n", + " )\n", "clutter = [\n", " ax.text(\n", - " 0.02, 0.08,\n", - " 'Labels denote the bitsize of: internal_distance_dtype/lut_dtype',\n", - " verticalalignment='top',\n", - " bbox={'facecolor': 'white', 'edgecolor': 'grey'},\n", - " transform = ax.transAxes)\n", + " 0.02,\n", + " 0.08,\n", + " \"Labels denote the bitsize of: internal_distance_dtype/lut_dtype\",\n", + " verticalalignment=\"top\",\n", + " bbox={\"facecolor\": \"white\", \"edgecolor\": \"grey\"},\n", + " transform=ax.transAxes,\n", + " )\n", "]\n", "adjust_text(annotations, objects=clutter);" ] @@ -554,18 +592,29 @@ "source": [ "def search_refine(ps, ratio):\n", " k_search = k * ratio\n", - " candidates = ivf_pq.search(ps, index, queries, k_search, resources=resources)[1]\n", - " return candidates if ratio == 1 else refine(dataset, queries, candidates, k, resources=resources)[1]\n", + " candidates = ivf_pq.search(\n", + " ps, index, queries, k_search, resources=resources\n", + " )[1]\n", + " return (\n", + " candidates\n", + " if ratio == 1\n", + " else refine(dataset, queries, candidates, k, resources=resources)[1]\n", + " )\n", + "\n", "\n", "ratios = [1, 2, 4]\n", "bench_qps_sr = np.zeros((len(ratios), len(search_ps)), dtype=np.float32)\n", "bench_recall_sr = np.zeros((len(ratios), len(search_ps)), dtype=np.float32)\n", "\n", - "for j, ratio in enumerate(ratios): \n", + "for j, ratio in enumerate(ratios):\n", " for i, ps in enumerate(search_ps):\n", " r = %timeit -o search_refine(ps, ratio); resources.sync()\n", - " bench_qps_sr[j, i] = (queries.shape[0] * r.loops / np.array(r.all_runs)).mean()\n", - " bench_recall_sr[j, i] = calc_recall(search_refine(ps, ratio), gt_neighbors)" + " bench_qps_sr[j, i] = (\n", + " queries.shape[0] * r.loops / np.array(r.all_runs)\n", + " ).mean()\n", + " bench_recall_sr[j, i] = calc_recall(\n", + " search_refine(ps, ratio), gt_neighbors\n", + " )" ] }, { @@ -574,34 +623,42 @@ "metadata": {}, "outputs": [], "source": [ - "fig, ax = plt.subplots(1, 1, figsize=plt.figaspect(1/2))\n", + "fig, ax = plt.subplots(1, 1, figsize=plt.figaspect(1 / 2))\n", "fig.suptitle(\n", - " f'Effects of search parameters on QPS/recall trade-off ({DATASET_NAME})\\n' + \\\n", - " f'k = {k}, n_probes = {n_probes}, pq_dim = {pq_dim}')\n", + " f\"Effects of search parameters on QPS/recall trade-off ({DATASET_NAME})\\n\"\n", + " + f\"k = {k}, n_probes = {n_probes}, pq_dim = {pq_dim}\"\n", + ")\n", "labels = []\n", "for j, ratio in enumerate(ratios):\n", - " ax.plot(bench_recall_sr[j, :], bench_qps_sr[j, :], 'o')\n", + " ax.plot(bench_recall_sr[j, :], bench_qps_sr[j, :], \"o\")\n", " labels.append(f\"refine ratio = {ratio}\")\n", "ax.legend(labels)\n", - "ax.set_xlabel('recall')\n", - "ax.set_ylabel('QPS')\n", + "ax.set_xlabel(\"recall\")\n", + "ax.set_ylabel(\"QPS\")\n", "ax.grid()\n", "colors = plt.rcParams[\"axes.prop_cycle\"].by_key()[\"color\"]\n", "annotations = []\n", "for j, ratio in enumerate(ratios):\n", " for i, label in enumerate(bench_names):\n", - " annotations.append(ax.text(\n", - " bench_recall_sr[j, i], bench_qps_sr[j, i],\n", - " f\" {label} \",\n", - " color=colors[j],\n", - " ha='center', va='center'))\n", + " annotations.append(\n", + " ax.text(\n", + " bench_recall_sr[j, i],\n", + " bench_qps_sr[j, i],\n", + " f\" {label} \",\n", + " color=colors[j],\n", + " ha=\"center\",\n", + " va=\"center\",\n", + " )\n", + " )\n", "clutter = [\n", " ax.text(\n", - " 0.02, 0.08,\n", - " 'Labels denote the bitsize of: internal_distance_dtype/lut_dtype',\n", - " verticalalignment='top',\n", - " bbox={'facecolor': 'white', 'edgecolor': 'grey'},\n", - " transform = ax.transAxes)\n", + " 0.02,\n", + " 0.08,\n", + " \"Labels denote the bitsize of: internal_distance_dtype/lut_dtype\",\n", + " verticalalignment=\"top\",\n", + " bbox={\"facecolor\": \"white\", \"edgecolor\": \"grey\"},\n", + " transform=ax.transAxes,\n", + " )\n", "]\n", "adjust_text(annotations, objects=clutter);" ] @@ -629,18 +686,24 @@ " ps = ivf_pq.SearchParams(\n", " n_probes=n_probes,\n", " internal_distance_dtype=internal_distance_dtype,\n", - " lut_dtype=lut_dtype)\n", - " candidates = ivf_pq.search(ps, index, queries, k_search, resources=resources)[1]\n", - " return candidates if ratio == 1 else refine(dataset, queries, candidates, k, resources=resources)[1]\n", + " lut_dtype=lut_dtype,\n", + " )\n", + " candidates = ivf_pq.search(\n", + " ps, index, queries, k_search, resources=resources\n", + " )[1]\n", + " return (\n", + " candidates\n", + " if ratio == 1\n", + " else refine(dataset, queries, candidates, k, resources=resources)[1]\n", + " )\n", + "\n", "\n", "search_configs = [\n", " lambda n_probes: search_refine(np.float16, np.float16, 1, n_probes),\n", " lambda n_probes: search_refine(np.float32, np.uint8, 1, n_probes),\n", - " lambda n_probes: search_refine(np.float32, np.uint8, 2, n_probes)\n", + " lambda n_probes: search_refine(np.float32, np.uint8, 2, n_probes),\n", "]\n", - "search_config_names = [\n", - " '16/16', '32/8', '32/8/r2'\n", - "]" + "search_config_names = [\"16/16\", \"32/8\", \"32/8/r2\"]" ] }, { @@ -699,16 +762,22 @@ "search_fun = search_configs[selected_search_variant]\n", "search_label = search_config_names[selected_search_variant]\n", "\n", - "bench_qps_nl = np.zeros((len(n_list_variants), len(pl_ratio_variants)), dtype=np.float32)\n", + "bench_qps_nl = np.zeros(\n", + " (len(n_list_variants), len(pl_ratio_variants)), dtype=np.float32\n", + ")\n", "bench_recall_nl = np.zeros_like(bench_qps_nl, dtype=np.float32)\n", "\n", "for i, n_lists in enumerate(n_list_variants):\n", - " index_params = ivf_pq.IndexParams(n_lists=n_lists, metric=metric, pq_dim=pq_dim)\n", + " index_params = ivf_pq.IndexParams(\n", + " n_lists=n_lists, metric=metric, pq_dim=pq_dim\n", + " )\n", " index = ivf_pq.build(index_params, dataset, resources=resources)\n", " for j, pl_ratio in enumerate(pl_ratio_variants):\n", " n_probes = max(1, n_lists // pl_ratio)\n", " r = %timeit -o search_fun(n_probes); resources.sync()\n", - " bench_qps_nl[i, j] = (queries.shape[0] * r.loops / np.array(r.all_runs)).mean()\n", + " bench_qps_nl[i, j] = (\n", + " queries.shape[0] * r.loops / np.array(r.all_runs)\n", + " ).mean()\n", " bench_recall_nl[i, j] = calc_recall(search_fun(n_probes), gt_neighbors)\n", " del index" ] @@ -719,19 +788,20 @@ "metadata": {}, "outputs": [], "source": [ - "fig, ax = plt.subplots(1, 1, figsize=plt.figaspect(1/2))\n", + "fig, ax = plt.subplots(1, 1, figsize=plt.figaspect(1 / 2))\n", "fig.suptitle(\n", - " f'Effects of n_list on QPS/recall trade-off ({DATASET_NAME})\\n' + \\\n", - " f'k = {k}, pq_dim = {pq_dim}, search = {search_label}')\n", + " f\"Effects of n_list on QPS/recall trade-off ({DATASET_NAME})\\n\"\n", + " + f\"k = {k}, pq_dim = {pq_dim}, search = {search_label}\"\n", + ")\n", "labels = []\n", "for i, n_lists in enumerate(n_list_variants):\n", " ax.plot(bench_recall_nl[i, :], bench_qps_nl[i, :])\n", " labels.append(f\"n_lists = {n_lists}\")\n", "\n", "ax.legend(labels)\n", - "ax.set_xlabel('recall')\n", - "ax.set_ylabel('QPS')\n", - "ax.set_yscale('log')\n", + "ax.set_xlabel(\"recall\")\n", + "ax.set_ylabel(\"QPS\")\n", + "ax.set_yscale(\"log\")\n", "ax.grid()" ] }, @@ -867,13 +937,40 @@ "n_lists = 1000\n", "\n", "build_configs = {\n", - " '64-8-subspace': ivf_pq.IndexParams(n_lists=n_lists, metric=metric, pq_dim=64, pq_bits=8, codebook_kind=\"subspace\"),\n", - " '128-8-subspace': ivf_pq.IndexParams(n_lists=n_lists, metric=metric, pq_dim=128, pq_bits=8, codebook_kind=\"subspace\"),\n", - " '128-6-subspace': ivf_pq.IndexParams(n_lists=n_lists, metric=metric, pq_dim=128, pq_bits=6, codebook_kind=\"subspace\"),\n", - " '128-6-cluster': ivf_pq.IndexParams(n_lists=n_lists, metric=metric, pq_dim=128, pq_bits=6, codebook_kind=\"cluster\"),\n", + " \"64-8-subspace\": ivf_pq.IndexParams(\n", + " n_lists=n_lists,\n", + " metric=metric,\n", + " pq_dim=64,\n", + " pq_bits=8,\n", + " codebook_kind=\"subspace\",\n", + " ),\n", + " \"128-8-subspace\": ivf_pq.IndexParams(\n", + " n_lists=n_lists,\n", + " metric=metric,\n", + " pq_dim=128,\n", + " pq_bits=8,\n", + " codebook_kind=\"subspace\",\n", + " ),\n", + " \"128-6-subspace\": ivf_pq.IndexParams(\n", + " n_lists=n_lists,\n", + " metric=metric,\n", + " pq_dim=128,\n", + " pq_bits=6,\n", + " codebook_kind=\"subspace\",\n", + " ),\n", + " \"128-6-cluster\": ivf_pq.IndexParams(\n", + " n_lists=n_lists,\n", + " metric=metric,\n", + " pq_dim=128,\n", + " pq_bits=6,\n", + " codebook_kind=\"cluster\",\n", + " ),\n", "}\n", "\n", - "bench_qps_ip = np.zeros((len(build_configs), len(search_configs), len(n_probes_variants)), dtype=np.float32)\n", + "bench_qps_ip = np.zeros(\n", + " (len(build_configs), len(search_configs), len(n_probes_variants)),\n", + " dtype=np.float32,\n", + ")\n", "bench_recall_ip = np.zeros_like(bench_qps_ip, dtype=np.float32)\n", "\n", "for i, index_params in enumerate(build_configs.values()):\n", @@ -881,8 +978,12 @@ " for l, search_fun in enumerate(search_configs):\n", " for j, n_probes in enumerate(n_probes_variants):\n", " r = %timeit -o search_fun(n_probes); resources.sync()\n", - " bench_qps_ip[i, l, j] = (queries.shape[0] * r.loops / np.array(r.all_runs)).mean()\n", - " bench_recall_ip[i, l, j] = calc_recall(search_fun(n_probes), gt_neighbors)" + " bench_qps_ip[i, l, j] = (\n", + " queries.shape[0] * r.loops / np.array(r.all_runs)\n", + " ).mean()\n", + " bench_recall_ip[i, l, j] = calc_recall(\n", + " search_fun(n_probes), gt_neighbors\n", + " )" ] }, { @@ -891,10 +992,13 @@ "metadata": {}, "outputs": [], "source": [ - "fig, ax = plt.subplots(len(search_config_names), 1, figsize=(16, len(search_config_names)*8))\n", + "fig, ax = plt.subplots(\n", + " len(search_config_names), 1, figsize=(16, len(search_config_names) * 8)\n", + ")\n", "fig.suptitle(\n", - " f'Effects of index parameters on QPS/recall trade-off ({DATASET_NAME})\\n' + \\\n", - " f'k = {k}, n_lists = {n_lists}')\n", + " f\"Effects of index parameters on QPS/recall trade-off ({DATASET_NAME})\\n\"\n", + " + f\"k = {k}, n_lists = {n_lists}\"\n", + ")\n", "\n", "for j, search_label in enumerate(search_config_names):\n", " labels = []\n", @@ -904,9 +1008,9 @@ "\n", " ax[j].set_title(f\"search: {search_label}\")\n", " ax[j].legend(labels)\n", - " ax[j].set_xlabel('recall')\n", - " ax[j].set_ylabel('QPS')\n", - " ax[j].set_yscale('log')\n", + " ax[j].set_xlabel(\"recall\")\n", + " ax[j].set_ylabel(\"QPS\")\n", + " ax[j].set_yscale(\"log\")\n", " ax[j].grid()" ] }, diff --git a/notebooks/utils.py b/notebooks/utils.py index c8a121f531..f456e6f02d 100644 --- a/notebooks/utils.py +++ b/notebooks/utils.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 @@ -10,6 +10,7 @@ import time import urllib + ## Check the quality of the prediction (recall) def calc_recall(found_indices, ground_truth): found_indices = cp.asarray(found_indices) diff --git a/pyproject.toml b/pyproject.toml index 4175144663..9a51f4fdb3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,22 +1,27 @@ -[tool.black] +# SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[tool.ruff] line-length = 79 -target-version = ["py310"] -include = '\.py?$' -force-exclude = ''' -/( - thirdparty | - \.eggs | - \.git | - \.hg | - \.mypy_cache | - \.tox | - \.venv | - _build | - buck-out | - build | - dist -)/ -''' +exclude = [ + "__init__.py", +] + +[tool.ruff.lint] +ignore = [ + # whitespace before : + "E203", +] + +[tool.ruff.lint.per-file-ignores] +"*.ipynb" = [ + # unused imports + "F401", + # unused variable + "F841", + # ambiguous variable name + "E741", +] [tool.pydocstyle] # Due to https://github.com/PyCQA/pydocstyle/issues/363, we must exclude rather diff --git a/python/cuvs/cuvs/tests/test_cagra.py b/python/cuvs/cuvs/tests/test_cagra.py index 20ac07fa9a..3468a086b9 100644 --- a/python/cuvs/cuvs/tests/test_cagra.py +++ b/python/cuvs/cuvs/tests/test_cagra.py @@ -161,7 +161,6 @@ def run_cagra_build_search_test( def test_cagra_dataset_dtype_host_device( dtype, array_type, inplace, build_algo, metric, serialize ): - # Note that inner_product tests use normalized input which we cannot # represent in int8, therefore we test only sqeuclidean metric here. run_cagra_build_search_test( diff --git a/python/cuvs/cuvs/tests/test_mg_cagra.py b/python/cuvs/cuvs/tests/test_mg_cagra.py index 42bb8220ea..903c16ea24 100644 --- a/python/cuvs/cuvs/tests/test_mg_cagra.py +++ b/python/cuvs/cuvs/tests/test_mg_cagra.py @@ -538,9 +538,9 @@ def test_mg_cagra_simple(): # Distances should be non-negative and sorted assert np.all(distances >= 0) for i in range(n_queries): - assert np.all( - distances[i, :-1] <= distances[i, 1:] - ), f"Distances not sorted for query {i}" + assert np.all(distances[i, :-1] <= distances[i, 1:]), ( + f"Distances not sorted for query {i}" + ) # Integration test with multiple operations diff --git a/python/cuvs/cuvs/tests/test_mg_ivf_flat.py b/python/cuvs/cuvs/tests/test_mg_ivf_flat.py index 08c2610b86..99dff4e221 100644 --- a/python/cuvs/cuvs/tests/test_mg_ivf_flat.py +++ b/python/cuvs/cuvs/tests/test_mg_ivf_flat.py @@ -572,9 +572,9 @@ def test_mg_ivf_flat_simple(): # Distances should be non-negative and sorted assert np.all(distances >= 0) for i in range(n_queries): - assert np.all( - distances[i, :-1] <= distances[i, 1:] - ), f"Distances not sorted for query {i}" + assert np.all(distances[i, :-1] <= distances[i, 1:]), ( + f"Distances not sorted for query {i}" + ) # Integration test with multiple operations diff --git a/python/cuvs/cuvs/tests/test_mg_ivf_pq.py b/python/cuvs/cuvs/tests/test_mg_ivf_pq.py index d54f170153..6c6cf8415b 100644 --- a/python/cuvs/cuvs/tests/test_mg_ivf_pq.py +++ b/python/cuvs/cuvs/tests/test_mg_ivf_pq.py @@ -600,9 +600,9 @@ def test_mg_ivf_pq_simple(): # Distances should be non-negative and sorted assert np.all(distances >= 0) for i in range(n_queries): - assert np.all( - distances[i, :-1] <= distances[i, 1:] - ), f"Distances not sorted for query {i}" + assert np.all(distances[i, :-1] <= distances[i, 1:]), ( + f"Distances not sorted for query {i}" + ) # Integration test with multiple operations diff --git a/python/cuvs/cuvs/tests/test_refine.py b/python/cuvs/cuvs/tests/test_refine.py index 2a6d3a3add..bb6b373e78 100644 --- a/python/cuvs/cuvs/tests/test_refine.py +++ b/python/cuvs/cuvs/tests/test_refine.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # @@ -23,7 +23,6 @@ def run_refine( dtype=np.float32, memory_type="device", ): - dataset = generate_data((n_rows, n_cols), dtype) queries = generate_data((n_queries, n_cols), dtype) diff --git a/python/cuvs_bench/cuvs_bench/plot/__main__.py b/python/cuvs_bench/cuvs_bench/plot/__main__.py index 6d9d9cb4cd..aca08505ea 100644 --- a/python/cuvs_bench/cuvs_bench/plot/__main__.py +++ b/python/cuvs_bench/cuvs_bench/plot/__main__.py @@ -520,7 +520,6 @@ def main( time_unit: str, raw: bool, ) -> None: - args = locals() if args["algorithms"]: diff --git a/python/cuvs_bench/cuvs_bench/run/data_export.py b/python/cuvs_bench/cuvs_bench/run/data_export.py index d658c07d9e..707677a083 100644 --- a/python/cuvs_bench/cuvs_bench/run/data_export.py +++ b/python/cuvs_bench/cuvs_bench/run/data_export.py @@ -222,9 +222,9 @@ def convert_json_to_csv_search(dataset, dataset_path): write.iloc[s_index, write_ncols] = build_df.iloc[ b_index, 2 ] - write.iloc[ - s_index, write_ncols + 1 : - ] = build_df.iloc[b_index, 3:] + write.iloc[s_index, write_ncols + 1 :] = ( + build_df.iloc[b_index, 3:] + ) break # Write search data and compute frontiers write.to_csv(file.replace(".json", ",raw.csv"), index=False) @@ -256,8 +256,9 @@ def create_pointset(data, xn, yn): xm, ym = metrics[xn], metrics[yn] y_col = 4 if yn == "latency" else 3 - rev_x, rev_y = (-1 if xm["worst"] < 0 else 1), ( - -1 if ym["worst"] < 0 else 1 + rev_x, rev_y = ( + (-1 if xm["worst"] < 0 else 1), + (-1 if ym["worst"] < 0 else 1), ) # Sort data based on x and y metrics data.sort(key=lambda t: (rev_y * t[y_col], rev_x * t[2])) diff --git a/python/cuvs_bench/cuvs_bench/tests/test_cli.py b/python/cuvs_bench/cuvs_bench/tests/test_cli.py index abd5fa636f..c65f97bc2a 100644 --- a/python/cuvs_bench/cuvs_bench/tests/test_cli.py +++ b/python/cuvs_bench/cuvs_bench/tests/test_cli.py @@ -40,9 +40,9 @@ def test_get_dataset_creates_expected_files(temp_datasets_dir: Path): # Verify that each expected file exists in the datasets directory. for filename in expected_files: file_path = temp_datasets_dir / filename - assert ( - file_path.exists() - ), f"Expected file {filename} was not generated." + assert file_path.exists(), ( + f"Expected file {filename} was not generated." + ) def test_run_command_creates_results(temp_datasets_dir: Path): @@ -82,9 +82,9 @@ def test_run_command_creates_results(temp_datasets_dir: Path): "--force", ] result = runner.invoke(run_main, run_args) - assert ( - result.exit_code == 0 - ), f"Run command failed with output:\n{result.output}" + assert result.exit_code == 0, ( + f"Run command failed with output:\n{result.output}" + ) common_build_header = [ "algo_name", @@ -426,9 +426,9 @@ def test_run_command_creates_results(temp_datasets_dir: Path): for rel_path, expectations in expected_files.items(): file_path = temp_datasets_dir / rel_path assert file_path.exists(), f"Expected file {file_path} does not exist." - assert ( - file_path.stat().st_size > 0 - ), f"Expected file {file_path} is empty." + assert file_path.stat().st_size > 0, ( + f"Expected file {file_path} is empty." + ) df = pd.read_csv(file_path) @@ -436,9 +436,9 @@ def test_run_command_creates_results(temp_datasets_dir: Path): actual_rows = len(df) # breakpoint() - assert ( - actual_header == expectations["header"] - ), f"Wrong header produced in file f{rel_path}" + assert actual_header == expectations["header"], ( + f"Wrong header produced in file f{rel_path}" + ) assert actual_rows == expectations["rows"] @@ -483,9 +483,9 @@ def test_plot_command_creates_png_files(temp_datasets_dir: Path): "latency", ] result = runner.invoke(plot_main, args) - assert ( - result.exit_code == 0 - ), f"Plot command failed with output:\n{result.output}" + assert result.exit_code == 0, ( + f"Plot command failed with output:\n{result.output}" + ) # Expected output file names. expected_files = [ @@ -496,6 +496,6 @@ def test_plot_command_creates_png_files(temp_datasets_dir: Path): for filename in expected_files: file_path = temp_datasets_dir / filename assert file_path.exists(), f"Expected file {filename} does not exist." - assert ( - file_path.stat().st_size > 0 - ), f"Expected file {filename} is empty." + assert file_path.stat().st_size > 0, ( + f"Expected file {filename} is empty." + ) From 7a7dfce294a135b903c7471f31fd56f2597ebc29 Mon Sep 17 00:00:00 2001 From: Jake Awe <50372925+AyodeAwe@users.noreply.github.com> Date: Mon, 17 Nov 2025 08:57:55 -0600 Subject: [PATCH 11/32] Merge pull request #1545 from rapidsai/version-update-26.02 Update to 26.02 --- .../cuda12.9-conda/devcontainer.json | 6 ++-- .devcontainer/cuda12.9-pip/devcontainer.json | 8 ++--- .../cuda13.0-conda/devcontainer.json | 6 ++-- .devcontainer/cuda13.0-pip/devcontainer.json | 8 ++--- .github/workflows/build.yaml | 10 +++--- .github/workflows/pr.yaml | 12 +++---- .github/workflows/publish-rust.yaml | 2 +- .github/workflows/test.yaml | 2 +- README.md | 4 +-- VERSION | 2 +- .../all_cuda-129_arch-aarch64.yaml | 4 +-- .../all_cuda-129_arch-x86_64.yaml | 4 +-- .../all_cuda-130_arch-aarch64.yaml | 4 +-- .../all_cuda-130_arch-x86_64.yaml | 4 +-- .../bench_ann_cuda-129_arch-aarch64.yaml | 8 ++--- .../bench_ann_cuda-129_arch-x86_64.yaml | 8 ++--- .../bench_ann_cuda-130_arch-aarch64.yaml | 8 ++--- .../bench_ann_cuda-130_arch-x86_64.yaml | 8 ++--- .../go_cuda-129_arch-aarch64.yaml | 4 +-- .../environments/go_cuda-129_arch-x86_64.yaml | 4 +-- .../go_cuda-130_arch-aarch64.yaml | 4 +-- .../environments/go_cuda-130_arch-x86_64.yaml | 4 +-- .../rust_cuda-129_arch-aarch64.yaml | 4 +-- .../rust_cuda-129_arch-x86_64.yaml | 4 +-- .../rust_cuda-130_arch-aarch64.yaml | 4 +-- .../rust_cuda-130_arch-x86_64.yaml | 4 +-- dependencies.yaml | 32 +++++++++---------- docs/source/cuvs_bench/index.rst | 8 ++--- examples/go/README.md | 2 +- java/benchmarks/pom.xml | 4 +-- java/build.sh | 2 +- java/cuvs-java/pom.xml | 2 +- java/examples/README.md | 6 ++-- java/examples/pom.xml | 9 ++++-- python/cuvs/pyproject.toml | 10 +++--- python/cuvs_bench/pyproject.toml | 4 +-- python/libcuvs/pyproject.toml | 8 ++--- rust/Cargo.toml | 2 +- rust/cuvs/Cargo.toml | 2 +- 39 files changed, 118 insertions(+), 113 deletions(-) diff --git a/.devcontainer/cuda12.9-conda/devcontainer.json b/.devcontainer/cuda12.9-conda/devcontainer.json index f7565bbeaa..7528d19967 100644 --- a/.devcontainer/cuda12.9-conda/devcontainer.json +++ b/.devcontainer/cuda12.9-conda/devcontainer.json @@ -5,19 +5,19 @@ "args": { "CUDA": "12.9", "PYTHON_PACKAGE_MANAGER": "conda", - "BASE": "rapidsai/devcontainers:25.12-cpp-mambaforge" + "BASE": "rapidsai/devcontainers:26.02-cpp-mambaforge" } }, "runArgs": [ "--rm", "--name", - "${localEnv:USER:anon}-rapids-${localWorkspaceFolderBasename}-25.12-cuda12.9-conda", + "${localEnv:USER:anon}-rapids-${localWorkspaceFolderBasename}-26.02-cuda12.9-conda", "--ulimit", "nofile=500000" ], "hostRequirements": {"gpu": "optional"}, "features": { - "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:25.12": {} + "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:26.2": {} }, "overrideFeatureInstallOrder": [ "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils" diff --git a/.devcontainer/cuda12.9-pip/devcontainer.json b/.devcontainer/cuda12.9-pip/devcontainer.json index b7b43b9b45..652d997405 100644 --- a/.devcontainer/cuda12.9-pip/devcontainer.json +++ b/.devcontainer/cuda12.9-pip/devcontainer.json @@ -5,26 +5,26 @@ "args": { "CUDA": "12.9", "PYTHON_PACKAGE_MANAGER": "pip", - "BASE": "rapidsai/devcontainers:25.12-cpp-cuda12.9-ucx1.19.0-openmpi5.0.7" + "BASE": "rapidsai/devcontainers:26.02-cpp-cuda12.9-ucx1.19.0-openmpi5.0.7" } }, "runArgs": [ "--rm", "--name", - "${localEnv:USER:anon}-rapids-${localWorkspaceFolderBasename}-25.12-cuda12.9-pip", + "${localEnv:USER:anon}-rapids-${localWorkspaceFolderBasename}-26.02-cuda12.9-pip", "--ulimit", "nofile=500000" ], "hostRequirements": {"gpu": "optional"}, "features": { - "ghcr.io/rapidsai/devcontainers/features/cuda:25.12": { + "ghcr.io/rapidsai/devcontainers/features/cuda:26.2": { "version": "12.9", "installcuBLAS": true, "installcuSOLVER": true, "installcuRAND": true, "installcuSPARSE": true }, - "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:25.12": {} + "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:26.2": {} }, "overrideFeatureInstallOrder": [ "ghcr.io/rapidsai/devcontainers/features/ucx", diff --git a/.devcontainer/cuda13.0-conda/devcontainer.json b/.devcontainer/cuda13.0-conda/devcontainer.json index f4e2e662eb..5c0beccf9c 100644 --- a/.devcontainer/cuda13.0-conda/devcontainer.json +++ b/.devcontainer/cuda13.0-conda/devcontainer.json @@ -5,19 +5,19 @@ "args": { "CUDA": "13.0", "PYTHON_PACKAGE_MANAGER": "conda", - "BASE": "rapidsai/devcontainers:25.12-cpp-mambaforge" + "BASE": "rapidsai/devcontainers:26.02-cpp-mambaforge" } }, "runArgs": [ "--rm", "--name", - "${localEnv:USER:anon}-rapids-${localWorkspaceFolderBasename}-25.12-cuda13.0-conda", + "${localEnv:USER:anon}-rapids-${localWorkspaceFolderBasename}-26.02-cuda13.0-conda", "--ulimit", "nofile=500000" ], "hostRequirements": {"gpu": "optional"}, "features": { - "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:25.12": {} + "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:26.2": {} }, "overrideFeatureInstallOrder": [ "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils" diff --git a/.devcontainer/cuda13.0-pip/devcontainer.json b/.devcontainer/cuda13.0-pip/devcontainer.json index 1fd011180e..88b6bc9def 100644 --- a/.devcontainer/cuda13.0-pip/devcontainer.json +++ b/.devcontainer/cuda13.0-pip/devcontainer.json @@ -5,26 +5,26 @@ "args": { "CUDA": "13.0", "PYTHON_PACKAGE_MANAGER": "pip", - "BASE": "rapidsai/devcontainers:25.12-cpp-cuda13.0-ucx1.19.0-openmpi5.0.7" + "BASE": "rapidsai/devcontainers:26.02-cpp-cuda13.0-ucx1.19.0-openmpi5.0.7" } }, "runArgs": [ "--rm", "--name", - "${localEnv:USER:anon}-rapids-${localWorkspaceFolderBasename}-25.12-cuda13.0-pip", + "${localEnv:USER:anon}-rapids-${localWorkspaceFolderBasename}-26.02-cuda13.0-pip", "--ulimit", "nofile=500000" ], "hostRequirements": {"gpu": "optional"}, "features": { - "ghcr.io/rapidsai/devcontainers/features/cuda:25.12": { + "ghcr.io/rapidsai/devcontainers/features/cuda:26.2": { "version": "13.0", "installcuBLAS": true, "installcuSOLVER": true, "installcuRAND": true, "installcuSPARSE": true }, - "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:25.12": {} + "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:26.2": {} }, "overrideFeatureInstallOrder": [ "ghcr.io/rapidsai/devcontainers/features/ucx", diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index e76c010f26..2f3146ac48 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -56,7 +56,7 @@ jobs: branch: ${{ inputs.branch }} arch: "amd64" date: ${{ inputs.date }} - container_image: "rapidsai/ci-wheel:25.12-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" + container_image: "rapidsai/ci-wheel:26.02-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" node_type: "cpu16" name: "${{ matrix.cuda_version }}, amd64, rockylinux8" # requires_license_builder: false @@ -81,7 +81,7 @@ jobs: branch: ${{ inputs.branch }} arch: "amd64" date: ${{ inputs.date }} - container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + container_image: "rapidsai/ci-conda:26.02-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" node_type: "gpu-l4-latest-1" script: "ci/build_rust.sh" sha: ${{ inputs.sha }} @@ -102,7 +102,7 @@ jobs: branch: ${{ inputs.branch }} arch: "amd64" date: ${{ inputs.date }} - container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + container_image: "rapidsai/ci-conda:26.02-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" node_type: "gpu-l4-latest-1" script: "ci/build_go.sh" sha: ${{ inputs.sha }} @@ -123,7 +123,7 @@ jobs: branch: ${{ inputs.branch }} arch: "amd64" date: ${{ inputs.date }} - container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + container_image: "rapidsai/ci-conda:26.02-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" script: "ci/build_java.sh" artifact-name: "cuvs-java-cuda${{ matrix.cuda_version }}" file_to_upload: "java/cuvs-java/target/" @@ -161,7 +161,7 @@ jobs: arch: "amd64" branch: ${{ inputs.branch }} build_type: ${{ inputs.build_type || 'branch' }} - container_image: "rapidsai/ci-conda:25.12-latest" + container_image: "rapidsai/ci-conda:26.02-latest" date: ${{ inputs.date }} node_type: "gpu-l4-latest-1" script: "ci/build_docs.sh" diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index de56414fc9..1c14b155d4 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -188,7 +188,7 @@ jobs: build_type: pull-request arch: "amd64" date: ${{ inputs.date }}_c - container_image: "rapidsai/ci-wheel:25.12-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" + container_image: "rapidsai/ci-wheel:26.02-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" node_type: "cpu16" # requires_license_builder: false script: "ci/build_standalone_c.sh --build-tests" @@ -211,7 +211,7 @@ jobs: node_type: "gpu-l4-latest-1" arch: "amd64" date: ${{ inputs.date }}_c - container_image: "rapidsai/ci-wheel:25.12-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" + container_image: "rapidsai/ci-wheel:26.02-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" script: "ci/test_standalone_c.sh" sha: ${{ inputs.sha }} conda-java-build-and-tests: @@ -231,7 +231,7 @@ jobs: build_type: pull-request node_type: "gpu-l4-latest-1" arch: "amd64" - container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + container_image: "rapidsai/ci-conda:26.02-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" script: "ci/test_java.sh" artifact-name: "cuvs-java-cuda${{ matrix.cuda_version }}" file_to_upload: "java/cuvs-java/target/" @@ -252,7 +252,7 @@ jobs: build_type: pull-request node_type: "gpu-l4-latest-1" arch: "amd64" - container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + container_image: "rapidsai/ci-conda:26.02-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" script: "ci/build_rust.sh" go-build: needs: [conda-cpp-build, changed-files] @@ -271,7 +271,7 @@ jobs: build_type: pull-request node_type: "gpu-l4-latest-1" arch: "amd64" - container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + container_image: "rapidsai/ci-conda:26.02-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" script: "ci/build_go.sh" docs-build: needs: conda-python-build @@ -281,7 +281,7 @@ jobs: build_type: pull-request node_type: "gpu-l4-latest-1" arch: "amd64" - container_image: "rapidsai/ci-conda:25.12-latest" + container_image: "rapidsai/ci-conda:26.02-latest" script: "ci/build_docs.sh" wheel-build-libcuvs: needs: checks diff --git a/.github/workflows/publish-rust.yaml b/.github/workflows/publish-rust.yaml index aa9438e55e..3b7fc41a3b 100644 --- a/.github/workflows/publish-rust.yaml +++ b/.github/workflows/publish-rust.yaml @@ -16,7 +16,7 @@ jobs: cuda_version: - '12.9.1' container: - image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + image: "rapidsai/ci-conda:26.02-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" steps: - uses: actions/checkout@v4 - name: Check if release build diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 28d72b0c74..77648919c7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -68,7 +68,7 @@ jobs: sha: ${{ inputs.sha }} node_type: "gpu-l4-latest-1" arch: "amd64" - container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + container_image: "rapidsai/ci-conda:26.02-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" script: "ci/test_java.sh" wheel-tests-cuvs: secrets: inherit diff --git a/README.md b/README.md index 5dba0cfc38..5da834f4c7 100755 --- a/README.md +++ b/README.md @@ -108,10 +108,10 @@ If installing a version that has not yet been released, the `rapidsai` channel c ```bash # CUDA 13 -conda install -c rapidsai-nightly -c conda-forge cuvs=25.12 cuda-version=13.0 +conda install -c rapidsai-nightly -c conda-forge cuvs=26.02 cuda-version=13.0 # CUDA 12 -conda install -c rapidsai-nightly -c conda-forge cuvs=25.12 cuda-version=12.9 +conda install -c rapidsai-nightly -c conda-forge cuvs=26.02 cuda-version=12.9 ``` cuVS also has `pip` wheel packages that can be installed. Please see the [Build and Install Guide](https://docs.rapids.ai/api/cuvs/nightly/build/) for more information on installing the available cuVS packages and building from source. diff --git a/VERSION b/VERSION index 7924af6192..5c33046aca 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -25.12.00 +26.02.00 diff --git a/conda/environments/all_cuda-129_arch-aarch64.yaml b/conda/environments/all_cuda-129_arch-aarch64.yaml index 9812a26a5d..f5aea13fd0 100644 --- a/conda/environments/all_cuda-129_arch-aarch64.yaml +++ b/conda/environments/all_cuda-129_arch-aarch64.yaml @@ -31,7 +31,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- librmm==25.12.*,>=0.0.0a0 +- librmm==26.2.*,>=0.0.0a0 - make - nccl>=2.19 - ninja @@ -39,7 +39,7 @@ dependencies: - numpydoc - openblas - pre-commit -- pylibraft==25.12.*,>=0.0.0a0 +- pylibraft==26.2.*,>=0.0.0a0 - pytest-cov - pytest<9.0.0a0 - rapids-build-backend>=0.4.0,<0.5.0.dev0 diff --git a/conda/environments/all_cuda-129_arch-x86_64.yaml b/conda/environments/all_cuda-129_arch-x86_64.yaml index 896c08e0e2..65e80d0bc4 100644 --- a/conda/environments/all_cuda-129_arch-x86_64.yaml +++ b/conda/environments/all_cuda-129_arch-x86_64.yaml @@ -31,7 +31,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- librmm==25.12.*,>=0.0.0a0 +- librmm==26.2.*,>=0.0.0a0 - make - nccl>=2.19 - ninja @@ -39,7 +39,7 @@ dependencies: - numpydoc - openblas - pre-commit -- pylibraft==25.12.*,>=0.0.0a0 +- pylibraft==26.2.*,>=0.0.0a0 - pytest-cov - pytest<9.0.0a0 - rapids-build-backend>=0.4.0,<0.5.0.dev0 diff --git a/conda/environments/all_cuda-130_arch-aarch64.yaml b/conda/environments/all_cuda-130_arch-aarch64.yaml index c9f180e849..da97ddd586 100644 --- a/conda/environments/all_cuda-130_arch-aarch64.yaml +++ b/conda/environments/all_cuda-130_arch-aarch64.yaml @@ -31,7 +31,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- librmm==25.12.*,>=0.0.0a0 +- librmm==26.2.*,>=0.0.0a0 - make - nccl>=2.19 - ninja @@ -39,7 +39,7 @@ dependencies: - numpydoc - openblas - pre-commit -- pylibraft==25.12.*,>=0.0.0a0 +- pylibraft==26.2.*,>=0.0.0a0 - pytest-cov - pytest<9.0.0a0 - rapids-build-backend>=0.4.0,<0.5.0.dev0 diff --git a/conda/environments/all_cuda-130_arch-x86_64.yaml b/conda/environments/all_cuda-130_arch-x86_64.yaml index a464e15db4..cec768aa29 100644 --- a/conda/environments/all_cuda-130_arch-x86_64.yaml +++ b/conda/environments/all_cuda-130_arch-x86_64.yaml @@ -31,7 +31,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- librmm==25.12.*,>=0.0.0a0 +- librmm==26.2.*,>=0.0.0a0 - make - nccl>=2.19 - ninja @@ -39,7 +39,7 @@ dependencies: - numpydoc - openblas - pre-commit -- pylibraft==25.12.*,>=0.0.0a0 +- pylibraft==26.2.*,>=0.0.0a0 - pytest-cov - pytest<9.0.0a0 - rapids-build-backend>=0.4.0,<0.5.0.dev0 diff --git a/conda/environments/bench_ann_cuda-129_arch-aarch64.yaml b/conda/environments/bench_ann_cuda-129_arch-aarch64.yaml index dbe568b842..cf78abc107 100644 --- a/conda/environments/bench_ann_cuda-129_arch-aarch64.yaml +++ b/conda/environments/bench_ann_cuda-129_arch-aarch64.yaml @@ -17,7 +17,7 @@ dependencies: - cuda-python>=12.9.2,<13.0a0 - cuda-version=12.9 - cupy>=13.6.0 -- cuvs==25.12.*,>=0.0.0a0 +- cuvs==26.2.*,>=0.0.0a0 - cxx-compiler - cython>=3.0.0,<3.2.0a0 - dlpack>=0.8,<1.0 @@ -29,15 +29,15 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- librmm==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- librmm==26.2.*,>=0.0.0a0 - matplotlib-base>=3.9 - nccl>=2.19 - ninja - nlohmann_json>=3.12.0 - openblas - pandas -- pylibraft==25.12.*,>=0.0.0a0 +- pylibraft==26.2.*,>=0.0.0a0 - pyyaml - rapids-build-backend>=0.4.0,<0.5.0.dev0 - requests diff --git a/conda/environments/bench_ann_cuda-129_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-129_arch-x86_64.yaml index b14735c696..45219e4ba6 100644 --- a/conda/environments/bench_ann_cuda-129_arch-x86_64.yaml +++ b/conda/environments/bench_ann_cuda-129_arch-x86_64.yaml @@ -17,7 +17,7 @@ dependencies: - cuda-python>=12.9.2,<13.0a0 - cuda-version=12.9 - cupy>=13.6.0 -- cuvs==25.12.*,>=0.0.0a0 +- cuvs==26.2.*,>=0.0.0a0 - cxx-compiler - cython>=3.0.0,<3.2.0a0 - dlpack>=0.8,<1.0 @@ -31,8 +31,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- librmm==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- librmm==26.2.*,>=0.0.0a0 - matplotlib-base>=3.9 - mkl-devel=2023 - nccl>=2.19 @@ -40,7 +40,7 @@ dependencies: - nlohmann_json>=3.12.0 - openblas - pandas -- pylibraft==25.12.*,>=0.0.0a0 +- pylibraft==26.2.*,>=0.0.0a0 - pyyaml - rapids-build-backend>=0.4.0,<0.5.0.dev0 - requests diff --git a/conda/environments/bench_ann_cuda-130_arch-aarch64.yaml b/conda/environments/bench_ann_cuda-130_arch-aarch64.yaml index 6c90edabea..417ab87b88 100644 --- a/conda/environments/bench_ann_cuda-130_arch-aarch64.yaml +++ b/conda/environments/bench_ann_cuda-130_arch-aarch64.yaml @@ -17,7 +17,7 @@ dependencies: - cuda-python>=13.0.1,<14.0a0 - cuda-version=13.0 - cupy>=13.6.0 -- cuvs==25.12.*,>=0.0.0a0 +- cuvs==26.2.*,>=0.0.0a0 - cxx-compiler - cython>=3.0.0,<3.2.0a0 - dlpack>=0.8,<1.0 @@ -29,15 +29,15 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- librmm==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- librmm==26.2.*,>=0.0.0a0 - matplotlib-base>=3.9 - nccl>=2.19 - ninja - nlohmann_json>=3.12.0 - openblas - pandas -- pylibraft==25.12.*,>=0.0.0a0 +- pylibraft==26.2.*,>=0.0.0a0 - pyyaml - rapids-build-backend>=0.4.0,<0.5.0.dev0 - requests diff --git a/conda/environments/bench_ann_cuda-130_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-130_arch-x86_64.yaml index e22a6900ba..30d4e2e7ca 100644 --- a/conda/environments/bench_ann_cuda-130_arch-x86_64.yaml +++ b/conda/environments/bench_ann_cuda-130_arch-x86_64.yaml @@ -17,7 +17,7 @@ dependencies: - cuda-python>=13.0.1,<14.0a0 - cuda-version=13.0 - cupy>=13.6.0 -- cuvs==25.12.*,>=0.0.0a0 +- cuvs==26.2.*,>=0.0.0a0 - cxx-compiler - cython>=3.0.0,<3.2.0a0 - dlpack>=0.8,<1.0 @@ -31,8 +31,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- librmm==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- librmm==26.2.*,>=0.0.0a0 - matplotlib-base>=3.9 - mkl-devel=2023 - nccl>=2.19 @@ -40,7 +40,7 @@ dependencies: - nlohmann_json>=3.12.0 - openblas - pandas -- pylibraft==25.12.*,>=0.0.0a0 +- pylibraft==26.2.*,>=0.0.0a0 - pyyaml - rapids-build-backend>=0.4.0,<0.5.0.dev0 - requests diff --git a/conda/environments/go_cuda-129_arch-aarch64.yaml b/conda/environments/go_cuda-129_arch-aarch64.yaml index b8bf557877..9ce9093e21 100644 --- a/conda/environments/go_cuda-129_arch-aarch64.yaml +++ b/conda/environments/go_cuda-129_arch-aarch64.yaml @@ -24,8 +24,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- libraft==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- libraft==26.2.*,>=0.0.0a0 - nccl>=2.19 - ninja - sysroot_linux-aarch64==2.28 diff --git a/conda/environments/go_cuda-129_arch-x86_64.yaml b/conda/environments/go_cuda-129_arch-x86_64.yaml index adc12d644b..4243077552 100644 --- a/conda/environments/go_cuda-129_arch-x86_64.yaml +++ b/conda/environments/go_cuda-129_arch-x86_64.yaml @@ -24,8 +24,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- libraft==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- libraft==26.2.*,>=0.0.0a0 - nccl>=2.19 - ninja - sysroot_linux-64==2.28 diff --git a/conda/environments/go_cuda-130_arch-aarch64.yaml b/conda/environments/go_cuda-130_arch-aarch64.yaml index ca450a317c..962d5f1079 100644 --- a/conda/environments/go_cuda-130_arch-aarch64.yaml +++ b/conda/environments/go_cuda-130_arch-aarch64.yaml @@ -24,8 +24,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- libraft==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- libraft==26.2.*,>=0.0.0a0 - nccl>=2.19 - ninja - sysroot_linux-aarch64==2.28 diff --git a/conda/environments/go_cuda-130_arch-x86_64.yaml b/conda/environments/go_cuda-130_arch-x86_64.yaml index 5873836633..ca8dc8a88a 100644 --- a/conda/environments/go_cuda-130_arch-x86_64.yaml +++ b/conda/environments/go_cuda-130_arch-x86_64.yaml @@ -24,8 +24,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- libraft==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- libraft==26.2.*,>=0.0.0a0 - nccl>=2.19 - ninja - sysroot_linux-64==2.28 diff --git a/conda/environments/rust_cuda-129_arch-aarch64.yaml b/conda/environments/rust_cuda-129_arch-aarch64.yaml index 28d7701d68..8da31cefbf 100644 --- a/conda/environments/rust_cuda-129_arch-aarch64.yaml +++ b/conda/environments/rust_cuda-129_arch-aarch64.yaml @@ -21,8 +21,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- libraft==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- libraft==26.2.*,>=0.0.0a0 - make - nccl>=2.19 - ninja diff --git a/conda/environments/rust_cuda-129_arch-x86_64.yaml b/conda/environments/rust_cuda-129_arch-x86_64.yaml index a21932185b..3cbf7fad6a 100644 --- a/conda/environments/rust_cuda-129_arch-x86_64.yaml +++ b/conda/environments/rust_cuda-129_arch-x86_64.yaml @@ -21,8 +21,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- libraft==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- libraft==26.2.*,>=0.0.0a0 - make - nccl>=2.19 - ninja diff --git a/conda/environments/rust_cuda-130_arch-aarch64.yaml b/conda/environments/rust_cuda-130_arch-aarch64.yaml index 7533f45e23..c71dff5bba 100644 --- a/conda/environments/rust_cuda-130_arch-aarch64.yaml +++ b/conda/environments/rust_cuda-130_arch-aarch64.yaml @@ -21,8 +21,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- libraft==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- libraft==26.2.*,>=0.0.0a0 - make - nccl>=2.19 - ninja diff --git a/conda/environments/rust_cuda-130_arch-x86_64.yaml b/conda/environments/rust_cuda-130_arch-x86_64.yaml index 0b4dbd7b09..a229c27795 100644 --- a/conda/environments/rust_cuda-130_arch-x86_64.yaml +++ b/conda/environments/rust_cuda-130_arch-x86_64.yaml @@ -21,8 +21,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- libraft==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- libraft==26.2.*,>=0.0.0a0 - make - nccl>=2.19 - ninja diff --git a/dependencies.yaml b/dependencies.yaml index b66e9d8691..6ef7dfd768 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -470,7 +470,7 @@ dependencies: - output_types: [conda, pyproject, requirements] packages: - click - - cuvs==25.12.*,>=0.0.0a0 + - cuvs==26.2.*,>=0.0.0a0 - pandas - pyyaml - requests @@ -497,17 +497,17 @@ dependencies: common: - output_types: conda packages: - - cuvs==25.12.*,>=0.0.0a0 + - cuvs==26.2.*,>=0.0.0a0 depends_on_cuvs_bench: common: - output_types: conda packages: - - cuvs-bench==25.12.*,>=0.0.0a0 + - cuvs-bench==26.2.*,>=0.0.0a0 depends_on_libcuvs: common: - output_types: conda packages: - - &libcuvs_unsuffixed libcuvs==25.12.*,>=0.0.0a0 + - &libcuvs_unsuffixed libcuvs==26.2.*,>=0.0.0a0 - output_types: requirements packages: # pip recognizes the index as a global option for the requirements.txt file @@ -520,23 +520,23 @@ dependencies: cuda: "12.*" cuda_suffixed: "true" packages: - - libcuvs-cu12==25.12.*,>=0.0.0a0 + - libcuvs-cu12==26.2.*,>=0.0.0a0 - matrix: cuda: "13.*" cuda_suffixed: "true" packages: - - libcuvs-cu13==25.12.*,>=0.0.0a0 + - libcuvs-cu13==26.2.*,>=0.0.0a0 - {matrix: null, packages: [*libcuvs_unsuffixed]} depends_on_libcuvs_tests: common: - output_types: conda packages: - - libcuvs-tests==25.12.*,>=0.0.0a0 + - libcuvs-tests==26.2.*,>=0.0.0a0 depends_on_libraft: common: - output_types: conda packages: - - &libraft_unsuffixed libraft==25.12.*,>=0.0.0a0 + - &libraft_unsuffixed libraft==26.2.*,>=0.0.0a0 - output_types: requirements packages: # pip recognizes the index as a global option for the requirements.txt file @@ -549,18 +549,18 @@ dependencies: cuda: "12.*" cuda_suffixed: "true" packages: - - libraft-cu12==25.12.*,>=0.0.0a0 + - libraft-cu12==26.2.*,>=0.0.0a0 - matrix: cuda: "13.*" cuda_suffixed: "true" packages: - - libraft-cu13==25.12.*,>=0.0.0a0 + - libraft-cu13==26.2.*,>=0.0.0a0 - {matrix: null, packages: [*libraft_unsuffixed]} depends_on_librmm: common: - output_types: conda packages: - - &librmm_unsuffixed librmm==25.12.*,>=0.0.0a0 + - &librmm_unsuffixed librmm==26.2.*,>=0.0.0a0 - output_types: requirements packages: # pip recognizes the index as a global option for the requirements.txt file @@ -573,18 +573,18 @@ dependencies: cuda: "12.*" cuda_suffixed: "true" packages: - - librmm-cu12==25.12.*,>=0.0.0a0 + - librmm-cu12==26.2.*,>=0.0.0a0 - matrix: cuda: "13.*" cuda_suffixed: "true" packages: - - librmm-cu13==25.12.*,>=0.0.0a0 + - librmm-cu13==26.2.*,>=0.0.0a0 - {matrix: null, packages: [*librmm_unsuffixed]} depends_on_pylibraft: common: - output_types: conda packages: - - &pylibraft_unsuffixed pylibraft==25.12.*,>=0.0.0a0 + - &pylibraft_unsuffixed pylibraft==26.2.*,>=0.0.0a0 - output_types: requirements packages: # pip recognizes the index as a global option for the requirements.txt file @@ -597,12 +597,12 @@ dependencies: cuda: "12.*" cuda_suffixed: "true" packages: - - pylibraft-cu12==25.12.*,>=0.0.0a0 + - pylibraft-cu12==26.2.*,>=0.0.0a0 - matrix: cuda: "13.*" cuda_suffixed: "true" packages: - - pylibraft-cu13==25.12.*,>=0.0.0a0 + - pylibraft-cu13==26.2.*,>=0.0.0a0 - {matrix: null, packages: [*pylibraft_unsuffixed]} depends_on_nccl: common: diff --git a/docs/source/cuvs_bench/index.rst b/docs/source/cuvs_bench/index.rst index 16914ac596..cc5f2731c6 100644 --- a/docs/source/cuvs_bench/index.rst +++ b/docs/source/cuvs_bench/index.rst @@ -89,7 +89,7 @@ The following command pulls the nightly container for Python version 3.10, CUDA .. code-block:: bash - docker pull rapidsai/cuvs-bench:25.12a-cuda12.5-py3.10 # substitute cuvs-bench for the exact desired container. + docker pull rapidsai/cuvs-bench:26.02a-cuda12.5-py3.10 # substitute cuvs-bench for the exact desired container. The CUDA and python versions can be changed for the supported values: - Supported CUDA versions: 12 @@ -237,7 +237,7 @@ For GPU-enabled systems, the `DATA_FOLDER` variable should be a local folder whe export DATA_FOLDER=path/to/store/datasets/and/results docker run --gpus all --rm -it -u $(id -u) \ -v $DATA_FOLDER:/data/benchmarks \ - rapidsai/cuvs-bench:25.12-cuda12.9-py3.13 \ + rapidsai/cuvs-bench:26.02-cuda12.9-py3.13 \ "--dataset deep-image-96-angular" \ "--normalize" \ "--algorithms cuvs_cagra,cuvs_ivf_pq --batch-size 10 -k 10" \ @@ -250,7 +250,7 @@ Usage of the above command is as follows: * - Argument - Description - * - `rapidsai/cuvs-bench:25.12-cuda12.9-py3.13` + * - `rapidsai/cuvs-bench:26.02-cuda12.9-py3.13` - Image to use. Can be either `cuvs-bench` or `cuvs-bench-datasets` * - `"--dataset deep-image-96-angular"` @@ -297,7 +297,7 @@ All of the `cuvs-bench` images contain the Conda packages, so they can be used d --entrypoint /bin/bash \ --workdir /data/benchmarks \ -v $DATA_FOLDER:/data/benchmarks \ - rapidsai/cuvs-bench:25.12-cuda12.9-py3.13 + rapidsai/cuvs-bench:26.02-cuda12.9-py3.13 This will drop you into a command line in the container, with the `cuvs-bench` python package ready to use, as described in the [Running the benchmarks](#running-the-benchmarks) section above: diff --git a/examples/go/README.md b/examples/go/README.md index f49020de62..2588ae19ce 100644 --- a/examples/go/README.md +++ b/examples/go/README.md @@ -24,7 +24,7 @@ export CC=clang 2. Install the Go module: ```bash -go get github.com/rapidsai/cuvs/go@v25.12.00 # 25.02.00 being your desired version, selected from https://github.com/rapidsai/cuvs/tags +go get github.com/rapidsai/cuvs/go@v26.02.00 # 25.02.00 being your desired version, selected from https://github.com/rapidsai/cuvs/tags ``` Then you can build your project with the usual `go build`. diff --git a/java/benchmarks/pom.xml b/java/benchmarks/pom.xml index 45588933c5..52cf0130e0 100644 --- a/java/benchmarks/pom.xml +++ b/java/benchmarks/pom.xml @@ -10,7 +10,7 @@ com.nvidia.cuvs benchmarks - 25.12.0 + 26.02.0 jar cuvs-java-benchmarks @@ -30,7 +30,7 @@ com.nvidia.cuvs cuvs-java - 25.12.0 + 26.02.0 jar diff --git a/java/build.sh b/java/build.sh index d40e97adef..339857bfe8 100755 --- a/java/build.sh +++ b/java/build.sh @@ -8,7 +8,7 @@ set -e -u -o pipefail ARGS="$*" NUMARGS=$# -VERSION="25.12.0" # Note: The version is updated automatically when ci/release/update-version.sh is invoked +VERSION="26.02.0" # Note: The version is updated automatically when ci/release/update-version.sh is invoked GROUP_ID="com.nvidia.cuvs" # Identify CUDA major version. diff --git a/java/cuvs-java/pom.xml b/java/cuvs-java/pom.xml index 99d0eb5e09..d0eb079fe9 100644 --- a/java/cuvs-java/pom.xml +++ b/java/cuvs-java/pom.xml @@ -11,7 +11,7 @@ com.nvidia.cuvs cuvs-java - 25.12.0 + 26.02.0 cuvs-java This project provides Java bindings for cuVS, enabling approximate nearest neighbors search and clustering diff --git a/java/examples/README.md b/java/examples/README.md index 9a48ad6ea1..58f7acdbdb 100644 --- a/java/examples/README.md +++ b/java/examples/README.md @@ -11,17 +11,17 @@ This maven project contains examples for CAGRA, HNSW, and Bruteforce algorithms. ### CAGRA Example In the current directory do: ``` -mvn package && java --enable-native-access=ALL-UNNAMED -cp target/cuvs-java-examples-25.12.0.jar:$HOME/.m2/repository/com/nvidia/cuvs/cuvs-java/25.12.0/cuvs-java-25.12.0.jar com.nvidia.cuvs.examples.CagraExample +mvn package && java --enable-native-access=ALL-UNNAMED -cp target/cuvs-java-examples-26.02.0.jar:$HOME/.m2/repository/com/nvidia/cuvs/cuvs-java/26.02.0/cuvs-java-26.02.0.jar com.nvidia.cuvs.examples.CagraExample ``` ### HNSW Example In the current directory do: ``` -mvn package && java --enable-native-access=ALL-UNNAMED -cp target/cuvs-java-examples-25.12.0.jar:$HOME/.m2/repository/com/nvidia/cuvs/cuvs-java/25.12.0/cuvs-java-25.12.0.jar com.nvidia.cuvs.examples.HnswExample +mvn package && java --enable-native-access=ALL-UNNAMED -cp target/cuvs-java-examples-26.02.0.jar:$HOME/.m2/repository/com/nvidia/cuvs/cuvs-java/26.02.0/cuvs-java-26.02.0.jar com.nvidia.cuvs.examples.HnswExample ``` ### Bruteforce Example In the current directory do: ``` -mvn package && java --enable-native-access=ALL-UNNAMED -cp target/cuvs-java-examples-25.12.0.jar:$HOME/.m2/repository/com/nvidia/cuvs/cuvs-java/25.12.0/cuvs-java-25.12.0.jar com.nvidia.cuvs.examples.BruteForceExample +mvn package && java --enable-native-access=ALL-UNNAMED -cp target/cuvs-java-examples-26.02.0.jar:$HOME/.m2/repository/com/nvidia/cuvs/cuvs-java/26.02.0/cuvs-java-26.02.0.jar com.nvidia.cuvs.examples.BruteForceExample ``` diff --git a/java/examples/pom.xml b/java/examples/pom.xml index 8ab8a7a560..16b1b6ede6 100644 --- a/java/examples/pom.xml +++ b/java/examples/pom.xml @@ -1,3 +1,8 @@ + + @@ -5,7 +10,7 @@ com.nvidia.cuvs.examples cuvs-java-examples - 25.12.0 + 26.02.0 cuvs-java-examples @@ -18,7 +23,7 @@ com.nvidia.cuvs cuvs-java - 25.12.0 + 26.02.0 diff --git a/python/cuvs/pyproject.toml b/python/cuvs/pyproject.toml index 3d0ebe2cd8..38ee2b6f12 100644 --- a/python/cuvs/pyproject.toml +++ b/python/cuvs/pyproject.toml @@ -21,9 +21,9 @@ license = { text = "Apache-2.0" } requires-python = ">=3.10" dependencies = [ "cuda-python>=13.0.1,<14.0a0", - "libcuvs==25.12.*,>=0.0.0a0", + "libcuvs==26.2.*,>=0.0.0a0", "numpy>=1.23,<3.0a0", - "pylibraft==25.12.*,>=0.0.0a0", + "pylibraft==26.2.*,>=0.0.0a0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. classifiers = [ "Intended Audience :: Developers", @@ -108,9 +108,9 @@ requires = [ "cmake>=3.30.4", "cuda-python>=13.0.1,<14.0a0", "cython>=3.0.0,<3.2.0a0", - "libcuvs==25.12.*,>=0.0.0a0", - "libraft==25.12.*,>=0.0.0a0", - "librmm==25.12.*,>=0.0.0a0", + "libcuvs==26.2.*,>=0.0.0a0", + "libraft==26.2.*,>=0.0.0a0", + "librmm==26.2.*,>=0.0.0a0", "ninja", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. build-backend = "scikit_build_core.build" diff --git a/python/cuvs_bench/pyproject.toml b/python/cuvs_bench/pyproject.toml index ce77211992..d7d8e3b891 100644 --- a/python/cuvs_bench/pyproject.toml +++ b/python/cuvs_bench/pyproject.toml @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 [build-system] @@ -20,7 +20,7 @@ license = { text = "Apache-2.0" } requires-python = ">=3.10" dependencies = [ "click", - "cuvs==25.12.*,>=0.0.0a0", + "cuvs==26.2.*,>=0.0.0a0", "matplotlib>=3.9", "pandas", "pyyaml", diff --git a/python/libcuvs/pyproject.toml b/python/libcuvs/pyproject.toml index 9690708c27..cc60040c5a 100644 --- a/python/libcuvs/pyproject.toml +++ b/python/libcuvs/pyproject.toml @@ -20,8 +20,8 @@ license = { text = "Apache-2.0" } requires-python = ">=3.10" dependencies = [ "cuda-toolkit[cublas,curand,cusolver,cusparse]>=12,<14", - "libraft==25.12.*,>=0.0.0a0", - "librmm==25.12.*,>=0.0.0a0", + "libraft==26.2.*,>=0.0.0a0", + "librmm==26.2.*,>=0.0.0a0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. classifiers = [ "Intended Audience :: Developers", @@ -79,8 +79,8 @@ regex = "(?P.*)" build-backend = "scikit_build_core.build" requires = [ "cmake>=3.30.4", - "libraft==25.12.*,>=0.0.0a0", - "librmm==25.12.*,>=0.0.0a0", + "libraft==26.2.*,>=0.0.0a0", + "librmm==26.2.*,>=0.0.0a0", "ninja", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. dependencies-file = "../../dependencies.yaml" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 3e45ac65ba..2ad456db53 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -6,7 +6,7 @@ members = [ resolver = "2" [workspace.package] -version = "25.12.0" +version = "26.2.0" edition = "2021" repository = "https://github.com/rapidsai/cuvs" homepage = "https://github.com/rapidsai/cuvs" diff --git a/rust/cuvs/Cargo.toml b/rust/cuvs/Cargo.toml index 30429f814c..62b6d51391 100644 --- a/rust/cuvs/Cargo.toml +++ b/rust/cuvs/Cargo.toml @@ -9,7 +9,7 @@ authors.workspace = true license.workspace = true [dependencies] -ffi = { package = "cuvs-sys", path = "../cuvs-sys", version = "25.12.0" } +ffi = { package = "cuvs-sys", path = "../cuvs-sys", version = "26.2.0" } ndarray = "0.15" [dev-dependencies] From 82c3ab3ee07fd98b5e6dc738dc6d9315d1f84971 Mon Sep 17 00:00:00 2001 From: Jake Awe <50372925+AyodeAwe@users.noreply.github.com> Date: Mon, 17 Nov 2025 08:58:59 -0600 Subject: [PATCH 12/32] Revert "Forward-merge release/25.12 into main" (#1547) Reverts rapidsai/cuvs#1546 --- .github/workflows/build.yaml | 24 ++++++------- .github/workflows/pr.yaml | 36 +++++++++---------- .github/workflows/test.yaml | 10 +++--- .../trigger-breaking-change-alert.yaml | 2 +- RAPIDS_BRANCH | 2 +- README.md | 4 +-- docs/source/developer_guide.md | 4 +-- python/cuvs_bench/cuvs_bench/plot/__main__.py | 12 +++---- 8 files changed, 47 insertions(+), 47 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index edb2f54000..2f3146ac48 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -34,7 +34,7 @@ concurrency: jobs: cpp-build: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@main with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -44,7 +44,7 @@ jobs: rocky8-clib-standalone-build: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main strategy: fail-fast: false matrix: @@ -67,7 +67,7 @@ jobs: rust-build: needs: cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. strategy: @@ -88,7 +88,7 @@ jobs: go-build: needs: cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. strategy: @@ -109,7 +109,7 @@ jobs: java-build: needs: cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. strategy: @@ -131,7 +131,7 @@ jobs: python-build: needs: [cpp-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@main with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -141,7 +141,7 @@ jobs: upload-conda: needs: [cpp-build, python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@main with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -156,7 +156,7 @@ jobs: if: github.ref_type == 'branch' needs: python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main with: arch: "amd64" branch: ${{ inputs.branch }} @@ -168,7 +168,7 @@ jobs: sha: ${{ inputs.sha }} wheel-build-libcuvs: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@main with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -182,7 +182,7 @@ jobs: wheel-publish-libcuvs: needs: wheel-build-libcuvs secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@main with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -193,7 +193,7 @@ jobs: wheel-build-cuvs: needs: wheel-build-libcuvs secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@main with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -205,7 +205,7 @@ jobs: wheel-publish-cuvs: needs: wheel-build-cuvs secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@main with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 2e6916313d..1c14b155d4 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -29,7 +29,7 @@ jobs: - devcontainer - telemetry-setup secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@main if: always() with: needs: ${{ toJSON(needs) }} @@ -56,7 +56,7 @@ jobs: changed-files: needs: telemetry-setup secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/changed-files.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/changed-files.yaml@main with: files_yaml: | test_cpp: @@ -132,14 +132,14 @@ jobs: checks: needs: telemetry-setup secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@main with: enable_check_generated_files: false ignored_pr_jobs: "telemetry-summarize" conda-cpp-build: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@main with: build_type: pull-request node_type: cpu16 @@ -147,7 +147,7 @@ jobs: conda-cpp-tests: needs: [conda-cpp-build, changed-files] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@main if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_cpp with: build_type: pull-request @@ -155,21 +155,21 @@ jobs: conda-cpp-checks: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@main with: build_type: pull-request symbol_exclusions: (void (thrust::|cub::)) conda-python-build: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@main with: build_type: pull-request script: ci/build_python.sh conda-python-tests: needs: [conda-python-build, changed-files] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@main if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_python with: build_type: pull-request @@ -177,7 +177,7 @@ jobs: rocky8-clib-standalone-build: needs: [checks] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main strategy: fail-fast: false matrix: @@ -198,7 +198,7 @@ jobs: rocky8-clib-tests: needs: [rocky8-clib-standalone-build, changed-files] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_cpp strategy: fail-fast: false @@ -217,7 +217,7 @@ jobs: conda-java-build-and-tests: needs: [conda-cpp-build, changed-files] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_java || fromJSON(needs.changed-files.outputs.changed_file_groups).test_cpp # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. @@ -238,7 +238,7 @@ jobs: rust-build: needs: [conda-cpp-build, changed-files] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_rust || fromJSON(needs.changed-files.outputs.changed_file_groups).test_cpp # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. @@ -257,7 +257,7 @@ jobs: go-build: needs: [conda-cpp-build, changed-files] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_go || fromJSON(needs.changed-files.outputs.changed_file_groups).test_cpp # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. @@ -276,7 +276,7 @@ jobs: docs-build: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main with: build_type: pull-request node_type: "gpu-l4-latest-1" @@ -286,7 +286,7 @@ jobs: wheel-build-libcuvs: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@main with: build_type: pull-request script: ci/build_wheel_libcuvs.sh @@ -297,7 +297,7 @@ jobs: wheel-build-cuvs: needs: wheel-build-libcuvs secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@main with: build_type: pull-request script: ci/build_wheel_cuvs.sh @@ -306,7 +306,7 @@ jobs: wheel-tests-cuvs: needs: [wheel-build-cuvs, changed-files] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@main if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_python with: build_type: pull-request @@ -314,7 +314,7 @@ jobs: devcontainer: secrets: inherit needs: telemetry-setup - uses: rapidsai/shared-workflows/.github/workflows/build-in-devcontainer.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/build-in-devcontainer.yaml@main with: arch: '["amd64", "arm64"]' cuda: '["13.0"]' diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index f97af13372..77648919c7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -25,7 +25,7 @@ on: jobs: conda-cpp-checks: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@main with: build_type: ${{ inputs.build_type }} branch: ${{ inputs.branch }} @@ -34,7 +34,7 @@ jobs: symbol_exclusions: (void (thrust::|cub::)) conda-cpp-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@main with: build_type: ${{ inputs.build_type }} branch: ${{ inputs.branch }} @@ -43,7 +43,7 @@ jobs: sha: ${{ inputs.sha }} conda-python-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@main with: build_type: ${{ inputs.build_type }} branch: ${{ inputs.branch }} @@ -52,7 +52,7 @@ jobs: sha: ${{ inputs.sha }} conda-java-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. strategy: @@ -72,7 +72,7 @@ jobs: script: "ci/test_java.sh" wheel-tests-cuvs: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@main with: build_type: ${{ inputs.build_type }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/trigger-breaking-change-alert.yaml b/.github/workflows/trigger-breaking-change-alert.yaml index 0b885544da..c471e2a151 100644 --- a/.github/workflows/trigger-breaking-change-alert.yaml +++ b/.github/workflows/trigger-breaking-change-alert.yaml @@ -12,7 +12,7 @@ jobs: trigger-notifier: if: contains(github.event.pull_request.labels.*.name, 'breaking') secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/breaking-change-alert.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/breaking-change-alert.yaml@main with: sender_login: ${{ github.event.sender.login }} sender_avatar: ${{ github.event.sender.avatar_url }} diff --git a/RAPIDS_BRANCH b/RAPIDS_BRANCH index 26b84372d3..ba2906d066 100644 --- a/RAPIDS_BRANCH +++ b/RAPIDS_BRANCH @@ -1 +1 @@ -release/25.12 +main diff --git a/README.md b/README.md index 601b903374..5da834f4c7 100755 --- a/README.md +++ b/README.md @@ -171,7 +171,7 @@ cuvsCagraIndexParamsDestroy(index_params); cuvsResourcesDestroy(res); ``` -For more code examples of the C APIs, including drop-in Cmake project templates, please refer to the [C examples](https://github.com/rapidsai/cuvs/tree/release/25.12/examples/c) +For more code examples of the C APIs, including drop-in Cmake project templates, please refer to the [C examples](https://github.com/rapidsai/cuvs/tree/main/examples/c) ### Rust API @@ -234,7 +234,7 @@ fn cagra_example() -> Result<()> { } ``` -For more code examples of the Rust APIs, including a drop-in project templates, please refer to the [Rust examples](https://github.com/rapidsai/cuvs/tree/release/25.12/examples/rust). +For more code examples of the Rust APIs, including a drop-in project templates, please refer to the [Rust examples](https://github.com/rapidsai/cuvs/tree/main/examples/rust). ## Contributing diff --git a/docs/source/developer_guide.md b/docs/source/developer_guide.md index e4081842d2..da50a44d27 100644 --- a/docs/source/developer_guide.md +++ b/docs/source/developer_guide.md @@ -187,7 +187,7 @@ RAFT relies on `clang-format` to enforce code style across all C++ and CUDA sour 1. Do not split empty functions/records/namespaces. 2. Two-space indentation everywhere, including the line continuations. 3. Disable reflowing of comments. - The reasons behind these deviations from the Google style guide are given in comments [here](https://github.com/rapidsai/cuvs/blob/release/25.12/cpp/.clang-format). + The reasons behind these deviations from the Google style guide are given in comments [here](https://github.com/rapidsai/cuvs/blob/main/cpp/.clang-format). [`doxygen`](https://doxygen.nl/) is used as documentation generator and also as a documentation linter. In order to run doxygen as a linter on C++/CUDA code, run @@ -205,7 +205,7 @@ you can run `codespell -i 3 -w .` from the repository root directory. This will bring up an interactive prompt to select which spelling fixes to apply. ### #include style -[include_checker.py](https://github.com/rapidsai/cuvs/blob/release/25.12/cpp/scripts/include_checker.py) is used to enforce the include style as follows: +[include_checker.py](https://github.com/rapidsai/cuvs/blob/main/cpp/scripts/include_checker.py) is used to enforce the include style as follows: 1. `#include "..."` should be used for referencing local files only. It is acceptable to be used for referencing files in a sub-folder/parent-folder of the same algorithm, but should never be used to include files in other algorithms or between algorithms and the primitives or other dependencies. 2. `#include <...>` should be used for referencing everything else diff --git a/python/cuvs_bench/cuvs_bench/plot/__main__.py b/python/cuvs_bench/cuvs_bench/plot/__main__.py index ddf687d38b..aca08505ea 100644 --- a/python/cuvs_bench/cuvs_bench/plot/__main__.py +++ b/python/cuvs_bench/cuvs_bench/plot/__main__.py @@ -3,10 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 # This script is inspired by -# 1: https://github.com/erikbern/ann-benchmarks/blob/release/25.12/plot.py -# 2: https://github.com/erikbern/ann-benchmarks/blob/release/25.12/ann_benchmarks/plotting/utils.py # noqa: E501 -# 3: https://github.com/erikbern/ann-benchmarks/blob/release/25.12/ann_benchmarks/plotting/metrics.py # noqa: E501 -# License: https://github.com/rapidsai/cuvs/blob/release/25.12/thirdparty/LICENSES/LICENSE.ann-benchmark # noqa: E501 +# 1: https://github.com/erikbern/ann-benchmarks/blob/main/plot.py +# 2: https://github.com/erikbern/ann-benchmarks/blob/main/ann_benchmarks/plotting/utils.py # noqa: E501 +# 3: https://github.com/erikbern/ann-benchmarks/blob/main/ann_benchmarks/plotting/metrics.py # noqa: E501 +# License: https://github.com/rapidsai/cuvs/blob/main/thirdparty/LICENSES/LICENSE.ann-benchmark # noqa: E501 import itertools import os @@ -502,7 +502,7 @@ def load_all_results( is_flag=True, help="Show raw results (not just Pareto frontier) of the mode argument.", ) -def release/25.12( +def main( dataset: str, dataset_path: str, output_filepath: str, @@ -603,4 +603,4 @@ def release/25.12( if __name__ == "__main__": - release/25.12() + main() From dc307dfde3cf3b4cd8e742a64e8696ead2c3f613 Mon Sep 17 00:00:00 2001 From: Jake Awe <50372925+AyodeAwe@users.noreply.github.com> Date: Mon, 17 Nov 2025 13:58:25 -0600 Subject: [PATCH 13/32] Revert "Forward-merge release/25.12 into main" (#1553) Reverts rapidsai/cuvs#1552 --- .../cuda12.9-conda/devcontainer.json | 4 +- .devcontainer/cuda12.9-pip/devcontainer.json | 6 +-- .../cuda13.0-conda/devcontainer.json | 4 +- .devcontainer/cuda13.0-pip/devcontainer.json | 6 +-- .github/workflows/build.yaml | 34 ++++++------- .github/workflows/pr.yaml | 48 +++++++++---------- .github/workflows/publish-rust.yaml | 2 +- .github/workflows/test.yaml | 12 ++--- .../trigger-breaking-change-alert.yaml | 2 +- RAPIDS_BRANCH | 2 +- README.md | 8 ++-- VERSION | 2 +- .../all_cuda-129_arch-aarch64.yaml | 4 +- .../all_cuda-129_arch-x86_64.yaml | 4 +- .../all_cuda-130_arch-aarch64.yaml | 4 +- .../all_cuda-130_arch-x86_64.yaml | 4 +- .../bench_ann_cuda-129_arch-aarch64.yaml | 8 ++-- .../bench_ann_cuda-129_arch-x86_64.yaml | 8 ++-- .../bench_ann_cuda-130_arch-aarch64.yaml | 8 ++-- .../bench_ann_cuda-130_arch-x86_64.yaml | 8 ++-- .../go_cuda-129_arch-aarch64.yaml | 4 +- .../environments/go_cuda-129_arch-x86_64.yaml | 4 +- .../go_cuda-130_arch-aarch64.yaml | 4 +- .../environments/go_cuda-130_arch-x86_64.yaml | 4 +- .../rust_cuda-129_arch-aarch64.yaml | 4 +- .../rust_cuda-129_arch-x86_64.yaml | 4 +- .../rust_cuda-130_arch-aarch64.yaml | 4 +- .../rust_cuda-130_arch-x86_64.yaml | 4 +- dependencies.yaml | 32 ++++++------- docs/source/cuvs_bench/index.rst | 8 ++-- docs/source/developer_guide.md | 4 +- examples/go/README.md | 2 +- java/benchmarks/pom.xml | 4 +- java/build.sh | 2 +- java/cuvs-java/pom.xml | 2 +- java/examples/README.md | 6 +-- java/examples/pom.xml | 4 +- python/cuvs/pyproject.toml | 10 ++-- python/cuvs_bench/cuvs_bench/plot/__main__.py | 12 ++--- python/cuvs_bench/pyproject.toml | 2 +- python/libcuvs/pyproject.toml | 8 ++-- rust/Cargo.toml | 2 +- rust/cuvs/Cargo.toml | 2 +- 43 files changed, 155 insertions(+), 155 deletions(-) diff --git a/.devcontainer/cuda12.9-conda/devcontainer.json b/.devcontainer/cuda12.9-conda/devcontainer.json index 6dd88581cb..7528d19967 100644 --- a/.devcontainer/cuda12.9-conda/devcontainer.json +++ b/.devcontainer/cuda12.9-conda/devcontainer.json @@ -5,7 +5,7 @@ "args": { "CUDA": "12.9", "PYTHON_PACKAGE_MANAGER": "conda", - "BASE": "rapidsai/devcontainers:25.12-cpp-mambaforge" + "BASE": "rapidsai/devcontainers:26.02-cpp-mambaforge" } }, "runArgs": [ @@ -17,7 +17,7 @@ ], "hostRequirements": {"gpu": "optional"}, "features": { - "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:25.12": {} + "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:26.2": {} }, "overrideFeatureInstallOrder": [ "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils" diff --git a/.devcontainer/cuda12.9-pip/devcontainer.json b/.devcontainer/cuda12.9-pip/devcontainer.json index ef3e78f2c5..652d997405 100644 --- a/.devcontainer/cuda12.9-pip/devcontainer.json +++ b/.devcontainer/cuda12.9-pip/devcontainer.json @@ -5,7 +5,7 @@ "args": { "CUDA": "12.9", "PYTHON_PACKAGE_MANAGER": "pip", - "BASE": "rapidsai/devcontainers:25.12-cpp-cuda12.9-ucx1.19.0-openmpi5.0.7" + "BASE": "rapidsai/devcontainers:26.02-cpp-cuda12.9-ucx1.19.0-openmpi5.0.7" } }, "runArgs": [ @@ -17,14 +17,14 @@ ], "hostRequirements": {"gpu": "optional"}, "features": { - "ghcr.io/rapidsai/devcontainers/features/cuda:25.12": { + "ghcr.io/rapidsai/devcontainers/features/cuda:26.2": { "version": "12.9", "installcuBLAS": true, "installcuSOLVER": true, "installcuRAND": true, "installcuSPARSE": true }, - "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:25.12": {} + "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:26.2": {} }, "overrideFeatureInstallOrder": [ "ghcr.io/rapidsai/devcontainers/features/ucx", diff --git a/.devcontainer/cuda13.0-conda/devcontainer.json b/.devcontainer/cuda13.0-conda/devcontainer.json index ddd13e728a..5c0beccf9c 100644 --- a/.devcontainer/cuda13.0-conda/devcontainer.json +++ b/.devcontainer/cuda13.0-conda/devcontainer.json @@ -5,7 +5,7 @@ "args": { "CUDA": "13.0", "PYTHON_PACKAGE_MANAGER": "conda", - "BASE": "rapidsai/devcontainers:25.12-cpp-mambaforge" + "BASE": "rapidsai/devcontainers:26.02-cpp-mambaforge" } }, "runArgs": [ @@ -17,7 +17,7 @@ ], "hostRequirements": {"gpu": "optional"}, "features": { - "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:25.12": {} + "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:26.2": {} }, "overrideFeatureInstallOrder": [ "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils" diff --git a/.devcontainer/cuda13.0-pip/devcontainer.json b/.devcontainer/cuda13.0-pip/devcontainer.json index ee0044aa06..88b6bc9def 100644 --- a/.devcontainer/cuda13.0-pip/devcontainer.json +++ b/.devcontainer/cuda13.0-pip/devcontainer.json @@ -5,7 +5,7 @@ "args": { "CUDA": "13.0", "PYTHON_PACKAGE_MANAGER": "pip", - "BASE": "rapidsai/devcontainers:25.12-cpp-cuda13.0-ucx1.19.0-openmpi5.0.7" + "BASE": "rapidsai/devcontainers:26.02-cpp-cuda13.0-ucx1.19.0-openmpi5.0.7" } }, "runArgs": [ @@ -17,14 +17,14 @@ ], "hostRequirements": {"gpu": "optional"}, "features": { - "ghcr.io/rapidsai/devcontainers/features/cuda:25.12": { + "ghcr.io/rapidsai/devcontainers/features/cuda:26.2": { "version": "13.0", "installcuBLAS": true, "installcuSOLVER": true, "installcuRAND": true, "installcuSPARSE": true }, - "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:25.12": {} + "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:26.2": {} }, "overrideFeatureInstallOrder": [ "ghcr.io/rapidsai/devcontainers/features/ucx", diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 0f9499a79d..2f3146ac48 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -34,7 +34,7 @@ concurrency: jobs: cpp-build: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@main with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -44,7 +44,7 @@ jobs: rocky8-clib-standalone-build: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main strategy: fail-fast: false matrix: @@ -56,7 +56,7 @@ jobs: branch: ${{ inputs.branch }} arch: "amd64" date: ${{ inputs.date }} - container_image: "rapidsai/ci-wheel:25.12-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" + container_image: "rapidsai/ci-wheel:26.02-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" node_type: "cpu16" name: "${{ matrix.cuda_version }}, amd64, rockylinux8" # requires_license_builder: false @@ -67,7 +67,7 @@ jobs: rust-build: needs: cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. strategy: @@ -81,14 +81,14 @@ jobs: branch: ${{ inputs.branch }} arch: "amd64" date: ${{ inputs.date }} - container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + container_image: "rapidsai/ci-conda:26.02-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" node_type: "gpu-l4-latest-1" script: "ci/build_rust.sh" sha: ${{ inputs.sha }} go-build: needs: cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. strategy: @@ -102,14 +102,14 @@ jobs: branch: ${{ inputs.branch }} arch: "amd64" date: ${{ inputs.date }} - container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + container_image: "rapidsai/ci-conda:26.02-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" node_type: "gpu-l4-latest-1" script: "ci/build_go.sh" sha: ${{ inputs.sha }} java-build: needs: cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. strategy: @@ -123,7 +123,7 @@ jobs: branch: ${{ inputs.branch }} arch: "amd64" date: ${{ inputs.date }} - container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + container_image: "rapidsai/ci-conda:26.02-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" script: "ci/build_java.sh" artifact-name: "cuvs-java-cuda${{ matrix.cuda_version }}" file_to_upload: "java/cuvs-java/target/" @@ -131,7 +131,7 @@ jobs: python-build: needs: [cpp-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@main with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -141,7 +141,7 @@ jobs: upload-conda: needs: [cpp-build, python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@main with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -156,19 +156,19 @@ jobs: if: github.ref_type == 'branch' needs: python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main with: arch: "amd64" branch: ${{ inputs.branch }} build_type: ${{ inputs.build_type || 'branch' }} - container_image: "rapidsai/ci-conda:25.12-latest" + container_image: "rapidsai/ci-conda:26.02-latest" date: ${{ inputs.date }} node_type: "gpu-l4-latest-1" script: "ci/build_docs.sh" sha: ${{ inputs.sha }} wheel-build-libcuvs: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@main with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -182,7 +182,7 @@ jobs: wheel-publish-libcuvs: needs: wheel-build-libcuvs secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@main with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -193,7 +193,7 @@ jobs: wheel-build-cuvs: needs: wheel-build-libcuvs secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@main with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -205,7 +205,7 @@ jobs: wheel-publish-cuvs: needs: wheel-build-cuvs secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@main with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index d4bca44463..1c14b155d4 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -29,7 +29,7 @@ jobs: - devcontainer - telemetry-setup secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@main if: always() with: needs: ${{ toJSON(needs) }} @@ -56,7 +56,7 @@ jobs: changed-files: needs: telemetry-setup secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/changed-files.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/changed-files.yaml@main with: files_yaml: | test_cpp: @@ -132,14 +132,14 @@ jobs: checks: needs: telemetry-setup secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@main with: enable_check_generated_files: false ignored_pr_jobs: "telemetry-summarize" conda-cpp-build: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@main with: build_type: pull-request node_type: cpu16 @@ -147,7 +147,7 @@ jobs: conda-cpp-tests: needs: [conda-cpp-build, changed-files] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@main if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_cpp with: build_type: pull-request @@ -155,21 +155,21 @@ jobs: conda-cpp-checks: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@main with: build_type: pull-request symbol_exclusions: (void (thrust::|cub::)) conda-python-build: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@main with: build_type: pull-request script: ci/build_python.sh conda-python-tests: needs: [conda-python-build, changed-files] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@main if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_python with: build_type: pull-request @@ -177,7 +177,7 @@ jobs: rocky8-clib-standalone-build: needs: [checks] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main strategy: fail-fast: false matrix: @@ -188,7 +188,7 @@ jobs: build_type: pull-request arch: "amd64" date: ${{ inputs.date }}_c - container_image: "rapidsai/ci-wheel:25.12-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" + container_image: "rapidsai/ci-wheel:26.02-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" node_type: "cpu16" # requires_license_builder: false script: "ci/build_standalone_c.sh --build-tests" @@ -198,7 +198,7 @@ jobs: rocky8-clib-tests: needs: [rocky8-clib-standalone-build, changed-files] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_cpp strategy: fail-fast: false @@ -211,13 +211,13 @@ jobs: node_type: "gpu-l4-latest-1" arch: "amd64" date: ${{ inputs.date }}_c - container_image: "rapidsai/ci-wheel:25.12-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" + container_image: "rapidsai/ci-wheel:26.02-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" script: "ci/test_standalone_c.sh" sha: ${{ inputs.sha }} conda-java-build-and-tests: needs: [conda-cpp-build, changed-files] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_java || fromJSON(needs.changed-files.outputs.changed_file_groups).test_cpp # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. @@ -231,14 +231,14 @@ jobs: build_type: pull-request node_type: "gpu-l4-latest-1" arch: "amd64" - container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + container_image: "rapidsai/ci-conda:26.02-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" script: "ci/test_java.sh" artifact-name: "cuvs-java-cuda${{ matrix.cuda_version }}" file_to_upload: "java/cuvs-java/target/" rust-build: needs: [conda-cpp-build, changed-files] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_rust || fromJSON(needs.changed-files.outputs.changed_file_groups).test_cpp # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. @@ -252,12 +252,12 @@ jobs: build_type: pull-request node_type: "gpu-l4-latest-1" arch: "amd64" - container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + container_image: "rapidsai/ci-conda:26.02-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" script: "ci/build_rust.sh" go-build: needs: [conda-cpp-build, changed-files] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_go || fromJSON(needs.changed-files.outputs.changed_file_groups).test_cpp # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. @@ -271,22 +271,22 @@ jobs: build_type: pull-request node_type: "gpu-l4-latest-1" arch: "amd64" - container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + container_image: "rapidsai/ci-conda:26.02-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" script: "ci/build_go.sh" docs-build: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main with: build_type: pull-request node_type: "gpu-l4-latest-1" arch: "amd64" - container_image: "rapidsai/ci-conda:25.12-latest" + container_image: "rapidsai/ci-conda:26.02-latest" script: "ci/build_docs.sh" wheel-build-libcuvs: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@main with: build_type: pull-request script: ci/build_wheel_libcuvs.sh @@ -297,7 +297,7 @@ jobs: wheel-build-cuvs: needs: wheel-build-libcuvs secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@main with: build_type: pull-request script: ci/build_wheel_cuvs.sh @@ -306,7 +306,7 @@ jobs: wheel-tests-cuvs: needs: [wheel-build-cuvs, changed-files] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@main if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_python with: build_type: pull-request @@ -314,7 +314,7 @@ jobs: devcontainer: secrets: inherit needs: telemetry-setup - uses: rapidsai/shared-workflows/.github/workflows/build-in-devcontainer.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/build-in-devcontainer.yaml@main with: arch: '["amd64", "arm64"]' cuda: '["13.0"]' diff --git a/.github/workflows/publish-rust.yaml b/.github/workflows/publish-rust.yaml index aa9438e55e..3b7fc41a3b 100644 --- a/.github/workflows/publish-rust.yaml +++ b/.github/workflows/publish-rust.yaml @@ -16,7 +16,7 @@ jobs: cuda_version: - '12.9.1' container: - image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + image: "rapidsai/ci-conda:26.02-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" steps: - uses: actions/checkout@v4 - name: Check if release build diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 1af29bbc8c..77648919c7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -25,7 +25,7 @@ on: jobs: conda-cpp-checks: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-post-build-checks.yaml@main with: build_type: ${{ inputs.build_type }} branch: ${{ inputs.branch }} @@ -34,7 +34,7 @@ jobs: symbol_exclusions: (void (thrust::|cub::)) conda-cpp-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@main with: build_type: ${{ inputs.build_type }} branch: ${{ inputs.branch }} @@ -43,7 +43,7 @@ jobs: sha: ${{ inputs.sha }} conda-python-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@main with: build_type: ${{ inputs.build_type }} branch: ${{ inputs.branch }} @@ -52,7 +52,7 @@ jobs: sha: ${{ inputs.sha }} conda-java-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@main # Artifacts are not published from these jobs, so it's safe to run for multiple CUDA versions. # If these jobs start producing artifacts, the names will have to differentiate between CUDA versions. strategy: @@ -68,11 +68,11 @@ jobs: sha: ${{ inputs.sha }} node_type: "gpu-l4-latest-1" arch: "amd64" - container_image: "rapidsai/ci-conda:25.12-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" + container_image: "rapidsai/ci-conda:26.02-cuda${{ matrix.cuda_version }}-ubuntu24.04-py3.13" script: "ci/test_java.sh" wheel-tests-cuvs: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@main with: build_type: ${{ inputs.build_type }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/trigger-breaking-change-alert.yaml b/.github/workflows/trigger-breaking-change-alert.yaml index 0b885544da..c471e2a151 100644 --- a/.github/workflows/trigger-breaking-change-alert.yaml +++ b/.github/workflows/trigger-breaking-change-alert.yaml @@ -12,7 +12,7 @@ jobs: trigger-notifier: if: contains(github.event.pull_request.labels.*.name, 'breaking') secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/breaking-change-alert.yaml@release/25.12 + uses: rapidsai/shared-workflows/.github/workflows/breaking-change-alert.yaml@main with: sender_login: ${{ github.event.sender.login }} sender_avatar: ${{ github.event.sender.avatar_url }} diff --git a/RAPIDS_BRANCH b/RAPIDS_BRANCH index 26b84372d3..ba2906d066 100644 --- a/RAPIDS_BRANCH +++ b/RAPIDS_BRANCH @@ -1 +1 @@ -release/25.12 +main diff --git a/README.md b/README.md index 1ad66d9c7d..5da834f4c7 100755 --- a/README.md +++ b/README.md @@ -108,10 +108,10 @@ If installing a version that has not yet been released, the `rapidsai` channel c ```bash # CUDA 13 -conda install -c rapidsai-nightly -c conda-forge cuvs=25.12 cuda-version=13.0 +conda install -c rapidsai-nightly -c conda-forge cuvs=26.02 cuda-version=13.0 # CUDA 12 -conda install -c rapidsai-nightly -c conda-forge cuvs=25.12 cuda-version=12.9 +conda install -c rapidsai-nightly -c conda-forge cuvs=26.02 cuda-version=12.9 ``` cuVS also has `pip` wheel packages that can be installed. Please see the [Build and Install Guide](https://docs.rapids.ai/api/cuvs/nightly/build/) for more information on installing the available cuVS packages and building from source. @@ -171,7 +171,7 @@ cuvsCagraIndexParamsDestroy(index_params); cuvsResourcesDestroy(res); ``` -For more code examples of the C APIs, including drop-in Cmake project templates, please refer to the [C examples](https://github.com/rapidsai/cuvs/tree/release/25.12/examples/c) +For more code examples of the C APIs, including drop-in Cmake project templates, please refer to the [C examples](https://github.com/rapidsai/cuvs/tree/main/examples/c) ### Rust API @@ -234,7 +234,7 @@ fn cagra_example() -> Result<()> { } ``` -For more code examples of the Rust APIs, including a drop-in project templates, please refer to the [Rust examples](https://github.com/rapidsai/cuvs/tree/release/25.12/examples/rust). +For more code examples of the Rust APIs, including a drop-in project templates, please refer to the [Rust examples](https://github.com/rapidsai/cuvs/tree/main/examples/rust). ## Contributing diff --git a/VERSION b/VERSION index 7924af6192..5c33046aca 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -25.12.00 +26.02.00 diff --git a/conda/environments/all_cuda-129_arch-aarch64.yaml b/conda/environments/all_cuda-129_arch-aarch64.yaml index 9812a26a5d..f5aea13fd0 100644 --- a/conda/environments/all_cuda-129_arch-aarch64.yaml +++ b/conda/environments/all_cuda-129_arch-aarch64.yaml @@ -31,7 +31,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- librmm==25.12.*,>=0.0.0a0 +- librmm==26.2.*,>=0.0.0a0 - make - nccl>=2.19 - ninja @@ -39,7 +39,7 @@ dependencies: - numpydoc - openblas - pre-commit -- pylibraft==25.12.*,>=0.0.0a0 +- pylibraft==26.2.*,>=0.0.0a0 - pytest-cov - pytest<9.0.0a0 - rapids-build-backend>=0.4.0,<0.5.0.dev0 diff --git a/conda/environments/all_cuda-129_arch-x86_64.yaml b/conda/environments/all_cuda-129_arch-x86_64.yaml index 896c08e0e2..65e80d0bc4 100644 --- a/conda/environments/all_cuda-129_arch-x86_64.yaml +++ b/conda/environments/all_cuda-129_arch-x86_64.yaml @@ -31,7 +31,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- librmm==25.12.*,>=0.0.0a0 +- librmm==26.2.*,>=0.0.0a0 - make - nccl>=2.19 - ninja @@ -39,7 +39,7 @@ dependencies: - numpydoc - openblas - pre-commit -- pylibraft==25.12.*,>=0.0.0a0 +- pylibraft==26.2.*,>=0.0.0a0 - pytest-cov - pytest<9.0.0a0 - rapids-build-backend>=0.4.0,<0.5.0.dev0 diff --git a/conda/environments/all_cuda-130_arch-aarch64.yaml b/conda/environments/all_cuda-130_arch-aarch64.yaml index c9f180e849..da97ddd586 100644 --- a/conda/environments/all_cuda-130_arch-aarch64.yaml +++ b/conda/environments/all_cuda-130_arch-aarch64.yaml @@ -31,7 +31,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- librmm==25.12.*,>=0.0.0a0 +- librmm==26.2.*,>=0.0.0a0 - make - nccl>=2.19 - ninja @@ -39,7 +39,7 @@ dependencies: - numpydoc - openblas - pre-commit -- pylibraft==25.12.*,>=0.0.0a0 +- pylibraft==26.2.*,>=0.0.0a0 - pytest-cov - pytest<9.0.0a0 - rapids-build-backend>=0.4.0,<0.5.0.dev0 diff --git a/conda/environments/all_cuda-130_arch-x86_64.yaml b/conda/environments/all_cuda-130_arch-x86_64.yaml index a464e15db4..cec768aa29 100644 --- a/conda/environments/all_cuda-130_arch-x86_64.yaml +++ b/conda/environments/all_cuda-130_arch-x86_64.yaml @@ -31,7 +31,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- librmm==25.12.*,>=0.0.0a0 +- librmm==26.2.*,>=0.0.0a0 - make - nccl>=2.19 - ninja @@ -39,7 +39,7 @@ dependencies: - numpydoc - openblas - pre-commit -- pylibraft==25.12.*,>=0.0.0a0 +- pylibraft==26.2.*,>=0.0.0a0 - pytest-cov - pytest<9.0.0a0 - rapids-build-backend>=0.4.0,<0.5.0.dev0 diff --git a/conda/environments/bench_ann_cuda-129_arch-aarch64.yaml b/conda/environments/bench_ann_cuda-129_arch-aarch64.yaml index dbe568b842..cf78abc107 100644 --- a/conda/environments/bench_ann_cuda-129_arch-aarch64.yaml +++ b/conda/environments/bench_ann_cuda-129_arch-aarch64.yaml @@ -17,7 +17,7 @@ dependencies: - cuda-python>=12.9.2,<13.0a0 - cuda-version=12.9 - cupy>=13.6.0 -- cuvs==25.12.*,>=0.0.0a0 +- cuvs==26.2.*,>=0.0.0a0 - cxx-compiler - cython>=3.0.0,<3.2.0a0 - dlpack>=0.8,<1.0 @@ -29,15 +29,15 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- librmm==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- librmm==26.2.*,>=0.0.0a0 - matplotlib-base>=3.9 - nccl>=2.19 - ninja - nlohmann_json>=3.12.0 - openblas - pandas -- pylibraft==25.12.*,>=0.0.0a0 +- pylibraft==26.2.*,>=0.0.0a0 - pyyaml - rapids-build-backend>=0.4.0,<0.5.0.dev0 - requests diff --git a/conda/environments/bench_ann_cuda-129_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-129_arch-x86_64.yaml index b14735c696..45219e4ba6 100644 --- a/conda/environments/bench_ann_cuda-129_arch-x86_64.yaml +++ b/conda/environments/bench_ann_cuda-129_arch-x86_64.yaml @@ -17,7 +17,7 @@ dependencies: - cuda-python>=12.9.2,<13.0a0 - cuda-version=12.9 - cupy>=13.6.0 -- cuvs==25.12.*,>=0.0.0a0 +- cuvs==26.2.*,>=0.0.0a0 - cxx-compiler - cython>=3.0.0,<3.2.0a0 - dlpack>=0.8,<1.0 @@ -31,8 +31,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- librmm==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- librmm==26.2.*,>=0.0.0a0 - matplotlib-base>=3.9 - mkl-devel=2023 - nccl>=2.19 @@ -40,7 +40,7 @@ dependencies: - nlohmann_json>=3.12.0 - openblas - pandas -- pylibraft==25.12.*,>=0.0.0a0 +- pylibraft==26.2.*,>=0.0.0a0 - pyyaml - rapids-build-backend>=0.4.0,<0.5.0.dev0 - requests diff --git a/conda/environments/bench_ann_cuda-130_arch-aarch64.yaml b/conda/environments/bench_ann_cuda-130_arch-aarch64.yaml index 6c90edabea..417ab87b88 100644 --- a/conda/environments/bench_ann_cuda-130_arch-aarch64.yaml +++ b/conda/environments/bench_ann_cuda-130_arch-aarch64.yaml @@ -17,7 +17,7 @@ dependencies: - cuda-python>=13.0.1,<14.0a0 - cuda-version=13.0 - cupy>=13.6.0 -- cuvs==25.12.*,>=0.0.0a0 +- cuvs==26.2.*,>=0.0.0a0 - cxx-compiler - cython>=3.0.0,<3.2.0a0 - dlpack>=0.8,<1.0 @@ -29,15 +29,15 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- librmm==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- librmm==26.2.*,>=0.0.0a0 - matplotlib-base>=3.9 - nccl>=2.19 - ninja - nlohmann_json>=3.12.0 - openblas - pandas -- pylibraft==25.12.*,>=0.0.0a0 +- pylibraft==26.2.*,>=0.0.0a0 - pyyaml - rapids-build-backend>=0.4.0,<0.5.0.dev0 - requests diff --git a/conda/environments/bench_ann_cuda-130_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-130_arch-x86_64.yaml index e22a6900ba..30d4e2e7ca 100644 --- a/conda/environments/bench_ann_cuda-130_arch-x86_64.yaml +++ b/conda/environments/bench_ann_cuda-130_arch-x86_64.yaml @@ -17,7 +17,7 @@ dependencies: - cuda-python>=13.0.1,<14.0a0 - cuda-version=13.0 - cupy>=13.6.0 -- cuvs==25.12.*,>=0.0.0a0 +- cuvs==26.2.*,>=0.0.0a0 - cxx-compiler - cython>=3.0.0,<3.2.0a0 - dlpack>=0.8,<1.0 @@ -31,8 +31,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- librmm==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- librmm==26.2.*,>=0.0.0a0 - matplotlib-base>=3.9 - mkl-devel=2023 - nccl>=2.19 @@ -40,7 +40,7 @@ dependencies: - nlohmann_json>=3.12.0 - openblas - pandas -- pylibraft==25.12.*,>=0.0.0a0 +- pylibraft==26.2.*,>=0.0.0a0 - pyyaml - rapids-build-backend>=0.4.0,<0.5.0.dev0 - requests diff --git a/conda/environments/go_cuda-129_arch-aarch64.yaml b/conda/environments/go_cuda-129_arch-aarch64.yaml index b8bf557877..9ce9093e21 100644 --- a/conda/environments/go_cuda-129_arch-aarch64.yaml +++ b/conda/environments/go_cuda-129_arch-aarch64.yaml @@ -24,8 +24,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- libraft==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- libraft==26.2.*,>=0.0.0a0 - nccl>=2.19 - ninja - sysroot_linux-aarch64==2.28 diff --git a/conda/environments/go_cuda-129_arch-x86_64.yaml b/conda/environments/go_cuda-129_arch-x86_64.yaml index adc12d644b..4243077552 100644 --- a/conda/environments/go_cuda-129_arch-x86_64.yaml +++ b/conda/environments/go_cuda-129_arch-x86_64.yaml @@ -24,8 +24,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- libraft==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- libraft==26.2.*,>=0.0.0a0 - nccl>=2.19 - ninja - sysroot_linux-64==2.28 diff --git a/conda/environments/go_cuda-130_arch-aarch64.yaml b/conda/environments/go_cuda-130_arch-aarch64.yaml index ca450a317c..962d5f1079 100644 --- a/conda/environments/go_cuda-130_arch-aarch64.yaml +++ b/conda/environments/go_cuda-130_arch-aarch64.yaml @@ -24,8 +24,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- libraft==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- libraft==26.2.*,>=0.0.0a0 - nccl>=2.19 - ninja - sysroot_linux-aarch64==2.28 diff --git a/conda/environments/go_cuda-130_arch-x86_64.yaml b/conda/environments/go_cuda-130_arch-x86_64.yaml index 5873836633..ca8dc8a88a 100644 --- a/conda/environments/go_cuda-130_arch-x86_64.yaml +++ b/conda/environments/go_cuda-130_arch-x86_64.yaml @@ -24,8 +24,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- libraft==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- libraft==26.2.*,>=0.0.0a0 - nccl>=2.19 - ninja - sysroot_linux-64==2.28 diff --git a/conda/environments/rust_cuda-129_arch-aarch64.yaml b/conda/environments/rust_cuda-129_arch-aarch64.yaml index 28d7701d68..8da31cefbf 100644 --- a/conda/environments/rust_cuda-129_arch-aarch64.yaml +++ b/conda/environments/rust_cuda-129_arch-aarch64.yaml @@ -21,8 +21,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- libraft==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- libraft==26.2.*,>=0.0.0a0 - make - nccl>=2.19 - ninja diff --git a/conda/environments/rust_cuda-129_arch-x86_64.yaml b/conda/environments/rust_cuda-129_arch-x86_64.yaml index a21932185b..3cbf7fad6a 100644 --- a/conda/environments/rust_cuda-129_arch-x86_64.yaml +++ b/conda/environments/rust_cuda-129_arch-x86_64.yaml @@ -21,8 +21,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- libraft==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- libraft==26.2.*,>=0.0.0a0 - make - nccl>=2.19 - ninja diff --git a/conda/environments/rust_cuda-130_arch-aarch64.yaml b/conda/environments/rust_cuda-130_arch-aarch64.yaml index 7533f45e23..c71dff5bba 100644 --- a/conda/environments/rust_cuda-130_arch-aarch64.yaml +++ b/conda/environments/rust_cuda-130_arch-aarch64.yaml @@ -21,8 +21,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- libraft==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- libraft==26.2.*,>=0.0.0a0 - make - nccl>=2.19 - ninja diff --git a/conda/environments/rust_cuda-130_arch-x86_64.yaml b/conda/environments/rust_cuda-130_arch-x86_64.yaml index 0b4dbd7b09..a229c27795 100644 --- a/conda/environments/rust_cuda-130_arch-x86_64.yaml +++ b/conda/environments/rust_cuda-130_arch-x86_64.yaml @@ -21,8 +21,8 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev -- libcuvs==25.12.*,>=0.0.0a0 -- libraft==25.12.*,>=0.0.0a0 +- libcuvs==26.2.*,>=0.0.0a0 +- libraft==26.2.*,>=0.0.0a0 - make - nccl>=2.19 - ninja diff --git a/dependencies.yaml b/dependencies.yaml index b66e9d8691..6ef7dfd768 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -470,7 +470,7 @@ dependencies: - output_types: [conda, pyproject, requirements] packages: - click - - cuvs==25.12.*,>=0.0.0a0 + - cuvs==26.2.*,>=0.0.0a0 - pandas - pyyaml - requests @@ -497,17 +497,17 @@ dependencies: common: - output_types: conda packages: - - cuvs==25.12.*,>=0.0.0a0 + - cuvs==26.2.*,>=0.0.0a0 depends_on_cuvs_bench: common: - output_types: conda packages: - - cuvs-bench==25.12.*,>=0.0.0a0 + - cuvs-bench==26.2.*,>=0.0.0a0 depends_on_libcuvs: common: - output_types: conda packages: - - &libcuvs_unsuffixed libcuvs==25.12.*,>=0.0.0a0 + - &libcuvs_unsuffixed libcuvs==26.2.*,>=0.0.0a0 - output_types: requirements packages: # pip recognizes the index as a global option for the requirements.txt file @@ -520,23 +520,23 @@ dependencies: cuda: "12.*" cuda_suffixed: "true" packages: - - libcuvs-cu12==25.12.*,>=0.0.0a0 + - libcuvs-cu12==26.2.*,>=0.0.0a0 - matrix: cuda: "13.*" cuda_suffixed: "true" packages: - - libcuvs-cu13==25.12.*,>=0.0.0a0 + - libcuvs-cu13==26.2.*,>=0.0.0a0 - {matrix: null, packages: [*libcuvs_unsuffixed]} depends_on_libcuvs_tests: common: - output_types: conda packages: - - libcuvs-tests==25.12.*,>=0.0.0a0 + - libcuvs-tests==26.2.*,>=0.0.0a0 depends_on_libraft: common: - output_types: conda packages: - - &libraft_unsuffixed libraft==25.12.*,>=0.0.0a0 + - &libraft_unsuffixed libraft==26.2.*,>=0.0.0a0 - output_types: requirements packages: # pip recognizes the index as a global option for the requirements.txt file @@ -549,18 +549,18 @@ dependencies: cuda: "12.*" cuda_suffixed: "true" packages: - - libraft-cu12==25.12.*,>=0.0.0a0 + - libraft-cu12==26.2.*,>=0.0.0a0 - matrix: cuda: "13.*" cuda_suffixed: "true" packages: - - libraft-cu13==25.12.*,>=0.0.0a0 + - libraft-cu13==26.2.*,>=0.0.0a0 - {matrix: null, packages: [*libraft_unsuffixed]} depends_on_librmm: common: - output_types: conda packages: - - &librmm_unsuffixed librmm==25.12.*,>=0.0.0a0 + - &librmm_unsuffixed librmm==26.2.*,>=0.0.0a0 - output_types: requirements packages: # pip recognizes the index as a global option for the requirements.txt file @@ -573,18 +573,18 @@ dependencies: cuda: "12.*" cuda_suffixed: "true" packages: - - librmm-cu12==25.12.*,>=0.0.0a0 + - librmm-cu12==26.2.*,>=0.0.0a0 - matrix: cuda: "13.*" cuda_suffixed: "true" packages: - - librmm-cu13==25.12.*,>=0.0.0a0 + - librmm-cu13==26.2.*,>=0.0.0a0 - {matrix: null, packages: [*librmm_unsuffixed]} depends_on_pylibraft: common: - output_types: conda packages: - - &pylibraft_unsuffixed pylibraft==25.12.*,>=0.0.0a0 + - &pylibraft_unsuffixed pylibraft==26.2.*,>=0.0.0a0 - output_types: requirements packages: # pip recognizes the index as a global option for the requirements.txt file @@ -597,12 +597,12 @@ dependencies: cuda: "12.*" cuda_suffixed: "true" packages: - - pylibraft-cu12==25.12.*,>=0.0.0a0 + - pylibraft-cu12==26.2.*,>=0.0.0a0 - matrix: cuda: "13.*" cuda_suffixed: "true" packages: - - pylibraft-cu13==25.12.*,>=0.0.0a0 + - pylibraft-cu13==26.2.*,>=0.0.0a0 - {matrix: null, packages: [*pylibraft_unsuffixed]} depends_on_nccl: common: diff --git a/docs/source/cuvs_bench/index.rst b/docs/source/cuvs_bench/index.rst index 16914ac596..cc5f2731c6 100644 --- a/docs/source/cuvs_bench/index.rst +++ b/docs/source/cuvs_bench/index.rst @@ -89,7 +89,7 @@ The following command pulls the nightly container for Python version 3.10, CUDA .. code-block:: bash - docker pull rapidsai/cuvs-bench:25.12a-cuda12.5-py3.10 # substitute cuvs-bench for the exact desired container. + docker pull rapidsai/cuvs-bench:26.02a-cuda12.5-py3.10 # substitute cuvs-bench for the exact desired container. The CUDA and python versions can be changed for the supported values: - Supported CUDA versions: 12 @@ -237,7 +237,7 @@ For GPU-enabled systems, the `DATA_FOLDER` variable should be a local folder whe export DATA_FOLDER=path/to/store/datasets/and/results docker run --gpus all --rm -it -u $(id -u) \ -v $DATA_FOLDER:/data/benchmarks \ - rapidsai/cuvs-bench:25.12-cuda12.9-py3.13 \ + rapidsai/cuvs-bench:26.02-cuda12.9-py3.13 \ "--dataset deep-image-96-angular" \ "--normalize" \ "--algorithms cuvs_cagra,cuvs_ivf_pq --batch-size 10 -k 10" \ @@ -250,7 +250,7 @@ Usage of the above command is as follows: * - Argument - Description - * - `rapidsai/cuvs-bench:25.12-cuda12.9-py3.13` + * - `rapidsai/cuvs-bench:26.02-cuda12.9-py3.13` - Image to use. Can be either `cuvs-bench` or `cuvs-bench-datasets` * - `"--dataset deep-image-96-angular"` @@ -297,7 +297,7 @@ All of the `cuvs-bench` images contain the Conda packages, so they can be used d --entrypoint /bin/bash \ --workdir /data/benchmarks \ -v $DATA_FOLDER:/data/benchmarks \ - rapidsai/cuvs-bench:25.12-cuda12.9-py3.13 + rapidsai/cuvs-bench:26.02-cuda12.9-py3.13 This will drop you into a command line in the container, with the `cuvs-bench` python package ready to use, as described in the [Running the benchmarks](#running-the-benchmarks) section above: diff --git a/docs/source/developer_guide.md b/docs/source/developer_guide.md index e4081842d2..da50a44d27 100644 --- a/docs/source/developer_guide.md +++ b/docs/source/developer_guide.md @@ -187,7 +187,7 @@ RAFT relies on `clang-format` to enforce code style across all C++ and CUDA sour 1. Do not split empty functions/records/namespaces. 2. Two-space indentation everywhere, including the line continuations. 3. Disable reflowing of comments. - The reasons behind these deviations from the Google style guide are given in comments [here](https://github.com/rapidsai/cuvs/blob/release/25.12/cpp/.clang-format). + The reasons behind these deviations from the Google style guide are given in comments [here](https://github.com/rapidsai/cuvs/blob/main/cpp/.clang-format). [`doxygen`](https://doxygen.nl/) is used as documentation generator and also as a documentation linter. In order to run doxygen as a linter on C++/CUDA code, run @@ -205,7 +205,7 @@ you can run `codespell -i 3 -w .` from the repository root directory. This will bring up an interactive prompt to select which spelling fixes to apply. ### #include style -[include_checker.py](https://github.com/rapidsai/cuvs/blob/release/25.12/cpp/scripts/include_checker.py) is used to enforce the include style as follows: +[include_checker.py](https://github.com/rapidsai/cuvs/blob/main/cpp/scripts/include_checker.py) is used to enforce the include style as follows: 1. `#include "..."` should be used for referencing local files only. It is acceptable to be used for referencing files in a sub-folder/parent-folder of the same algorithm, but should never be used to include files in other algorithms or between algorithms and the primitives or other dependencies. 2. `#include <...>` should be used for referencing everything else diff --git a/examples/go/README.md b/examples/go/README.md index f49020de62..2588ae19ce 100644 --- a/examples/go/README.md +++ b/examples/go/README.md @@ -24,7 +24,7 @@ export CC=clang 2. Install the Go module: ```bash -go get github.com/rapidsai/cuvs/go@v25.12.00 # 25.02.00 being your desired version, selected from https://github.com/rapidsai/cuvs/tags +go get github.com/rapidsai/cuvs/go@v26.02.00 # 25.02.00 being your desired version, selected from https://github.com/rapidsai/cuvs/tags ``` Then you can build your project with the usual `go build`. diff --git a/java/benchmarks/pom.xml b/java/benchmarks/pom.xml index 45588933c5..52cf0130e0 100644 --- a/java/benchmarks/pom.xml +++ b/java/benchmarks/pom.xml @@ -10,7 +10,7 @@ com.nvidia.cuvs benchmarks - 25.12.0 + 26.02.0 jar cuvs-java-benchmarks @@ -30,7 +30,7 @@ com.nvidia.cuvs cuvs-java - 25.12.0 + 26.02.0 jar diff --git a/java/build.sh b/java/build.sh index d40e97adef..339857bfe8 100755 --- a/java/build.sh +++ b/java/build.sh @@ -8,7 +8,7 @@ set -e -u -o pipefail ARGS="$*" NUMARGS=$# -VERSION="25.12.0" # Note: The version is updated automatically when ci/release/update-version.sh is invoked +VERSION="26.02.0" # Note: The version is updated automatically when ci/release/update-version.sh is invoked GROUP_ID="com.nvidia.cuvs" # Identify CUDA major version. diff --git a/java/cuvs-java/pom.xml b/java/cuvs-java/pom.xml index 99d0eb5e09..d0eb079fe9 100644 --- a/java/cuvs-java/pom.xml +++ b/java/cuvs-java/pom.xml @@ -11,7 +11,7 @@ com.nvidia.cuvs cuvs-java - 25.12.0 + 26.02.0 cuvs-java This project provides Java bindings for cuVS, enabling approximate nearest neighbors search and clustering diff --git a/java/examples/README.md b/java/examples/README.md index 9a48ad6ea1..58f7acdbdb 100644 --- a/java/examples/README.md +++ b/java/examples/README.md @@ -11,17 +11,17 @@ This maven project contains examples for CAGRA, HNSW, and Bruteforce algorithms. ### CAGRA Example In the current directory do: ``` -mvn package && java --enable-native-access=ALL-UNNAMED -cp target/cuvs-java-examples-25.12.0.jar:$HOME/.m2/repository/com/nvidia/cuvs/cuvs-java/25.12.0/cuvs-java-25.12.0.jar com.nvidia.cuvs.examples.CagraExample +mvn package && java --enable-native-access=ALL-UNNAMED -cp target/cuvs-java-examples-26.02.0.jar:$HOME/.m2/repository/com/nvidia/cuvs/cuvs-java/26.02.0/cuvs-java-26.02.0.jar com.nvidia.cuvs.examples.CagraExample ``` ### HNSW Example In the current directory do: ``` -mvn package && java --enable-native-access=ALL-UNNAMED -cp target/cuvs-java-examples-25.12.0.jar:$HOME/.m2/repository/com/nvidia/cuvs/cuvs-java/25.12.0/cuvs-java-25.12.0.jar com.nvidia.cuvs.examples.HnswExample +mvn package && java --enable-native-access=ALL-UNNAMED -cp target/cuvs-java-examples-26.02.0.jar:$HOME/.m2/repository/com/nvidia/cuvs/cuvs-java/26.02.0/cuvs-java-26.02.0.jar com.nvidia.cuvs.examples.HnswExample ``` ### Bruteforce Example In the current directory do: ``` -mvn package && java --enable-native-access=ALL-UNNAMED -cp target/cuvs-java-examples-25.12.0.jar:$HOME/.m2/repository/com/nvidia/cuvs/cuvs-java/25.12.0/cuvs-java-25.12.0.jar com.nvidia.cuvs.examples.BruteForceExample +mvn package && java --enable-native-access=ALL-UNNAMED -cp target/cuvs-java-examples-26.02.0.jar:$HOME/.m2/repository/com/nvidia/cuvs/cuvs-java/26.02.0/cuvs-java-26.02.0.jar com.nvidia.cuvs.examples.BruteForceExample ``` diff --git a/java/examples/pom.xml b/java/examples/pom.xml index a61412aff8..16b1b6ede6 100644 --- a/java/examples/pom.xml +++ b/java/examples/pom.xml @@ -10,7 +10,7 @@ SPDX-License-Identifier: Apache-2.0 com.nvidia.cuvs.examples cuvs-java-examples - 25.12.0 + 26.02.0 cuvs-java-examples @@ -23,7 +23,7 @@ SPDX-License-Identifier: Apache-2.0 com.nvidia.cuvs cuvs-java - 25.12.0 + 26.02.0 diff --git a/python/cuvs/pyproject.toml b/python/cuvs/pyproject.toml index 3d0ebe2cd8..38ee2b6f12 100644 --- a/python/cuvs/pyproject.toml +++ b/python/cuvs/pyproject.toml @@ -21,9 +21,9 @@ license = { text = "Apache-2.0" } requires-python = ">=3.10" dependencies = [ "cuda-python>=13.0.1,<14.0a0", - "libcuvs==25.12.*,>=0.0.0a0", + "libcuvs==26.2.*,>=0.0.0a0", "numpy>=1.23,<3.0a0", - "pylibraft==25.12.*,>=0.0.0a0", + "pylibraft==26.2.*,>=0.0.0a0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. classifiers = [ "Intended Audience :: Developers", @@ -108,9 +108,9 @@ requires = [ "cmake>=3.30.4", "cuda-python>=13.0.1,<14.0a0", "cython>=3.0.0,<3.2.0a0", - "libcuvs==25.12.*,>=0.0.0a0", - "libraft==25.12.*,>=0.0.0a0", - "librmm==25.12.*,>=0.0.0a0", + "libcuvs==26.2.*,>=0.0.0a0", + "libraft==26.2.*,>=0.0.0a0", + "librmm==26.2.*,>=0.0.0a0", "ninja", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. build-backend = "scikit_build_core.build" diff --git a/python/cuvs_bench/cuvs_bench/plot/__main__.py b/python/cuvs_bench/cuvs_bench/plot/__main__.py index ddf687d38b..aca08505ea 100644 --- a/python/cuvs_bench/cuvs_bench/plot/__main__.py +++ b/python/cuvs_bench/cuvs_bench/plot/__main__.py @@ -3,10 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 # This script is inspired by -# 1: https://github.com/erikbern/ann-benchmarks/blob/release/25.12/plot.py -# 2: https://github.com/erikbern/ann-benchmarks/blob/release/25.12/ann_benchmarks/plotting/utils.py # noqa: E501 -# 3: https://github.com/erikbern/ann-benchmarks/blob/release/25.12/ann_benchmarks/plotting/metrics.py # noqa: E501 -# License: https://github.com/rapidsai/cuvs/blob/release/25.12/thirdparty/LICENSES/LICENSE.ann-benchmark # noqa: E501 +# 1: https://github.com/erikbern/ann-benchmarks/blob/main/plot.py +# 2: https://github.com/erikbern/ann-benchmarks/blob/main/ann_benchmarks/plotting/utils.py # noqa: E501 +# 3: https://github.com/erikbern/ann-benchmarks/blob/main/ann_benchmarks/plotting/metrics.py # noqa: E501 +# License: https://github.com/rapidsai/cuvs/blob/main/thirdparty/LICENSES/LICENSE.ann-benchmark # noqa: E501 import itertools import os @@ -502,7 +502,7 @@ def load_all_results( is_flag=True, help="Show raw results (not just Pareto frontier) of the mode argument.", ) -def release/25.12( +def main( dataset: str, dataset_path: str, output_filepath: str, @@ -603,4 +603,4 @@ def release/25.12( if __name__ == "__main__": - release/25.12() + main() diff --git a/python/cuvs_bench/pyproject.toml b/python/cuvs_bench/pyproject.toml index dc69e8cad8..d7d8e3b891 100644 --- a/python/cuvs_bench/pyproject.toml +++ b/python/cuvs_bench/pyproject.toml @@ -20,7 +20,7 @@ license = { text = "Apache-2.0" } requires-python = ">=3.10" dependencies = [ "click", - "cuvs==25.12.*,>=0.0.0a0", + "cuvs==26.2.*,>=0.0.0a0", "matplotlib>=3.9", "pandas", "pyyaml", diff --git a/python/libcuvs/pyproject.toml b/python/libcuvs/pyproject.toml index 9690708c27..cc60040c5a 100644 --- a/python/libcuvs/pyproject.toml +++ b/python/libcuvs/pyproject.toml @@ -20,8 +20,8 @@ license = { text = "Apache-2.0" } requires-python = ">=3.10" dependencies = [ "cuda-toolkit[cublas,curand,cusolver,cusparse]>=12,<14", - "libraft==25.12.*,>=0.0.0a0", - "librmm==25.12.*,>=0.0.0a0", + "libraft==26.2.*,>=0.0.0a0", + "librmm==26.2.*,>=0.0.0a0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. classifiers = [ "Intended Audience :: Developers", @@ -79,8 +79,8 @@ regex = "(?P.*)" build-backend = "scikit_build_core.build" requires = [ "cmake>=3.30.4", - "libraft==25.12.*,>=0.0.0a0", - "librmm==25.12.*,>=0.0.0a0", + "libraft==26.2.*,>=0.0.0a0", + "librmm==26.2.*,>=0.0.0a0", "ninja", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. dependencies-file = "../../dependencies.yaml" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 3e45ac65ba..2ad456db53 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -6,7 +6,7 @@ members = [ resolver = "2" [workspace.package] -version = "25.12.0" +version = "26.2.0" edition = "2021" repository = "https://github.com/rapidsai/cuvs" homepage = "https://github.com/rapidsai/cuvs" diff --git a/rust/cuvs/Cargo.toml b/rust/cuvs/Cargo.toml index 30429f814c..62b6d51391 100644 --- a/rust/cuvs/Cargo.toml +++ b/rust/cuvs/Cargo.toml @@ -9,7 +9,7 @@ authors.workspace = true license.workspace = true [dependencies] -ffi = { package = "cuvs-sys", path = "../cuvs-sys", version = "25.12.0" } +ffi = { package = "cuvs-sys", path = "../cuvs-sys", version = "26.2.0" } ndarray = "0.15" [dev-dependencies] From 479fd1647eb14ed049832273c4481165f8b7af61 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 18 Nov 2025 08:57:02 -0500 Subject: [PATCH 14/32] fix(ci): remove unknown parameter `name` from rocky8 build job (#1554) --- .github/workflows/build.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 2f3146ac48..bf4cbed705 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -58,7 +58,6 @@ jobs: date: ${{ inputs.date }} container_image: "rapidsai/ci-wheel:26.02-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" node_type: "cpu16" - name: "${{ matrix.cuda_version }}, amd64, rockylinux8" # requires_license_builder: false script: "ci/build_standalone_c.sh" artifact-name: "libcuvs_c_${{ matrix.cuda_version }}.tar.gz" From cc8cdc3eeb23990318d76dfc66e63ab4940932f2 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Tue, 18 Nov 2025 10:47:21 -0800 Subject: [PATCH 15/32] Forward merge 25.12 into main (#1562) Admin merge as part of NBS cleanup. Replaces #1558 --------- Co-authored-by: Nate Rock Co-authored-by: Bradley Dice Co-authored-by: Paul Taylor <178183+trxcllnt@users.noreply.github.com> Co-authored-by: Gil Forsyth --- ci/release/update-version.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index 9842281ed4..49da9abe83 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -141,7 +141,8 @@ elif [[ "${RUN_CONTEXT}" == "release" ]]; then # In release context, use release branch for documentation links (word boundaries to avoid partial matches) sed_runner "/rapidsai\\/cuvs/ s|\\bmain\\b|release/${NEXT_SHORT_TAG}|g" docs/source/developer_guide.md sed_runner "s|\\bmain\\b|release/${NEXT_SHORT_TAG}|g" README.md - sed_runner "s|\\bmain\\b|release/${NEXT_SHORT_TAG}|g" python/cuvs_bench/cuvs_bench/plot/__main__.py + # Only update the GitHub URL, not the main() function + sed_runner "s|/cuvs/blob/\\bmain\\b/|/cuvs/blob/release/${NEXT_SHORT_TAG}/|g" python/cuvs_bench/cuvs_bench/plot/__main__.py fi # Update cuvs-bench Docker image references (version-only, not branch-related) From b8eba1d6bd6149520f91aea481b463fe9cd01c61 Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Wed, 19 Nov 2025 16:59:59 +0900 Subject: [PATCH 16/32] Integrate random orth transform into CAGRA-Q --- cpp/include/cuvs/neighbors/cagra.hpp | 28 +++++++++++++++ .../neighbors/detail/cagra/cagra_build.cuh | 36 +++++++++++++++++++ .../neighbors/detail/cagra/cagra_search.cuh | 32 ++++++++++++++++- .../detail/cagra/cagra_serialize.cuh | 36 +++++++++++++++++-- 4 files changed, 129 insertions(+), 3 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index fb1b1549af..6a93c26c31 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -150,6 +151,12 @@ struct index_params : cuvs::neighbors::index_params { */ bool attach_dataset_on_build = true; + /** + * Preprocess transformation configure + */ + std::variant + preprocess_params; + /** * @brief Create a CAGRA index parameters compatible with HNSW index * @@ -449,6 +456,22 @@ struct index : cuvs::neighbors::index { return std::nullopt; } + /** Get the preprocess taransformer */ + [[nodiscard]] inline auto preprocess_transformer() const noexcept -> const + std::variant>& + { + return preprocess_transformer_; + } + + /** Get the preprocess taransformer */ + [[nodiscard]] inline auto preprocess_transformer() noexcept -> std::variant< + std::monostate, + cuvs::preprocessing::linear_transform::random_orthogonal::transformer>& + { + return preprocess_transformer_; + } + // Don't allow copying the index for performance reasons (try avoiding copying data) /** \cond */ index(const index&) = delete; @@ -838,6 +861,11 @@ struct index : cuvs::neighbors::index { std::optional graph_fd_; std::optional mapping_fd_; + // Preprocess transformer + std::variant> + preprocess_transformer_; + void compute_dataset_norms_(raft::resources const& res); size_t n_rows_ = 0; size_t dim_ = 0; diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 5f7389493a..dafc59f1bf 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -2065,6 +2065,8 @@ index build( std::holds_alternative(knn_build_params) || std::holds_alternative(knn_build_params), "CosineExpanded distance is not supported for iterative CAGRA graph build."); + RAFT_EXPECTS(std::holds_alternative(params.preprocess_params), + "Dataset preprocessing is not supported in cagra::build with const dataset"); // Validate data type for BitwiseHamming metric RAFT_EXPECTS(params.metric != cuvs::distance::DistanceType::BitwiseHamming || @@ -2168,4 +2170,38 @@ index build( idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view())); return idx; } + +template , + raft::memory_type::host>> +index build( + raft::resources const& res, + const index_params& params, + raft::mdspan, raft::row_major, Accessor> dataset) +{ + std::variant> + preprocess_transformer; + if (std::holds_alternative( + params.preprocess_params)) { + const auto rand_orth_params = + std::get( + params.preprocess_params); + + auto transformer = cuvs::preprocessing::linear_transform::random_orthogonal::train( + res, rand_orth_params, raft::make_const_mdspan(dataset)); + cuvs::preprocessing::linear_transform::random_orthogonal::transform( + res, rand_orth_params, dataset, raft::make_const_mdspan(dataset)); + + preprocess_transformer.emplace(transformer); + } + index_params new_params = params; + new_params.preprocess_params = std::monostate{}; + + auto idx = build(res, params, dataset); + idx.preprocess_transformer = preprocess_transformer; + + return idx; +} } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 45328377be..76b5ad45b3 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -180,6 +180,36 @@ void search_main(raft::resources const& res, RAFT_FAIL("FP32 VPQ dataset support is coming soon"); } else if (auto* vpq_dset = dynamic_cast*>(&index.data()); vpq_dset != nullptr) { + raft::device_matrix_view queries_ = queries; + + // Preprocess if needed + auto mr = raft::resource::get_workspace_resource(res); + auto preprocessed_queries = + raft::make_device_mdarray(res, mr, raft::make_extents(0, 0)); + if (std::holds_alternative< + cuvs::preprocessing::linear_transform::random_orthogonal::transformer>( + index.preprocess_transformer())) { + constexpr bool is_supported_dtype = std::is_same_v || std::is_same_v; + RAFT_EXPECTS(is_supported_dtype, + "Only the float and half data types are supported in the random orthogonal " + "transform preprocessing"); + if constexpr (is_supported_dtype) { + const auto& rand_orth_transformer = + std::get>( + index.preprocess_transformer()); + // cuvs::preprocessing::linear_transform::random_orthogonal::params params; + // auto rand_orth_transformer = + // cuvs::preprocessing::linear_transform::random_orthogonal::train(res, params, queries); + + auto preprocessed_queries = raft::make_device_mdarray( + res, mr, raft::make_extents(queries.extent(0), queries.extent(1))); + + cuvs::preprocessing::linear_transform::random_orthogonal::transform( + res, rand_orth_transformer, queries, preprocessed_queries.view()); + queries_ = raft::make_const_mdspan(preprocessed_queries.view()); + } + } + auto desc = dataset_descriptor_init_with_cache( res, params, *vpq_dset, index.metric(), nullptr); search_main_core( @@ -188,7 +218,7 @@ void search_main(raft::resources const& res, desc, index.graph(), index.source_indices(), - queries, + queries_, neighbors, distances, sample_filter); diff --git a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh index 866415b1e4..b998bbb51b 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh @@ -65,8 +65,12 @@ void serialize(raft::resources const& res, raft::serialize_mdspan(res, os, index_.graph()); include_dataset &= (index_.data().n_rows() > 0); - bool has_source_indices = index_.source_indices().has_value(); - uint32_t content_map = 0x1u * include_dataset + 0x2u * has_source_indices; + bool has_source_indices = index_.source_indices().has_value(); + bool has_rand_orth_preprocessor = std::holds_alternative< + cuvs::preprocessing::linear_transform::random_orthogonal::transformer>( + index_.preprocess_transformer()); + uint32_t content_map = + 0x1u * include_dataset + 0x2u * has_source_indices + 0x4u * has_rand_orth_preprocessor; raft::serialize_scalar(res, os, content_map); if (include_dataset) { @@ -77,6 +81,19 @@ void serialize(raft::resources const& res, } if (has_source_indices) { raft::serialize_mdspan(res, os, index_.source_indices().value()); } + if (has_rand_orth_preprocessor) { + const auto& rand_orth_transformer = + std::get>( + index_.preprocess_transformer()); + + auto host_matrix = raft::make_host_matrix(index_.dim(), index_.dim()); + raft::copy(host_matrix.data_handle(), + rand_orth_transformer.orthogonal_matrix.data_handle(), + index_.dim() * index_.dim(), + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + raft::serialize_mdspan(res, os, host_matrix.view()); + } } template @@ -301,6 +318,21 @@ void deserialize(raft::resources const& res, std::istream& is, index* i raft::resource::sync_stream( res); // Don't let the vector out of the scope before the copy is finished } + + bool has_rand_orth_preprocessor = content_map & 0x4u; + if (has_rand_orth_preprocessor) { + auto host_matrix = raft::make_host_matrix(dim, dim); + raft::deserialize_mdspan(res, is, host_matrix.view()); + + auto device_matrix = raft::make_device_matrix(res, dim, dim); + raft::copy(device_matrix.data_handle(), + host_matrix.data_handle(), + dim * dim, + raft::resource::get_cuda_stream(res)); + index_->preprocess_transformer() = + cuvs::preprocessing::linear_transform::random_orthogonal::transformer{ + std::move(device_matrix)}; + } } template From 50a8f844d2678ddeb6b275db00bdbe09f0bd930d Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Wed, 19 Nov 2025 11:53:58 -0600 Subject: [PATCH 17/32] Update FAISS patch for RMM memory resource header migration (#1566) Updates FAISS patch for RMM memory resource header migration. xref: https://github.com/rapidsai/rmm/issues/2141 Authors: - Bradley Dice (https://github.com/bdice) Approvers: - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/cuvs/pull/1566 --- .../patches/{faiss.diff => faiss-25.12.diff} | 30 ++++++++++++++----- cpp/cmake/patches/faiss_override.json | 4 +-- examples/cpp/src/cagra_hnsw_ace_example.cu | 4 +-- 3 files changed, 27 insertions(+), 11 deletions(-) rename cpp/cmake/patches/{faiss.diff => faiss-25.12.diff} (91%) diff --git a/cpp/cmake/patches/faiss.diff b/cpp/cmake/patches/faiss-25.12.diff similarity index 91% rename from cpp/cmake/patches/faiss.diff rename to cpp/cmake/patches/faiss-25.12.diff index 185449c4e2..98edbdfb8c 100644 --- a/cpp/cmake/patches/faiss.diff +++ b/cpp/cmake/patches/faiss-25.12.diff @@ -11,17 +11,33 @@ index c82c73e7d..b9100c272 100644 auto resImpl = prov->getResources(); auto res = resImpl.get(); +diff --git a/faiss/gpu/GpuResources.h b/faiss/gpu/GpuResources.h +index c0c851a89..61d9d4dbe 100644 +--- a/faiss/gpu/GpuResources.h ++++ b/faiss/gpu/GpuResources.h +@@ -33,7 +33,7 @@ + + #if defined USE_NVIDIA_CUVS + #include +-#include ++#include + #endif + + namespace faiss { diff --git a/faiss/gpu/StandardGpuResources.cpp b/faiss/gpu/StandardGpuResources.cpp -index 649b7cb5c..765fdb3d0 100644 +index 649b7cb5c..622443044 100644 --- a/faiss/gpu/StandardGpuResources.cpp +++ b/faiss/gpu/StandardGpuResources.cpp -@@ -24,8 +24,8 @@ +@@ -23,9 +23,9 @@ + #if defined USE_NVIDIA_CUVS #include - #include -+#include - #include +-#include +-#include -#include ++#include ++#include ++#include #include #endif @@ -80,7 +96,7 @@ index 649b7cb5c..765fdb3d0 100644 auto err = cudaFree(p); FAISS_ASSERT_FMT( diff --git a/faiss/gpu/StandardGpuResources.h b/faiss/gpu/StandardGpuResources.h -index f23ca19d8..c43926fce 100644 +index f23ca19d8..3ba606606 100644 --- a/faiss/gpu/StandardGpuResources.h +++ b/faiss/gpu/StandardGpuResources.h @@ -25,7 +25,8 @@ @@ -88,7 +104,7 @@ index f23ca19d8..c43926fce 100644 #if defined USE_NVIDIA_CUVS #include -#include -+#include ++#include +#include #endif diff --git a/cpp/cmake/patches/faiss_override.json b/cpp/cmake/patches/faiss_override.json index b8358da48b..989a043343 100644 --- a/cpp/cmake/patches/faiss_override.json +++ b/cpp/cmake/patches/faiss_override.json @@ -6,8 +6,8 @@ "git_tag": "v1.12.0", "patches" : [ { - "file" : "${current_json_dir}/faiss.diff", - "issue" : "Multiple fixes for cuVS compatibility", + "file" : "${current_json_dir}/faiss-25.12.diff", + "issue" : "Multiple fixes for cuVS and RMM compatibility", "fixed_in" : "" } ] diff --git a/examples/cpp/src/cagra_hnsw_ace_example.cu b/examples/cpp/src/cagra_hnsw_ace_example.cu index b2474eeab9..8907248b1f 100644 --- a/examples/cpp/src/cagra_hnsw_ace_example.cu +++ b/examples/cpp/src/cagra_hnsw_ace_example.cu @@ -14,8 +14,8 @@ #include #include -#include -#include +#include +#include #include "common.cuh" From 930d42b7548e6d96ea0e774a1f37711201062896 Mon Sep 17 00:00:00 2001 From: Robert Maynard Date: Wed, 19 Nov 2025 17:00:46 -0500 Subject: [PATCH 18/32] Assign the c/ folder to the the c code ownder group (#1573) This ensures that people are properly assigned to code review any changes to the C API Authors: - Robert Maynard (https://github.com/robertmaynard) Approvers: - Kyle Edwards (https://github.com/KyleFromNVIDIA) URL: https://github.com/rapidsai/cuvs/pull/1573 --- .github/CODEOWNERS | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 73320903c6..72a854b3df 100755 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,7 +1,11 @@ +#c code owners +c/ @rapidsai/cuvs-c-codeowners +examples/c/ @rapidsai/cuvs-c-codeowners + #cpp code owners cpp/ @rapidsai/cuvs-cpp-codeowners examples/cpp/ @rapidsai/cuvs-cpp-codeowners -examples/c/ @rapidsai/cuvs-cpp-codeowners + #java code owners java/ @rapidsai/cuvs-java-codeowners From 2cf5fa7666d703dccbe655f8214656b0952bb69b Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Fri, 21 Nov 2025 19:41:47 -0600 Subject: [PATCH 19/32] Use strict priority in CI conda tests (#1583) This PR sets conda to use `strict` priority in CI tests. Mixing channel priority is frequently a cause of unexpected errors. Our CI jobs should always use strict priority in order to enforce that conda packages come from local channels with the artifacts built in CI, not mixing with older nightly artifacts from the `rapidsai-nightly` channel or other sources. xref: https://github.com/rapidsai/build-planning/issues/14 Authors: - Bradley Dice (https://github.com/bdice) Approvers: - https://github.com/jakirkham URL: https://github.com/rapidsai/cuvs/pull/1583 --- ci/test_cpp.sh | 3 +++ ci/test_python.sh | 3 +++ conda/environments/all_cuda-129_arch-aarch64.yaml | 2 +- conda/environments/all_cuda-129_arch-x86_64.yaml | 2 +- conda/environments/all_cuda-130_arch-aarch64.yaml | 2 +- conda/environments/all_cuda-130_arch-x86_64.yaml | 2 +- conda/environments/bench_ann_cuda-129_arch-aarch64.yaml | 2 +- conda/environments/bench_ann_cuda-129_arch-x86_64.yaml | 2 +- conda/environments/bench_ann_cuda-130_arch-aarch64.yaml | 2 +- conda/environments/bench_ann_cuda-130_arch-x86_64.yaml | 2 +- conda/environments/go_cuda-129_arch-aarch64.yaml | 2 +- conda/environments/go_cuda-129_arch-x86_64.yaml | 2 +- conda/environments/go_cuda-130_arch-aarch64.yaml | 2 +- conda/environments/go_cuda-130_arch-x86_64.yaml | 2 +- conda/environments/rust_cuda-129_arch-aarch64.yaml | 2 +- conda/environments/rust_cuda-129_arch-x86_64.yaml | 2 +- conda/environments/rust_cuda-130_arch-aarch64.yaml | 2 +- conda/environments/rust_cuda-130_arch-x86_64.yaml | 2 +- dependencies.yaml | 2 +- 19 files changed, 23 insertions(+), 17 deletions(-) diff --git a/ci/test_cpp.sh b/ci/test_cpp.sh index 83cc9e5d31..d1a1d2d5f4 100755 --- a/ci/test_cpp.sh +++ b/ci/test_cpp.sh @@ -6,6 +6,9 @@ set -euo pipefail . /opt/conda/etc/profile.d/conda.sh +rapids-logger "Configuring conda strict channel priority" +conda config --set channel_priority strict + CPP_CHANNEL=$(rapids-download-conda-from-github cpp) rapids-logger "Generate C++ testing dependencies" diff --git a/ci/test_python.sh b/ci/test_python.sh index bf1cfe43a4..a427b16862 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -6,6 +6,9 @@ set -euo pipefail . /opt/conda/etc/profile.d/conda.sh +rapids-logger "Configuring conda strict channel priority" +conda config --set channel_priority strict + rapids-logger "Downloading artifacts from previous jobs" CPP_CHANNEL=$(rapids-download-conda-from-github cpp) PYTHON_CHANNEL=$(rapids-download-conda-from-github python) diff --git a/conda/environments/all_cuda-129_arch-aarch64.yaml b/conda/environments/all_cuda-129_arch-aarch64.yaml index f5aea13fd0..adcc024ffc 100644 --- a/conda/environments/all_cuda-129_arch-aarch64.yaml +++ b/conda/environments/all_cuda-129_arch-aarch64.yaml @@ -1,8 +1,8 @@ # This file is generated by `rapids-dependency-file-generator`. # To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. channels: -- rapidsai - rapidsai-nightly +- rapidsai - conda-forge dependencies: - _go_select *=cgo diff --git a/conda/environments/all_cuda-129_arch-x86_64.yaml b/conda/environments/all_cuda-129_arch-x86_64.yaml index 65e80d0bc4..24eb7ccbbe 100644 --- a/conda/environments/all_cuda-129_arch-x86_64.yaml +++ b/conda/environments/all_cuda-129_arch-x86_64.yaml @@ -1,8 +1,8 @@ # This file is generated by `rapids-dependency-file-generator`. # To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. channels: -- rapidsai - rapidsai-nightly +- rapidsai - conda-forge dependencies: - _go_select *=cgo diff --git a/conda/environments/all_cuda-130_arch-aarch64.yaml b/conda/environments/all_cuda-130_arch-aarch64.yaml index da97ddd586..01bbaec65e 100644 --- a/conda/environments/all_cuda-130_arch-aarch64.yaml +++ b/conda/environments/all_cuda-130_arch-aarch64.yaml @@ -1,8 +1,8 @@ # This file is generated by `rapids-dependency-file-generator`. # To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. channels: -- rapidsai - rapidsai-nightly +- rapidsai - conda-forge dependencies: - _go_select *=cgo diff --git a/conda/environments/all_cuda-130_arch-x86_64.yaml b/conda/environments/all_cuda-130_arch-x86_64.yaml index cec768aa29..f42abeb5d6 100644 --- a/conda/environments/all_cuda-130_arch-x86_64.yaml +++ b/conda/environments/all_cuda-130_arch-x86_64.yaml @@ -1,8 +1,8 @@ # This file is generated by `rapids-dependency-file-generator`. # To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. channels: -- rapidsai - rapidsai-nightly +- rapidsai - conda-forge dependencies: - _go_select *=cgo diff --git a/conda/environments/bench_ann_cuda-129_arch-aarch64.yaml b/conda/environments/bench_ann_cuda-129_arch-aarch64.yaml index cf78abc107..714f1d7c1c 100644 --- a/conda/environments/bench_ann_cuda-129_arch-aarch64.yaml +++ b/conda/environments/bench_ann_cuda-129_arch-aarch64.yaml @@ -1,8 +1,8 @@ # This file is generated by `rapids-dependency-file-generator`. # To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. channels: -- rapidsai - rapidsai-nightly +- rapidsai - conda-forge dependencies: - c-compiler diff --git a/conda/environments/bench_ann_cuda-129_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-129_arch-x86_64.yaml index 45219e4ba6..f842c9eaed 100644 --- a/conda/environments/bench_ann_cuda-129_arch-x86_64.yaml +++ b/conda/environments/bench_ann_cuda-129_arch-x86_64.yaml @@ -1,8 +1,8 @@ # This file is generated by `rapids-dependency-file-generator`. # To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. channels: -- rapidsai - rapidsai-nightly +- rapidsai - conda-forge dependencies: - c-compiler diff --git a/conda/environments/bench_ann_cuda-130_arch-aarch64.yaml b/conda/environments/bench_ann_cuda-130_arch-aarch64.yaml index 417ab87b88..8fd964b088 100644 --- a/conda/environments/bench_ann_cuda-130_arch-aarch64.yaml +++ b/conda/environments/bench_ann_cuda-130_arch-aarch64.yaml @@ -1,8 +1,8 @@ # This file is generated by `rapids-dependency-file-generator`. # To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. channels: -- rapidsai - rapidsai-nightly +- rapidsai - conda-forge dependencies: - c-compiler diff --git a/conda/environments/bench_ann_cuda-130_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-130_arch-x86_64.yaml index 30d4e2e7ca..8cee4997f4 100644 --- a/conda/environments/bench_ann_cuda-130_arch-x86_64.yaml +++ b/conda/environments/bench_ann_cuda-130_arch-x86_64.yaml @@ -1,8 +1,8 @@ # This file is generated by `rapids-dependency-file-generator`. # To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. channels: -- rapidsai - rapidsai-nightly +- rapidsai - conda-forge dependencies: - c-compiler diff --git a/conda/environments/go_cuda-129_arch-aarch64.yaml b/conda/environments/go_cuda-129_arch-aarch64.yaml index 9ce9093e21..d899c09dd4 100644 --- a/conda/environments/go_cuda-129_arch-aarch64.yaml +++ b/conda/environments/go_cuda-129_arch-aarch64.yaml @@ -1,8 +1,8 @@ # This file is generated by `rapids-dependency-file-generator`. # To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. channels: -- rapidsai - rapidsai-nightly +- rapidsai - conda-forge dependencies: - _go_select *=cgo diff --git a/conda/environments/go_cuda-129_arch-x86_64.yaml b/conda/environments/go_cuda-129_arch-x86_64.yaml index 4243077552..1af5244cde 100644 --- a/conda/environments/go_cuda-129_arch-x86_64.yaml +++ b/conda/environments/go_cuda-129_arch-x86_64.yaml @@ -1,8 +1,8 @@ # This file is generated by `rapids-dependency-file-generator`. # To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. channels: -- rapidsai - rapidsai-nightly +- rapidsai - conda-forge dependencies: - _go_select *=cgo diff --git a/conda/environments/go_cuda-130_arch-aarch64.yaml b/conda/environments/go_cuda-130_arch-aarch64.yaml index 962d5f1079..1e5ccc0671 100644 --- a/conda/environments/go_cuda-130_arch-aarch64.yaml +++ b/conda/environments/go_cuda-130_arch-aarch64.yaml @@ -1,8 +1,8 @@ # This file is generated by `rapids-dependency-file-generator`. # To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. channels: -- rapidsai - rapidsai-nightly +- rapidsai - conda-forge dependencies: - _go_select *=cgo diff --git a/conda/environments/go_cuda-130_arch-x86_64.yaml b/conda/environments/go_cuda-130_arch-x86_64.yaml index ca8dc8a88a..643fea7df4 100644 --- a/conda/environments/go_cuda-130_arch-x86_64.yaml +++ b/conda/environments/go_cuda-130_arch-x86_64.yaml @@ -1,8 +1,8 @@ # This file is generated by `rapids-dependency-file-generator`. # To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. channels: -- rapidsai - rapidsai-nightly +- rapidsai - conda-forge dependencies: - _go_select *=cgo diff --git a/conda/environments/rust_cuda-129_arch-aarch64.yaml b/conda/environments/rust_cuda-129_arch-aarch64.yaml index 8da31cefbf..6669aa151b 100644 --- a/conda/environments/rust_cuda-129_arch-aarch64.yaml +++ b/conda/environments/rust_cuda-129_arch-aarch64.yaml @@ -1,8 +1,8 @@ # This file is generated by `rapids-dependency-file-generator`. # To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. channels: -- rapidsai - rapidsai-nightly +- rapidsai - conda-forge dependencies: - c-compiler diff --git a/conda/environments/rust_cuda-129_arch-x86_64.yaml b/conda/environments/rust_cuda-129_arch-x86_64.yaml index 3cbf7fad6a..a9d5f2bd53 100644 --- a/conda/environments/rust_cuda-129_arch-x86_64.yaml +++ b/conda/environments/rust_cuda-129_arch-x86_64.yaml @@ -1,8 +1,8 @@ # This file is generated by `rapids-dependency-file-generator`. # To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. channels: -- rapidsai - rapidsai-nightly +- rapidsai - conda-forge dependencies: - c-compiler diff --git a/conda/environments/rust_cuda-130_arch-aarch64.yaml b/conda/environments/rust_cuda-130_arch-aarch64.yaml index c71dff5bba..b975c685f9 100644 --- a/conda/environments/rust_cuda-130_arch-aarch64.yaml +++ b/conda/environments/rust_cuda-130_arch-aarch64.yaml @@ -1,8 +1,8 @@ # This file is generated by `rapids-dependency-file-generator`. # To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. channels: -- rapidsai - rapidsai-nightly +- rapidsai - conda-forge dependencies: - c-compiler diff --git a/conda/environments/rust_cuda-130_arch-x86_64.yaml b/conda/environments/rust_cuda-130_arch-x86_64.yaml index a229c27795..5394f45a27 100644 --- a/conda/environments/rust_cuda-130_arch-x86_64.yaml +++ b/conda/environments/rust_cuda-130_arch-x86_64.yaml @@ -1,8 +1,8 @@ # This file is generated by `rapids-dependency-file-generator`. # To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. channels: -- rapidsai - rapidsai-nightly +- rapidsai - conda-forge dependencies: - c-compiler diff --git a/dependencies.yaml b/dependencies.yaml index 6ef7dfd768..89dcb38b99 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -214,8 +214,8 @@ files: includes: - bench_python channels: - - rapidsai - rapidsai-nightly + - rapidsai - conda-forge dependencies: build: From 94b59e20cc69776c0d5409b682998e4d23063e1d Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Tue, 25 Nov 2025 17:18:21 -0600 Subject: [PATCH 20/32] Update FAISS from 1.12.0 to 1.13.0 (#1585) ## Summary - Update FAISS dependency from 1.12.0 to 1.13.0 - Remove thrust include patches already present in FAISS 1.13.0 - All other RMM API compatibility patches still apply cleanly Verified that updated patches apply cleanly to FAISS v1.13.0. Follow-up to https://github.com/rapidsai/cuvs/pull/1566. Authors: - Bradley Dice (https://github.com/bdice) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuvs/pull/1585 --- ...-25.12.diff => faiss-1.13-cuvs-25.12.diff} | 28 ------------------- cpp/cmake/patches/faiss_override.json | 6 ++-- 2 files changed, 3 insertions(+), 31 deletions(-) rename cpp/cmake/patches/{faiss-25.12.diff => faiss-1.13-cuvs-25.12.diff} (90%) diff --git a/cpp/cmake/patches/faiss-25.12.diff b/cpp/cmake/patches/faiss-1.13-cuvs-25.12.diff similarity index 90% rename from cpp/cmake/patches/faiss-25.12.diff rename to cpp/cmake/patches/faiss-1.13-cuvs-25.12.diff index 98edbdfb8c..7fabbc675f 100644 --- a/cpp/cmake/patches/faiss-25.12.diff +++ b/cpp/cmake/patches/faiss-1.13-cuvs-25.12.diff @@ -120,34 +120,6 @@ index f23ca19d8..3ba606606 100644 #endif /// Pinned memory allocation for use with this GPU -diff --git a/faiss/gpu/impl/BinaryCuvsCagra.cu b/faiss/gpu/impl/BinaryCuvsCagra.cu -index 0ca21dc5f..b331fdc8f 100644 ---- a/faiss/gpu/impl/BinaryCuvsCagra.cu -+++ b/faiss/gpu/impl/BinaryCuvsCagra.cu -@@ -32,6 +32,9 @@ - #include - #include - -+#include -+#include -+ - namespace faiss { - namespace gpu { - -diff --git a/faiss/gpu/impl/CuvsCagra.cu b/faiss/gpu/impl/CuvsCagra.cu -index 482e4d672..4246776e8 100644 ---- a/faiss/gpu/impl/CuvsCagra.cu -+++ b/faiss/gpu/impl/CuvsCagra.cu -@@ -31,6 +31,9 @@ - #include - #include - -+#include -+#include -+ - namespace faiss { - namespace gpu { - diff --git a/faiss/gpu/impl/CuvsFlatIndex.cu b/faiss/gpu/impl/CuvsFlatIndex.cu index 15cf427cf..d877e766d 100644 --- a/faiss/gpu/impl/CuvsFlatIndex.cu diff --git a/cpp/cmake/patches/faiss_override.json b/cpp/cmake/patches/faiss_override.json index 989a043343..7d2d755740 100644 --- a/cpp/cmake/patches/faiss_override.json +++ b/cpp/cmake/patches/faiss_override.json @@ -1,12 +1,12 @@ { "packages" : { "faiss" : { - "version": "1.12.0", + "version": "1.13.0", "git_url": "https://github.com/facebookresearch/faiss.git", - "git_tag": "v1.12.0", + "git_tag": "v1.13.0", "patches" : [ { - "file" : "${current_json_dir}/faiss-25.12.diff", + "file" : "${current_json_dir}/faiss-1.13-cuvs-25.12.diff", "issue" : "Multiple fixes for cuVS and RMM compatibility", "fixed_in" : "" } From 91c51b1cc43d45cc3d949830e153725de5a2c972 Mon Sep 17 00:00:00 2001 From: irina-resh-nvda Date: Wed, 26 Nov 2025 15:54:30 +0100 Subject: [PATCH 21/32] CMake check for FAISS use in benchmarks (#1591) CUVS_ANN_BENCH_USE_FAISS is now set to OFF if all relevant flags are set OFF. The status is reported in the cmake log: -- Finding or building hnswlib -- Checking for FAISS use in benchmarks... -- CUVS_ANN_BENCH_USE_FAISS is OFF closes #1590. Authors: - https://github.com/irina-resh-nvda Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/cuvs/pull/1591 --- cpp/bench/ann/CMakeLists.txt | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index 46bbd318d2..8d254c0933 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -51,7 +51,29 @@ option(CUVS_KNN_BENCH_USE_CUVS_BRUTE_FORCE "Include cuVS brute force knn in benc find_package(Threads REQUIRED) -set(CUVS_ANN_BENCH_USE_FAISS ON) +# ----- FAISS use in Benchmarks ---- +get_cmake_property(_variableNames VARIABLES) + +set(CUVS_ANN_BENCH_USE_FAISS OFF) +message(STATUS "Checking for FAISS use in benchmarks...") +foreach(_varName ${_variableNames}) + if(_varName MATCHES "CUVS_ANN_BENCH_USE_FAISS.+") + if(${_varName}) + set(CUVS_ANN_BENCH_USE_FAISS ON) + message(STATUS "${_varName} is detected as ON.") + break() + endif() + endif() +endforeach() + +if(CUVS_ANN_BENCH_USE_FAISS) + message(STATUS "CUVS_ANN_BENCH_USE_FAISS is switched ON") +else() + message(STATUS "CUVS_ANN_BENCH_USE_FAISS is switched OFF") +endif() + +# ---------------------------------- + set(CUVS_FAISS_ENABLE_GPU ON) set(CUVS_USE_FAISS_STATIC ON) From 1af437eb77fa52b09841780abd1d5ad619bc779f Mon Sep 17 00:00:00 2001 From: Robert Maynard Date: Tue, 2 Dec 2025 12:27:54 -0500 Subject: [PATCH 22/32] Add arm64 builds to the libcuvs_c rocky8 matrix (#1570) Extend the `rocky8-clib-standalone-build` to include arm64 builds Authors: - Robert Maynard (https://github.com/robertmaynard) Approvers: - James Lamb (https://github.com/jameslamb) - Kyle Edwards (https://github.com/KyleFromNVIDIA) URL: https://github.com/rapidsai/cuvs/pull/1570 --- .github/workflows/pr.yaml | 14 ++++++++++---- ci/build_standalone_c.sh | 16 +++------------- ci/test_standalone_c.sh | 28 ++++++++++++---------------- 3 files changed, 25 insertions(+), 33 deletions(-) diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 1c14b155d4..d74ab4abbe 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -184,15 +184,18 @@ jobs: cuda_version: - &latest_cuda12 '12.9.1' - &latest_cuda13 '13.0.2' + arch: + - amd64 + - arm64 with: build_type: pull-request - arch: "amd64" + arch: "${{matrix.arch}}" date: ${{ inputs.date }}_c container_image: "rapidsai/ci-wheel:26.02-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" node_type: "cpu16" # requires_license_builder: false script: "ci/build_standalone_c.sh --build-tests" - artifact-name: "libcuvs_c_${{ matrix.cuda_version }}.tar.gz" + artifact-name: "libcuvs_c_${{ matrix.cuda_version }}_${{ matrix.arch }}.tar.gz" file_to_upload: "libcuvs_c.tar.gz" sha: ${{ inputs.sha }} rocky8-clib-tests: @@ -206,13 +209,16 @@ jobs: cuda_version: - *latest_cuda12 - *latest_cuda13 + arch: + - amd64 + - arm64 with: build_type: pull-request node_type: "gpu-l4-latest-1" - arch: "amd64" + arch: "${{matrix.arch}}" date: ${{ inputs.date }}_c container_image: "rapidsai/ci-wheel:26.02-cuda${{ matrix.cuda_version }}-rockylinux8-py3.10" - script: "ci/test_standalone_c.sh" + script: "ci/test_standalone_c.sh libcuvs_c_${{ matrix.cuda_version }}_${{ matrix.arch }}.tar.gz" sha: ${{ inputs.sha }} conda-java-build-and-tests: needs: [conda-cpp-build, changed-files] diff --git a/ci/build_standalone_c.sh b/ci/build_standalone_c.sh index 88043b10ad..bee25e0d37 100755 --- a/ci/build_standalone_c.sh +++ b/ci/build_standalone_c.sh @@ -5,8 +5,6 @@ set -euo pipefail TOOLSET_VERSION=14 -CMAKE_VERSION=3.31.8 -CMAKE_ARCH=x86_64 BUILD_C_LIB_TESTS="OFF" if [[ "${1:-}" == "--build-tests" ]]; then @@ -18,20 +16,12 @@ dnf install -y \ tar \ make -# Fetch and install CMake. -if [ ! -e "/usr/local/bin/cmake" ]; then - pushd /usr/local - wget --quiet https://github.com/Kitware/CMake/releases/download/v"${CMAKE_VERSION}"/cmake-"${CMAKE_VERSION}"-linux-"${CMAKE_ARCH}".tar.gz - tar zxf cmake-"${CMAKE_VERSION}"-linux-"${CMAKE_ARCH}".tar.gz - rm cmake-"${CMAKE_VERSION}"-linux-"${CMAKE_ARCH}".tar.gz - ln -s /usr/local/cmake-"${CMAKE_VERSION}"-linux-"${CMAKE_ARCH}"/bin/cmake /usr/local/bin/cmake - popd -fi - source rapids-configure-sccache - source rapids-date-string +rapids-pip-retry install cmake +pyenv rehash + rapids-print-env rapids-logger "Begin cpp build" diff --git a/ci/test_standalone_c.sh b/ci/test_standalone_c.sh index 123f14a061..894d687bff 100755 --- a/ci/test_standalone_c.sh +++ b/ci/test_standalone_c.sh @@ -4,31 +4,27 @@ set -euo pipefail -CMAKE_VERSION=4.1.2 -CMAKE_ARCH=x86_64 - -# Fetch and install CMake. -if [ ! -e "/usr/local/bin/cmake" ]; then - pushd /usr/local - wget --quiet https://github.com/Kitware/CMake/releases/download/v"${CMAKE_VERSION}"/cmake-"${CMAKE_VERSION}"-linux-"${CMAKE_ARCH}".tar.gz - tar zxf cmake-"${CMAKE_VERSION}"-linux-"${CMAKE_ARCH}".tar.gz - rm cmake-"${CMAKE_VERSION}"-linux-"${CMAKE_ARCH}".tar.gz - ln -s /usr/local/cmake-"${CMAKE_VERSION}"-linux-"${CMAKE_ARCH}"/bin/cmake /usr/local/bin/cmake - popd -fi +rapids-pip-retry install cmake +pyenv rehash + +INSTALL_PREFIX="${PWD}/libcuvs_c_install" +mkdir -p "${INSTALL_PREFIX}" # Download the standalone C library artifact -payload_name="libcuvs_c_${RAPIDS_CUDA_VERSION}.tar.gz" +if [ -z "$1" ]; then + echo "Error: name of the standalone C library artifact is missing" + exit 1 +fi + +payload_name="$1" pkg_name="libcuvs_c.tar.gz" rapids-logger "Download ${payload_name} artifacts from previous jobs" DOWNLOAD_LOCATION=$(rapids-download-from-github "${payload_name}") # Extract the artifact to a staging directory -INSTALL_PREFIX="${PWD}/libcuvs_c_install" -mkdir -p "${INSTALL_PREFIX}" -ls -l "${DOWNLOAD_LOCATION}" tar -xf "${DOWNLOAD_LOCATION}/${pkg_name}" -C "${INSTALL_PREFIX}" + rapids-logger "Run C API tests" ls -l "${INSTALL_PREFIX}" cd "$INSTALL_PREFIX"/bin/gtests/libcuvs From 55fca1ef17db1adf62acb033f587b603736e0a4f Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Tue, 2 Dec 2025 17:48:32 -0600 Subject: [PATCH 23/32] Use strict priority in CI conda tests (#1606) This PR sets conda to use `strict` priority in CI tests. Mixing channel priority is frequently a cause of unexpected errors. Our CI jobs should always use strict priority in order to enforce that conda packages come from local channels with the artifacts built in CI, not mixing with older nightly artifacts from the `rapidsai-nightly` channel or other sources. xref: https://github.com/rapidsai/build-planning/issues/14 Authors: - Bradley Dice (https://github.com/bdice) Approvers: - James Lamb (https://github.com/jameslamb) URL: https://github.com/rapidsai/cuvs/pull/1606 --- ci/build_docs.sh | 3 +++ ci/build_go.sh | 5 ++++- ci/build_java.sh | 3 +++ ci/build_rust.sh | 5 ++++- ci/check_style.sh | 3 +++ 5 files changed, 17 insertions(+), 2 deletions(-) diff --git a/ci/build_docs.sh b/ci/build_docs.sh index a4eea8b915..f9ab38721b 100755 --- a/ci/build_docs.sh +++ b/ci/build_docs.sh @@ -11,6 +11,9 @@ PYTHON_CHANNEL=$(rapids-download-conda-from-github python) rapids-logger "Create test conda environment" . /opt/conda/etc/profile.d/conda.sh +rapids-logger "Configuring conda strict channel priority" +conda config --set channel_priority strict + RAPIDS_VERSION_MAJOR_MINOR="$(rapids-version-major-minor)" export RAPIDS_VERSION_MAJOR_MINOR diff --git a/ci/build_go.sh b/ci/build_go.sh index 5e9cac68bc..925dfb9153 100755 --- a/ci/build_go.sh +++ b/ci/build_go.sh @@ -1,5 +1,5 @@ #!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 set -euo pipefail @@ -10,6 +10,9 @@ CPP_CHANNEL=$(rapids-download-conda-from-github cpp) rapids-logger "Create test conda environment" . /opt/conda/etc/profile.d/conda.sh +rapids-logger "Configuring conda strict channel priority" +conda config --set channel_priority strict + rapids-dependency-file-generator \ --output conda \ --file-key go \ diff --git a/ci/build_java.sh b/ci/build_java.sh index d5352910f8..922483446f 100755 --- a/ci/build_java.sh +++ b/ci/build_java.sh @@ -16,6 +16,9 @@ if [ -e "/opt/conda/etc/profile.d/conda.sh" ]; then . /opt/conda/etc/profile.d/conda.sh fi +rapids-logger "Configuring conda strict channel priority" +conda config --set channel_priority strict + rapids-logger "Downloading artifacts from previous jobs" CPP_CHANNEL=$(rapids-download-conda-from-github cpp) diff --git a/ci/build_rust.sh b/ci/build_rust.sh index adba4bd71f..e0f0b023fa 100755 --- a/ci/build_rust.sh +++ b/ci/build_rust.sh @@ -1,5 +1,5 @@ #!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 set -euo pipefail @@ -10,6 +10,9 @@ CPP_CHANNEL=$(rapids-download-conda-from-github cpp) rapids-logger "Create test conda environment" . /opt/conda/etc/profile.d/conda.sh +rapids-logger "Configuring conda strict channel priority" +conda config --set channel_priority strict + rapids-dependency-file-generator \ --output conda \ --file-key rust \ diff --git a/ci/check_style.sh b/ci/check_style.sh index 26b2bbeee4..cddb338d3e 100755 --- a/ci/check_style.sh +++ b/ci/check_style.sh @@ -7,6 +7,9 @@ set -euo pipefail rapids-logger "Create checks conda environment" . /opt/conda/etc/profile.d/conda.sh +rapids-logger "Configuring conda strict channel priority" +conda config --set channel_priority strict + rapids-dependency-file-generator \ --output conda \ --file-key checks \ From 3a79fbe059994da6785bc8169b09f8837933da04 Mon Sep 17 00:00:00 2001 From: Jinsol Park Date: Tue, 2 Dec 2025 18:16:45 -0800 Subject: [PATCH 24/32] Fix overflow in `preprocess_data_kernel` of NN Descent (#1596) `preprocess_data_kernel` in NN Descent had overflow issues. casting `blockIdx.x` to `size_t` to avoid overflow. Authors: - Jinsol Park (https://github.com/jinsolp) Approvers: - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/cuvs/pull/1596 --- cpp/src/neighbors/detail/nn_descent.cuh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cpp/src/neighbors/detail/nn_descent.cuh b/cpp/src/neighbors/detail/nn_descent.cuh index db26fbe6d9..184cbc72cd 100644 --- a/cpp/src/neighbors/detail/nn_descent.cuh +++ b/cpp/src/neighbors/detail/nn_descent.cuh @@ -243,7 +243,11 @@ RAFT_KERNEL preprocess_data_kernel( Data_t* s_vec = (Data_t*)buffer; size_t list_id = list_offset + blockIdx.x; - load_vec(s_vec, input_data + blockIdx.x * dim, dim, dim, threadIdx.x % raft::warp_size()); + load_vec(s_vec, + input_data + static_cast(blockIdx.x) * dim, + dim, + dim, + threadIdx.x % raft::warp_size()); if (threadIdx.x == 0) { l2_norm = 0; } __syncthreads(); From 113a64d996c823cf1a40ffcd4beb411e57fb957b Mon Sep 17 00:00:00 2001 From: Jinsol Park Date: Wed, 3 Dec 2025 11:46:12 -0800 Subject: [PATCH 25/32] Deduplicate `{unpack/pack}_list_data_kernel` (#1609) Closes https://github.com/rapidsai/cuvs/issues/1578 This PR refactors the code so that we have pre-compiled code for launching `{unpack/pack}_list_data_kernel`. `.cuh`: with declarations for including in other files `*_impl.cuh`, `*.cu`: actual implementation CUDA 12: 1104.44 MB -> 1100.26 MB CUDA 13: 437.47 MB -> 435.32 MB Authors: - Jinsol Park (https://github.com/jinsolp) Approvers: - Robert Maynard (https://github.com/robertmaynard) - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/cuvs/pull/1609 --- cpp/CMakeLists.txt | 1 + .../ivf_pq/detail/ivf_pq_list_data.cu | 31 +++ cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh | 178 +----------------- cpp/src/neighbors/ivf_pq/ivf_pq_list_data.hpp | 59 ++++++ .../ivf_pq/ivf_pq_list_data_impl.cuh | 163 ++++++++++++++++ 5 files changed, 255 insertions(+), 177 deletions(-) create mode 100644 cpp/src/neighbors/ivf_pq/detail/ivf_pq_list_data.cu create mode 100644 cpp/src/neighbors/ivf_pq/ivf_pq_list_data.hpp create mode 100644 cpp/src/neighbors/ivf_pq/ivf_pq_list_data_impl.cuh diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index af15e4a399..76abede565 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -488,6 +488,7 @@ if(NOT BUILD_CPU_ONLY) src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_fp8_false.cu src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_fp8_true.cu src/neighbors/ivf_pq/detail/ivf_pq_contiguous_list_data.cu + src/neighbors/ivf_pq/detail/ivf_pq_list_data.cu src/neighbors/ivf_pq/detail/ivf_pq_process_and_fill_codes.cu src/neighbors/ivf_pq/detail/ivf_pq_search_float_int64_t.cu src/neighbors/ivf_pq/detail/ivf_pq_search_half_int64_t.cu diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_list_data.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_list_data.cu new file mode 100644 index 0000000000..acef4c516e --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_list_data.cu @@ -0,0 +1,31 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "../ivf_pq_list_data_impl.cuh" +#include + +namespace cuvs::neighbors::ivf_pq::detail { +void unpack_list_data( + raft::device_matrix_view codes, + raft::device_mdspan::list_extents, raft::row_major> + list_data, + std::variant offset_or_indices, + uint32_t pq_bits, + rmm::cuda_stream_view stream) +{ + unpack_list_data_impl(codes, list_data, offset_or_indices, pq_bits, stream); +}; + +void pack_list_data( + raft::device_mdspan::list_extents, raft::row_major> + list_data, + raft::device_matrix_view codes, + std::variant offset_or_indices, + uint32_t pq_bits, + rmm::cuda_stream_view stream) +{ + pack_list_data_impl(list_data, codes, offset_or_indices, pq_bits, stream); +}; +}; // namespace cuvs::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index 31f2a12989..582bc890c0 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -10,6 +10,7 @@ #include "../ivf_list.cuh" #include "ivf_pq_codepacking.cuh" #include "ivf_pq_contiguous_list_data.cuh" +#include "ivf_pq_list_data.hpp" #include "ivf_pq_process_and_fill_codes.cuh" #include #include @@ -492,96 +493,6 @@ void train_per_cluster(raft::resources const& handle, transpose_pq_centers(handle, index, pq_centers_tmp.data()); } -/** - * A consumer for the `run_on_list` and `run_on_vector` that just flattens PQ codes - * one-per-byte. That is, independent of the code width (pq_bits), one code uses - * the whole byte, hence one vectors uses pq_dim bytes. - */ -struct unpack_codes { - raft::device_matrix_view out_codes; - - /** - * Create a callable to be passed to `run_on_list`. - * - * @param[out] out_codes the destination for the read codes. - */ - __device__ inline unpack_codes( - raft::device_matrix_view out_codes) - : out_codes{out_codes} - { - } - - /** Write j-th component (code) of the i-th vector into the output array. */ - __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) - { - out_codes(i, j) = code; - } -}; - -template -__launch_bounds__(BlockSize) static __global__ void unpack_list_data_kernel( - raft::device_matrix_view out_codes, - raft::device_mdspan::list_extents, raft::row_major> - in_list_data, - std::variant offset_or_indices) -{ - const uint32_t pq_dim = out_codes.extent(1); - auto unpack_action = unpack_codes{out_codes}; - run_on_list(in_list_data, offset_or_indices, out_codes.extent(0), pq_dim, unpack_action); -} - -/** - * Unpack flat PQ codes from an existing list by the given offset. - * - * @param[out] codes flat PQ codes, one code per byte [n_rows, pq_dim] - * @param[in] list_data the packed ivf::list data. - * @param[in] offset_or_indices how many records in the list to skip or the exact indices. - * @param[in] pq_bits codebook size (1 << pq_bits) - * @param[in] stream - */ -inline void unpack_list_data( - raft::device_matrix_view codes, - raft::device_mdspan::list_extents, raft::row_major> - list_data, - std::variant offset_or_indices, - uint32_t pq_bits, - rmm::cuda_stream_view stream) -{ - auto n_rows = codes.extent(0); - if (n_rows == 0) { return; } - - constexpr uint32_t kBlockSize = 256; - dim3 blocks(raft::div_rounding_up_safe(n_rows, kBlockSize), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto kernel = [pq_bits]() { - switch (pq_bits) { - case 4: return unpack_list_data_kernel; - case 5: return unpack_list_data_kernel; - case 6: return unpack_list_data_kernel; - case 7: return unpack_list_data_kernel; - case 8: return unpack_list_data_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(); - kernel<<>>(codes, list_data, offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -/** Unpack the list data; see the public interface for the api and usage. */ -template -void unpack_list_data(raft::resources const& res, - const index& index, - raft::device_matrix_view out_codes, - uint32_t label, - std::variant offset_or_indices) -{ - unpack_list_data(out_codes, - index.lists()[label]->data.view(), - offset_or_indices, - index.pq_bits(), - raft::resource::get_cuda_stream(res)); -} - /** A consumer for the `run_on_list` and `run_on_vector` that approximates the original input data. */ struct reconstruct_vectors { @@ -740,93 +651,6 @@ void reconstruct_list_data(raft::resources const& res, } } -/** - * A producer for the `write_list` and `write_vector` reads the codes byte-by-byte. That is, - * independent of the code width (pq_bits), one code uses the whole byte, hence one vectors uses - * pq_dim bytes. - */ -struct pass_codes { - raft::device_matrix_view codes; - - /** - * Create a callable to be passed to `run_on_list`. - * - * @param[in] codes the source codes. - */ - __device__ inline pass_codes( - raft::device_matrix_view codes) - : codes{codes} - { - } - - /** Read j-th component (code) of the i-th vector from the source. */ - __device__ inline auto operator()(uint32_t i, uint32_t j) const -> uint8_t { return codes(i, j); } -}; - -template -__launch_bounds__(BlockSize) static __global__ void pack_list_data_kernel( - raft::device_mdspan::list_extents, raft::row_major> - list_data, - raft::device_matrix_view codes, - std::variant offset_or_indices) -{ - write_list( - list_data, offset_or_indices, codes.extent(0), codes.extent(1), pass_codes{codes}); -} - -/** - * Write flat PQ codes into an existing list by the given offset. - * - * NB: no memory allocation happens here; the list must fit the data (offset + n_rows). - * - * @param[out] list_data the packed ivf::list data. - * @param[in] codes flat PQ codes, one code per byte [n_rows, pq_dim] - * @param[in] offset_or_indices how many records in the list to skip or the exact indices. - * @param[in] pq_bits codebook size (1 << pq_bits) - * @param[in] stream - */ -inline void pack_list_data( - raft::device_mdspan::list_extents, raft::row_major> - list_data, - raft::device_matrix_view codes, - std::variant offset_or_indices, - uint32_t pq_bits, - rmm::cuda_stream_view stream) -{ - auto n_rows = codes.extent(0); - if (n_rows == 0) { return; } - - constexpr uint32_t kBlockSize = 256; - dim3 blocks(raft::div_rounding_up_safe(n_rows, kBlockSize), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto kernel = [pq_bits]() { - switch (pq_bits) { - case 4: return pack_list_data_kernel; - case 5: return pack_list_data_kernel; - case 6: return pack_list_data_kernel; - case 7: return pack_list_data_kernel; - case 8: return pack_list_data_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(); - kernel<<>>(list_data, codes, offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -void pack_list_data(raft::resources const& res, - index* index, - raft::device_matrix_view new_codes, - uint32_t label, - std::variant offset_or_indices) -{ - pack_list_data(index->lists()[label]->data.view(), - new_codes, - offset_or_indices, - index->pq_bits(), - raft::resource::get_cuda_stream(res)); -} - template __launch_bounds__(BlockSize) static __global__ void encode_list_data_kernel( raft::device_mdspan::list_extents, raft::row_major> diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_list_data.hpp b/cpp/src/neighbors/ivf_pq/ivf_pq_list_data.hpp new file mode 100644 index 0000000000..f9c9e28b3b --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_list_data.hpp @@ -0,0 +1,59 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include +#include +#include + +namespace cuvs::neighbors::ivf_pq::detail { + +void unpack_list_data( + raft::device_matrix_view codes, + raft::device_mdspan::list_extents, raft::row_major> + list_data, + std::variant offset_or_indices, + uint32_t pq_bits, + rmm::cuda_stream_view stream); + +/** Unpack the list data; see the public interface for the api and usage. */ +template +void unpack_list_data(raft::resources const& res, + const index& index, + raft::device_matrix_view out_codes, + uint32_t label, + std::variant offset_or_indices) +{ + unpack_list_data(out_codes, + index.lists()[label]->data.view(), + offset_or_indices, + index.pq_bits(), + raft::resource::get_cuda_stream(res)); +} + +void pack_list_data( + raft::device_mdspan::list_extents, raft::row_major> + list_data, + raft::device_matrix_view codes, + std::variant offset_or_indices, + uint32_t pq_bits, + rmm::cuda_stream_view stream); + +template +void pack_list_data(raft::resources const& res, + index* index, + raft::device_matrix_view new_codes, + uint32_t label, + std::variant offset_or_indices) +{ + pack_list_data(index->lists()[label]->data.view(), + new_codes, + offset_or_indices, + index->pq_bits(), + raft::resource::get_cuda_stream(res)); +} + +} // namespace cuvs::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_list_data_impl.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_list_data_impl.cuh new file mode 100644 index 0000000000..c20857cb10 --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_list_data_impl.cuh @@ -0,0 +1,163 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once +#include "ivf_pq_codepacking.cuh" +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::ivf_pq::detail { + +/** + * A consumer for the `run_on_list` and `run_on_vector` that just flattens PQ codes + * one-per-byte. That is, independent of the code width (pq_bits), one code uses + * the whole byte, hence one vectors uses pq_dim bytes. + */ +struct unpack_codes { + raft::device_matrix_view out_codes; + + /** + * Create a callable to be passed to `run_on_list`. + * + * @param[out] out_codes the destination for the read codes. + */ + __device__ inline unpack_codes( + raft::device_matrix_view out_codes) + : out_codes{out_codes} + { + } + + /** Write j-th component (code) of the i-th vector into the output array. */ + __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) + { + out_codes(i, j) = code; + } +}; + +template +__launch_bounds__(BlockSize) static __global__ void unpack_list_data_kernel( + raft::device_matrix_view out_codes, + raft::device_mdspan::list_extents, raft::row_major> + in_list_data, + std::variant offset_or_indices) +{ + const uint32_t pq_dim = out_codes.extent(1); + auto unpack_action = unpack_codes{out_codes}; + run_on_list(in_list_data, offset_or_indices, out_codes.extent(0), pq_dim, unpack_action); +} + +/** + * Unpack flat PQ codes from an existing list by the given offset. + * + * @param[out] codes flat PQ codes, one code per byte [n_rows, pq_dim] + * @param[in] list_data the packed ivf::list data. + * @param[in] offset_or_indices how many records in the list to skip or the exact indices. + * @param[in] pq_bits codebook size (1 << pq_bits) + * @param[in] stream + */ +inline void unpack_list_data_impl( + raft::device_matrix_view codes, + raft::device_mdspan::list_extents, raft::row_major> + list_data, + std::variant offset_or_indices, + uint32_t pq_bits, + rmm::cuda_stream_view stream) +{ + auto n_rows = codes.extent(0); + if (n_rows == 0) { return; } + + constexpr uint32_t kBlockSize = 256; + dim3 blocks(raft::div_rounding_up_safe(n_rows, kBlockSize), 1, 1); + dim3 threads(kBlockSize, 1, 1); + auto kernel = [pq_bits]() { + switch (pq_bits) { + case 4: return unpack_list_data_kernel; + case 5: return unpack_list_data_kernel; + case 6: return unpack_list_data_kernel; + case 7: return unpack_list_data_kernel; + case 8: return unpack_list_data_kernel; + default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); + } + }(); + kernel<<>>(codes, list_data, offset_or_indices); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +/** + * A producer for the `write_list` and `write_vector` reads the codes byte-by-byte. That is, + * independent of the code width (pq_bits), one code uses the whole byte, hence one vectors uses + * pq_dim bytes. + */ +struct pass_codes { + raft::device_matrix_view codes; + + /** + * Create a callable to be passed to `run_on_list`. + * + * @param[in] codes the source codes. + */ + __device__ inline pass_codes( + raft::device_matrix_view codes) + : codes{codes} + { + } + + /** Read j-th component (code) of the i-th vector from the source. */ + __device__ inline auto operator()(uint32_t i, uint32_t j) const -> uint8_t { return codes(i, j); } +}; + +template +__launch_bounds__(BlockSize) static __global__ void pack_list_data_kernel( + raft::device_mdspan::list_extents, raft::row_major> + list_data, + raft::device_matrix_view codes, + std::variant offset_or_indices) +{ + write_list( + list_data, offset_or_indices, codes.extent(0), codes.extent(1), pass_codes{codes}); +} + +/** + * Write flat PQ codes into an existing list by the given offset. + * + * NB: no memory allocation happens here; the list must fit the data (offset + n_rows). + * + * @param[out] list_data the packed ivf::list data. + * @param[in] codes flat PQ codes, one code per byte [n_rows, pq_dim] + * @param[in] offset_or_indices how many records in the list to skip or the exact indices. + * @param[in] pq_bits codebook size (1 << pq_bits) + * @param[in] stream + */ +inline void pack_list_data_impl( + raft::device_mdspan::list_extents, raft::row_major> + list_data, + raft::device_matrix_view codes, + std::variant offset_or_indices, + uint32_t pq_bits, + rmm::cuda_stream_view stream) +{ + auto n_rows = codes.extent(0); + if (n_rows == 0) { return; } + + constexpr uint32_t kBlockSize = 256; + dim3 blocks(raft::div_rounding_up_safe(n_rows, kBlockSize), 1, 1); + dim3 threads(kBlockSize, 1, 1); + auto kernel = [pq_bits]() { + switch (pq_bits) { + case 4: return pack_list_data_kernel; + case 5: return pack_list_data_kernel; + case 6: return pack_list_data_kernel; + case 7: return pack_list_data_kernel; + case 8: return pack_list_data_kernel; + default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); + } + }(); + kernel<<>>(list_data, codes, offset_or_indices); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} +}; // namespace cuvs::neighbors::ivf_pq::detail From 84a8a7c5563236794471c89b44c9393d7e92f2f7 Mon Sep 17 00:00:00 2001 From: Jinsol Park Date: Wed, 3 Dec 2025 13:44:19 -0800 Subject: [PATCH 26/32] Expose NN Descent fp16 data type support to python (#1616) Closes https://github.com/rapidsai/cuvs/issues/1586 NN Descent python wrapper fails the `_check_input_array` check when given fp16 data. Authors: - Jinsol Park (https://github.com/jinsolp) Approvers: - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/cuvs/pull/1616 --- python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx | 2 +- python/cuvs/cuvs/tests/test_nn_descent.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx b/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx index 62486e46d3..afae653f35 100644 --- a/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx +++ b/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx @@ -223,7 +223,7 @@ def build(IndexParams index_params, dataset, graph=None, resources=None): >>> graph = index.graph """ dataset_ai = wrap_array(dataset) - _check_input_array(dataset_ai, [np.dtype('float32'), np.dtype('byte'), + _check_input_array(dataset_ai, [np.dtype('float32'), np.dtype('float16'), np.dtype('byte'), np.dtype('ubyte')]) cdef Index idx = Index() diff --git a/python/cuvs/cuvs/tests/test_nn_descent.py b/python/cuvs/cuvs/tests/test_nn_descent.py index c2d128edbb..862f82dc12 100644 --- a/python/cuvs/cuvs/tests/test_nn_descent.py +++ b/python/cuvs/cuvs/tests/test_nn_descent.py @@ -13,7 +13,7 @@ @pytest.mark.parametrize("n_rows", [1024, 2048]) @pytest.mark.parametrize("n_cols", [32, 64]) @pytest.mark.parametrize("device_memory", [True, False]) -@pytest.mark.parametrize("dtype", [np.float32]) +@pytest.mark.parametrize("dtype", [np.float32, np.float16]) @pytest.mark.parametrize("inplace", [True, False]) @pytest.mark.parametrize("return_distances", [True, False]) def test_nn_descent( From 2fc8e9934782c5f12f9a60bfc7f9bbeb1e80718e Mon Sep 17 00:00:00 2001 From: Robert Maynard Date: Thu, 4 Dec 2025 03:12:39 -0500 Subject: [PATCH 27/32] [FEA] Enforce tighter link restrictions on libcuvs_c (#1614) This makes sure we don't leak unneeded dependencies in our `PUBLIC` target_link_libraries for cuvs_c Authors: - Robert Maynard (https://github.com/robertmaynard) Approvers: - Divye Gala (https://github.com/divyegala) - Kyle Edwards (https://github.com/KyleFromNVIDIA) URL: https://github.com/rapidsai/cuvs/pull/1614 --- c/CMakeLists.txt | 6 ++++-- c/tests/CMakeLists.txt | 9 +++++++-- ci/build_standalone_c.sh | 1 - examples/c/CMakeLists.txt | 20 +++++++++++++++----- 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/c/CMakeLists.txt b/c/CMakeLists.txt index 5c66cad9fe..30f9299788 100644 --- a/c/CMakeLists.txt +++ b/c/CMakeLists.txt @@ -155,8 +155,10 @@ target_include_directories( target_link_libraries( cuvs_c - PUBLIC $,cuvs::cuvs_static,cuvs::cuvs> - PRIVATE raft::raft $ + PRIVATE # we don't want any dependencies exported + $,cuvs::cuvs_static,cuvs::cuvs> # + $ # enforce we shouldn't use raft symbols + $ ) # ################################################################################################## diff --git a/c/tests/CMakeLists.txt b/c/tests/CMakeLists.txt index 0218f8d4b8..6d52e5b174 100644 --- a/c/tests/CMakeLists.txt +++ b/c/tests/CMakeLists.txt @@ -43,8 +43,12 @@ function(ConfigureTest) add_executable(${TEST_NAME} ${_CUVS_TEST_PATH}) target_link_libraries( - ${TEST_NAME} PRIVATE cuvs::c_api GTest::gtest GTest::gtest_main - $ + ${TEST_NAME} + PRIVATE cuvs::c_api + GTest::gtest + GTest::gtest_main + $,$,cuvs::cuvs> + $ ) set_target_properties( ${TEST_NAME} @@ -90,6 +94,7 @@ endif() ConfigureTest(NAME cuvs_c_headers PATH core/headers.c) ConfigureTest(NAME cuvs_c_test PATH core/c_api.c) +target_link_libraries(cuvs_c_test PRIVATE CUDA::cudart) ConfigureTest(NAME cuvs_c_neighbors_test PATH neighbors/c_api.c) # ################################################################################################## diff --git a/ci/build_standalone_c.sh b/ci/build_standalone_c.sh index bee25e0d37..98477ae0db 100755 --- a/ci/build_standalone_c.sh +++ b/ci/build_standalone_c.sh @@ -40,7 +40,6 @@ scl enable gcc-toolset-${TOOLSET_VERSION} -- \ cmake -S cpp -B cpp/build/ \ -DCMAKE_CUDA_HOST_COMPILER=/opt/rh/gcc-toolset-${TOOLSET_VERSION}/root/usr/bin/gcc \ -DCMAKE_CUDA_ARCHITECTURES=RAPIDS \ - -DBUILD_SHARED_LIBS=OFF \ -DCUTLASS_ENABLE_TESTS=OFF \ -DDISABLE_OPENMP=OFF \ -DBUILD_TESTS=OFF \ diff --git a/examples/c/CMakeLists.txt b/examples/c/CMakeLists.txt index 6a1c6a4e83..b6bee50e26 100644 --- a/examples/c/CMakeLists.txt +++ b/examples/c/CMakeLists.txt @@ -32,20 +32,30 @@ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Wextra -Werror") add_executable(CAGRA_C_EXAMPLE src/cagra_c_example.c) target_include_directories(CAGRA_C_EXAMPLE PUBLIC "$") -target_link_libraries(CAGRA_C_EXAMPLE PRIVATE cuvs::c_api $) +target_link_libraries( + CAGRA_C_EXAMPLE PRIVATE cuvs::c_api CUDA::cudart $ +) add_executable(L2_C_EXAMPLE src/L2_c_example.c) target_include_directories(L2_C_EXAMPLE PUBLIC "$") -target_link_libraries(L2_C_EXAMPLE PRIVATE cuvs::c_api $) +target_link_libraries( + L2_C_EXAMPLE PRIVATE cuvs::c_api CUDA::cudart $ +) add_executable(IVF_FLAT_C_EXAMPLE src/ivf_flat_c_example.c) target_include_directories(IVF_FLAT_C_EXAMPLE PUBLIC "$") -target_link_libraries(IVF_FLAT_C_EXAMPLE PRIVATE cuvs::c_api $) +target_link_libraries( + IVF_FLAT_C_EXAMPLE PRIVATE cuvs::c_api CUDA::cudart $ +) add_executable(IVF_PQ_C_EXAMPLE src/ivf_pq_c_example.c) target_include_directories(IVF_PQ_C_EXAMPLE PUBLIC "$") -target_link_libraries(IVF_PQ_C_EXAMPLE PRIVATE cuvs::c_api $) +target_link_libraries( + IVF_PQ_C_EXAMPLE PRIVATE cuvs::c_api CUDA::cudart $ +) add_executable(BRUTEFORCE_C_EXAMPLE src/bruteforce_c_example.c) target_include_directories(BRUTEFORCE_C_EXAMPLE PUBLIC "$") -target_link_libraries(BRUTEFORCE_C_EXAMPLE PRIVATE cuvs::c_api $) +target_link_libraries( + BRUTEFORCE_C_EXAMPLE PRIVATE cuvs::c_api CUDA::cudart $ +) From ba67db12dcd2b4f894e9b0693c11e85dff4c3626 Mon Sep 17 00:00:00 2001 From: Anupam <54245698+aamijar@users.noreply.github.com> Date: Thu, 4 Dec 2025 08:48:52 -0800 Subject: [PATCH 28/32] cmake is missing `sparse/gram.cu` gtest (#1611) The sparse/gram apis were moved from raft in https://github.com/rapidsai/cuvs/pull/463. However, the cmake has not been updated to compile the tests. Authors: - Anupam (https://github.com/aamijar) Approvers: - Robert Maynard (https://github.com/robertmaynard) URL: https://github.com/rapidsai/cuvs/pull/1611 --- cpp/tests/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 85a28950ec..1a2d7fa23d 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -309,6 +309,7 @@ ConfigureTest( distance/gram.cu distance/masked_nn.cu distance/sparse_distance.cu + sparse/gram.cu sparse/neighbors/cross_component_nn.cu GPUS 1 PERCENT 100 From 94795b09444746eda80dc27f775a73c9bf8b2ecd Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Fri, 5 Dec 2025 07:11:08 -0800 Subject: [PATCH 29/32] Use CCCL's mdspan implementation (#1605) Based on https://github.com/rapidsai/raft/pull/2836 Authors: - Divye Gala (https://github.com/divyegala) Approvers: - Bradley Dice (https://github.com/bdice) - Dante Gama Dessavre (https://github.com/dantegd) URL: https://github.com/rapidsai/cuvs/pull/1605 --- c/src/core/detail/interop.hpp | 2 +- cpp/cmake/thirdparty/get_raft.cmake | 7 ++- cpp/include/cuvs/cluster/kmeans.hpp | 45 +++++++++++-------- cpp/include/cuvs/neighbors/common.hpp | 4 +- cpp/include/cuvs/neighbors/ivf_pq.hpp | 14 +++--- .../cuvs/neighbors/knn_merge_parts.hpp | 6 +-- cpp/src/cluster/detail/kmeans.cuh | 6 +-- cpp/src/cluster/detail/kmeans_auto_find_k.cuh | 19 ++++---- cpp/src/cluster/detail/single_linkage.cuh | 2 +- cpp/src/cluster/kmeans_fit_double.cu | 8 ++-- cpp/src/cluster/kmeans_fit_float.cu | 8 ++-- cpp/src/cluster/kmeans_fit_mg_double.cu | 10 ++--- cpp/src/cluster/kmeans_fit_mg_float.cu | 10 ++--- cpp/src/cluster/kmeans_impl_fit_predict.cuh | 20 +++++++-- cpp/src/cluster/kmeans_mg.hpp | 18 ++++---- cpp/src/cluster/kmeans_predict_double.cu | 8 ++-- cpp/src/cluster/kmeans_predict_float.cu | 9 ++-- cpp/src/neighbors/cagra.cuh | 31 ++++++------- cpp/src/neighbors/detail/cagra/add_nodes.cuh | 2 +- .../neighbors/detail/cagra/cagra_build.cuh | 35 +++++++-------- .../neighbors/detail/cagra/cagra_merge.cuh | 4 +- cpp/src/neighbors/detail/cagra/graph_core.cuh | 20 ++++----- cpp/src/neighbors/detail/knn_brute_force.cuh | 4 +- cpp/src/neighbors/detail/nn_descent.cuh | 12 ++--- cpp/src/neighbors/detail/sparse_knn.cuh | 4 +- .../neighbors/detail/vamana/greedy_search.cuh | 6 +-- .../neighbors/detail/vamana/robust_prune.cuh | 6 +-- .../neighbors/detail/vamana/vamana_build.cuh | 12 ++--- .../detail/vamana/vamana_structs.cuh | 6 +-- .../iface/iface_cagra_float_uint32_t.cu | 16 +++---- .../iface/iface_cagra_half_uint32_t.cu | 16 +++---- .../iface/iface_cagra_int8_t_uint32_t.cu | 16 +++---- .../iface/iface_cagra_uint8_t_uint32_t.cu | 16 +++---- .../iface/iface_flat_float_int64_t.cu | 16 +++---- .../iface/iface_flat_half_int64_t.cu | 14 +++--- .../iface/iface_flat_int8_t_int64_t.cu | 16 +++---- .../iface/iface_flat_uint8_t_int64_t.cu | 16 +++---- .../neighbors/iface/iface_pq_float_int64_t.cu | 16 +++---- .../neighbors/iface/iface_pq_half_int64_t.cu | 16 +++---- .../iface/iface_pq_int8_t_int64_t.cu | 16 +++---- .../iface/iface_pq_uint8_t_int64_t.cu | 16 +++---- .../detail/ivf_pq_process_and_fill_codes.cu | 2 +- cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh | 2 +- .../neighbors/ivf_pq/ivf_pq_build_common.cu | 8 ++-- .../ivf_pq/ivf_pq_process_and_fill_codes.cuh | 13 +++--- .../ivf_pq_process_and_fill_codes_impl.cuh | 13 +++--- cpp/src/neighbors/knn_merge_parts.cu | 8 ++-- cpp/src/neighbors/scann/detail/scann_avq.cuh | 24 +++++----- .../neighbors/scann/detail/scann_build.cuh | 6 +-- .../neighbors/scann/detail/scann_quantize.cuh | 6 +-- .../scann/detail/scann_serialize.cuh | 4 +- cpp/src/neighbors/scann/scann.cuh | 6 +-- cpp/src/neighbors/vamana.cuh | 6 +-- cpp/src/sparse/cluster/cluster_solvers.cuh | 11 ++--- cpp/tests/cluster/kmeans.cu | 4 +- 55 files changed, 331 insertions(+), 310 deletions(-) diff --git a/c/src/core/detail/interop.hpp b/c/src/core/detail/interop.hpp index 3b94feb78d..5974d1aa15 100644 --- a/c/src/core/detail/interop.hpp +++ b/c/src/core/detail/interop.hpp @@ -129,7 +129,7 @@ inline MdspanType from_dlpack(DLManagedTensor* managed_tensor) "ndim mismatch between return mdspan and DLTensor"); // auto exts = typename MdspanType::extents_type{tensor.shape}; - std::array shape{}; + cuda::std::array shape{}; for (int64_t i = 0; i < tensor.ndim; ++i) { shape[i] = tensor.shape[i]; } diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 492ae6cba1..8ecf3686be 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -14,13 +14,18 @@ function(find_and_configure_raft) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) + # Set BUILD_SHARED_LIBS whenever building static dependencies + if(PKG_BUILD_STATIC_DEPS) + set(BUILD_SHARED_LIBS OFF) + endif() + + # Determine whether to clone raft locally if(PKG_CLONE_ON_PIN AND NOT PKG_PINNED_TAG STREQUAL "${rapids-cmake-checkout-tag}") message(STATUS "cuVS: RAFT pinned tag found: ${PKG_PINNED_TAG}. Cloning raft locally.") set(CPM_DOWNLOAD_raft ON) elseif(PKG_BUILD_STATIC_DEPS AND (NOT CPM_raft_SOURCE)) message(STATUS "cuVS: Cloning raft locally to build static libraries.") set(CPM_DOWNLOAD_raft ON) - set(BUILD_SHARED_LIBS OFF) endif() set(RAFT_COMPONENTS "") diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index c87e006315..a8aa6b9807 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -181,8 +181,8 @@ void fit(raft::resources const& handle, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); /** * @brief Find clusters with k-means algorithm. @@ -232,8 +232,8 @@ void fit(raft::resources const& handle, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); /** * @brief Find clusters with k-means algorithm. @@ -282,8 +282,8 @@ void fit(raft::resources const& handle, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); /** * @brief Find clusters with k-means algorithm. @@ -333,8 +333,8 @@ void fit(raft::resources const& handle, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); /** * @brief Find clusters with k-means algorithm. @@ -383,8 +383,8 @@ void fit(raft::resources const& handle, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); /** * @brief Find balanced clusters with k-means algorithm. @@ -581,6 +581,15 @@ void predict(raft::resources const& handle, bool normalize_weight, raft::host_scalar_view inertia); +void predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia); + /** * @brief Predict the closest cluster each sample in X belongs to. * @@ -632,10 +641,10 @@ void predict(raft::resources const& handle, */ void predict(raft::resources const& handle, const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::device_vector_view labels, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::device_vector_view labels, bool normalize_weight, raft::host_scalar_view inertia); @@ -748,10 +757,10 @@ void predict(raft::resources const& handle, */ void predict(raft::resources const& handle, const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::device_vector_view labels, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::device_vector_view labels, bool normalize_weight, raft::host_scalar_view inertia); diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 749561896f..4f697b3604 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -248,7 +248,7 @@ auto make_strided_dataset(const raft::resources& res, const SrcT& src, uint32_t } // Something is wrong: have to make a copy and produce an owning dataset auto out_layout = - raft::make_strided_layout(src.extents(), std::array{required_stride, 1}); + raft::make_strided_layout(src.extents(), cuda::std::array{required_stride, 1}); auto out_array = raft::make_device_matrix(res, src.extent(0), required_stride); @@ -310,7 +310,7 @@ auto make_strided_dataset( const bool stride_matches = required_stride == src_stride; auto out_layout = - raft::make_strided_layout(src.extents(), std::array{required_stride, 1}); + raft::make_strided_layout(src.extents(), cuda::std::array{required_stride, 1}); using out_mdarray_type = raft::device_matrix; using out_layout_type = typename out_mdarray_type::layout_type; diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index 6836757cea..cfb344b04c 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -269,10 +269,8 @@ constexpr typename list_spec::list_extents list_spec:: { // how many elems of pq_dim fit into one kIndexGroupVecLen-byte chunk auto pq_chunk = (kIndexGroupVecLen * 8u) / pq_bits; - return raft::make_extents(raft::div_rounding_up_safe(n_rows, kIndexGroupSize), - raft::div_rounding_up_safe(pq_dim, pq_chunk), - kIndexGroupSize, - kIndexGroupVecLen); + return list_extents{raft::div_rounding_up_safe(n_rows, kIndexGroupSize), + raft::div_rounding_up_safe(pq_dim, pq_chunk)}; } template @@ -335,8 +333,8 @@ struct index : cuvs::neighbors::index { static_assert(!raft::is_narrowing_v, "IdxT must be able to represent all values of uint32_t"); - using pq_centers_extents = std::experimental:: - extents; + using pq_centers_extents = + raft::extents; public: index(const index&) = delete; @@ -2875,7 +2873,7 @@ void make_rotation_matrix(raft::resources const& res, */ void set_centers(raft::resources const& res, index* index, - raft::device_matrix_view cluster_centers); + raft::device_matrix_view cluster_centers); /** * @brief Public helper API for fetching a trained index's IVF centroids @@ -2896,7 +2894,7 @@ void set_centers(raft::resources const& res, */ void extract_centers(raft::resources const& res, const index& index, - raft::device_matrix_view cluster_centers); + raft::device_matrix_view cluster_centers); /** @copydoc extract_centers */ void extract_centers(raft::resources const& res, diff --git a/cpp/include/cuvs/neighbors/knn_merge_parts.hpp b/cpp/include/cuvs/neighbors/knn_merge_parts.hpp index 01551987fc..7581a28c7c 100644 --- a/cpp/include/cuvs/neighbors/knn_merge_parts.hpp +++ b/cpp/include/cuvs/neighbors/knn_merge_parts.hpp @@ -27,17 +27,17 @@ void knn_merge_parts(raft::resources const& res, raft::device_matrix_view inV, raft::device_matrix_view outK, raft::device_matrix_view outV, - raft::device_vector_view translations); + raft::device_vector_view translations); void knn_merge_parts(raft::resources const& res, raft::device_matrix_view inK, raft::device_matrix_view inV, raft::device_matrix_view outK, raft::device_matrix_view outV, - raft::device_vector_view translations); + raft::device_vector_view translations); void knn_merge_parts(raft::resources const& res, raft::device_matrix_view inK, raft::device_matrix_view inV, raft::device_matrix_view outK, raft::device_matrix_view outV, - raft::device_vector_view translations); + raft::device_vector_view translations); } // namespace cuvs::neighbors diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index c3096945b4..7dc3fdd963 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -1117,9 +1117,9 @@ void kmeans_predict(raft::resources const& handle, template void kmeans_transform(raft::resources const& handle, const cuvs::cluster::kmeans::params& pams, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_matrix_view X_new) + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_matrix_view X_new) { raft::common::nvtx::range fun_scope("kmeans_transform"); raft::default_logger().set_level(pams.verbosity); diff --git a/cpp/src/cluster/detail/kmeans_auto_find_k.cuh b/cpp/src/cluster/detail/kmeans_auto_find_k.cuh index ae94db341d..594a63e8da 100644 --- a/cpp/src/cluster/detail/kmeans_auto_find_k.cuh +++ b/cpp/src/cluster/detail/kmeans_auto_find_k.cuh @@ -25,8 +25,8 @@ void compute_dispersion(raft::resources const& handle, raft::device_matrix_view X, cuvs::cluster::kmeans::params& params, raft::device_matrix_view centroids_view, - raft::device_vector_view labels, - raft::device_vector_view clusterSizes, + raft::device_vector_view labels, + raft::device_vector_view clusterSizes, rmm::device_uvector& workspace, raft::host_vector_view clusterDispertionView, raft::host_vector_view resultsView, @@ -109,12 +109,15 @@ void find_k(raft::resources const& handle, auto centroids_view = raft::make_device_matrix_view(centroids.data_handle(), left, d); + auto labels_view = raft::make_device_vector_view(labels.data_handle(), n); + auto clusterSizes_view = + raft::make_device_vector_view(clusterSizes.data_handle(), kmax); compute_dispersion(handle, X, params, centroids_view, - labels.view(), - clusterSizes.view(), + labels_view, + clusterSizes_view, workspace, clusterDispertionView, resultsView, @@ -133,8 +136,8 @@ void find_k(raft::resources const& handle, X, params, centroids_view, - labels.view(), - clusterSizes.view(), + labels_view, + clusterSizes_view, workspace, clusterDispertionView, resultsView, @@ -159,8 +162,8 @@ void find_k(raft::resources const& handle, X, params, centroids_view, - labels.view(), - clusterSizes.view(), + labels_view, + clusterSizes_view, workspace, clusterDispertionView, resultsView, diff --git a/cpp/src/cluster/detail/single_linkage.cuh b/cpp/src/cluster/detail/single_linkage.cuh index b095e7e73c..f8d4615c75 100644 --- a/cpp/src/cluster/detail/single_linkage.cuh +++ b/cpp/src/cluster/detail/single_linkage.cuh @@ -43,7 +43,7 @@ namespace cuvs::cluster::agglomerative::detail { template >> + typename Accessor = raft::device_accessor>> void build_mr_linkage( raft::resources const& handle, raft::mdspan, raft::row_major, Accessor> X, diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index 0bf06d5f31..43f457a29a 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -44,8 +44,8 @@ void fit(raft::resources const& handle, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); @@ -56,8 +56,8 @@ void fit(raft::resources const& handle, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index 2511814427..5624151943 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -44,8 +44,8 @@ void fit(raft::resources const& handle, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); @@ -56,8 +56,8 @@ void fit(raft::resources const& handle, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); diff --git a/cpp/src/cluster/kmeans_fit_mg_double.cu b/cpp/src/cluster/kmeans_fit_mg_double.cu index 73e40742e9..bd7f8453c1 100644 --- a/cpp/src/cluster/kmeans_fit_mg_double.cu +++ b/cpp/src/cluster/kmeans_fit_mg_double.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -14,8 +14,8 @@ void fit(raft::resources const& handle, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); @@ -28,8 +28,8 @@ void fit(raft::resources const& handle, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); diff --git a/cpp/src/cluster/kmeans_fit_mg_float.cu b/cpp/src/cluster/kmeans_fit_mg_float.cu index 4352e8859c..ae7c5722b7 100644 --- a/cpp/src/cluster/kmeans_fit_mg_float.cu +++ b/cpp/src/cluster/kmeans_fit_mg_float.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -14,8 +14,8 @@ void fit(raft::resources const& handle, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); @@ -28,8 +28,8 @@ void fit(raft::resources const& handle, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); diff --git a/cpp/src/cluster/kmeans_impl_fit_predict.cuh b/cpp/src/cluster/kmeans_impl_fit_predict.cuh index 6489934a03..e350b7d6ed 100644 --- a/cpp/src/cluster/kmeans_impl_fit_predict.cuh +++ b/cpp/src/cluster/kmeans_impl_fit_predict.cuh @@ -26,12 +26,24 @@ void fit_predict(raft::resources const& handle, raft::make_device_matrix(handle, pams.n_clusters, n_features); cuvs::cluster::kmeans::fit( handle, pams, X, sample_weight, centroids_matrix.view(), inertia, n_iter); - cuvs::cluster::kmeans::predict( - handle, pams, X, sample_weight, centroids_matrix.view(), labels, true, inertia); + cuvs::cluster::kmeans::predict(handle, + pams, + X, + sample_weight, + raft::make_const_mdspan(centroids_matrix.view()), + labels, + true, + inertia); } else { cuvs::cluster::kmeans::fit(handle, pams, X, sample_weight, centroids.value(), inertia, n_iter); - cuvs::cluster::kmeans::predict( - handle, pams, X, sample_weight, centroids.value(), labels, true, inertia); + cuvs::cluster::kmeans::predict(handle, + pams, + X, + sample_weight, + raft::make_const_mdspan(centroids.value()), + labels, + true, + inertia); } } diff --git a/cpp/src/cluster/kmeans_mg.hpp b/cpp/src/cluster/kmeans_mg.hpp index 7281b28fec..77cceff962 100644 --- a/cpp/src/cluster/kmeans_mg.hpp +++ b/cpp/src/cluster/kmeans_mg.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -35,30 +35,30 @@ void fit(raft::resources const& handle, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); } // namespace cuvs::cluster::kmeans::mg diff --git a/cpp/src/cluster/kmeans_predict_double.cu b/cpp/src/cluster/kmeans_predict_double.cu index 8d65d6fb4f..52d120a232 100644 --- a/cpp/src/cluster/kmeans_predict_double.cu +++ b/cpp/src/cluster/kmeans_predict_double.cu @@ -41,10 +41,10 @@ void predict(raft::resources const& handle, void predict(raft::resources const& handle, const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::device_vector_view labels, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::device_vector_view labels, bool normalize_weight, raft::host_scalar_view inertia) diff --git a/cpp/src/cluster/kmeans_predict_float.cu b/cpp/src/cluster/kmeans_predict_float.cu index 0e5180fc05..30812aa141 100644 --- a/cpp/src/cluster/kmeans_predict_float.cu +++ b/cpp/src/cluster/kmeans_predict_float.cu @@ -38,12 +38,13 @@ void predict(raft::resources const& handle, cuvs::cluster::kmeans::predict( handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); } + void predict(raft::resources const& handle, const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::device_vector_view labels, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::device_vector_view labels, bool normalize_weight, raft::host_scalar_view inertia) diff --git a/cpp/src/neighbors/cagra.cuh b/cpp/src/neighbors/cagra.cuh index cf65ef7c4d..81d181acbd 100644 --- a/cpp/src/neighbors/cagra.cuh +++ b/cpp/src/neighbors/cagra.cuh @@ -163,7 +163,7 @@ void build_knn_graph( */ template , + typename accessor = raft::host_device_accessor, raft::memory_type::device>> void build_knn_graph( raft::resources const& res, @@ -206,13 +206,12 @@ void build_knn_graph( * @param[in,out] knn_graph a matrix view (host or device) of the input knn graph [n_rows, * knn_graph_degree] */ -template < - typename DataT, - typename IdxT = uint32_t, - typename d_accessor = raft::host_device_accessor, - raft::memory_type::device>, - typename g_accessor = - raft::host_device_accessor, raft::memory_type::host>> +template , + raft::memory_type::device>, + typename g_accessor = + raft::host_device_accessor, raft::memory_type::host>> void sort_knn_graph( raft::resources const& res, cuvs::distance::DistanceType metric, @@ -222,8 +221,7 @@ void sort_knn_graph( using internal_IdxT = typename std::make_unsigned::type; using g_accessor_internal = - raft::host_device_accessor, - g_accessor::mem_type>; + raft::host_device_accessor, g_accessor::mem_type>; auto knn_graph_internal = raft::mdspan, raft::row_major, g_accessor_internal>( reinterpret_cast(knn_graph.data_handle()), @@ -251,10 +249,9 @@ void sort_knn_graph( * knn_graph_degree] * @param[out] new_graph a host matrix view of the optimized knn graph [n_rows, graph_degree] */ -template < - typename IdxT = uint32_t, - typename g_accessor = - raft::host_device_accessor, raft::memory_type::host>> +template , raft::memory_type::host>> void optimize( raft::resources const& res, raft::mdspan, raft::row_major, g_accessor> knn_graph, @@ -265,9 +262,9 @@ void optimize( } template , - raft::memory_type::host>> + typename IdxT = uint32_t, + typename Accessor = + raft::host_device_accessor, raft::memory_type::host>> index build( raft::resources const& res, const index_params& params, diff --git a/cpp/src/neighbors/detail/cagra/add_nodes.cuh b/cpp/src/neighbors/detail/cagra/add_nodes.cuh index 9d70f7848c..5a8de8454b 100644 --- a/cpp/src/neighbors/detail/cagra/add_nodes.cuh +++ b/cpp/src/neighbors/detail/cagra/add_nodes.cuh @@ -464,7 +464,7 @@ void extend_core( using out_owning_type = owning_dataset; auto out_layout = raft::make_strided_layout(updated_dataset_view.extents(), - std::array{stride, 1}); + cuda::std::array{stride, 1}); index.update_dataset(handle, out_owning_type{std::move(updated_dataset), out_layout}); } diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 5f7389493a..675ce34c57 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -1597,13 +1597,11 @@ void build_knn_graph( size_t previous_batch_offset = 0; for (const auto& batch : vec_batches) { - // Map int64_t to uint32_t because ivf_pq requires the latter. - // TODO(tfeher): remove this mapping once ivf_pq accepts mdspan with int64_t index type - auto queries_view = raft::make_device_matrix_view( + auto queries_view = raft::make_device_matrix_view( batch.data(), batch.size(), batch.row_width()); - auto neighbors_view = raft::make_device_matrix_view( + auto neighbors_view = raft::make_device_matrix_view( neighbors.data_handle(), batch.size(), neighbors.extent(1)); - auto distances_view = raft::make_device_matrix_view( + auto distances_view = raft::make_device_matrix_view( distances.data_handle(), batch.size(), distances.extent(1)); cuvs::neighbors::ivf_pq::search( @@ -1664,7 +1662,7 @@ void build_knn_graph( gpu_top_k); } } else { - auto neighbor_candidates_view = raft::make_device_matrix_view( + auto neighbor_candidates_view = raft::make_device_matrix_view( neighbors.data_handle(), batch.size(), gpu_top_k); auto refined_neighbors_view = raft::make_device_matrix_view( refined_neighbors.data_handle(), batch.size(), top_k); @@ -1744,8 +1742,7 @@ void build_knn_graph( using internal_IdxT = typename std::make_unsigned::type; using g_accessor = typename decltype(nn_descent_idx.graph())::accessor_type; using g_accessor_internal = - raft::host_device_accessor, - g_accessor::mem_type>; + raft::host_device_accessor, g_accessor::mem_type>; auto knn_graph_internal = raft::mdspan, raft::row_major, g_accessor_internal>( @@ -1757,10 +1754,9 @@ void build_knn_graph( res, build_params.metric, dataset, knn_graph_internal); } -template < - typename IdxT = uint32_t, - typename g_accessor = - raft::host_device_accessor, raft::memory_type::host>> +template , raft::memory_type::host>> void optimize( raft::resources const& res, raft::mdspan, raft::row_major, g_accessor> knn_graph, @@ -1775,8 +1771,7 @@ void optimize( new_graph.extent(1)); using g_accessor_internal = - raft::host_device_accessor, - raft::memory_type::host>; + raft::host_device_accessor, raft::memory_type::host>; auto knn_graph_internal = raft::mdspan, raft::row_major, g_accessor_internal>( reinterpret_cast(knn_graph.data_handle()), @@ -1834,9 +1829,9 @@ struct mmap_owner { }; template , - raft::memory_type::host>> + typename IdxT = uint32_t, + typename Accessor = + raft::host_device_accessor, raft::memory_type::host>> auto iterative_build_graph( raft::resources const& res, const index_params& params, @@ -2020,9 +2015,9 @@ auto iterative_build_graph( } template , - raft::memory_type::host>> + typename IdxT = uint32_t, + typename Accessor = + raft::host_device_accessor, raft::memory_type::host>> index build( raft::resources const& res, const index_params& params, diff --git a/cpp/src/neighbors/detail/cagra/cagra_merge.cuh b/cpp/src/neighbors/detail/cagra/cagra_merge.cuh index 78c28ac16b..144b7ada56 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_merge.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_merge.cuh @@ -100,7 +100,7 @@ index merge(raft::resources const& handle, using container_policy_t = typename matrix_t::container_policy_type; using owning_t = owning_dataset; auto out_layout = raft::make_strided_layout(updated_dataset.view().extents(), - std::array{stride, 1}); + cuda::std::array{stride, 1}); merged_index.update_dataset(handle, owning_t{std::move(updated_dataset), out_layout}); } RAFT_LOG_DEBUG("cagra merge: using device memory for merged dataset"); @@ -122,7 +122,7 @@ index merge(raft::resources const& handle, using container_policy_t = typename matrix_t::container_policy_type; using owning_t = owning_dataset; auto out_layout = raft::make_strided_layout(updated_dataset.view().extents(), - std::array{stride, 1}); + cuda::std::array{stride, 1}); merged_index.update_dataset(handle, owning_t{std::move(updated_dataset), out_layout}); } return merged_index; diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index f8091c9e51..e2e775116d 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -489,13 +489,12 @@ void shift_array(T* array, uint64_t num) } } // namespace -template < - typename DataT, - typename IdxT = uint32_t, - typename d_accessor = raft::host_device_accessor, - raft::memory_type::device>, - typename g_accessor = - raft::host_device_accessor, raft::memory_type::host>> +template , + raft::memory_type::device>, + typename g_accessor = + raft::host_device_accessor, raft::memory_type::host>> void sort_knn_graph( raft::resources const& res, const cuvs::distance::DistanceType metric, @@ -1157,10 +1156,9 @@ void count_2hop_detours(raft::host_matrix_view k } } -template < - typename IdxT = uint32_t, - typename g_accessor = - raft::host_device_accessor, raft::memory_type::host>> +template , raft::memory_type::host>> void optimize( raft::resources const& res, raft::mdspan, raft::row_major, g_accessor> knn_graph, diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index 118d377e7d..6dca237107 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -519,7 +519,7 @@ void brute_force_knn_impl( raft::make_device_matrix_view(out_I, n, input.size() * k), raft::make_device_matrix_view(res_D, n, k), raft::make_device_matrix_view(res_I, n, k), - raft::make_device_vector_view(trans.data(), input.size())); + raft::make_device_vector_view(trans.data(), input.size())); } }; diff --git a/cpp/src/neighbors/detail/nn_descent.cuh b/cpp/src/neighbors/detail/nn_descent.cuh index 184cbc72cd..18422daeb2 100644 --- a/cpp/src/neighbors/detail/nn_descent.cuh +++ b/cpp/src/neighbors/detail/nn_descent.cuh @@ -1302,9 +1302,9 @@ void GNND::build(Data_t* data, } template , - raft::memory_type::host>> + typename IdxT = uint32_t, + typename Accessor = + raft::host_device_accessor, raft::memory_type::host>> void build(raft::resources const& res, const index_params& params, raft::mdspan, raft::row_major, Accessor> dataset, @@ -1348,9 +1348,9 @@ void build(raft::resources const& res, } template , - raft::memory_type::host>> + typename IdxT = uint32_t, + typename Accessor = + raft::host_device_accessor, raft::memory_type::host>> index build( raft::resources const& res, const index_params& params, diff --git a/cpp/src/neighbors/detail/sparse_knn.cuh b/cpp/src/neighbors/detail/sparse_knn.cuh index 8432928193..f24b1d586c 100644 --- a/cpp/src/neighbors/detail/sparse_knn.cuh +++ b/cpp/src/neighbors/detail/sparse_knn.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2020-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -335,7 +335,7 @@ class sparse_knn_t { raft::make_device_matrix_view(merge_buffer_indices, rows, 2 * k), raft::make_device_matrix_view(out_dists, rows, k), raft::make_device_matrix_view(out_indices, rows, k), - raft::make_device_vector_view(trans.data(), id_ranges.size())); + raft::make_device_vector_view(trans.data(), id_ranges.size())); } void perform_k_selection(csr_batcher_t idx_batcher, diff --git a/cpp/src/neighbors/detail/vamana/greedy_search.cuh b/cpp/src/neighbors/detail/vamana/greedy_search.cuh index 9248cf90b0..717e389c32 100644 --- a/cpp/src/neighbors/detail/vamana/greedy_search.cuh +++ b/cpp/src/neighbors/detail/vamana/greedy_search.cuh @@ -82,9 +82,9 @@ __global__ void SortPairsKernel(void* query_list_ptr, int num_queries, int topk) **********************************************************************************************/ template , - raft::memory_type::host>> + typename IdxT = uint32_t, + typename Accessor = + raft::host_device_accessor, raft::memory_type::host>> __global__ void GreedySearchKernel( raft::device_matrix_view graph, raft::mdspan, raft::row_major, Accessor> dataset, diff --git a/cpp/src/neighbors/detail/vamana/robust_prune.cuh b/cpp/src/neighbors/detail/vamana/robust_prune.cuh index 1515cd246f..9fe3c01a8c 100644 --- a/cpp/src/neighbors/detail/vamana/robust_prune.cuh +++ b/cpp/src/neighbors/detail/vamana/robust_prune.cuh @@ -51,9 +51,9 @@ namespace { **********************************************************************************************/ template , - raft::memory_type::host>> + typename IdxT = uint32_t, + typename Accessor = + raft::host_device_accessor, raft::memory_type::host>> __global__ void RobustPruneKernel( raft::device_matrix_view graph, raft::mdspan, raft::row_major, Accessor> dataset, diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index caf08d770c..9468bcc5c4 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -94,9 +94,9 @@ __global__ void print_queryIds(void* query_list_ptr) *******************************************************************************************/ template , - raft::memory_type::host>> + typename IdxT = uint32_t, + typename Accessor = + raft::host_device_accessor, raft::memory_type::host>> void batched_insert_vamana( raft::resources const& res, const index_params& params, @@ -567,9 +567,9 @@ auto quantize_all_vectors(raft::resources const& res, } template , - raft::memory_type::host>> + typename IdxT = uint64_t, + typename Accessor = + raft::host_device_accessor, raft::memory_type::host>> index build( raft::resources const& res, const index_params& params, diff --git a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh index c14e77812a..f413020879 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh @@ -440,9 +440,9 @@ __global__ void populate_reverse_list_struct(QueryCandidates* revers // Recompute distances of reverse list. Allows us to avoid keeping distances during sort template , - raft::memory_type::host>> + typename IdxT = uint32_t, + typename Accessor = + raft::host_device_accessor, raft::memory_type::host>> __global__ void recompute_reverse_dists( QueryCandidates* reverse_list, raft::mdspan, raft::row_major, Accessor> dataset, diff --git a/cpp/src/neighbors/iface/iface_cagra_float_uint32_t.cu b/cpp/src/neighbors/iface/iface_cagra_float_uint32_t.cu index 10d0d51067..4817452ca3 100644 --- a/cpp/src/neighbors/iface/iface_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/iface/iface_cagra_float_uint32_t.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -17,14 +17,14 @@ namespace cuvs::neighbors { #define CUVS_INST_MG_CAGRA(T, IdxT) \ - using T_ha = raft::host_device_accessor, \ + using T_ha = \ + raft::host_device_accessor, raft::memory_type::device>; \ + using T_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ + using IdxT_ha = raft::host_device_accessor, \ raft::memory_type::device>; \ - using T_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ - using IdxT_ha = raft::host_device_accessor, \ - raft::memory_type::device>; \ - using IdxT_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ + using IdxT_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ \ template void build( \ const raft::resources& handle, \ diff --git a/cpp/src/neighbors/iface/iface_cagra_half_uint32_t.cu b/cpp/src/neighbors/iface/iface_cagra_half_uint32_t.cu index 4e6773f887..2c0c6d050b 100644 --- a/cpp/src/neighbors/iface/iface_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/iface/iface_cagra_half_uint32_t.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -17,14 +17,14 @@ namespace cuvs::neighbors { #define CUVS_INST_MG_CAGRA(T, IdxT) \ - using T_ha = raft::host_device_accessor, \ + using T_ha = \ + raft::host_device_accessor, raft::memory_type::device>; \ + using T_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ + using IdxT_ha = raft::host_device_accessor, \ raft::memory_type::device>; \ - using T_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ - using IdxT_ha = raft::host_device_accessor, \ - raft::memory_type::device>; \ - using IdxT_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ + using IdxT_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ \ template void build( \ const raft::resources& handle, \ diff --git a/cpp/src/neighbors/iface/iface_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/iface/iface_cagra_int8_t_uint32_t.cu index 877536a25f..2cea26f515 100644 --- a/cpp/src/neighbors/iface/iface_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/iface/iface_cagra_int8_t_uint32_t.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -17,14 +17,14 @@ namespace cuvs::neighbors { #define CUVS_INST_MG_CAGRA(T, IdxT) \ - using T_ha = raft::host_device_accessor, \ + using T_ha = \ + raft::host_device_accessor, raft::memory_type::device>; \ + using T_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ + using IdxT_ha = raft::host_device_accessor, \ raft::memory_type::device>; \ - using T_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ - using IdxT_ha = raft::host_device_accessor, \ - raft::memory_type::device>; \ - using IdxT_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ + using IdxT_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ \ template void build( \ const raft::resources& handle, \ diff --git a/cpp/src/neighbors/iface/iface_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/iface/iface_cagra_uint8_t_uint32_t.cu index bb15489849..a8d719ab80 100644 --- a/cpp/src/neighbors/iface/iface_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/iface/iface_cagra_uint8_t_uint32_t.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -17,14 +17,14 @@ namespace cuvs::neighbors { #define CUVS_INST_MG_CAGRA(T, IdxT) \ - using T_ha = raft::host_device_accessor, \ + using T_ha = \ + raft::host_device_accessor, raft::memory_type::device>; \ + using T_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ + using IdxT_ha = raft::host_device_accessor, \ raft::memory_type::device>; \ - using T_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ - using IdxT_ha = raft::host_device_accessor, \ - raft::memory_type::device>; \ - using IdxT_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ + using IdxT_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ \ template void build( \ const raft::resources& handle, \ diff --git a/cpp/src/neighbors/iface/iface_flat_float_int64_t.cu b/cpp/src/neighbors/iface/iface_flat_float_int64_t.cu index 26dc7e24ab..d497a4cbd0 100644 --- a/cpp/src/neighbors/iface/iface_flat_float_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_flat_float_int64_t.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -17,14 +17,14 @@ namespace cuvs::neighbors { #define CUVS_INST_MG_FLAT(T, IdxT) \ - using T_ha = raft::host_device_accessor, \ + using T_ha = \ + raft::host_device_accessor, raft::memory_type::device>; \ + using T_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ + using IdxT_ha = raft::host_device_accessor, \ raft::memory_type::device>; \ - using T_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ - using IdxT_ha = raft::host_device_accessor, \ - raft::memory_type::device>; \ - using IdxT_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ + using IdxT_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ \ template void build( \ const raft::resources& handle, \ diff --git a/cpp/src/neighbors/iface/iface_flat_half_int64_t.cu b/cpp/src/neighbors/iface/iface_flat_half_int64_t.cu index 5b0e8d8af7..76e8debf3c 100644 --- a/cpp/src/neighbors/iface/iface_flat_half_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_flat_half_int64_t.cu @@ -17,14 +17,14 @@ namespace cuvs::neighbors { #define CUVS_INST_MG_FLAT(T, IdxT) \ - using T_ha = raft::host_device_accessor, \ + using T_ha = \ + raft::host_device_accessor, raft::memory_type::device>; \ + using T_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ + using IdxT_ha = raft::host_device_accessor, \ raft::memory_type::device>; \ - using T_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ - using IdxT_ha = raft::host_device_accessor, \ - raft::memory_type::device>; \ - using IdxT_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ + using IdxT_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ \ template void build( \ const raft::resources& handle, \ diff --git a/cpp/src/neighbors/iface/iface_flat_int8_t_int64_t.cu b/cpp/src/neighbors/iface/iface_flat_int8_t_int64_t.cu index 408f71a9e0..eb6f740fcc 100644 --- a/cpp/src/neighbors/iface/iface_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_flat_int8_t_int64_t.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -17,14 +17,14 @@ namespace cuvs::neighbors { #define CUVS_INST_MG_FLAT(T, IdxT) \ - using T_ha = raft::host_device_accessor, \ + using T_ha = \ + raft::host_device_accessor, raft::memory_type::device>; \ + using T_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ + using IdxT_ha = raft::host_device_accessor, \ raft::memory_type::device>; \ - using T_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ - using IdxT_ha = raft::host_device_accessor, \ - raft::memory_type::device>; \ - using IdxT_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ + using IdxT_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ \ template void build( \ const raft::resources& handle, \ diff --git a/cpp/src/neighbors/iface/iface_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/iface/iface_flat_uint8_t_int64_t.cu index e820f5bdb2..392b5b306f 100644 --- a/cpp/src/neighbors/iface/iface_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_flat_uint8_t_int64_t.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -17,14 +17,14 @@ namespace cuvs::neighbors { #define CUVS_INST_MG_FLAT(T, IdxT) \ - using T_ha = raft::host_device_accessor, \ + using T_ha = \ + raft::host_device_accessor, raft::memory_type::device>; \ + using T_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ + using IdxT_ha = raft::host_device_accessor, \ raft::memory_type::device>; \ - using T_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ - using IdxT_ha = raft::host_device_accessor, \ - raft::memory_type::device>; \ - using IdxT_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ + using IdxT_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ \ template void build( \ const raft::resources& handle, \ diff --git a/cpp/src/neighbors/iface/iface_pq_float_int64_t.cu b/cpp/src/neighbors/iface/iface_pq_float_int64_t.cu index 204f47a948..0f0395e8a6 100644 --- a/cpp/src/neighbors/iface/iface_pq_float_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_pq_float_int64_t.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -17,14 +17,14 @@ namespace cuvs::neighbors { #define CUVS_INST_MG_PQ(T, IdxT) \ - using T_ha = raft::host_device_accessor, \ + using T_ha = \ + raft::host_device_accessor, raft::memory_type::device>; \ + using T_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ + using IdxT_ha = raft::host_device_accessor, \ raft::memory_type::device>; \ - using T_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ - using IdxT_ha = raft::host_device_accessor, \ - raft::memory_type::device>; \ - using IdxT_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ + using IdxT_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ \ template void build( \ const raft::resources& handle, \ diff --git a/cpp/src/neighbors/iface/iface_pq_half_int64_t.cu b/cpp/src/neighbors/iface/iface_pq_half_int64_t.cu index 782ea1ffb7..b605150b5d 100644 --- a/cpp/src/neighbors/iface/iface_pq_half_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_pq_half_int64_t.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -17,14 +17,14 @@ namespace cuvs::neighbors { #define CUVS_INST_MG_PQ(T, IdxT) \ - using T_ha = raft::host_device_accessor, \ + using T_ha = \ + raft::host_device_accessor, raft::memory_type::device>; \ + using T_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ + using IdxT_ha = raft::host_device_accessor, \ raft::memory_type::device>; \ - using T_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ - using IdxT_ha = raft::host_device_accessor, \ - raft::memory_type::device>; \ - using IdxT_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ + using IdxT_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ \ template void build( \ const raft::resources& handle, \ diff --git a/cpp/src/neighbors/iface/iface_pq_int8_t_int64_t.cu b/cpp/src/neighbors/iface/iface_pq_int8_t_int64_t.cu index 73b1405364..129e8d254a 100644 --- a/cpp/src/neighbors/iface/iface_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_pq_int8_t_int64_t.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -17,14 +17,14 @@ namespace cuvs::neighbors { #define CUVS_INST_MG_PQ(T, IdxT) \ - using T_ha = raft::host_device_accessor, \ + using T_ha = \ + raft::host_device_accessor, raft::memory_type::device>; \ + using T_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ + using IdxT_ha = raft::host_device_accessor, \ raft::memory_type::device>; \ - using T_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ - using IdxT_ha = raft::host_device_accessor, \ - raft::memory_type::device>; \ - using IdxT_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ + using IdxT_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ \ template void build( \ const raft::resources& handle, \ diff --git a/cpp/src/neighbors/iface/iface_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/iface/iface_pq_uint8_t_int64_t.cu index bc76cd3991..8bbbed7597 100644 --- a/cpp/src/neighbors/iface/iface_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/iface/iface_pq_uint8_t_int64_t.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -17,14 +17,14 @@ namespace cuvs::neighbors { #define CUVS_INST_MG_PQ(T, IdxT) \ - using T_ha = raft::host_device_accessor, \ + using T_ha = \ + raft::host_device_accessor, raft::memory_type::device>; \ + using T_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ + using IdxT_ha = raft::host_device_accessor, \ raft::memory_type::device>; \ - using T_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ - using IdxT_ha = raft::host_device_accessor, \ - raft::memory_type::device>; \ - using IdxT_da = raft::host_device_accessor, \ - raft::memory_type::host>; \ + using IdxT_da = \ + raft::host_device_accessor, raft::memory_type::host>; \ \ template void build( \ const raft::resources& handle, \ diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_process_and_fill_codes.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_process_and_fill_codes.cu index f67677080a..e7d74b81bf 100644 --- a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_process_and_fill_codes.cu +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_process_and_fill_codes.cu @@ -10,7 +10,7 @@ template void cuvs::neighbors::ivf_pq::detail::launch_process_and_fill_codes_kernel( \ raft::resources const& handle, \ cuvs::neighbors::ivf_pq::index& index, \ - raft::device_matrix_view new_vectors_residual, \ + raft::device_matrix_view new_vectors_residual, \ std::variant src_offset_or_indices, \ const uint32_t* new_labels, \ IdxT n_rows); diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index 582bc890c0..c6c1de45e4 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -1238,7 +1238,7 @@ auto build(raft::resources const& handle, // Make rotation matrix helpers::make_rotation_matrix(handle, &index, params.force_random_rotation); - helpers::set_centers(handle, &index, raft::make_const_mdspan(centers_view)); + helpers::set_centers(handle, &index, centers_const_view); // Train PQ codebooks switch (index.codebook_kind()) { diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build_common.cu b/cpp/src/neighbors/ivf_pq/ivf_pq_build_common.cu index 983d0d8b95..f3f0ae311a 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build_common.cu +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build_common.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -269,7 +269,7 @@ void make_rotation_matrix(raft::resources const& res, void set_centers(raft::resources const& handle, index* index, - raft::device_matrix_view cluster_centers) + raft::device_matrix_view cluster_centers) { RAFT_EXPECTS(cluster_centers.extent(0) == index->n_lists(), "Number of rows in the new centers must be equal to the number of IVF lists"); @@ -281,14 +281,14 @@ void set_centers(raft::resources const& handle, void extract_centers(raft::resources const& res, const cuvs::neighbors::ivf_pq::index& index, - raft::device_matrix_view cluster_centers) + raft::device_matrix_view cluster_centers) { detail::extract_centers(res, index, cluster_centers); } void extract_centers(raft::resources const& res, const cuvs::neighbors::ivf_pq::index& index, - raft::host_matrix_view cluster_centers) + raft::host_matrix_view cluster_centers) { detail::extract_centers(res, index, cluster_centers); } diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_process_and_fill_codes.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_process_and_fill_codes.cuh index 3ecfb22f5d..3102bf3c0d 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_process_and_fill_codes.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_process_and_fill_codes.cuh @@ -127,11 +127,12 @@ struct encode_vectors { }; template -void launch_process_and_fill_codes_kernel(raft::resources const& handle, - index& index, - raft::device_matrix_view new_vectors_residual, - std::variant src_offset_or_indices, - const uint32_t* new_labels, - IdxT n_rows); +void launch_process_and_fill_codes_kernel( + raft::resources const& handle, + index& index, + raft::device_matrix_view new_vectors_residual, + std::variant src_offset_or_indices, + const uint32_t* new_labels, + IdxT n_rows); } // namespace cuvs::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_process_and_fill_codes_impl.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_process_and_fill_codes_impl.cuh index 7496064e47..ea56b79426 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_process_and_fill_codes_impl.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_process_and_fill_codes_impl.cuh @@ -61,12 +61,13 @@ __launch_bounds__(BlockSize) static __global__ void process_and_fill_codes_kerne } template -void launch_process_and_fill_codes_kernel(raft::resources const& handle, - index& index, - raft::device_matrix_view new_vectors_residual, - std::variant src_offset_or_indices, - const uint32_t* new_labels, - IdxT n_rows) +void launch_process_and_fill_codes_kernel( + raft::resources const& handle, + index& index, + raft::device_matrix_view new_vectors_residual, + std::variant src_offset_or_indices, + const uint32_t* new_labels, + IdxT n_rows) { constexpr uint32_t kBlockSize = 256; const uint32_t threads_per_vec = std::min(raft::WarpSize, index.pq_book_size()); diff --git a/cpp/src/neighbors/knn_merge_parts.cu b/cpp/src/neighbors/knn_merge_parts.cu index fcac8d2a7f..da7110a475 100644 --- a/cpp/src/neighbors/knn_merge_parts.cu +++ b/cpp/src/neighbors/knn_merge_parts.cu @@ -15,7 +15,7 @@ void _knn_merge_parts(raft::resources const& res, raft::device_matrix_view inV, raft::device_matrix_view outK, raft::device_matrix_view outV, - raft::device_vector_view translations) + raft::device_vector_view translations) { auto parts = translations.extent(0); auto rows = outK.extent(0); @@ -38,7 +38,7 @@ void knn_merge_parts(raft::resources const& res, raft::device_matrix_view inV, raft::device_matrix_view outK, raft::device_matrix_view outV, - raft::device_vector_view translations) + raft::device_vector_view translations) { _knn_merge_parts(res, inK, inV, outK, outV, translations); } @@ -47,7 +47,7 @@ void knn_merge_parts(raft::resources const& res, raft::device_matrix_view inV, raft::device_matrix_view outK, raft::device_matrix_view outV, - raft::device_vector_view translations) + raft::device_vector_view translations) { _knn_merge_parts(res, inK, inV, outK, outV, translations); } @@ -56,7 +56,7 @@ void knn_merge_parts(raft::resources const& res, raft::device_matrix_view inV, raft::device_matrix_view outK, raft::device_matrix_view outV, - raft::device_vector_view translations) + raft::device_vector_view translations) { _knn_merge_parts(res, inK, inV, outK, outV, translations); } diff --git a/cpp/src/neighbors/scann/detail/scann_avq.cuh b/cpp/src/neighbors/scann/detail/scann_avq.cuh index 95f2d84499..3e8238c5e2 100644 --- a/cpp/src/neighbors/scann/detail/scann_avq.cuh +++ b/cpp/src/neighbors/scann/detail/scann_avq.cuh @@ -395,8 +395,8 @@ class cluster_loader { raft::device_matrix d_cluster_buf_; raft::device_matrix d_cluster_copy_buf_; const T* dataset_ptr_; - raft::host_vector_view h_cluster_offsets_; - raft::device_vector_view cluster_ids_; + raft::host_vector_view h_cluster_offsets_; + raft::device_vector_view cluster_ids_; cudaStream_t stream_; int64_t dim_; int64_t n_rows_; @@ -419,8 +419,8 @@ class cluster_loader { int64_t n_rows, int64_t max_cluster_size, int64_t h_buf_size, - raft::host_vector_view h_cluster_offsets, - raft::device_vector_view cluster_ids, + raft::host_vector_view h_cluster_offsets, + raft::device_vector_view cluster_ids, bool needs_copy, cudaStream_t stream) : dim_(dim), @@ -440,8 +440,8 @@ class cluster_loader { public: cluster_loader(raft::resources const& res, raft::device_matrix_view dataset_view, - raft::host_vector_view h_cluster_offsets, - raft::device_vector_view cluster_ids, + raft::host_vector_view h_cluster_offsets, + raft::device_vector_view cluster_ids, int64_t max_cluster_size, cudaStream_t stream) : cluster_loader(res, @@ -460,8 +460,8 @@ class cluster_loader { cluster_loader(raft::resources const& res, raft::host_matrix_view dataset_view, - raft::host_vector_view h_cluster_offsets, - raft::device_vector_view cluster_ids, + raft::host_vector_view h_cluster_offsets, + raft::device_vector_view cluster_ids, int64_t max_cluster_size, cudaStream_t stream) : cluster_loader(res, @@ -577,10 +577,10 @@ class cluster_loader { * @param eta the weight for the parallel component of the residual in the avq update */ template , - raft::memory_type::host>> + typename IdxT = int64_t, + typename LabelT = uint32_t, + typename Accessor = + raft::host_device_accessor, raft::memory_type::host>> void apply_avq(raft::resources const& res, raft::mdspan, raft::row_major, Accessor> dataset, raft::device_matrix_view centroids_view, diff --git a/cpp/src/neighbors/scann/detail/scann_build.cuh b/cpp/src/neighbors/scann/detail/scann_build.cuh index 2c5163286c..8902fc2051 100644 --- a/cpp/src/neighbors/scann/detail/scann_build.cuh +++ b/cpp/src/neighbors/scann/detail/scann_build.cuh @@ -34,9 +34,9 @@ using namespace cuvs::spatial::knn::detail; // NOLINT */ template , - raft::memory_type::host>> + typename IdxT = int64_t, + typename Accessor = + raft::host_device_accessor, raft::memory_type::host>> index build( raft::resources const& res, const index_params& params, diff --git a/cpp/src/neighbors/scann/detail/scann_quantize.cuh b/cpp/src/neighbors/scann/detail/scann_quantize.cuh index 69c7ca08e8..0b122d38b6 100644 --- a/cpp/src/neighbors/scann/detail/scann_quantize.cuh +++ b/cpp/src/neighbors/scann/detail/scann_quantize.cuh @@ -326,7 +326,7 @@ template __launch_bounds__(BlockSize) RAFT_KERNEL quantize_bfloat16_noise_shaped_kernel(raft::device_matrix_view dataset, raft::device_matrix_view bf16_dataset, - raft::device_vector_view sq_norms, + raft::device_vector_view sq_norms, float noise_shaping_threshold) { IdxT row_idx = raft::Pow2<32>::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x}); @@ -516,8 +516,8 @@ void quantize_bfloat16(raft::resources const& res, */ template , - raft::memory_type::host>> + typename Accessor = + raft::host_device_accessor, raft::memory_type::host>> auto sample_training_residuals( raft::resources const& res, random::RngState random_state, diff --git a/cpp/src/neighbors/scann/detail/scann_serialize.cuh b/cpp/src/neighbors/scann/detail/scann_serialize.cuh index 8da1e0d9fc..c3b4fdf357 100644 --- a/cpp/src/neighbors/scann/detail/scann_serialize.cuh +++ b/cpp/src/neighbors/scann/detail/scann_serialize.cuh @@ -39,8 +39,8 @@ std::ofstream open_file(std::string file_name) // Helper for serializing device/host matrix to a given file template , - raft::memory_type::host>> + typename Accessor = + raft::host_device_accessor, raft::memory_type::host>> void serialize_matrix( raft::resources const& res, std::filesystem::path file_path, diff --git a/cpp/src/neighbors/scann/scann.cuh b/cpp/src/neighbors/scann/scann.cuh index e6844bc44e..56b302e3f9 100644 --- a/cpp/src/neighbors/scann/scann.cuh +++ b/cpp/src/neighbors/scann/scann.cuh @@ -53,9 +53,9 @@ namespace cuvs::neighbors::experimental::scann { * @return the constructed scann index */ template , - raft::memory_type::host>> + typename IdxT = uint32_t, + typename Accessor = + raft::host_device_accessor, raft::memory_type::host>> index build( raft::resources const& res, const index_params& params, diff --git a/cpp/src/neighbors/vamana.cuh b/cpp/src/neighbors/vamana.cuh index 65fd54ca83..d2a73809ed 100644 --- a/cpp/src/neighbors/vamana.cuh +++ b/cpp/src/neighbors/vamana.cuh @@ -67,9 +67,9 @@ namespace cuvs::neighbors::vamana { * @return the constructed vamana index */ template , - raft::memory_type::host>> + typename IdxT = uint32_t, + typename Accessor = + raft::host_device_accessor, raft::memory_type::host>> index build( raft::resources const& res, const index_params& params, diff --git a/cpp/src/sparse/cluster/cluster_solvers.cuh b/cpp/src/sparse/cluster/cluster_solvers.cuh index 6df61b4696..0b6044453d 100644 --- a/cpp/src/sparse/cluster/cluster_solvers.cuh +++ b/cpp/src/sparse/cluster/cluster_solvers.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -55,8 +55,8 @@ struct kmeans_solver_t { km_params.max_iter = config_.maxIter; km_params.rng_state.seed = config_.seed; - auto X = raft::make_device_matrix_view(obs, n_obs_vecs, dim); - auto labels = raft::make_device_vector_view(codes, n_obs_vecs); + auto X = raft::make_device_matrix_view(obs, n_obs_vecs, dim); + auto labels = raft::make_device_vector_view(codes, n_obs_vecs); auto centroids = raft::make_device_matrix(handle, config_.n_clusters, dim); auto weight = raft::make_device_vector(handle, n_obs_vecs); @@ -65,12 +65,13 @@ struct kmeans_solver_t { weight.data_handle() + n_obs_vecs, 1); - auto sw = std::make_optional((raft::device_vector_view)weight.view()); + auto sw = + std::make_optional((raft::device_vector_view)weight.view()); cuvs::cluster::kmeans::fit_predict(handle, km_params, X, sw, - centroids.view(), + std::make_optional(centroids.view()), labels, raft::make_host_scalar_view(&residual), raft::make_host_scalar_view(&iters)); diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index bb4868a54c..05668ff4da 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -256,7 +256,7 @@ class KmeansTest : public ::testing::TestWithParam> { d_labels_ref.resize(n_samples, stream); d_centroids.resize(params.n_clusters * n_features, stream); - std::optional> d_sw = std::nullopt; + std::optional> d_sw = std::nullopt; auto d_centroids_view = raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); if (testparams.weighted) { From c4c1daf1c91e3370ccdfb234d3c342c1d5af7add Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Fri, 5 Dec 2025 14:55:01 -0600 Subject: [PATCH 30/32] Remove alpha specs from non-RAPIDS dependencies (#1618) This PR removes pre-release upper bound pinnings from non-RAPIDS dependencies. The presence of pre-release indicators like `<...a0` tells pip "pre-releases are OK, even if `--pre` was not passed to pip install." RAPIDS projects currently use such constraints in situations where it's not actually desirable to get pre-releases. xref: https://github.com/rapidsai/build-planning/issues/144 Authors: - Bradley Dice (https://github.com/bdice) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cuvs/pull/1618 --- conda/environments/all_cuda-129_arch-aarch64.yaml | 10 +++++----- conda/environments/all_cuda-129_arch-x86_64.yaml | 10 +++++----- conda/environments/all_cuda-130_arch-aarch64.yaml | 10 +++++----- conda/environments/all_cuda-130_arch-x86_64.yaml | 10 +++++----- .../bench_ann_cuda-129_arch-aarch64.yaml | 6 +++--- .../bench_ann_cuda-129_arch-x86_64.yaml | 6 +++--- .../bench_ann_cuda-130_arch-aarch64.yaml | 6 +++--- .../bench_ann_cuda-130_arch-x86_64.yaml | 6 +++--- conda/recipes/cuvs-bench-cpu/recipe.yaml | 4 ++-- conda/recipes/cuvs-bench/recipe.yaml | 2 +- conda/recipes/cuvs/recipe.yaml | 14 +++++++------- dependencies.yaml | 12 ++++++------ python/cuvs/pyproject.toml | 12 ++++++------ python/cuvs_bench/pyproject.toml | 2 +- python/libcuvs/pyproject.toml | 2 +- 15 files changed, 56 insertions(+), 56 deletions(-) diff --git a/conda/environments/all_cuda-129_arch-aarch64.yaml b/conda/environments/all_cuda-129_arch-aarch64.yaml index adcc024ffc..7dd7df95f1 100644 --- a/conda/environments/all_cuda-129_arch-aarch64.yaml +++ b/conda/environments/all_cuda-129_arch-aarch64.yaml @@ -15,11 +15,11 @@ dependencies: - cuda-nvcc - cuda-nvtx-dev - cuda-profiler-api -- cuda-python>=12.9.2,<13.0a0 +- cuda-python>=12.9.2,<13.0 - cuda-version=12.9 - cupy>=13.6.0 - cxx-compiler -- cython>=3.0.0,<3.2.0a0 +- cython>=3.0.0,<3.2.0 - dlpack>=0.8,<1.0 - doxygen>=1.8.20 - gcc_linux-aarch64=14.* @@ -35,14 +35,14 @@ dependencies: - make - nccl>=2.19 - ninja -- numpy>=1.23,<3.0a0 +- numpy>=1.23,<3.0 - numpydoc - openblas - pre-commit - pylibraft==26.2.*,>=0.0.0a0 - pytest-cov -- pytest<9.0.0a0 -- rapids-build-backend>=0.4.0,<0.5.0.dev0 +- pytest<9.0.0 +- rapids-build-backend>=0.4.0,<0.5.0 - recommonmark - rust - scikit-build-core>=0.10.0 diff --git a/conda/environments/all_cuda-129_arch-x86_64.yaml b/conda/environments/all_cuda-129_arch-x86_64.yaml index 24eb7ccbbe..46a60255dd 100644 --- a/conda/environments/all_cuda-129_arch-x86_64.yaml +++ b/conda/environments/all_cuda-129_arch-x86_64.yaml @@ -15,11 +15,11 @@ dependencies: - cuda-nvcc - cuda-nvtx-dev - cuda-profiler-api -- cuda-python>=12.9.2,<13.0a0 +- cuda-python>=12.9.2,<13.0 - cuda-version=12.9 - cupy>=13.6.0 - cxx-compiler -- cython>=3.0.0,<3.2.0a0 +- cython>=3.0.0,<3.2.0 - dlpack>=0.8,<1.0 - doxygen>=1.8.20 - gcc_linux-64=14.* @@ -35,14 +35,14 @@ dependencies: - make - nccl>=2.19 - ninja -- numpy>=1.23,<3.0a0 +- numpy>=1.23,<3.0 - numpydoc - openblas - pre-commit - pylibraft==26.2.*,>=0.0.0a0 - pytest-cov -- pytest<9.0.0a0 -- rapids-build-backend>=0.4.0,<0.5.0.dev0 +- pytest<9.0.0 +- rapids-build-backend>=0.4.0,<0.5.0 - recommonmark - rust - scikit-build-core>=0.10.0 diff --git a/conda/environments/all_cuda-130_arch-aarch64.yaml b/conda/environments/all_cuda-130_arch-aarch64.yaml index 01bbaec65e..ff50da3cfb 100644 --- a/conda/environments/all_cuda-130_arch-aarch64.yaml +++ b/conda/environments/all_cuda-130_arch-aarch64.yaml @@ -15,11 +15,11 @@ dependencies: - cuda-nvcc - cuda-nvtx-dev - cuda-profiler-api -- cuda-python>=13.0.1,<14.0a0 +- cuda-python>=13.0.1,<14.0 - cuda-version=13.0 - cupy>=13.6.0 - cxx-compiler -- cython>=3.0.0,<3.2.0a0 +- cython>=3.0.0,<3.2.0 - dlpack>=0.8,<1.0 - doxygen>=1.8.20 - gcc_linux-aarch64=14.* @@ -35,14 +35,14 @@ dependencies: - make - nccl>=2.19 - ninja -- numpy>=1.23,<3.0a0 +- numpy>=1.23,<3.0 - numpydoc - openblas - pre-commit - pylibraft==26.2.*,>=0.0.0a0 - pytest-cov -- pytest<9.0.0a0 -- rapids-build-backend>=0.4.0,<0.5.0.dev0 +- pytest<9.0.0 +- rapids-build-backend>=0.4.0,<0.5.0 - recommonmark - rust - scikit-build-core>=0.10.0 diff --git a/conda/environments/all_cuda-130_arch-x86_64.yaml b/conda/environments/all_cuda-130_arch-x86_64.yaml index f42abeb5d6..62d0526fe6 100644 --- a/conda/environments/all_cuda-130_arch-x86_64.yaml +++ b/conda/environments/all_cuda-130_arch-x86_64.yaml @@ -15,11 +15,11 @@ dependencies: - cuda-nvcc - cuda-nvtx-dev - cuda-profiler-api -- cuda-python>=13.0.1,<14.0a0 +- cuda-python>=13.0.1,<14.0 - cuda-version=13.0 - cupy>=13.6.0 - cxx-compiler -- cython>=3.0.0,<3.2.0a0 +- cython>=3.0.0,<3.2.0 - dlpack>=0.8,<1.0 - doxygen>=1.8.20 - gcc_linux-64=14.* @@ -35,14 +35,14 @@ dependencies: - make - nccl>=2.19 - ninja -- numpy>=1.23,<3.0a0 +- numpy>=1.23,<3.0 - numpydoc - openblas - pre-commit - pylibraft==26.2.*,>=0.0.0a0 - pytest-cov -- pytest<9.0.0a0 -- rapids-build-backend>=0.4.0,<0.5.0.dev0 +- pytest<9.0.0 +- rapids-build-backend>=0.4.0,<0.5.0 - recommonmark - rust - scikit-build-core>=0.10.0 diff --git a/conda/environments/bench_ann_cuda-129_arch-aarch64.yaml b/conda/environments/bench_ann_cuda-129_arch-aarch64.yaml index 714f1d7c1c..2710242ccd 100644 --- a/conda/environments/bench_ann_cuda-129_arch-aarch64.yaml +++ b/conda/environments/bench_ann_cuda-129_arch-aarch64.yaml @@ -14,12 +14,12 @@ dependencies: - cuda-nvcc - cuda-nvtx-dev - cuda-profiler-api -- cuda-python>=12.9.2,<13.0a0 +- cuda-python>=12.9.2,<13.0 - cuda-version=12.9 - cupy>=13.6.0 - cuvs==26.2.*,>=0.0.0a0 - cxx-compiler -- cython>=3.0.0,<3.2.0a0 +- cython>=3.0.0,<3.2.0 - dlpack>=0.8,<1.0 - gcc_linux-aarch64=14.* - glog>=0.6.0 @@ -39,7 +39,7 @@ dependencies: - pandas - pylibraft==26.2.*,>=0.0.0a0 - pyyaml -- rapids-build-backend>=0.4.0,<0.5.0.dev0 +- rapids-build-backend>=0.4.0,<0.5.0 - requests - scikit-learn - setuptools diff --git a/conda/environments/bench_ann_cuda-129_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-129_arch-x86_64.yaml index f842c9eaed..3000bcd781 100644 --- a/conda/environments/bench_ann_cuda-129_arch-x86_64.yaml +++ b/conda/environments/bench_ann_cuda-129_arch-x86_64.yaml @@ -14,12 +14,12 @@ dependencies: - cuda-nvcc - cuda-nvtx-dev - cuda-profiler-api -- cuda-python>=12.9.2,<13.0a0 +- cuda-python>=12.9.2,<13.0 - cuda-version=12.9 - cupy>=13.6.0 - cuvs==26.2.*,>=0.0.0a0 - cxx-compiler -- cython>=3.0.0,<3.2.0a0 +- cython>=3.0.0,<3.2.0 - dlpack>=0.8,<1.0 - gcc_linux-64=14.* - glog>=0.6.0 @@ -42,7 +42,7 @@ dependencies: - pandas - pylibraft==26.2.*,>=0.0.0a0 - pyyaml -- rapids-build-backend>=0.4.0,<0.5.0.dev0 +- rapids-build-backend>=0.4.0,<0.5.0 - requests - scikit-learn - setuptools diff --git a/conda/environments/bench_ann_cuda-130_arch-aarch64.yaml b/conda/environments/bench_ann_cuda-130_arch-aarch64.yaml index 8fd964b088..e1b83a7aed 100644 --- a/conda/environments/bench_ann_cuda-130_arch-aarch64.yaml +++ b/conda/environments/bench_ann_cuda-130_arch-aarch64.yaml @@ -14,12 +14,12 @@ dependencies: - cuda-nvcc - cuda-nvtx-dev - cuda-profiler-api -- cuda-python>=13.0.1,<14.0a0 +- cuda-python>=13.0.1,<14.0 - cuda-version=13.0 - cupy>=13.6.0 - cuvs==26.2.*,>=0.0.0a0 - cxx-compiler -- cython>=3.0.0,<3.2.0a0 +- cython>=3.0.0,<3.2.0 - dlpack>=0.8,<1.0 - gcc_linux-aarch64=14.* - glog>=0.6.0 @@ -39,7 +39,7 @@ dependencies: - pandas - pylibraft==26.2.*,>=0.0.0a0 - pyyaml -- rapids-build-backend>=0.4.0,<0.5.0.dev0 +- rapids-build-backend>=0.4.0,<0.5.0 - requests - scikit-learn - setuptools diff --git a/conda/environments/bench_ann_cuda-130_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-130_arch-x86_64.yaml index 8cee4997f4..ab9a48686d 100644 --- a/conda/environments/bench_ann_cuda-130_arch-x86_64.yaml +++ b/conda/environments/bench_ann_cuda-130_arch-x86_64.yaml @@ -14,12 +14,12 @@ dependencies: - cuda-nvcc - cuda-nvtx-dev - cuda-profiler-api -- cuda-python>=13.0.1,<14.0a0 +- cuda-python>=13.0.1,<14.0 - cuda-version=13.0 - cupy>=13.6.0 - cuvs==26.2.*,>=0.0.0a0 - cxx-compiler -- cython>=3.0.0,<3.2.0a0 +- cython>=3.0.0,<3.2.0 - dlpack>=0.8,<1.0 - gcc_linux-64=14.* - glog>=0.6.0 @@ -42,7 +42,7 @@ dependencies: - pandas - pylibraft==26.2.*,>=0.0.0a0 - pyyaml -- rapids-build-backend>=0.4.0,<0.5.0.dev0 +- rapids-build-backend>=0.4.0,<0.5.0 - requests - scikit-learn - setuptools diff --git a/conda/recipes/cuvs-bench-cpu/recipe.yaml b/conda/recipes/cuvs-bench-cpu/recipe.yaml index 6ae5f9cbce..1a8c11228f 100644 --- a/conda/recipes/cuvs-bench-cpu/recipe.yaml +++ b/conda/recipes/cuvs-bench-cpu/recipe.yaml @@ -63,7 +63,7 @@ requirements: - openblas - pip - python =${{ py_version }} - - rapids-build-backend>=0.4.0,<0.5.0.dev0 + - rapids-build-backend>=0.4.0,<0.5.0 - setuptools >=64.0.0 - if: linux64 then: @@ -75,7 +75,7 @@ requirements: - glog ${{ glog_version }} - h5py ${{ h5py_version }} - matplotlib-base>=3.9 - - numpy >=1.23,<3.0a0 + - numpy >=1.23,<3.0 - pandas - pyyaml - python diff --git a/conda/recipes/cuvs-bench/recipe.yaml b/conda/recipes/cuvs-bench/recipe.yaml index 382a4c3f72..614b21189f 100644 --- a/conda/recipes/cuvs-bench/recipe.yaml +++ b/conda/recipes/cuvs-bench/recipe.yaml @@ -33,7 +33,7 @@ requirements: - libcuvs-bench-ann =${{ version }} - python =${{ py_version }} - pip - - rapids-build-backend >=0.4.0,<0.5.0.dev0 + - rapids-build-backend >=0.4.0,<0.5.0 - rmm =${{ minor_version }} - setuptools >=64.0.0 run: diff --git a/conda/recipes/cuvs/recipe.yaml b/conda/recipes/cuvs/recipe.yaml index f2de8a2b19..0debf2dd75 100644 --- a/conda/recipes/cuvs/recipe.yaml +++ b/conda/recipes/cuvs/recipe.yaml @@ -51,27 +51,27 @@ requirements: - ${{ stdlib("c") }} host: - cuda-version =${{ cuda_version }} - - cython >=3.0.0,<3.2.0a0 + - cython >=3.0.0,<3.2.0 - dlpack >=0.8 - libcuvs =${{ version }} - pip - pylibraft =${{ minor_version }} - python =${{ py_version }} - - rapids-build-backend >=0.4.0,<0.5.0.dev0 + - rapids-build-backend >=0.4.0,<0.5.0 - scikit-build-core >=0.10.0 - if: cuda_major == "12" - then: cuda-python >=12.9.2,<13.0a0 - else: cuda-python >=13.0.1,<14.0a0 + then: cuda-python >=12.9.2,<13.0 + else: cuda-python >=13.0.1,<14.0 - cuda-cudart-dev run: - ${{ pin_compatible("cuda-version", upper_bound="x", lower_bound="x") }} - libcuvs =${{ version }} - pylibraft =${{ minor_version }} - python - - numpy >=1.23,<3.0a0 + - numpy >=1.23,<3.0 - if: cuda_major == "12" - then: cuda-python >=12.9.2,<13.0a0 - else: cuda-python >=13.0.1,<14.0a0 + then: cuda-python >=12.9.2,<13.0 + else: cuda-python >=13.0.1,<14.0 - cuda-cudart ignore_run_exports: by_name: diff --git a/dependencies.yaml b/dependencies.yaml index 89dcb38b99..e92b1e30a6 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -222,7 +222,7 @@ dependencies: common: - output_types: [conda, requirements, pyproject] packages: - - &rapids_build_backend rapids-build-backend>=0.4.0,<0.5.0.dev0 + - &rapids_build_backend rapids-build-backend>=0.4.0,<0.5.0 - output_types: [conda] packages: - scikit-build-core>=0.10.0 @@ -233,7 +233,7 @@ dependencies: common: - output_types: [conda, requirements, pyproject] packages: - - cython>=3.0.0,<3.2.0a0 + - cython>=3.0.0,<3.2.0 rapids_build: common: - output_types: [conda, requirements, pyproject] @@ -425,12 +425,12 @@ dependencies: common: - output_types: [conda, requirements, pyproject] packages: - - &numpy numpy>=1.23,<3.0a0 + - &numpy numpy>=1.23,<3.0 test_python_common: common: - output_types: [conda, requirements, pyproject] packages: - - pytest<9.0.0a0 + - pytest<9.0.0 - pytest-cov test_py_cuvs: common: @@ -488,11 +488,11 @@ dependencies: - matrix: cuda: "12.*" packages: - - cuda-python>=12.9.2,<13.0a0 + - cuda-python>=12.9.2,<13.0 # fallback to CUDA 13 versions if 'cuda' is '13.*' or not provided - matrix: packages: - - cuda-python>=13.0.1,<14.0a0 + - cuda-python>=13.0.1,<14.0 depends_on_cuvs: common: - output_types: conda diff --git a/python/cuvs/pyproject.toml b/python/cuvs/pyproject.toml index 38ee2b6f12..e467be2e4b 100644 --- a/python/cuvs/pyproject.toml +++ b/python/cuvs/pyproject.toml @@ -4,7 +4,7 @@ [build-system] requires = [ - "rapids-build-backend>=0.4.0,<0.5.0.dev0", + "rapids-build-backend>=0.4.0,<0.5.0", "scikit-build-core[pyproject]>=0.10.0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. build-backend = "rapids_build_backend.build" @@ -20,9 +20,9 @@ authors = [ license = { text = "Apache-2.0" } requires-python = ">=3.10" dependencies = [ - "cuda-python>=13.0.1,<14.0a0", + "cuda-python>=13.0.1,<14.0", "libcuvs==26.2.*,>=0.0.0a0", - "numpy>=1.23,<3.0a0", + "numpy>=1.23,<3.0", "pylibraft==26.2.*,>=0.0.0a0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. classifiers = [ @@ -38,7 +38,7 @@ classifiers = [ test = [ "cupy-cuda13x>=13.6.0", "pytest-cov", - "pytest<9.0.0a0", + "pytest<9.0.0", "scikit-learn", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. @@ -106,8 +106,8 @@ regex = "(?P.*)" [tool.rapids-build-backend] requires = [ "cmake>=3.30.4", - "cuda-python>=13.0.1,<14.0a0", - "cython>=3.0.0,<3.2.0a0", + "cuda-python>=13.0.1,<14.0", + "cython>=3.0.0,<3.2.0", "libcuvs==26.2.*,>=0.0.0a0", "libraft==26.2.*,>=0.0.0a0", "librmm==26.2.*,>=0.0.0a0", diff --git a/python/cuvs_bench/pyproject.toml b/python/cuvs_bench/pyproject.toml index d7d8e3b891..38a44844aa 100644 --- a/python/cuvs_bench/pyproject.toml +++ b/python/cuvs_bench/pyproject.toml @@ -4,7 +4,7 @@ [build-system] build-backend = "rapids_build_backend.build" requires = [ - "rapids-build-backend>=0.4.0,<0.5.0.dev0", + "rapids-build-backend>=0.4.0,<0.5.0", "setuptools", "wheel", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. diff --git a/python/libcuvs/pyproject.toml b/python/libcuvs/pyproject.toml index cc60040c5a..89c994740a 100644 --- a/python/libcuvs/pyproject.toml +++ b/python/libcuvs/pyproject.toml @@ -3,7 +3,7 @@ [build-system] requires = [ - "rapids-build-backend>=0.4.0,<0.5.0.dev0", + "rapids-build-backend>=0.4.0,<0.5.0", "scikit-build-core[pyproject]>=0.10.0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. build-backend = "rapids_build_backend.build" From ef2dc3db7e7bd5294b7e2bd9bf16f997d83dc3a3 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Tue, 25 Nov 2025 17:18:21 -0600 Subject: [PATCH 31/32] Update FAISS from 1.12.0 to 1.13.0 (#1585) ## Summary - Update FAISS dependency from 1.12.0 to 1.13.0 - Remove thrust include patches already present in FAISS 1.13.0 - All other RMM API compatibility patches still apply cleanly Verified that updated patches apply cleanly to FAISS v1.13.0. Follow-up to https://github.com/rapidsai/cuvs/pull/1566. Authors: - Bradley Dice (https://github.com/bdice) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuvs/pull/1585 --- ...-25.12.diff => faiss-1.13-cuvs-25.12.diff} | 28 ------------------- cpp/cmake/patches/faiss_override.json | 6 ++-- 2 files changed, 3 insertions(+), 31 deletions(-) rename cpp/cmake/patches/{faiss-25.12.diff => faiss-1.13-cuvs-25.12.diff} (90%) diff --git a/cpp/cmake/patches/faiss-25.12.diff b/cpp/cmake/patches/faiss-1.13-cuvs-25.12.diff similarity index 90% rename from cpp/cmake/patches/faiss-25.12.diff rename to cpp/cmake/patches/faiss-1.13-cuvs-25.12.diff index 98edbdfb8c..7fabbc675f 100644 --- a/cpp/cmake/patches/faiss-25.12.diff +++ b/cpp/cmake/patches/faiss-1.13-cuvs-25.12.diff @@ -120,34 +120,6 @@ index f23ca19d8..3ba606606 100644 #endif /// Pinned memory allocation for use with this GPU -diff --git a/faiss/gpu/impl/BinaryCuvsCagra.cu b/faiss/gpu/impl/BinaryCuvsCagra.cu -index 0ca21dc5f..b331fdc8f 100644 ---- a/faiss/gpu/impl/BinaryCuvsCagra.cu -+++ b/faiss/gpu/impl/BinaryCuvsCagra.cu -@@ -32,6 +32,9 @@ - #include - #include - -+#include -+#include -+ - namespace faiss { - namespace gpu { - -diff --git a/faiss/gpu/impl/CuvsCagra.cu b/faiss/gpu/impl/CuvsCagra.cu -index 482e4d672..4246776e8 100644 ---- a/faiss/gpu/impl/CuvsCagra.cu -+++ b/faiss/gpu/impl/CuvsCagra.cu -@@ -31,6 +31,9 @@ - #include - #include - -+#include -+#include -+ - namespace faiss { - namespace gpu { - diff --git a/faiss/gpu/impl/CuvsFlatIndex.cu b/faiss/gpu/impl/CuvsFlatIndex.cu index 15cf427cf..d877e766d 100644 --- a/faiss/gpu/impl/CuvsFlatIndex.cu diff --git a/cpp/cmake/patches/faiss_override.json b/cpp/cmake/patches/faiss_override.json index 989a043343..7d2d755740 100644 --- a/cpp/cmake/patches/faiss_override.json +++ b/cpp/cmake/patches/faiss_override.json @@ -1,12 +1,12 @@ { "packages" : { "faiss" : { - "version": "1.12.0", + "version": "1.13.0", "git_url": "https://github.com/facebookresearch/faiss.git", - "git_tag": "v1.12.0", + "git_tag": "v1.13.0", "patches" : [ { - "file" : "${current_json_dir}/faiss-25.12.diff", + "file" : "${current_json_dir}/faiss-1.13-cuvs-25.12.diff", "issue" : "Multiple fixes for cuVS and RMM compatibility", "fixed_in" : "" } From 21d16afbd2ca2cf6cc35cfaf914f08ea2afb4a1e Mon Sep 17 00:00:00 2001 From: irina-resh-nvda Date: Wed, 26 Nov 2025 15:54:30 +0100 Subject: [PATCH 32/32] CMake check for FAISS use in benchmarks (#1591) CUVS_ANN_BENCH_USE_FAISS is now set to OFF if all relevant flags are set OFF. The status is reported in the cmake log: -- Finding or building hnswlib -- Checking for FAISS use in benchmarks... -- CUVS_ANN_BENCH_USE_FAISS is OFF closes #1590. Authors: - https://github.com/irina-resh-nvda Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/cuvs/pull/1591 --- cpp/bench/ann/CMakeLists.txt | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index 46bbd318d2..8d254c0933 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -51,7 +51,29 @@ option(CUVS_KNN_BENCH_USE_CUVS_BRUTE_FORCE "Include cuVS brute force knn in benc find_package(Threads REQUIRED) -set(CUVS_ANN_BENCH_USE_FAISS ON) +# ----- FAISS use in Benchmarks ---- +get_cmake_property(_variableNames VARIABLES) + +set(CUVS_ANN_BENCH_USE_FAISS OFF) +message(STATUS "Checking for FAISS use in benchmarks...") +foreach(_varName ${_variableNames}) + if(_varName MATCHES "CUVS_ANN_BENCH_USE_FAISS.+") + if(${_varName}) + set(CUVS_ANN_BENCH_USE_FAISS ON) + message(STATUS "${_varName} is detected as ON.") + break() + endif() + endif() +endforeach() + +if(CUVS_ANN_BENCH_USE_FAISS) + message(STATUS "CUVS_ANN_BENCH_USE_FAISS is switched ON") +else() + message(STATUS "CUVS_ANN_BENCH_USE_FAISS is switched OFF") +endif() + +# ---------------------------------- + set(CUVS_FAISS_ENABLE_GPU ON) set(CUVS_USE_FAISS_STATIC ON)