diff --git a/src/runner_pt.cpp b/src/runner_pt.cpp index 934cb31d6d23853cfa9b384bccd1cc3a30a78798..ba6272df022760a6272ae8d7c71fce9d0f659397 100644 --- a/src/runner_pt.cpp +++ b/src/runner_pt.cpp @@ -69,7 +69,6 @@ void pt_chain::checkpoint_read(const iodump::group &g) { g.read("nup_histogram", nup_histogram); g.read("ndown_histogram", ndown_histogram); g.read("entries_before_optimization", entries_before_optimization); - g.read("histogram_entries", histogram_entries); } void pt_chain::checkpoint_write(const iodump::group &g) { @@ -77,13 +76,15 @@ void pt_chain::checkpoint_write(const iodump::group &g) { g.write("nup_histogram", nup_histogram); g.write("ndown_histogram", ndown_histogram); g.write("entries_before_optimization", entries_before_optimization); - g.write("histogram_entries", histogram_entries); +} + +int pt_chain::histogram_entries() { + return nup_histogram.at(0); } void pt_chain::clear_histograms() { std::fill(nup_histogram.begin(), nup_histogram.end(), 0); std::fill(ndown_histogram.begin(), ndown_histogram.end(), 0); - histogram_entries = 0; } static double linear_regression(const std::vector<double> &x, const std::vector<double> &y, int i, @@ -444,19 +445,22 @@ void runner_pt_master::pt_param_optimization(pt_chain &chain, pt_chain_run &chai chain.ndown_histogram[chain_run.rank_to_pos[rank]] += chain_run.last_visited[rank] == -1; chain.nup_histogram[chain_run.rank_to_pos[rank]] += chain_run.last_visited[rank] == 1; } - chain.histogram_entries++; - if(chain.histogram_entries >= chain.entries_before_optimization) { - auto [fnonlinearity, convergence] = chain.optimize_params( - job_.jobfile["jobconfig"].get<int>("pt_parameter_optimization_linreg_len", 2)); - job_.log( - fmt::format("chain {}: pt param optimization: entries={}, f nonlinearity={:.2g}, " - "convergence={:.2g}", - chain.id, chain.entries_before_optimization, fnonlinearity, convergence)); - + if(chain.histogram_entries() >= chain.entries_before_optimization) { chain.entries_before_optimization *= job_.jobfile["jobconfig"].get<double>("pt_parameter_optimization_nsamples_growth", 1.5); - checkpoint_write(); - chain.clear_histograms(); + if(std::any_of(chain_run.last_visited.begin(), chain_run.last_visited.end(), [](int label) { return label == 0; })) { + job_.log(fmt::format("chain {}: some ranks still have no label. holding off parameter optimization due to insufficient statistics")); + } else { + auto [fnonlinearity, convergence] = chain.optimize_params( + job_.jobfile["jobconfig"].get<int>("pt_parameter_optimization_linreg_len", 2)); + job_.log( + fmt::format("chain {}: pt param optimization: entries={}, f nonlinearity={:.2g}, " + "convergence={:.2g}", + chain.id, chain.entries_before_optimization, fnonlinearity, convergence)); + + checkpoint_write(); + chain.clear_histograms(); + } } } diff --git a/src/runner_pt.h b/src/runner_pt.h index 32b24917c606e0c4246956bac8921186e3ec918e..3b66382e3082224fdc191a5245d93a3459d9c7f1 100644 --- a/src/runner_pt.h +++ b/src/runner_pt.h @@ -19,13 +19,13 @@ struct pt_chain { std::vector<int> nup_histogram; std::vector<int> ndown_histogram; int entries_before_optimization{0}; - int histogram_entries{}; bool is_done(); void checkpoint_read(const iodump::group &g); void checkpoint_write(const iodump::group &g); void clear_histograms(); + int histogram_entries(); std::tuple<double, double> optimize_params(int linreg_len); };