Commit d024e973 authored by Lukas Weber's avatar Lukas Weber

fix multichain behavior

parent 905db47e
...@@ -88,12 +88,9 @@ void pt_chain::checkpoint_write(const iodump::group &g) { ...@@ -88,12 +88,9 @@ void pt_chain::checkpoint_write(const iodump::group &g) {
g.write("entries_before_optimization", entries_before_optimization); g.write("entries_before_optimization", entries_before_optimization);
} }
int pt_chain::histogram_entries() {
return nup_histogram.at(0);
}
void pt_chain::clear_histograms() { void pt_chain::clear_histograms() {
rejection_rate_entries = 0; rejection_rate_entries[0] = 0;
rejection_rate_entries[1] = 0;
std::fill(rejection_rates.begin(), rejection_rates.end(), 0); std::fill(rejection_rates.begin(), rejection_rates.end(), 0);
std::fill(nup_histogram.begin(), nup_histogram.end(), 0); std::fill(nup_histogram.begin(), nup_histogram.end(), 0);
std::fill(ndown_histogram.begin(), ndown_histogram.end(), 0); std::fill(ndown_histogram.begin(), ndown_histogram.end(), 0);
...@@ -103,8 +100,10 @@ void pt_chain::clear_histograms() { ...@@ -103,8 +100,10 @@ void pt_chain::clear_histograms() {
// https://arxiv.org/pdf/1905.02939.pdf // https://arxiv.org/pdf/1905.02939.pdf
std::tuple<double, double> pt_chain::optimize_params() { std::tuple<double, double> pt_chain::optimize_params() {
std::vector<double> rejection_est(rejection_rates); std::vector<double> rejection_est(rejection_rates);
bool odd = false;
for(auto& r : rejection_est) { for(auto& r : rejection_est) {
r /= rejection_rate_entries/2; r /= rejection_rate_entries[odd];
odd = !odd;
} }
std::vector<double> comm_barrier(params.size()); std::vector<double> comm_barrier(params.size());
...@@ -314,8 +313,10 @@ void runner_pt_master::write_param_optimization_stats() { ...@@ -314,8 +313,10 @@ void runner_pt_master::write_param_optimization_stats() {
cg.insert_back("params", chain.params); cg.insert_back("params", chain.params);
std::vector<double> rejection_est(chain.rejection_rates); std::vector<double> rejection_est(chain.rejection_rates);
bool odd = false;
for(auto &r : rejection_est) { for(auto &r : rejection_est) {
r /= chain.rejection_rate_entries/2.; r /= chain.rejection_rate_entries[odd];
odd = !odd;
} }
cg.insert_back("rejection_rates", rejection_est); cg.insert_back("rejection_rates", rejection_est);
...@@ -448,14 +449,14 @@ void runner_pt_master::pt_param_optimization(pt_chain &chain, pt_chain_run &chai ...@@ -448,14 +449,14 @@ void runner_pt_master::pt_param_optimization(pt_chain &chain, pt_chain_run &chai
chain.ndown_histogram[chain_run.rank_to_pos[rank]] += chain_run.last_visited[rank] == -1; chain.ndown_histogram[chain_run.rank_to_pos[rank]] += chain_run.last_visited[rank] == -1;
chain.nup_histogram[chain_run.rank_to_pos[rank]] += chain_run.last_visited[rank] == 1; chain.nup_histogram[chain_run.rank_to_pos[rank]] += chain_run.last_visited[rank] == 1;
} }
if(chain.histogram_entries() >= chain.entries_before_optimization) { if(std::min(chain.rejection_rate_entries[0], chain.rejection_rate_entries[1]) >= chain.entries_before_optimization) {
chain.entries_before_optimization *= po_config_.nsamples_growth; chain.entries_before_optimization *= po_config_.nsamples_growth;
auto [fnonlinearity, convergence] = chain.optimize_params(); auto [fnonlinearity, convergence] = chain.optimize_params();
job_.log( job_.log(
fmt::format("chain {}: pt param optimization: entries={}, f nonlinearity={:.2g}, " fmt::format("chain {}: pt param optimization: entries={}, f nonlinearity={:.2g}, "
"convergence={:.2g}", "convergence={:.2g}",
chain.id, chain.histogram_entries(), fnonlinearity, convergence)); chain.id, chain.rejection_rate_entries[0], fnonlinearity, convergence));
checkpoint_write(); checkpoint_write();
write_param_optimization_stats(); write_param_optimization_stats();
chain.clear_histograms(); chain.clear_histograms();
...@@ -573,7 +574,7 @@ void runner_pt_master::pt_global_update(pt_chain &chain, pt_chain_run &chain_run ...@@ -573,7 +574,7 @@ void runner_pt_master::pt_global_update(pt_chain &chain, pt_chain_run &chain_run
double r = rng_->random_double(); double r = rng_->random_double();
chain.rejection_rates[i] += 1 - std::min(w1 * w2, 1.); chain.rejection_rates[i] += 1 - std::min(w1 * w2, 1.);
chain.rejection_rate_entries++; chain.rejection_rate_entries[chain_run.swap_odd]++;
if(r < w1 * w2) { if(r < w1 * w2) {
int rank0{}; int rank0{};
int rank1{}; int rank1{};
......
...@@ -21,7 +21,7 @@ struct pt_chain { ...@@ -21,7 +21,7 @@ struct pt_chain {
int entries_before_optimization{0}; int entries_before_optimization{0};
std::vector<double> rejection_rates; std::vector<double> rejection_rates;
int rejection_rate_entries{0}; std::vector<int> rejection_rate_entries{0,0};
bool is_done(); bool is_done();
void checkpoint_read(const iodump::group &g); void checkpoint_read(const iodump::group &g);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment