Commit 0e51bdd5 authored by Lukas Weber's avatar Lukas Weber
Browse files

BREAKING: require register_evalables to be a static function

* eliminates a source of bugs (relying on state that was not set in the mc constructor)
* allows speeding up your merges: mine were slow because of all the things that happen in the constructor

However you may need to rearrange your code to calculate all the things you need for your evalables in register_evalables.
parent 61aae46a
......@@ -57,7 +57,7 @@ std::string jobinfo::rundir(int task_id, int run_id) const {
return fmt::format("{}/run{:04d}", taskdir(task_id), run_id);
}
jobinfo::jobinfo(const std::string &jobfile_name) : jobfile{jobfile_name} {
jobinfo::jobinfo(const std::string &jobfile_name, register_evalables_func evalable_func) : evalable_func_{evalable_func}, jobfile{jobfile_name} {
for(auto node : jobfile["tasks"]) {
std::string task_name = node.first;
task_names.push_back(task_name);
......@@ -135,16 +135,14 @@ void jobinfo::concatenate_results() {
cat_results << "]\n";
}
void jobinfo::merge_task(int task_id, const mc_factory &mccreator) {
std::unique_ptr<mc> sys{mccreator(jobfile["tasks"][task_names[task_id]])};
void jobinfo::merge_task(int task_id) {
std::vector<std::string> meas_files = list_run_files(taskdir(task_id), "meas\\.h5");
size_t rebinning_bin_length = jobfile["jobconfig"].get<size_t>("merge_rebin_length", 0);
size_t sample_skip = jobfile["jobconfig"].get<size_t>("merge_sample_skip", 0);
results results = merge(meas_files, rebinning_bin_length, sample_skip);
evaluator eval{results};
sys->register_evalables(eval);
evalable_func_(eval, jobfile["tasks"][task_names[task_id]]);
eval.append_results();
std::string result_filename = fmt::format("{}/results.json", taskdir(task_id));
......
......@@ -2,14 +2,18 @@
#include "evalable.h"
#include "iodump.h"
#include "mc.h"
#include "parser.h"
#include <string>
#include <vector>
namespace loadl {
struct jobinfo {
using register_evalables_func = std::function<void (evaluator &, const parser &)>;
class jobinfo {
private:
register_evalables_func evalable_func_;
public:
parser jobfile;
std::string jobname;
......@@ -18,7 +22,7 @@ struct jobinfo {
double checkpoint_time{};
double runtime{};
jobinfo(const std::string &jobfile_name);
jobinfo(const std::string &jobfile_name, register_evalables_func evalable_func);
std::string jobdir() const;
std::string rundir(int task_id, int run_id) const;
......@@ -27,7 +31,7 @@ struct jobinfo {
static std::vector<std::string> list_run_files(const std::string &taskdir,
const std::string &file_ending);
int read_dump_progress(int task_id) const;
void merge_task(int task_id, const mc_factory &mccreator);
void merge_task(int task_id);
void concatenate_results();
void log(const std::string &message);
};
......
......@@ -7,9 +7,9 @@
namespace loadl {
inline int merge_only(jobinfo job, const mc_factory &mccreator, int, char **) {
inline int merge_only(jobinfo job, const mc_factory &, int, char **) {
for(size_t task_id = 0; task_id < job.task_names.size(); task_id++) {
job.merge_task(task_id, mccreator);
job.merge_task(task_id);
std::cout << fmt::format("-- {} merged\n", job.taskdir(task_id));
}
......@@ -17,8 +17,9 @@ inline int merge_only(jobinfo job, const mc_factory &mccreator, int, char **) {
return 0;
}
inline int run_mc(int (*starter)(jobinfo job, const mc_factory &, int argc, char **argv),
mc_factory mccreator, int argc, char **argv) {
template <typename mc_implementation>
int run_mc(int (*starter)(jobinfo job, const mc_factory &, int argc, char **argv),
int argc, char **argv) {
if(argc < 2) {
std::cerr << fmt::format(
"{0} JOBFILE\n{0} single JOBFILE\n{0} merge JOBFILE\n\n Without further flags, the MPI "
......@@ -29,7 +30,9 @@ inline int run_mc(int (*starter)(jobinfo job, const mc_factory &, int argc, char
}
std::string jobfile{argv[1]};
jobinfo job{jobfile};
auto mccreator = [&](const parser &p) -> mc * { return new mc_implementation{p}; };
jobinfo job{jobfile, &mc_implementation::register_evalables};
// bad hack because hdf5 locking features will happily kill your
// production run in the middle of writing measurements if you block
......@@ -45,14 +48,13 @@ inline int run_mc(int (*starter)(jobinfo job, const mc_factory &, int argc, char
// run this function from main() in your code.
template<class mc_implementation>
int run(int argc, char **argv) {
auto mccreator = [&](const parser &p) -> mc * { return new mc_implementation{p}; };
if(argc > 1 && std::string(argv[1]) == "merge") {
return run_mc(merge_only, mccreator, argc - 1, argv + 1);
return run_mc<mc_implementation>(merge_only, argc - 1, argv + 1);
} else if(argc > 1 && std::string(argv[1]) == "single") {
return run_mc(runner_single_start, mccreator, argc - 1, argv + 1);
return run_mc<mc_implementation>(runner_single_start, argc - 1, argv + 1);
}
return run_mc(runner_mpi_start, mccreator, argc, argv);
return run_mc<mc_implementation>(runner_mpi_start, argc, argv);
}
}
......@@ -42,7 +42,9 @@ public:
size_t sweep() const;
virtual void register_evalables(evaluator &evalables) = 0;
// implement this static function in your class!
//static void register_evalables(evaluator &evalables);
virtual void write_output(const std::string &filename);
// these functions do a little more, like taking care of the
......
......@@ -262,6 +262,6 @@ void runner_slave::merge_measurements() {
std::string unique_filename = job_.taskdir(task_id_);
sys_->write_output(unique_filename);
job_.merge_task(task_id_, mccreator_);
job_.merge_task(task_id_);
}
}
......@@ -788,7 +788,7 @@ void runner_pt_slave::merge_measurements() {
std::string unique_filename = job_.taskdir(task_id_);
sys_->write_output(unique_filename);
job_.merge_task(task_id_, mccreator_);
job_.merge_task(task_id_);
}
}
......@@ -98,6 +98,6 @@ void runner_single::merge_measurements() {
sys_->write_output(unique_filename);
job_.log(fmt::format("merging {}", job_.taskdir(task_id_)));
job_.merge_task(task_id_, mccreator_);
job_.merge_task(task_id_);
}
}
......@@ -29,7 +29,7 @@ void silly_mc::checkpoint_read(const loadl::iodump::group &d) {
d.read("idx", idx_);
}
void silly_mc::register_evalables(loadl::evaluator &eval) {
void silly_mc::register_evalables(loadl::evaluator &eval, const loadl::parser &) {
eval.evaluate("AntimagicNumber", {"MagicNumber", "MagicNumber2"},
[](const std::vector<std::vector<double>> &obs) {
double mag = obs[0][0];
......
......@@ -15,7 +15,7 @@ public:
void checkpoint_write(const loadl::iodump::group &out);
void checkpoint_read(const loadl::iodump::group &in);
void register_evalables(loadl::evaluator &eval);
static void register_evalables(loadl::evaluator &eval, const loadl::parser &p);
silly_mc(const loadl::parser &p);
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment