109 lines
3.4 KiB
C++
109 lines
3.4 KiB
C++
/**
|
|
* \file v0.hpp
|
|
* \brief
|
|
*
|
|
* \author
|
|
* Christos Choutouridis AEM:8997
|
|
* <cchoutou@ece.auth.gr>
|
|
*/
|
|
#ifndef V1_HPP_
|
|
#define V1_HPP_
|
|
|
|
#include <vector>
|
|
#include <algorithm>
|
|
|
|
#include "matrix.hpp"
|
|
#include "v0.hpp"
|
|
#include "config.h"
|
|
|
|
namespace v1 {
|
|
|
|
template <typename DataType, typename IndexType>
|
|
void mergeResultsWithM(mtx::Matrix<IndexType>& N1, mtx::Matrix<DataType>& D1,
|
|
mtx::Matrix<IndexType>& N2, mtx::Matrix<DataType>& D2,
|
|
size_t k, size_t m,
|
|
mtx::Matrix<IndexType>& N, mtx::Matrix<DataType>& D) {
|
|
size_t numQueries = N1.rows();
|
|
size_t maxCandidates = std::min((IndexType)m, (IndexType)(N1.columns() + N2.columns()));
|
|
|
|
for (size_t q = 0; q < numQueries; ++q) {
|
|
// Combine distances and neighbors
|
|
std::vector<std::pair<DataType, IndexType>> candidates(N1.columns() + N2.columns());
|
|
|
|
// Concatenate N1 and N2 rows
|
|
for (size_t i = 0; i < N1.columns(); ++i) {
|
|
candidates[i] = {D1.get(q, i), N1.get(q, i)};
|
|
}
|
|
for (size_t i = 0; i < N2.columns(); ++i) {
|
|
candidates[i + N1.columns()] = {D2.get(q, i), N2.get(q, i)};
|
|
}
|
|
|
|
// Keep only the top-m candidates
|
|
v0::quickselect(candidates, maxCandidates);
|
|
|
|
// Sort the top-m candidates
|
|
std::sort(candidates.begin(), candidates.begin() + maxCandidates);
|
|
|
|
// If m < k, pad the remaining slots with invalid values
|
|
for (size_t i = 0; i < k; ++i) {
|
|
if (i < maxCandidates) {
|
|
D.set(candidates[i].first, q, i);
|
|
N.set(candidates[i].second, q, i);
|
|
} else {
|
|
D.set(std::numeric_limits<DataType>::infinity(), q, i);
|
|
N.set(static_cast<IndexType>(-1), q, i); // Invalid index (end)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
template<typename MatrixD, typename MatrixI>
|
|
void knnsearch(const MatrixD& C, const MatrixD& Q, size_t idx_offset, size_t k, size_t m, MatrixI& idx, MatrixD& dst) {
|
|
|
|
using DstType = typename MatrixD::dataType;
|
|
using IdxType = typename MatrixI::dataType;
|
|
|
|
if (C.rows() <= 8 || Q.rows() <= 4) {
|
|
// Base case: Call knnsearch directly
|
|
v0::knnsearch(C, Q, idx_offset, k, m, idx, dst);
|
|
return;
|
|
}
|
|
|
|
// Divide Corpus and Query into subsets
|
|
IdxType midC = C.rows() / 2;
|
|
IdxType midQ = Q.rows() / 2;
|
|
|
|
// Slice corpus and query matrixes
|
|
MatrixD C1((DstType*)C.data(), 0, midC, C.columns());
|
|
MatrixD C2((DstType*)C.data(), midC, midC, C.columns());
|
|
MatrixD Q1((DstType*)Q.data(), 0, midQ, Q.columns());
|
|
MatrixD Q2((DstType*)Q.data(), midQ, midQ, Q.columns());
|
|
|
|
// Allocate temporary matrixes for all permutations
|
|
MatrixI N1_1(midQ, k), N1_2(midQ, k), N2_1(midQ, k), N2_2(midQ, k);
|
|
MatrixD D1_1(midQ, k), D1_2(midQ, k), D2_1(midQ, k), D2_2(midQ, k);
|
|
|
|
// Recursive calls
|
|
knnsearch(C1, Q1, idx_offset, k, m, N1_1, D1_1);
|
|
knnsearch(C2, Q1, idx_offset + midC, k, m, N1_2, D1_2);
|
|
knnsearch(C1, Q2, idx_offset, k, m, N2_1, D2_1);
|
|
knnsearch(C2, Q2, idx_offset + midC, k, m, N2_2, D2_2);
|
|
|
|
// slice output matrixes
|
|
MatrixI N1((IdxType*)idx.data(), 0, midQ, k);
|
|
MatrixI N2((IdxType*)idx.data(), midQ, midQ, k);
|
|
MatrixD D1((DstType*)dst.data(), 0, midQ, k);
|
|
MatrixD D2((DstType*)dst.data(), midQ, midQ, k);
|
|
|
|
// Merge results in place
|
|
mergeResultsWithM(N1_1, D1_1, N1_2, D1_2, k, m, N1, D1);
|
|
mergeResultsWithM(N2_1, D2_1, N2_2, D2_2, k, m, N2, D2);
|
|
|
|
}
|
|
|
|
} // namespace v1
|
|
|
|
#endif /* V1_HPP_ */
|