...
 
Commits (3)
......@@ -1379,17 +1379,9 @@ class TranslationDataset(CachedDataset2):
max_seq = numpy.amax(cur_seq_difficulty)
cur_seq_difficulty = cur_seq_difficulty / max_seq
elif curriculum_learing['norm'] == 'equally_distant':
print('cur_seq_difficulty_at_the_start')
print(cur_seq_difficulty)
idx_sorted = numpy.argsort(cur_seq_difficulty)
print('idx_sorted')
print(idx_sorted)
max_seq = numpy.amax(idx_sorted)
print('max_seq')
print(max_seq)
cur_seq_difficulty = idx_sorted / max_seq
print('cur_seq_difficulty_at_the_end')
print(cur_seq_difficulty)
elif curriculum_learing['norm'] == 'no_norm':
pass
else:
......
......@@ -1422,7 +1422,6 @@ class Engine(EngineBase):
:return: nothing
"""
extra_fetches = None
print("marc in eval model start")
if output_per_seq_file:
assert output_per_seq_file_format in {"txt", "py"}
allowed_outputs = {"seq_tag", "seq_len", "score", "error", "pos_score", "pos_error"}
......@@ -1452,7 +1451,6 @@ class Engine(EngineBase):
if "seq_len" in output_per_seq_format or has_positional_fetch:
extra_fetches["seq_len"] = loss_holder.loss.output.get_sequence_lengths()
if "score" in output_per_seq_format:
print("marc in eval_model in score")
extra_fetches["score"] = loss_holder.get_normalized_loss_value_per_seq()
if "error" in output_per_seq_format:
extra_fetches["error"] = loss_holder.get_normalized_error_value_per_seq()
......@@ -1472,17 +1470,8 @@ class Engine(EngineBase):
:param list[str] seq_tags:
:param dict[str,numpy.ndarray] extra_fetches_out: see extra_fetches
"""
print("seq_idx")
print(seq_idx)
print("seq_tags")
print(seq_tags)
for batch_idx in range(len(seq_idx)):
corpus_seq_idx = dataset.get_corpus_seq_idx(seq_idx[batch_idx])
print("corpus_seq_idx")
print(corpus_seq_idx)
seq_idx_to_tag[corpus_seq_idx] = seq_tags[batch_idx]
for name, value in extra_fetches_out.items():
......@@ -1556,17 +1545,9 @@ class Engine(EngineBase):
f.write(better_repr(results) + '\n')
if output_per_seq_file:
print('Write eval results per seq to %r' % output_per_seq_file, file=log.v3)
print("marc problemzone")
print(seq_idx_to_tag)
print("output_per_seq_format")
print(output_per_seq_format)
print("results_per_seq")
print(results_per_seq)
with open(output_per_seq_file, 'w') as f:
if output_per_seq_file_format == "txt":
for seq_idx in range(len(results_per_seq)):
print("marc seq_idx")
print(seq_idx)
seq_tag = seq_idx_to_tag[seq_idx]
value_list = [results_per_seq[seq_tag][req_out] for req_out in output_per_seq_format]
value_list = [' '.join(map(str, v)) if isinstance(v, numpy.ndarray) else str(v) for v in value_list]
......@@ -1576,8 +1557,6 @@ class Engine(EngineBase):
elif output_per_seq_file_format == "py":
f.write("{\n")
for seq_idx in range(len(results_per_seq)):
print("marc seq_idx")
print(seq_idx)
seq_tag = seq_idx_to_tag[seq_idx]
f.write("%r: {" % seq_tag)
f.write(", ".join([
......