Commit ec60b5ae authored by Lukas Weber's avatar Lukas Weber

fix scatter and improve logging

parent ba0bc95b
......@@ -94,7 +94,7 @@ static double linear_regression(const std::vector<double> &x, const std::vector<
return (sxy - sx * sy / n) / (sx2 - sx * sx / n);
}
void pt_chain::optimize_params(int linreg_len) {
std::tuple<double, double> pt_chain::optimize_params(int linreg_len) {
std::vector<double> eta(params.size());
std::vector<double> f(params.size());
std::vector<double> new_params(params.size());
......@@ -104,13 +104,18 @@ void pt_chain::optimize_params(int linreg_len) {
new_params[0] = params[0];
new_params[params.size() - 1] = params[params.size() - 1];
double flinearity = 0;
for(size_t i = 0; i < params.size(); i++) {
if(nup_histogram[i] + ndown_histogram[i] == 0) {
f[i] = 0;
} else {
f[i] = nup_histogram[i] / static_cast<double>(nup_histogram[i] + ndown_histogram[i]);
}
double ideal_f = 1-i/static_cast<double>(params.size());
flinearity += (f[i]-ideal_f)*(f[i]-ideal_f);
}
flinearity = sqrt(flinearity)/params.size();
double norm = 0;
for(size_t i = 0; i < params.size() - 1; i++) {
......@@ -122,6 +127,8 @@ void pt_chain::optimize_params(int linreg_len) {
for(auto &v : eta) {
v /= norm;
}
double convergence = 0;
for(size_t i = 1; i < params.size() - 1; i++) {
double target = static_cast<double>(i) / (params.size() - 1);
int etai = 0;
......@@ -133,8 +140,17 @@ void pt_chain::optimize_params(int linreg_len) {
target -= deta;
}
new_params[i] = params[etai] + target / eta[etai];
convergence += (new_params[i]-params[i])*(new_params[i]-params[i]);
}
params = new_params;
convergence = sqrt(convergence)/(params.size()-2);
for(size_t i = 0; i < params.size(); i++) {
double relaxation_fac = 1;
params[i] = params[i]*(1-relaxation_fac) + relaxation_fac*new_params[i];
}
return std::tie(flinearity, convergence);
}
bool pt_chain::is_done() {
......@@ -237,7 +253,7 @@ void runner_pt_master::construct_pt_chains() {
c.nup_histogram.resize(chain_len_);
c.ndown_histogram.resize(chain_len_);
c.entries_before_optimization =
job_.jobfile["jobconfig"].get<int>("pt_parameter_optimization_nsamples_initial", 500);
job_.jobfile["jobconfig"].get<int>("pt_parameter_optimization_nsamples_initial", 10000);
}
if(chain_len_ == -1) {
throw std::runtime_error{
......@@ -338,7 +354,7 @@ void runner_pt_master::start() {
for(int i = 1; i < num_active_ranks_; i++) {
group_idx[i] = (i - 1) / chain_len_;
}
MPI_Scatter(group_idx.data(), 1, MPI_INT, group_idx.data(), 1, MPI_INT, MASTER, MPI_COMM_WORLD);
MPI_Scatter(group_idx.data(), 1, MPI_INT, MPI_IN_PLACE, 1, MPI_INT, MASTER, MPI_COMM_WORLD);
MPI_Comm tmp;
MPI_Comm_split(MPI_COMM_WORLD, MPI_UNDEFINED, 0, &tmp);
......@@ -407,6 +423,34 @@ int runner_pt_master::assign_new_chain(int rank_section) {
return chain_run_id;
}
void runner_pt_master::pt_param_optimization(pt_chain &chain, pt_chain_run &chain_run) {
for(size_t rank = 0; rank < chain_run.rank_to_pos.size(); rank++) {
if(chain_run.rank_to_pos[rank] == 0) {
chain_run.last_visited[rank] = 1;
}
if(chain_run.rank_to_pos[rank] ==
static_cast<int>(chain_run.rank_to_pos.size()) - 1) {
chain_run.last_visited[rank] = -1;
}
chain.ndown_histogram[chain_run.rank_to_pos[rank]] +=
chain_run.last_visited[rank] == -1;
chain.nup_histogram[chain_run.rank_to_pos[rank]] +=
chain_run.last_visited[rank] == 1;
}
chain.histogram_entries++;
if(chain.histogram_entries >= chain.entries_before_optimization) {
auto [flinearity, convergence] = chain.optimize_params(job_.jobfile["jobconfig"].get<int>(
"pt_parameter_optimization_linreg_len", 2));
chain.clear_histograms();
job_.log(fmt::format("chain {}: pt param optimization: entries={}, f linearity={:.2g}, convergence={:.2g}",
chain.id, chain.entries_before_optimization, flinearity, convergence));
chain.entries_before_optimization *= job_.jobfile["jobconfig"].get<double>(
"pt_parameter_optimization_nsamples_growth", 1.5);
}
}
void runner_pt_master::react() {
int rank_status;
MPI_Status stat;
......@@ -450,31 +494,7 @@ void runner_pt_master::react() {
int rank_section = rank / chain_len_;
if(use_param_optimization_) {
for(size_t rank = 0; rank < chain_run.rank_to_pos.size(); rank++) {
if(chain_run.rank_to_pos[rank] == 0) {
chain_run.last_visited[rank] = 1;
}
if(chain_run.rank_to_pos[rank] ==
static_cast<int>(chain_run.rank_to_pos.size()) - 1) {
chain_run.last_visited[rank] = -1;
}
chain.ndown_histogram[chain_run.rank_to_pos[rank]] +=
chain_run.last_visited[rank] == -1;
chain.nup_histogram[chain_run.rank_to_pos[rank]] +=
chain_run.last_visited[rank] == 1;
}
chain.histogram_entries++;
if(chain.histogram_entries >= chain.entries_before_optimization) {
chain.optimize_params(job_.jobfile["jobconfig"].get<int>(
"pt_parameter_optimization_regression_len", 2));
chain.clear_histograms();
job_.log(fmt::format("chain {}: pt feedback optimization: {:.2g}, entries: {}",
chain.id, fmt::join(chain.params, ", "),
chain.entries_before_optimization));
chain.entries_before_optimization *= job_.jobfile["jobconfig"].get<double>(
"pt_parameter_optimization_nsamples_growth", 1.5);
}
pt_param_optimization(chain, chain_run);
}
for(int target = 0; target < chain_len_; target++) {
......@@ -552,7 +572,7 @@ void runner_pt_slave::start() {
time_last_checkpoint_ = time_start_;
int group_idx;
MPI_Scatter(&group_idx, 1, MPI_INT, &group_idx, 1, MPI_INT, MASTER, MPI_COMM_WORLD);
MPI_Scatter(NULL, 1, MPI_INT, &group_idx, 1, MPI_INT, MASTER, MPI_COMM_WORLD);
MPI_Comm_split(MPI_COMM_WORLD, group_idx, 0, &chain_comm_);
MPI_Comm_rank(chain_comm_, &chain_rank_);
......
......@@ -26,7 +26,7 @@ struct pt_chain {
void checkpoint_write(const iodump::group &g);
void clear_histograms();
void optimize_params(int linreg_len);
std::tuple<double, double> optimize_params(int linreg_len);
};
struct pt_chain_run {
......@@ -73,6 +73,7 @@ private:
int schedule_chain_run();
void pt_global_update(pt_chain &chain, pt_chain_run &chain_run);
void pt_param_optimization(pt_chain &chain, pt_chain_run &chain_run);
void react();
void send_action(int action, int destination);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment