Commit c94359b5 authored by Lukas Weber's avatar Lukas Weber

more pt statistics, smarter optimization algorithm

parent 39bcbb1d
......@@ -19,12 +19,6 @@ void mc::_init() {
measure.register_observable("_ll_measurement_time", 1000);
measure.register_observable("_ll_sweep_time", 1000);
if(pt_mode_) {
if(param.get<bool>("pt_statistics", false)) {
measure.register_observable("_ll_pt_rank", 1);
}
}
if(param.defined("seed")) {
rng.reset(new random_number_generator(param.get<uint64_t>("seed")));
} else {
......@@ -70,14 +64,6 @@ void mc::_pt_update_param(int target_rank, const std::string &param_name, double
pt_update_param(param_name, new_param);
}
void mc::pt_measure_statistics() {
if(param.get<bool>("pt_statistics", false)) {
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
measure.add("_ll_pt_rank", rank);
}
}
double mc::_pt_weight_ratio(const std::string &param_name, double new_param) {
double wr = pt_weight_ratio(param_name, new_param);
return wr;
......
......@@ -65,8 +65,6 @@ public:
void _pt_update_param(int target_rank, const std::string &param_name, double new_param);
double _pt_weight_ratio(const std::string &param_name, double new_param);
void pt_measure_statistics();
double safe_exit_interval();
bool is_thermalized();
......
......@@ -40,6 +40,10 @@ pt_chain_run::pt_chain_run(const pt_chain &chain, int run_id) : id{chain.id}, ru
}
}
void pt_chain_run::clear_histograms() {
std::fill(last_visited.begin(), last_visited.end(), 0);
}
pt_chain_run pt_chain_run::checkpoint_read(const iodump::group &g) {
pt_chain_run run;
g.read("id", run.id);
......@@ -87,7 +91,7 @@ void pt_chain::clear_histograms() {
std::fill(ndown_histogram.begin(), ndown_histogram.end(), 0);
}
/*
static double linear_regression(const std::vector<double> &x, const std::vector<double> &y, int i,
int range) {
int lower = std::max(0, i - range + 1);
......@@ -106,8 +110,9 @@ static double linear_regression(const std::vector<double> &x, const std::vector<
int n = upper - lower + 1;
return (sxy - sx * sy / n) / (sx2 - sx * sx / n);
}*/
}
/*
// x and y are ordered
static double linear_inverse(const std::vector<double> &x, const std::vector<double> &y, double y0) {
for(size_t i = 0; i < y.size()-1; i++) {
......@@ -116,7 +121,7 @@ static double linear_inverse(const std::vector<double> &x, const std::vector<dou
}
}
throw std::out_of_range{"y0 outside of image of y(x)"};
}
}*/
std::tuple<double, double> pt_chain::optimize_params(double relaxation_fac) {
std::vector<double> eta(params.size());
......@@ -143,11 +148,17 @@ std::tuple<double, double> pt_chain::optimize_params(double relaxation_fac) {
fnonlinearity += (f[i] - ideal_f) * (f[i] - ideal_f);
}
fnonlinearity = sqrt(fnonlinearity) / fnonlinearity_worst;
/*
double norm = 0;
for(size_t i = 0; i < params.size() - 1; i++) {
double dfdp = linear_regression(params, f, i, linreg_len);
double dfdp = 0;
size_t linreg_len = 0;
double dp = params[i + 1] - params[i];
do {
linreg_len++;
dfdp = linear_regression(params, f, i, linreg_len);
} while(dfdp * dp > 0 && linreg_len < params.size()/2);
eta[i] = sqrt(std::max(0.01, -dfdp / dp));
norm += eta[i] * dp;
}
......@@ -168,13 +179,14 @@ std::tuple<double, double> pt_chain::optimize_params(double relaxation_fac) {
}
new_params[i] = params[etai] + target / eta[etai];
convergence += (new_params[i] - params[i]) * (new_params[i] - params[i]);
}*/
}
/*
double convergence = 0;
for(size_t i = 1; i < params.size() - 1; i++) {
double target = 1-static_cast<double>(i) / (params.size()-1);
new_params[i] = linear_inverse(params, f, target);
convergence += (new_params[i] - params[i]) * (new_params[i] - params[i]);
}
}*/
convergence = sqrt(convergence) / (params.size() - 2);
......@@ -341,6 +353,24 @@ void runner_pt_master::write_params_yaml() {
file << params.c_str() << "\n";
}
void runner_pt_master::write_param_optimization_stats() {
std::string stat_name = job_.jobdir() + "/pt_param_stats.h5";
iodump stat = iodump::open_readwrite(stat_name);
auto g = stat.get_root();
g.write("chain_length", chain_len_);
for(auto &chain : pt_chains_) {
auto cg = g.open_group(fmt::format("chain{:04d}", chain.id));
std::vector<double> f(chain_len_);
for(size_t i = 0; i < f.size(); i++) {
f[i] = chain.nup_histogram[i]*1./(chain.ndown_histogram[i] + chain.nup_histogram[i]);
}
cg.insert_back("f", f);
cg.insert_back("params", chain.params);
}
}
void runner_pt_master::checkpoint_write() {
std::string master_dump_name = job_.jobdir() + "/pt_master.dump.h5";
......@@ -470,18 +500,20 @@ void runner_pt_master::pt_param_optimization(pt_chain &chain, pt_chain_run &chai
}
if(chain.histogram_entries() >= chain.entries_before_optimization) {
chain.entries_before_optimization *= po_config_.nsamples_growth;
if(std::any_of(chain_run.last_visited.begin(), chain_run.last_visited.end(), [](int label) { return label == 0; })) {
job_.log(fmt::format("chain {}: some ranks still have no label. holding off parameter optimization due to insufficient statistics", chain.id));
checkpoint_write();
} else {
auto [fnonlinearity, convergence] = chain.optimize_params(po_config_.relaxation_fac);
job_.log(
fmt::format("chain {}: pt param optimization: entries={}, f nonlinearity={:.2g}, "
"convergence={:.2g}",
chain.id, chain.histogram_entries(), fnonlinearity, convergence));
checkpoint_write();
chain.clear_histograms();
auto [fnonlinearity, convergence] = chain.optimize_params(po_config_.relaxation_fac);
job_.log(
fmt::format("chain {}: pt param optimization: entries={}, f nonlinearity={:.2g}, "
"convergence={:.2g}",
chain.id, chain.histogram_entries(), fnonlinearity, convergence));
checkpoint_write();
write_param_optimization_stats();
chain.clear_histograms();
for(auto &cr : pt_chain_runs_) {
if(cr.id == chain.id) {
cr.clear_histograms();
}
}
}
}
......@@ -647,7 +679,6 @@ void runner_pt_slave::start() {
}
if(sys_->sweep() % sweeps_per_global_update_ == 0) {
sys_->pt_measure_statistics();
pt_global_update();
sweeps_since_last_query_++;
......
......@@ -46,6 +46,8 @@ public:
pt_chain_run(const pt_chain &chain, int run_id);
static pt_chain_run checkpoint_read(const iodump::group &g);
void checkpoint_write(const iodump::group &g);
void clear_histograms();
};
int runner_pt_start(jobinfo job, const mc_factory &mccreator, int argc, char **argv);
......@@ -79,6 +81,7 @@ private:
void checkpoint_write();
void checkpoint_read();
void write_params_yaml();
void write_param_optimization_stats();
int schedule_chain_run();
void pt_global_update(pt_chain &chain, pt_chain_run &chain_run);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment