130 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			130 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /**
 | |
|  * \file    v0.hpp
 | |
|  * \brief
 | |
|  *
 | |
|  * \author
 | |
|  *    Christos Choutouridis AEM:8997
 | |
|  *    <cchoutou@ece.auth.gr>
 | |
|  */
 | |
| #ifndef V0_HPP_
 | |
| #define V0_HPP_
 | |
| 
 | |
| #include <cblas.h>
 | |
| #include <cmath>
 | |
| #include <vector>
 | |
| #include <algorithm>
 | |
| 
 | |
| #include "matrix.hpp"
 | |
| #include "config.h"
 | |
| 
 | |
| namespace v0 {
 | |
| 
 | |
| /*!
 | |
|  * Function to compute squared Euclidean distances
 | |
|  *
 | |
|  * \fn void pdist2(const double*, const double*, double*, int, int, int)
 | |
|  * \param X    m x d matrix (Column major)
 | |
|  * \param Y    n x d matrix (Column major)
 | |
|  * \param D2   m x n matrix to store distances (Column major)
 | |
|  * \param m    number of rows in X
 | |
|  * \param n    number of rows in Y
 | |
|  * \param d    number of columns in both X and Y
 | |
|  */
 | |
| template<typename Matrix>
 | |
| void pdist2(const Matrix& X, const Matrix& Y, Matrix& D2) {
 | |
|    using DataType = typename Matrix::dataType;
 | |
| 
 | |
|    int M = X.rows();
 | |
|    int N = Y.rows();
 | |
|    int d = X.columns();
 | |
| 
 | |
|    // Compute the squared norms of each row in X and Y
 | |
|    std::vector<DataType> X_norms(M), Y_norms(N);
 | |
|    for (int i = 0; i < M ; ++i) {
 | |
|       X_norms[i] = cblas_ddot(d, X.data() + i * d, 1, X.data() + i * d, 1);
 | |
|    }
 | |
|    for (int j = 0; j < N ; ++j) {
 | |
|       Y_norms[j] = cblas_ddot(d, Y.data() + j * d, 1, Y.data() + j * d, 1);
 | |
|    }
 | |
| 
 | |
|    // Compute -2 * X * Y'
 | |
|    cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, M, N, d, -2.0, X.data(), d, Y.data(), d, 0.0, D2.data(), N);
 | |
| 
 | |
|    // Step 3: Add the squared norms to each entry in D2
 | |
|    for (int i = 0; i < M ; ++i) {
 | |
|       for (int j = 0; j < N; ++j) {
 | |
|          D2.set(D2.get(i, j) + X_norms[i] + Y_norms[j], i, j);
 | |
|          D2.set(std::max(D2.get(i, j), 0.0),            i, j); // Ensure non-negative
 | |
|          D2.set(std::sqrt(D2.get(i, j)),                i, j); // Take the square root of each
 | |
|       }
 | |
|    }
 | |
|    M++;
 | |
| }
 | |
| 
 | |
| /*!
 | |
|  * Quick select implementation
 | |
|  * \fn void quickselect(std::vector<std::pair<DataType,IndexType>>&, int)
 | |
|  * \tparam DataType
 | |
|  * \tparam IndexType
 | |
|  * \param vec  Vector of paire(distance, index) to partially sort over distance
 | |
|  * \param k    The number of elements to sort-select
 | |
|  */
 | |
| template<typename DataType, typename IndexType>
 | |
| void quickselect(std::vector<std::pair<DataType, IndexType>>& vec, int k) {
 | |
|    std::nth_element(
 | |
|       vec.begin(),
 | |
|       vec.begin() + k,
 | |
|       vec.end(),
 | |
|       [](const std::pair<DataType, IndexType>& a, const std::pair<DataType, IndexType>& b) {
 | |
|          return a.first < b.first;
 | |
|    });
 | |
|    vec.resize(k);  // Keep only the k smallest elements
 | |
| }
 | |
| 
 | |
| /*!
 | |
|  * \param C    Is a MxD matrix (Corpus)
 | |
|  * \param Q    Is a NxD matrix (Query)
 | |
|  * \param idx_offset The offset of the indexes for output (to match with the actual Corpus indexes)
 | |
|  * \param k    The number of nearest neighbors needed
 | |
|  * \param idx  Is the Nxk matrix with the k indexes of the C points, that are
 | |
|  *             neighbors of the nth point of Q
 | |
|  * \param dst  Is the Nxk matrix with the k distances to the C points of the nth
 | |
|  *             point of Q
 | |
|  */
 | |
| template<typename MatrixD, typename MatrixI>
 | |
| void knnsearch(MatrixD& C, MatrixD& Q, size_t idx_offset, size_t k, size_t m, MatrixI& idx, MatrixD& dst) {
 | |
| 
 | |
|    using DstType = typename MatrixD::dataType;
 | |
|    using IdxType = typename MatrixI::dataType;
 | |
| 
 | |
|    size_t M = C.rows();
 | |
|    size_t N = Q.rows();
 | |
| 
 | |
|    mtx::Matrix<DstType> D(M, N);
 | |
| 
 | |
|    pdist2(C, Q, D);
 | |
| 
 | |
|    for (size_t j = 0; j < N; ++j) {
 | |
|       // Create a vector of pairs (distance, index) for the j-th query
 | |
|       std::vector<std::pair<DstType, IdxType>> dst_idx(M);
 | |
|       for (size_t i = 0; i < M; ++i) {
 | |
|          dst_idx[i] = {D.data()[i * N + j], i};
 | |
|       }
 | |
|       // Find the k smallest distances using quickSelectKSmallest
 | |
|       quickselect(dst_idx, k);
 | |
| 
 | |
|       // Sort the k smallest results by distance for consistency
 | |
|       std::sort(dst_idx.begin(), dst_idx.end());
 | |
| 
 | |
|       // Store the indices and distances
 | |
|       for (size_t i = 0; i < k; ++i) {
 | |
|          dst.set(dst_idx[i].first, j, i);
 | |
|          idx.set(dst_idx[i].second + idx_offset, j, i);
 | |
|       }
 | |
|    }
 | |
| }
 | |
| 
 | |
| }
 | |
| 
 | |
| #endif /* V0_HPP_ */
 |