Commit 798b506b authored by Lukas Weber's avatar Lukas Weber

parallel tempering wip

parent 296e2f1e
......@@ -24,6 +24,7 @@ void mc::_init() {
if(param.get<bool>("pt_statistics", false)) {
measure.add_observable("_ll_pt_rank", 1);
measure.add_observable("_ll_pt_weight_ratio", 1);
}
if(param.defined("seed")) {
......@@ -68,10 +69,10 @@ void mc::_do_update() {
void mc::_pt_update_param(double new_param, const std::string &new_dir) {
// take over the bins of the new target dir
{
iodump dump_file = iodump::create(new_dir + ".dump.h5.tmp");
/*{
iodump dump_file = iodump::open_readonly(new_dir + ".dump.h5");
measure.checkpoint_read(dump_file.get_root().open_group("measurements"));
}
}*/
if(param.get<bool>("pt_statistics", false)) {
int rank;
......@@ -81,6 +82,14 @@ void mc::_pt_update_param(double new_param, const std::string &new_dir) {
pt_update_param(new_param);
}
double mc::_pt_weight_ratio(double new_param) {
double wr = pt_weight_ratio(new_param);
if(param.get<bool>("pt_statistics", false)) {
measure.add("_ll_pt_weight_ratio", wr);
}
return wr;
}
void mc::_write(const std::string &dir) {
struct timespec tstart, tend;
clock_gettime(CLOCK_MONOTONIC_RAW, &tstart);
......
......@@ -34,20 +34,21 @@ protected:
virtual void pt_update_param(double /*new_param*/) {
throw std::runtime_error{"running parallel tempering, but pt_update_param not implemented"};
}
virtual double pt_weight_ratio(double /*new_param*/) {
throw std::runtime_error{"running parallel tempering, but pt_weight_ratio not implemented"};
return 1;
}
public:
double random01();
int sweep() const;
virtual void register_evalables(std::vector<evalable> &evalables) = 0;
virtual double pt_weight_ratio(double /*new_param*/) {
throw std::runtime_error{"running parallel tempering, but pt_weight_ratio not implemented"};
return 1;
}
// these functions do a little more, like taking care of the
// random number generator state, then call the child class versions.
void _init();
void _write(const std::string &dir);
bool _read(const std::string &dir);
......@@ -56,8 +57,12 @@ public:
void _do_update();
void _do_measurement();
void _pt_update_param(double new_param, const std::string &new_dir);
double _pt_weight_ratio(double new_param);
double safe_exit_interval();
// write only measurement data (useful for parallel tempering)
void measurement_write(const std::string &dir);
bool is_thermalized();
measurements measure;
......
......@@ -334,7 +334,7 @@ void runner_pt_master::react() {
if(partner_pos < 0 || partner_pos >= chain_len_) {
int response = GA_SKIP;
MPI_Send(&response, 1, MPI_INT, node + 1, T_GLOBAL, MPI_COMM_WORLD);
chain_run.weight_ratios[chain_run.node_to_pos[node % chain_len_]] = 1;
chain_run.weight_ratios[pos] = 1;
} else {
int response = GA_CALC_WEIGHT;
MPI_Send(&response, 1, MPI_INT, node + 1, T_GLOBAL, MPI_COMM_WORLD);
......@@ -345,7 +345,7 @@ void runner_pt_master::react() {
double weight;
MPI_Recv(&weight, 1, MPI_DOUBLE, node + 1, T_GLOBAL, MPI_COMM_WORLD, &stat);
assert(weight >= 0);
chain_run.weight_ratios[chain_run.node_to_pos[node % chain_len_]] = weight;
chain_run.weight_ratios[pos] = weight;
}
bool all_ready =
......@@ -381,7 +381,13 @@ void runner_pt_master::pt_global_update(pt_chain &chain, pt_chain_run &chain_run
double r = rng_->random_double();
if(r < w1 * w2) {
std::swap(chain_run.node_to_pos[i], chain_run.node_to_pos[i + 1]);
for(auto &p : chain_run.node_to_pos) {
if(p == i) {
p = i+1;
} else if(p == i+1) {
p = i;
}
}
}
}
......@@ -411,6 +417,7 @@ void runner_pt_slave::start() {
}
if(sys_->sweep() % sweeps_per_global_update_ == 0) {
checkpoint_write();
pt_global_update();
}
......@@ -444,20 +451,20 @@ void runner_pt_slave::pt_global_update() {
int response;
MPI_Recv(&response, 1, MPI_INT, MASTER, T_GLOBAL, MPI_COMM_WORLD, &stat);
job_.log(fmt::format(" * rank {}: ready for global update", rank_));
//job_.log(fmt::format(" * rank {}: ready for global update", rank_));
if(response == GA_DONE) {
job_.log(fmt::format(" * rank {}: everything done", rank_));
//job_.log(fmt::format(" * rank {}: everything done", rank_));
time_last_checkpoint_ = 0; // time to call back!
return;
} else if(response == GA_CALC_WEIGHT) {
double partner_param;
MPI_Recv(&partner_param, 1, MPI_DOUBLE, MASTER, T_GLOBAL, MPI_COMM_WORLD, &stat);
double weight_ratio = sys_->pt_weight_ratio(partner_param);
double weight_ratio = sys_->_pt_weight_ratio(partner_param);
MPI_Send(&weight_ratio, 1, MPI_DOUBLE, MASTER, T_GLOBAL, MPI_COMM_WORLD);
job_.log(fmt::format(" * rank {}: weight sent", rank_));
//job_.log(fmt::format(" * rank {}: weight sent", rank_));
} else {
job_.log(fmt::format(" * rank {}: no weight needed", rank_));
//job_.log(fmt::format(" * rank {}: no weight needed", rank_));
}
......@@ -470,7 +477,6 @@ void runner_pt_slave::pt_global_update() {
job_.jobfile["tasks"][job_.task_names[task_id_]].get<int>("pt_sweeps_per_global_update");
sys_->_pt_update_param(new_param, job_.rundir(task_id_, run_id_));
job_.log(fmt::format(" * rank {}: global update received param {}", rank_, new_param));
}
bool runner_pt_slave::accept_new_chain() {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment