Commit d024e973 authored by Lukas Weber's avatar Lukas Weber

fix multichain behavior

parent 905db47e
......@@ -88,12 +88,9 @@ void pt_chain::checkpoint_write(const iodump::group &g) {
g.write("entries_before_optimization", entries_before_optimization);
}
int pt_chain::histogram_entries() {
return nup_histogram.at(0);
}
void pt_chain::clear_histograms() {
rejection_rate_entries = 0;
rejection_rate_entries[0] = 0;
rejection_rate_entries[1] = 0;
std::fill(rejection_rates.begin(), rejection_rates.end(), 0);
std::fill(nup_histogram.begin(), nup_histogram.end(), 0);
std::fill(ndown_histogram.begin(), ndown_histogram.end(), 0);
......@@ -103,8 +100,10 @@ void pt_chain::clear_histograms() {
// https://arxiv.org/pdf/1905.02939.pdf
std::tuple<double, double> pt_chain::optimize_params() {
std::vector<double> rejection_est(rejection_rates);
bool odd = false;
for(auto& r : rejection_est) {
r /= rejection_rate_entries/2;
r /= rejection_rate_entries[odd];
odd = !odd;
}
std::vector<double> comm_barrier(params.size());
......@@ -314,8 +313,10 @@ void runner_pt_master::write_param_optimization_stats() {
cg.insert_back("params", chain.params);
std::vector<double> rejection_est(chain.rejection_rates);
bool odd = false;
for(auto &r : rejection_est) {
r /= chain.rejection_rate_entries/2.;
r /= chain.rejection_rate_entries[odd];
odd = !odd;
}
cg.insert_back("rejection_rates", rejection_est);
......@@ -448,14 +449,14 @@ void runner_pt_master::pt_param_optimization(pt_chain &chain, pt_chain_run &chai
chain.ndown_histogram[chain_run.rank_to_pos[rank]] += chain_run.last_visited[rank] == -1;
chain.nup_histogram[chain_run.rank_to_pos[rank]] += chain_run.last_visited[rank] == 1;
}
if(chain.histogram_entries() >= chain.entries_before_optimization) {
if(std::min(chain.rejection_rate_entries[0], chain.rejection_rate_entries[1]) >= chain.entries_before_optimization) {
chain.entries_before_optimization *= po_config_.nsamples_growth;
auto [fnonlinearity, convergence] = chain.optimize_params();
job_.log(
fmt::format("chain {}: pt param optimization: entries={}, f nonlinearity={:.2g}, "
"convergence={:.2g}",
chain.id, chain.histogram_entries(), fnonlinearity, convergence));
chain.id, chain.rejection_rate_entries[0], fnonlinearity, convergence));
checkpoint_write();
write_param_optimization_stats();
chain.clear_histograms();
......@@ -573,7 +574,7 @@ void runner_pt_master::pt_global_update(pt_chain &chain, pt_chain_run &chain_run
double r = rng_->random_double();
chain.rejection_rates[i] += 1 - std::min(w1 * w2, 1.);
chain.rejection_rate_entries++;
chain.rejection_rate_entries[chain_run.swap_odd]++;
if(r < w1 * w2) {
int rank0{};
int rank1{};
......
......@@ -21,7 +21,7 @@ struct pt_chain {
int entries_before_optimization{0};
std::vector<double> rejection_rates;
int rejection_rate_entries{0};
std::vector<int> rejection_rate_entries{0,0};
bool is_done();
void checkpoint_read(const iodump::group &g);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment