Added an association-based initialization method. Works great.

parent 33ab9d01
......@@ -15,22 +15,26 @@ du_algorithm_t update_dictionary = update_dictionary;
ml_algorithm_t learn_model = learn_model_traditional;
ml_algorithm_t learn_model_inner = learn_model_traditional;
mi_algorithm_t mi_algorithm_catalog[] = {initialize_dictionary_neighbor,
mi_algorithm_t mi_algorithm_catalog[] = {
initialize_dictionary_neighbor,
initialize_dictionary_partition,
initialize_dictionary_sdct,
initialize_dictionary_sdct2d,
initialize_dictionary_hamming,
initialize_dictionary_random,
initialize_dictionary_samples,
initialize_dictionary_association,
0};
const char* mi_algorithm_names[] = {"Neighbor initialization",
const char* mi_algorithm_names[] = {
"Neighbor initialization",
"Partition initialization",
"sign of DCT" ,
" sign of 2D DCT",
"Hamming basis",
"purely random dictionary",
"random samples from data",
"Association initialization",
0
};
......@@ -78,7 +82,7 @@ const char* lm_algorithm_names[] = {"Model learning by traditional alternate des
void learn_model_setup(int mi_algo, int es_algo, int du_algo, int lm_algo, int lmi_algo) {
if (mi_algo > 6) { std::cerr << "Invalid model initialization algorithm (0-" << 5 << ')' << std::endl; exit(-1); }
if (mi_algo > 7) { std::cerr << "Invalid model initialization algorithm (0-" << 5 << ')' << std::endl; exit(-1); }
if (es_algo > 2) { std::cerr << "Invalid coefficients update algorithm (0-" << 2 << ')' << std::endl; exit(-1); }
if (du_algo > 3) { std::cerr << "Invalid dictionary update algorithm (0-" << 3 << ')' << std::endl; exit(-1); }
if (lm_algo > 6) { std::cerr << "Invalid model learning algorithm (0-" << 6 << ')' << std::endl; exit(-1); }
......
......@@ -17,11 +17,64 @@ void initialize_dictionary_random(const binary_matrix& E,
binary_matrix& A) {
const idx_t K = D.get_rows();
const idx_t M = D.get_cols();
const idx_t N = E.get_rows();
double* col_densities = new double[M];
for (idx_t j = 0; j < M; j++) {
col_densities[j] = E.col_weight(j) / (double) N;
//std::cout << "density for col " << j << ":" << col_densities[j] << std::endl;
}
for (idx_t k = 0; k < K; k++) {
for (idx_t j = 0; j < M; j++) {
D.set(k,j,get_bernoulli_sample(density));
D.set(k,j,get_bernoulli_sample(col_densities[j]));
}
}
delete[] col_densities;
A.clear();
}
/// computes an association score for each row i
/// s_i = \sum_{j neq i} C(i,j) where C(i,j) = h(E_i AND E_j) (*)
/// and then chooses the K rows with the highest s_i
/// as initial dictionary atoms.
///
/// (*) actually, original version is normalized: C(i,j) = h(E_i AND E_j) / h(E_i)
///
void initialize_dictionary_association(const binary_matrix& E,
const binary_matrix& H,
binary_matrix& D,
binary_matrix& A) {
const idx_t K = D.get_rows();
const idx_t M = D.get_cols();
const idx_t N = E.get_rows();
binary_matrix rowi,rowj,iandj;
rowi.allocate(1, M);
rowj.allocate(1, M);
iandj.allocate ( 1, M );
aux_t* scores = new aux_t[N];
for (idx_t i = 0; i < N; i++) {
scores[i].second = i; // original index
scores[i].first = 0; // score
E.copy_row_to(i,rowi);
for (idx_t j = 0; j < N; j++) {
if (j == i) continue;
//std::cout << "density for col " << j << ":" << col_densities[j] << std::endl;
bool_and(rowi,rowj,iandj);
scores[i].first += iandj.weight();
}
}
counting_sort(scores,N);
for (idx_t i = 0; i < K; i++) {
std::cout << i << ":" << scores[N-i-1].first << " " << scores[N-i-1].second << std::endl;
E.copy_row_to(scores[N-i-1].second, rowi);
D.set_row(i,rowi);
}
delete[] scores;
rowi.destroy();
rowj.destroy();
iandj.destroy();
A.clear();
}
......
......@@ -3,6 +3,14 @@
#include "binmat.h"
/**
* Initialize dictionary to the K samples that are more similar to
* the rest of the rows.
*/
void initialize_dictionary_association(const binary_matrix& E,
const binary_matrix& H,
binary_matrix& D,
binary_matrix& A);
/**
......
#ifndef RANDOM_NUMBER_GENERATION
#define RANDOM_NUMBER_GENERATION
#include "gsl/gsl_randist.h"
#include <gsl/gsl_randist.h>
/**
* random seed used
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment