PDS/homework_1/inc/v1.hpp

109 lines
3.4 KiB
C++

/**
* \file v0.hpp
* \brief
*
* \author
* Christos Choutouridis AEM:8997
* <cchoutou@ece.auth.gr>
*/
#ifndef V1_HPP_
#define V1_HPP_
#include <vector>
#include <algorithm>
#include "matrix.hpp"
#include "v0.hpp"
#include "config.h"
namespace v1 {
template <typename DataType, typename IndexType>
void mergeResultsWithM(mtx::Matrix<IndexType>& N1, mtx::Matrix<DataType>& D1,
mtx::Matrix<IndexType>& N2, mtx::Matrix<DataType>& D2,
size_t k, size_t m,
mtx::Matrix<IndexType>& N, mtx::Matrix<DataType>& D) {
size_t numQueries = N1.rows();
size_t maxCandidates = std::min((IndexType)m, (IndexType)(N1.columns() + N2.columns()));
for (size_t q = 0; q < numQueries; ++q) {
// Combine distances and neighbors
std::vector<std::pair<DataType, IndexType>> candidates(N1.columns() + N2.columns());
// Concatenate N1 and N2 rows
for (size_t i = 0; i < N1.columns(); ++i) {
candidates[i] = {D1.get(q, i), N1.get(q, i)};
}
for (size_t i = 0; i < N2.columns(); ++i) {
candidates[i + N1.columns()] = {D2.get(q, i), N2.get(q, i)};
}
// Keep only the top-m candidates
v0::quickselect(candidates, maxCandidates);
// Sort the top-m candidates
std::sort(candidates.begin(), candidates.begin() + maxCandidates);
// If m < k, pad the remaining slots with invalid values
for (size_t i = 0; i < k; ++i) {
if (i < maxCandidates) {
D.set(candidates[i].first, q, i);
N.set(candidates[i].second, q, i);
} else {
D.set(std::numeric_limits<DataType>::infinity(), q, i);
N.set(static_cast<IndexType>(-1), q, i); // Invalid index (end)
}
}
}
}
template<typename MatrixD, typename MatrixI>
void knnsearch(const MatrixD& C, const MatrixD& Q, size_t idx_offset, size_t k, size_t m, MatrixI& idx, MatrixD& dst) {
using DstType = typename MatrixD::dataType;
using IdxType = typename MatrixI::dataType;
if (C.rows() <= 8 || Q.rows() <= 4) {
// Base case: Call knnsearch directly
v0::knnsearch(C, Q, idx_offset, k, m, idx, dst);
return;
}
// Divide Corpus and Query into subsets
IdxType midC = C.rows() / 2;
IdxType midQ = Q.rows() / 2;
// Slice corpus and query matrixes
MatrixD C1((DstType*)C.data(), 0, midC, C.columns());
MatrixD C2((DstType*)C.data(), midC, midC, C.columns());
MatrixD Q1((DstType*)Q.data(), 0, midQ, Q.columns());
MatrixD Q2((DstType*)Q.data(), midQ, midQ, Q.columns());
// Allocate temporary matrixes for all permutations
MatrixI N1_1(midQ, k), N1_2(midQ, k), N2_1(midQ, k), N2_2(midQ, k);
MatrixD D1_1(midQ, k), D1_2(midQ, k), D2_1(midQ, k), D2_2(midQ, k);
// Recursive calls
knnsearch(C1, Q1, idx_offset, k, m, N1_1, D1_1);
knnsearch(C2, Q1, idx_offset + midC, k, m, N1_2, D1_2);
knnsearch(C1, Q2, idx_offset, k, m, N2_1, D2_1);
knnsearch(C2, Q2, idx_offset + midC, k, m, N2_2, D2_2);
// slice output matrixes
MatrixI N1((IdxType*)idx.data(), 0, midQ, k);
MatrixI N2((IdxType*)idx.data(), midQ, midQ, k);
MatrixD D1((DstType*)dst.data(), 0, midQ, k);
MatrixD D2((DstType*)dst.data(), midQ, midQ, k);
// Merge results in place
mergeResultsWithM(N1_1, D1_1, N1_2, D1_2, k, m, N1, D1);
mergeResultsWithM(N2_1, D2_1, N2_2, D2_2, k, m, N2, D2);
}
} // namespace v1
#endif /* V1_HPP_ */