runner_pt.h 2.68 KB
Newer Older
1 2
#pragma once
#include "runner.h"
3 4
#include <map>
#include <vector>
5 6 7 8 9 10

namespace loadl {

struct pt_chain {
	int id{};
	std::vector<int> task_ids;
11 12 13 14 15
	std::vector<double> params;

	int sweeps{-1};
	int target_sweeps{-1};
	int target_thermalization{-1};
16
	int scheduled_runs{};
Lukas Weber's avatar
Lukas Weber committed
17

18
	// parameter optimization
19
	int entries_before_optimization{0};
20

21
	std::vector<double> rejection_rates;
Lukas Weber's avatar
Lukas Weber committed
22
	std::vector<int> rejection_rate_entries{0,0};
23

24
	bool is_done();
25 26 27 28
	void checkpoint_read(const iodump::group &g);
	void checkpoint_write(const iodump::group &g);

	void clear_histograms();
29
	std::tuple<double, double> optimize_params();
30 31 32 33 34
};

struct pt_chain_run {
private:
	pt_chain_run() = default;
35

36
public:
Lukas Weber's avatar
ehhhh  
Lukas Weber committed
37 38
	int id{};
	int run_id{};
39
	bool swap_odd{};
Lukas Weber's avatar
Lukas Weber committed
40

41
	std::vector<int> rank_to_pos;
42
	std::vector<int> switch_partners;
Lukas Weber's avatar
ehhhh  
Lukas Weber committed
43
	std::vector<double> weight_ratios;
44

45 46 47
	pt_chain_run(const pt_chain &chain, int run_id);
	static pt_chain_run checkpoint_read(const iodump::group &g);
	void checkpoint_write(const iodump::group &g);
48 49 50 51 52 53 54 55 56 57 58
};

int runner_pt_start(jobinfo job, const mc_factory &mccreator, int argc, char **argv);

class runner_pt_master {
private:
	jobinfo job_;
	int num_active_ranks_{0};

	double time_last_checkpoint_{0};

59 60 61 62 63 64 65
	// parameter optimization
	struct {
		bool enabled{};
		int nsamples_initial{};
		double nsamples_growth{};
	} po_config_;

66 67 68 69 70
	std::vector<pt_chain> pt_chains_;
	std::vector<pt_chain_run> pt_chain_runs_;
	int chain_len_;
	std::unique_ptr<random_number_generator> rng_;

71
	std::map<int, int> rank_to_chain_run_;
72 73
	int current_chain_id_{-1};

74 75
	measurements pt_meas_;

76 77 78
	void construct_pt_chains();
	void checkpoint_write();
	void checkpoint_read();
79
	void write_params_json();
80 81
	void write_statistics(const pt_chain_run &chain_run);
	void write_param_optimization_statistics();
82 83

	int schedule_chain_run();
84
	void pt_global_update(pt_chain &chain, pt_chain_run &chain_run);
85
	void pt_param_optimization(pt_chain &chain);
86 87 88

	void react();
	void send_action(int action, int destination);
89
	int assign_new_chain(int rank_section);
90 91 92 93 94 95 96 97 98 99 100 101 102

public:
	runner_pt_master(jobinfo job);
	void start();
};

class runner_pt_slave {
private:
	jobinfo job_;

	mc_factory mccreator_;
	std::unique_ptr<mc> sys_;

103 104 105
	MPI_Comm chain_comm_;
	int chain_rank_{};

106 107 108 109 110 111 112 113 114 115
	double time_last_checkpoint_{0};
	double time_start_{0};

	int rank_{};
	int sweeps_since_last_query_{};
	int sweeps_before_communication_{};
	int sweeps_per_global_update_{};
	int task_id_{-1};
	int run_id_{-1};

116 117
	double current_param_{};

118 119
	void pt_global_update();

120 121
	int negotiate_timeout();

122 123 124 125 126 127 128 129 130 131 132 133
	void send_status(int status);
	int recv_action();
	void checkpoint_write();
	void merge_measurements();
	bool accept_new_chain();
	int what_is_next(int status);

public:
	runner_pt_slave(jobinfo job, mc_factory mccreator);
	void start();
};
}