std::vector<double> aux;
45std::transform(a.begin(), a.end(), b.begin(), std::back_inserter(aux),
46[](T x1, T x2) { return std::pow((x1 - x2), 2); });
48 returnstd::sqrt(std::accumulate(aux.begin(), aux.end(), 0.0));
57std::vector<std::vector<double>>
X_{};
58std::vector<int>
Y_{};
67 explicit Knn(std::vector<std::vector<double>>& X, std::vector<int>& Y)
103 int predict(std::vector<double>& sample,
intk) {
104std::vector<int> neighbors;
105std::vector<std::pair<double, int>> distances;
106 for(
size_ti = 0; i < this->X_.size(); ++i) {
107 autocurrent = this->X_.at(i);
108 autolabel = this->Y_.at(i);
110distances.emplace_back(distance, label);
112std::sort(distances.begin(), distances.end());
113 for(
inti = 0; i < k; i++) {
114 autolabel = distances.at(i).second;
115neighbors.push_back(label);
117std::unordered_map<int, int> frequency;
118 for(
autoneighbor : neighbors) {
119++frequency[neighbor];
121std::pair<int, int> predicted;
122predicted.first = -1;
123predicted.second = -1;
124 for(
auto& kv : frequency) {
125 if(kv.second > predicted.second) {
126predicted.second = kv.second;
127predicted.first = kv.first;
130 returnpredicted.first;
141std::cout <<
"------- Test 1 -------"<< std::endl;
142std::vector<std::vector<double>> X1 = {{0.0, 0.0}, {0.25, 0.25},
143{0.0, 0.5}, {0.5, 0.5},
144{1.0, 0.5}, {1.0, 1.0}};
145std::vector<int> Y1 = {1, 1, 1, 1, 2, 2};
147std::vector<double> sample1 = {1.2, 1.2};
148std::vector<double> sample2 = {0.1, 0.1};
149std::vector<double> sample3 = {0.1, 0.5};
150std::vector<double> sample4 = {1.0, 0.75};
151assert(model1.predict(sample1, 2) == 2);
152assert(model1.predict(sample2, 2) == 1);
153assert(model1.predict(sample3, 2) == 1);
154assert(model1.predict(sample4, 2) == 2);
155std::cout <<
"... Passed"<< std::endl;
156std::cout <<
"------- Test 2 -------"<< std::endl;
157std::vector<std::vector<double>> X2 = {
158{0.0, 0.0, 0.0}, {0.25, 0.25, 0.0}, {0.0, 0.5, 0.0}, {0.5, 0.5, 0.0},
159{1.0, 0.5, 0.0}, {1.0, 1.0, 0.0}, {1.0, 1.0, 1.0}, {1.5, 1.5, 1.0}};
160std::vector<int> Y2 = {1, 1, 1, 1, 2, 2, 3, 3};
162std::vector<double> sample5 = {1.2, 1.2, 0.0};
163std::vector<double> sample6 = {0.1, 0.1, 0.0};
164std::vector<double> sample7 = {0.1, 0.5, 0.0};
165std::vector<double> sample8 = {1.0, 0.75, 1.0};
166assert(model2.predict(sample5, 2) == 2);
167assert(model2.predict(sample6, 2) == 1);
168assert(model2.predict(sample7, 2) == 1);
169assert(model2.predict(sample8, 2) == 3);
170std::cout <<
"... Passed"<< std::endl;
171std::cout <<
"------- Test 3 -------"<< std::endl;
172std::vector<std::vector<double>> X3 = {{0.0}, {1.0}, {2.0}, {3.0},
173{4.0}, {5.0}, {6.0}, {7.0}};
174std::vector<int> Y3 = {1, 1, 1, 1, 2, 2, 2, 2};
176std::vector<double> sample9 = {0.5};
177std::vector<double> sample10 = {2.9};
178std::vector<double> sample11 = {5.5};
179std::vector<double> sample12 = {7.5};
180assert(model3.predict(sample9, 3) == 1);
181assert(model3.predict(sample10, 3) == 1);
182assert(model3.predict(sample11, 3) == 2);
183assert(model3.predict(sample12, 3) == 2);
184std::cout <<
"... Passed"<< std::endl;
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4