Commit 7cd8b32c authored by Lukas Weber's avatar Lukas Weber

be more conservative about insufficiently filled histograms

parent 3cc738ac
......@@ -69,7 +69,6 @@ void pt_chain::checkpoint_read(const iodump::group &g) {
g.read("nup_histogram", nup_histogram);
g.read("ndown_histogram", ndown_histogram);
g.read("entries_before_optimization", entries_before_optimization);
g.read("histogram_entries", histogram_entries);
}
void pt_chain::checkpoint_write(const iodump::group &g) {
......@@ -77,13 +76,15 @@ void pt_chain::checkpoint_write(const iodump::group &g) {
g.write("nup_histogram", nup_histogram);
g.write("ndown_histogram", ndown_histogram);
g.write("entries_before_optimization", entries_before_optimization);
g.write("histogram_entries", histogram_entries);
}
int pt_chain::histogram_entries() {
return nup_histogram.at(0);
}
void pt_chain::clear_histograms() {
std::fill(nup_histogram.begin(), nup_histogram.end(), 0);
std::fill(ndown_histogram.begin(), ndown_histogram.end(), 0);
histogram_entries = 0;
}
static double linear_regression(const std::vector<double> &x, const std::vector<double> &y, int i,
......@@ -444,19 +445,22 @@ void runner_pt_master::pt_param_optimization(pt_chain &chain, pt_chain_run &chai
chain.ndown_histogram[chain_run.rank_to_pos[rank]] += chain_run.last_visited[rank] == -1;
chain.nup_histogram[chain_run.rank_to_pos[rank]] += chain_run.last_visited[rank] == 1;
}
chain.histogram_entries++;
if(chain.histogram_entries >= chain.entries_before_optimization) {
auto [fnonlinearity, convergence] = chain.optimize_params(
job_.jobfile["jobconfig"].get<int>("pt_parameter_optimization_linreg_len", 2));
job_.log(
fmt::format("chain {}: pt param optimization: entries={}, f nonlinearity={:.2g}, "
"convergence={:.2g}",
chain.id, chain.entries_before_optimization, fnonlinearity, convergence));
if(chain.histogram_entries() >= chain.entries_before_optimization) {
chain.entries_before_optimization *=
job_.jobfile["jobconfig"].get<double>("pt_parameter_optimization_nsamples_growth", 1.5);
checkpoint_write();
chain.clear_histograms();
if(std::any_of(chain_run.last_visited.begin(), chain_run.last_visited.end(), [](int label) { return label == 0; })) {
job_.log(fmt::format("chain {}: some ranks still have no label. holding off parameter optimization due to insufficient statistics"));
} else {
auto [fnonlinearity, convergence] = chain.optimize_params(
job_.jobfile["jobconfig"].get<int>("pt_parameter_optimization_linreg_len", 2));
job_.log(
fmt::format("chain {}: pt param optimization: entries={}, f nonlinearity={:.2g}, "
"convergence={:.2g}",
chain.id, chain.entries_before_optimization, fnonlinearity, convergence));
checkpoint_write();
chain.clear_histograms();
}
}
}
......
......@@ -19,13 +19,13 @@ struct pt_chain {
std::vector<int> nup_histogram;
std::vector<int> ndown_histogram;
int entries_before_optimization{0};
int histogram_entries{};
bool is_done();
void checkpoint_read(const iodump::group &g);
void checkpoint_write(const iodump::group &g);
void clear_histograms();
int histogram_entries();
std::tuple<double, double> optimize_params(int linreg_len);
};
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment