...
 
Commits (130)
......@@ -6,6 +6,10 @@
path = extern/kenlm
url = https://github.com/kpu/kenlm.git
branch = master
[submodule "extern/openfst"]
path = extern/openfst
url = https://github.com/mjansche/openfst.git
branch = 1.7.3
[submodule "extern/ParseOggVorbis"]
path = extern/ParseOggVorbis
url = https://github.com/albertz/ParseOggVorbis.git
......
......@@ -2,13 +2,13 @@
language: python
python:
- "3.6"
- "3.7"
- "2.7"
# https://docs.travis-ci.com/user/reference/overview/
# https://blog.travis-ci.com/2017-04-17-precise-EOL
# Also needed for TensorFlow for glibc 2.7.
dist: trusty
dist: xenial
sudo: false # should use the faster container based image
# command to install dependencies
......@@ -18,6 +18,7 @@ sudo: false # should use the faster container based image
# pip2 without --user (if pip is PY3) also does not work because we don't have root access.
# pip2 with --user (if pip is PY2) does not work, because it the same as pip.
install:
- 'if { python -V 2>&1 | grep -q "Python 3."; } && test -n "$PY3_VER"; then source ~/virtualenv/python$PY3_VER/bin/activate; python -V; fi;'
- pip2 install -qq --upgrade pip setuptools wheel six | cat # Python2<->Python3
- pip2 install --user -r requirements.txt | cat # need for Python2<->Python3 communication tests
- pip2 install --user typing | cat
......@@ -45,8 +46,9 @@ env:
matrix:
- TEST=TFEngine
- TEST=TFNativeOp
- TEST=TFNativeOp TF_PACKAGE=tensorflow==1.8.0
- TEST=TFNativeOp TF_PACKAGE=tensorflow==1.4.0
# There are no Python >=3.7 pip packages for older TF versions.
- TEST=TFNativeOp TF_PACKAGE=tensorflow==1.8.0 PY3_VER=3.6
- TEST=TFNativeOp TF_PACKAGE=tensorflow==1.4.0 PY3_VER=3.6
- TEST=TFNetworkLayer
- TEST=TFNetworkRecLayer
- TEST=TFNetworkSigProcLayer
......@@ -54,7 +56,8 @@ env:
- TEST=TFUtil
- TEST=Config
- TEST=Dataset
- TEST=demos
# Theano using NativeOp is somewhat broken on Python 3.7 in some cases, thus we use Python 3.6.
- TEST=demos PY3_VER=3.6
- TEST=Device
- TEST=EngineTask
- TEST=EngineUtil
......@@ -66,8 +69,8 @@ env:
- TEST=LearningRateControl
- TEST=Log
- TEST=multi_target
- TEST=MultiBatchBeam
- TEST=NativeOp
- TEST=MultiBatchBeam PY3_VER=3.6
- TEST=NativeOp PY3_VER=3.6
- TEST=NativeOp_chunk
- TEST=NativeOp_sparse
- TEST=NativeOp_subtensor_batched_index
......
......@@ -6,30 +6,6 @@ or any changes which could potentially break or change the behavior of existing
This is intentionally kept short. For a full change log, just see the Git log.
## 2019-08-20: Pretrain `#config` can overwrite datasets (`train`, `dev`, `eval`)
## 2019-08-13: `Data` `batch_shape_meta` extra debug repr output
This will show the same information as before, but much more compact,
and also in addition the dimension tags (`DimensionTag`),
which also got improved in many further cases.
## 2019-08-07: overlay nets (`extra_nets`)
You can have e.g. multiple additional networks which redefine
existing layers (they would automatically share params),
which can use different flags (e.g. enable the search flag).
## 2019-07: multiple stochastic (latent) variables
It was designed to support this from the very beginning,
but the implementation was never fully finished for this.
Now examples like hard attention work.
## 2019-05: better support for RETURNN as a framework
`pip install returnn`, and then `import returnn`.
## 2019-03-29: remove hard Theano dependency
## 2019-03-24 and ongoing: automatic linter checks
......
......@@ -38,8 +38,9 @@ class CachedDataset(Dataset):
self.alloc_intervals = None # type: list
self._seq_start = [] # [numpy.array([0,0])] # uses sorted seq idx, see set_batching()
self._seq_index = []; """ :type: list[int] """ # Via init_seq_order(). seq_index idx -> hdf seq idx
self._seq_index_inv = {}; """ :type: dict[int,int] """ # Via init_seq_order(). hdf seq idx -> seq_index idx
self._index_map = range(len(self._seq_index)) # sorted seq idx -> seq_index idx
self._seq_lengths = numpy.zeros((0, 0)) # real seq idx -> tuple of len of data and all targets
self._tags = []; """ :type: list[str|bytes] """ # uses real seq idx. access via _get_tag_by_real_idx
self._tag_idx = {}; ":type: dict[str,int] " # map of tag -> real-seq-idx. call _update_tag_idx
self.targets = {}
self.target_keys = []
......@@ -71,7 +72,7 @@ class CachedDataset(Dataset):
self._update_tag_idx()
seq_index = [self._tag_idx[tag] for tag in seq_list]
else:
seq_index = self.get_seq_order_for_epoch(epoch, self._num_seqs, lambda s: self._get_seq_length_by_real_idx(s)[0])
seq_index = self.get_seq_order_for_epoch(epoch, self._num_seqs, lambda s: self._seq_lengths[s][0])
old_index_map = self._index_map[:]
self._index_map = range(len(seq_index)) # sorted seq idx -> seq_index idx
......@@ -83,29 +84,21 @@ class CachedDataset(Dataset):
# Give some hint to the user in case he is wondering why the cache is reloading.
print("Reinitialize dataset seq order for epoch %i." % epoch, file=log.v4)
if (self.cache_byte_size_limit_at_start == 0
or self.num_seqs_cached_at_start != len(seq_index)
or not self.start_cache_initialized):
if self.num_seqs_cached_at_start != len(seq_index) or not self.start_cache_initialized:
self._seq_index = seq_index
self._seq_index_inv = {} # reset, create later if needed
self._seq_index_inv = dict(zip(seq_index, range(len(seq_index)))) # hdf seq idx -> seq_index idx
self._init_seq_starts()
self._init_alloc_intervals()
self._init_start_cache()
self.start_cache_initialized = True
else:
if not self._seq_index_inv:
self._seq_index_inv = dict(zip(self._seq_index, range(len(self._seq_index)))) # hdf seq idx -> seq_index idx
self._index_map = [self._seq_index_inv[i] for i in seq_index] # sorted seq idx -> seq_index idx
if self._index_map == old_index_map:
return False
return True
def get_current_seq_order(self):
assert self.cache_byte_size_limit_at_start == 0 # not implemented otherwise, we ignore _index_map
return self._seq_index
def _get_tag_by_real_idx(self, real_idx):
raise NotImplementedError
return self._tags[real_idx]
def _update_tag_idx(self):
if self._tag_idx:
......@@ -136,7 +129,7 @@ class CachedDataset(Dataset):
self._seq_start = [self._seq_start[0] * 0] # idx like in seq_index, *not* real idx
for i in range(self.num_seqs):
ids = self._seq_index[i]
self._seq_start.append(self._seq_start[-1] + self._get_seq_length_by_real_idx(ids))
self._seq_start.append(self._seq_start[-1] + self._seq_lengths[ids])
def _init_start_cache(self):
if self.cache_byte_size_limit_at_start == 0:
......@@ -150,7 +143,7 @@ class CachedDataset(Dataset):
cached_bytes = 0
for i in range(self.num_seqs):
if i == num_cached:
nbytes = self.get_seq_length_nd(i)[0] * self.nbytes
nbytes = self.get_seq_length_2d(i)[0] * self.nbytes
if self.cache_byte_size_limit_at_start >= cached_bytes + nbytes:
num_cached = i + 1
cached_bytes += nbytes
......@@ -206,7 +199,7 @@ class CachedDataset(Dataset):
gc.collect()
# Preload as much as we can so that we fill up the cache.
while end < self.num_seqs:
num_needed_cache_frames = self.get_seq_length_nd(end)[0]
num_needed_cache_frames = self.get_seq_length_2d(end)[0]
if self.cache_num_frames_free - num_needed_cache_frames < 0:
break
self.cache_num_frames_free -= num_needed_cache_frames
......@@ -423,7 +416,7 @@ class CachedDataset(Dataset):
if ai[1] > self.num_seqs_cached_at_start and ai[0] < ai[1]:
removed = self.remove_alloc_interval(max(ai[0],self.num_seqs_cached_at_start), ai[1])
self.preload_set -= set(removed)
deleted += sum([self._get_seq_length_by_real_idx(self._seq_index[i])[0] for i in removed])
deleted += sum([self._seq_lengths[self._seq_index[i]][0] for i in removed])
else:
i += 1
return deleted
......@@ -453,27 +446,19 @@ class CachedDataset(Dataset):
return True
return set(range(start,end)) <= self.preload_set
def _get_seq_length_by_real_idx(self, real_seq_idx):
"""
:param int real_seq_idx:
:returns length of the sequence with index 'real_seq_idx'
:rtype: numpy.ndarray
"""
raise NotImplementedError
def get_seq_length_nd(self, sorted_seq_idx):
def get_seq_length_2d(self, sorted_seq_idx):
"""
:type sorted_seq_idx: int
:rtype: numpy.ndarray
:rtype: (int,int)
"""
real_seq_idx = self._seq_index[self._index_map[sorted_seq_idx]]
return self._get_seq_length_by_real_idx(real_seq_idx)
return self._seq_lengths[real_seq_idx]
def get_seq_length(self, seq_idx):
"""
:rtype: NumbersDict
"""
lengths = self.get_seq_length_nd(seq_idx)
lengths = self.get_seq_length_2d(seq_idx)
d = {"data": lengths[0]}
for k, l in zip(self.target_keys, lengths[1:]):
d[k] = l
......@@ -488,7 +473,7 @@ class CachedDataset(Dataset):
def get_times(self, sorted_seq_idx):
seq_start = self.get_seq_start(sorted_seq_idx)[0]
seq_len = self.get_seq_length_nd(sorted_seq_idx)[0]
seq_len = self.get_seq_length_2d(sorted_seq_idx)[0]
return self.timestamps[seq_start:seq_start + seq_len]
def get_input_data(self, sorted_seq_idx):
......@@ -498,7 +483,7 @@ class CachedDataset(Dataset):
alloc_start_seq, alloc_end_seq, alloc_data = self.alloc_intervals[idi]
o = self.get_seq_start(seq_idx)[0] - self.get_seq_start(alloc_start_seq)[0]
assert o >= 0
l = self.get_seq_length_nd(sorted_seq_idx)[0]
l = self.get_seq_length_2d(sorted_seq_idx)[0]
assert alloc_data.shape[0] >= o + l
return alloc_data[o:o + l]
......@@ -511,7 +496,7 @@ class CachedDataset(Dataset):
seq_idx = self._index_map[sorted_seq_idx]
idx = self.target_keys.index(target) + 1
seq_start = self.get_seq_start(seq_idx)[idx]
seq_len = self.get_seq_length_nd(sorted_seq_idx)[idx]
seq_len = self.get_seq_length_2d(sorted_seq_idx)[idx]
return self.targets[target][seq_start:seq_start + seq_len]
def get_target_list(self):
......@@ -536,4 +521,6 @@ class CachedDataset(Dataset):
:return: the sequence index as-is in the original corpus. only defined if self.have_corpus_seq_idx()
:rtype: int
"""
if self.seq_ordering == "default":
return seq_idx
return self._seq_index[self._index_map[seq_idx]]
......@@ -265,7 +265,7 @@ class CachedDataset2(Dataset):
:rtype: str
"""
self._load_something()
return self.added_data[0].get_data(key).dtype
return str(self.added_data[0].get_data(key).dtype)
class SingleStreamPipeDataset(CachedDataset2):
......
......@@ -519,12 +519,15 @@ def get_global_config(raise_exception=True, auto_create=False):
return _global_config
import TaskSystem
import Util
if Util.BackendEngine.is_theano_selected():
import Device
if not TaskSystem.isMainProcess:
# We expect that we are a Device subprocess.
assert Device.asyncChildGlobalDevice is not None
return Device.asyncChildGlobalDevice.config
try:
if Util.BackendEngine.is_theano_selected():
import Device
if not TaskSystem.isMainProcess:
# We expect that we are a Device subprocess.
assert Device.asyncChildGlobalDevice is not None
return Device.asyncChildGlobalDevice.config
except Util.BackendEngine.CannotSelectEngine:
pass # ignore
# We are the main process.
import sys
main_mod = sys.modules["__main__"] # should be rnn.py
......
......@@ -1014,6 +1014,7 @@ class Dataset(object):
def _generate_batches(self, recurrent_net,
batch_size, max_seqs=-1, max_seq_length=sys.maxsize,
max_pad_size=None,
min_seq_length=0, pruning=0.0,
seq_drop=0.0, max_total_num_seqs=-1,
used_data_keys=None):
......@@ -1021,6 +1022,7 @@ class Dataset(object):
:param bool recurrent_net: If True, the batch might have a batch seq dimension > 1.
Otherwise, the batch seq dimension is always 1 and multiple seqs will be concatenated.
:param int|dict[str,int]|NumbersDict batch_size: Max number of frames in one batch.
:param int|dict[str,int]|NumbersDict max_pad_size: Max number of zero-padded frames in one batch.
:param int max_seqs: Max number of seqs per batch.
:param int max_total_num_seqs:
:param int|dict[str,int]|NumbersDict max_seq_length:
......@@ -1030,6 +1032,7 @@ class Dataset(object):
batch_size = sys.maxsize
batch_size = NumbersDict(batch_size)
assert not batch_size.any_compare(NumbersDict(0), (lambda a, b: a <= b))
max_pad_size = NumbersDict(max_pad_size)
if max_seqs == -1:
max_seqs = float('inf')
if not max_seq_length:
......@@ -1074,9 +1077,16 @@ class Dataset(object):
if self.rnd_seq_drop.random() < seq_drop:
continue
dt, ds = batch.try_sequence_as_slice(length)
if ds > 1 and ((dt * ds).any_compare(batch_size, (lambda a, b: a > b)) or ds > max_seqs):
yield batch
batch = Batch()
if batch.num_slices >= 1:
if (dt * ds).any_compare(batch_size, (lambda a, b: a > b)):
yield batch
batch = Batch()
elif ds > max_seqs:
yield batch
batch = Batch()
elif (dt * ds - batch.get_total_num_frames() - length).any_compare(max_pad_size, (lambda a, b: a > b)):
yield batch
batch = Batch()
batch.add_sequence_as_slice(seq_idx=seq_idx, seq_start_frame=t_start, length=length)
else: # Not recurrent.
while t_start.max_value() < t_end.max_value():
......@@ -1231,6 +1241,11 @@ def init_dataset(kwargs, extra_kwargs=None, default_kwargs=None):
if isinstance(kwargs, (str, unicode)):
if kwargs.startswith("{"):
kwargs = eval(kwargs)
elif kwargs.startswith("config:"):
from Config import get_global_config
config = get_global_config()
data = eval(kwargs[len("config:"):], config.typed_dict, config.typed_dict)
return init_dataset(data, extra_kwargs=extra_kwargs, default_kwargs=default_kwargs)
else:
config_str = kwargs
kwargs = {}
......
......@@ -437,9 +437,8 @@ class Engine(EngineBase):
forwarder = ClassificationTaskThread(self.network, self.devices, dataset, batches)
forwarder.join()
output = list(forwarder.result.values())[0][0]
assert output.shape[1] == 1
return output[:, 0]
assert forwarder.output.shape[1] == 1
return forwarder.output[:, 0]
def forward_to_hdf(self, data, output_file, combine_labels='', batch_size=0):
"""
......
......@@ -73,10 +73,14 @@ class EngineBase(object):
assert start_epoch >= 1
load_model_epoch_filename = config.value('load', '')
if load_model_epoch_filename.endswith(".meta"):
load_model_epoch_filename = load_model_epoch_filename[:-len(".meta")]
if load_model_epoch_filename:
assert os.path.exists(load_model_epoch_filename + get_model_filename_postfix())
import_model_train_epoch1 = config.value('import_model_train_epoch1', '')
if import_model_train_epoch1.endswith(".meta"):
import_model_train_epoch1 = import_model_train_epoch1[:-len(".meta")]
if import_model_train_epoch1:
assert os.path.exists(import_model_train_epoch1 + get_model_filename_postfix())
......
......@@ -2609,21 +2609,22 @@ class OggZipDataset(CachedDataset2):
import zipfile
import Util
from MetaDataset import EpochWiseFilter
name, ext = os.path.splitext(os.path.basename(path))
if ext != ".zip" and os.path.isdir(path) and os.path.isfile(path + ".txt"):
if not isinstance(path, list) and os.path.splitext(path)[1] != ".zip" and os.path.isdir(path) and os.path.isfile(path + ".txt"):
# Special case (mostly for debugging) to directly access the filesystem, not via zip-file.
path, name = os.path.dirname(path), os.path.basename(path)
self._zip_file = None
self.paths = [os.path.dirname(path)]
self._names = [os.path.basename(path)]
self._zip_files = None
assert not use_cache_manager, "cache manager only for zip file"
else:
assert ext == ".zip"
self._zip_file = zipfile.ZipFile(path)
kwargs.setdefault("name", name)
self.paths = path if isinstance(path, list) else [path]
for path in self.paths:
assert os.path.splitext(path)[1] == ".zip"
self._names = [os.path.splitext(os.path.basename(path))[0] for path in self.paths]
if use_cache_manager:
self.paths = [Util.cf(path) for path in self.paths]
self._zip_files = [zipfile.ZipFile(path) for path in self.paths]
kwargs.setdefault("name", self._names[0])
super(OggZipDataset, self).__init__(**kwargs)
if use_cache_manager:
assert self._zip_file is not None, "cache manager only for zip file"
path = Util.cf(path)
self.path = path
self._name = name
self.targets = Vocabulary.create_vocab(**targets) if targets is not None else None
if self.targets:
self.labels["classes"] = self.targets.labels
......@@ -2651,21 +2652,24 @@ class OggZipDataset(CachedDataset2):
self._seq_order = None # type: typing.Optional[typing.List[int]]
self.init_seq_order()
def _read(self, filename):
def _read(self, filename, zip_index):
"""
:param str filename: in zip-file
:param int zip_index: index of the zip file to load, unused when loading without zip
:rtype: bytes
"""
if self._zip_file is not None:
return self._zip_file.read(filename)
return open("%s/%s" % (self.path, filename), "rb").read()
if self._zip_files is not None:
return self._zip_files[zip_index].read(filename)
return open("%s/%s" % (self.paths[0], filename), "rb").read()
def _collect_data(self):
def _collect_data_part(self, zip_index):
"""
:return: entries
collect all the entries of a single zip-file or txt file
:param int zip_index: index of the zip-file in self._zip_files, unused when loading without zip
:return: data entries
:rtype: list[dict[str]]
"""
data = eval(self._read("%s.txt" % self._name)) # type: typing.List[typing.Dict[str]]
data = eval(self._read("%s.txt" % self._names[zip_index], zip_index)) # type: typing.List[typing.Dict[str]]
assert data and isinstance(data, list)
first_entry = data[0]
assert isinstance(first_entry, dict)
......@@ -2677,6 +2681,24 @@ class OggZipDataset(CachedDataset2):
else:
assert self.feature_extractor, "feature extraction is enabled, but no audio files are specified"
assert isinstance(first_entry["seq_name"], str)
# add index to data list
for entry in data:
entry['_zip_file_index'] = zip_index
return data
def _collect_data(self):
"""
:return: entries
:rtype: list[dict[str]]
"""
data = []
if self._zip_files:
for zip_index in range(len(self._zip_files)):
zip_data = self._collect_data_part(zip_index)
data += zip_data
else:
# collect data from a txt file
data = self._collect_data_part(0)
return data
def _filter_fixed_random_subset(self, fixed_random_subset):
......@@ -2813,8 +2835,8 @@ class OggZipDataset(CachedDataset2):
"""
import io
seq = self._data[self._get_ref_seq_idx(seq_idx)]
audio_fn = "%s/%s" % (self._name, seq["file"])
raw_bytes = self._read(audio_fn)
audio_fn = "%s/%s" % (self._names[seq['_zip_file_index']], seq["file"])
raw_bytes = self._read(audio_fn, seq['_zip_file_index'])
return io.BytesIO(raw_bytes)
def _collect_single_seq(self, seq_idx):
......
......@@ -36,6 +36,8 @@ class HDFDataset(CachedDataset):
:param bool use_cache_manager: uses :func:`Util.cf` for files
"""
super(HDFDataset, self).__init__(**kwargs)
assert self.partition_epoch == 1 or self.cache_byte_size_total_limit == 0, \
"To use partition_epoch in HDFDatasets, disable caching by setting cache_byte_size=0"
self._use_cache_manager = use_cache_manager
self.files = [] # type: typing.List[str] # file names
self.h5_files = [] # type: typing.List[h5py.File]
......
......@@ -20,7 +20,9 @@ import time
import re
import typing
from random import Random
import math
from sklearn.cluster import KMeans
import h5py
class LmDataset(CachedDataset2):
"""
......@@ -1096,6 +1098,8 @@ class TranslationDataset(CachedDataset2):
self._thread = Thread(name="%r reader" % self, target=self._thread_main)
self._thread.daemon = True
self._thread.start()
self.seq_difficulty = []
def _extend_data(self, k, data_strs):
vocab = self._vocabs[k]
......@@ -1318,6 +1322,152 @@ class TranslationDataset(CachedDataset2):
self._num_seqs = len(self._seq_order)
return True
def get_seq_difficulty(self, curriculum_learning):
mode = curriculum_learning['difficulty']
if mode == 'seq_len_only_source':
cur_seq_difficulty = [len(seq) for seq in self._data[self._main_data_key]]
elif mode == 'seq_len_only_target':
cur_seq_difficulty = [len(seq) for seq in self._data[self._main_classes_key]]
elif mode == 'seq_len_source_and_target':
source_diff = [len(seq) for seq in self._data[self._main_data_key]]
target_diff = [len(seq) for seq in self._data[self._main_classes_key]]
# just add them; will be normalized later on
cur_seq_difficulty = [x + y for x, y in zip(source_diff, target_diff)]
elif mode == 'word_rarity_only_source':
cur_seq_difficulty = self.get_word_rarity('source', curriculum_learning)
elif mode == 'word_rarity_only_target':
cur_seq_difficulty = self.get_word_rarity('target', curriculum_learning)
elif mode == 'word_rarity':
source_seq_difficulty = self.get_word_rarity('source', curriculum_learning)
target_seq_difficulty = self.get_word_rarity('target', curriculum_learning)
# just add them; will be normalized later on
cur_seq_difficulty = [x + y for x, y in zip(source_seq_difficulty, target_seq_difficulty)]
elif mode == 'neg_log_likelihood':
# neg log likelihood should be 0 <= x < infity; the bigger x the easier the sentence -> reverse order
with open(curriculum_learning['neg_log_likelihood_file'], "r") as f:
f1 = f.readlines()
f2 = [float(x) for x in f1]
cur_seq_difficulty = f2
assert self._num_seqs == len(cur_seq_difficulty), "We expect the train_data and the score_file to have the" \
" same length, but {} vs {}".format(self._num_seqs,
len(cur_seq_difficulty))
if 'reverse' in curriculum_learning and curriculum_learning['reverse']:
curriculum_learning['reverse'] = False
else:
curriculum_learning['reverse'] = True
elif mode == 'kmeans':
# sequence length
source_diff = [len(seq) for seq in self._data[self._main_data_key]]
target_diff = [len(seq) for seq in self._data[self._main_classes_key]]
seq_len = [x + y for x, y in zip(source_diff, target_diff)]
# word rarity
source_seq_difficulty = self.get_word_rarity('source', curriculum_learning)
target_seq_difficulty = self.get_word_rarity('target', curriculum_learning)
word_rarity = [x + y for x, y in zip(source_seq_difficulty, target_seq_difficulty)]
# negative log-likelihood
with open(curriculum_learning['neg_log_likelihood_file'], "r") as f:
f1 = f.readlines()
f2 = [float(x) for x in f1]
neg_log_likeli = f2
assert self._num_seqs == len(neg_log_likeli), "We expect the train_data and the score_file to have the" \
" same length, but {} vs {}".format(self._num_seqs,
len(neg_log_likeli))
difficulty_arr = [[x, y, z] for x, y, z in zip(seq_len, word_rarity, neg_log_likeli)]
kmeans = KMeans(n_clusters=curriculum_learning['number_of_clusters'], random_state=0).fit(difficulty_arr)
# ranges from 0 to n_clusters-1
cur_seq_difficulty = kmeans.labels_
if curriculum_learning['competence'] == 'discrete':
curriculum_learning['norm'] = 'no_norm'
else:
raise NotImplementedError("This difficulty mode is not implemented.")
# norm it between 0 and 1
if not 'norm' in curriculum_learning:
curriculum_learning['norm'] = 'equally_distant'
if curriculum_learning['norm'] == 'longest_sentence':
max_seq = numpy.amax(cur_seq_difficulty)
cur_seq_difficulty = cur_seq_difficulty / max_seq
elif curriculum_learning['norm'] == 'equally_distant':
idx_sorted = numpy.argsort(cur_seq_difficulty)
max_seq = numpy.amax(idx_sorted)
cur_seq_difficulty = idx_sorted / max_seq
elif curriculum_learning['norm'] == 'no_norm':
pass
else:
raise NotImplementedError("This norm mode is not implemented.")
if 'reverse' in curriculum_learning and curriculum_learning['reverse']:
cur_seq_difficulty = [1 - difficulty for difficulty in cur_seq_difficulty]
self.seq_difficulty = cur_seq_difficulty
def get_word_rarity(self, mode, curriculum_learning):
if mode == 'target':
data_key = self._main_data_key
else:
data_key = self._main_classes_key
num_total = 0
num_words = {}
for i in range(len(self._data[data_key])):
cur_seq = self._get_data(key=data_key, line_nr=i)
num_total += len(cur_seq)
for word in cur_seq:
if word in num_words:
num_words[word] += 1
else:
num_words[word] = 1
# +1 because every value should be 0 < x < 1
max_num_of_words = (max(num_words.values()) + 1)
for word in num_words:
num_words[word] = num_words[word] / max_num_of_words
cur_seq_difficulty = []
for i in range(len(self._data[data_key])):
cur_seq = self._get_data(key=data_key, line_nr=i)
score = 0
for word in cur_seq:
score -= math.log(num_words[word])
if not 'norm_by_len' in curriculum_learning or curriculum_learning['norm_by_len']:
score = score / len(cur_seq)
cur_seq_difficulty.append(score)
return cur_seq_difficulty
def get_model_competence(self, curriculum_learning):
mode = curriculum_learning['competence']
T = curriculum_learning['total_number_of_iterations']
c0 = curriculum_learning['initial_competence']
if mode == 'linear':
competence = min(1, self.epoch * ((1 - c0) / T) + c0)
elif mode == 'pth-root':
p = curriculum_learning['p']
competence = min(1, (self.epoch * ((1 - c0 ** p) / T) + c0 ** p) ** (1. / p))
elif mode == 'discrete':
assert curriculum_learning['difficulty'] == 'kmeans', "Only use discrete competence with kmeans"
competence = self.epoch * curriculum_learning['number_of_clusters']\
/ curriculum_learning['total_number_of_iterations']
else:
raise NotImplementedError("This competence mode is not implemented.")
return competence
def make_cur_slice(self, curriculum_learning):
competence = self.get_model_competence(curriculum_learning)
if curriculum_learning['total_number_of_iterations'] < self.epoch:
print("Competence is " + str(competence) + "high. From now on usual Training. We are in Epoch "
+ str(self.epoch) + ".")
quit(2222)
self._seq_order = [i for i in range(len(self._data[self._main_data_key])) if (self.seq_difficulty[i] < competence)]
self._num_seqs = len(self._seq_order)
if 'seq_order_len' in curriculum_learning:
if self._num_seqs > curriculum_learning['seq_order_len']:
rnd_seed = 9876543
rnd = Random(rnd_seed)
rnd.shuffle(self._seq_order)
self._seq_order = self._seq_order[:curriculum_learning['seq_order_len']]
self._num_seqs = len(self._seq_order)
if 'n_times_per_slice' in curriculum_learning:
self._seq_order = numpy.tile(self._seq_order, curriculum_learning['n_times_per_slice'])
self._num_seqs = len(self._seq_order)
return True
def _collect_single_seq(self, seq_idx):
if seq_idx >= self._num_seqs:
return None
......
......@@ -30,6 +30,12 @@
#define HOST_FUNC
#endif
#ifdef isinf
#undef isinf
#endif
#ifdef isnan
#undef isnan
#endif
#define assert_cmp(a, cmp, b) \
......@@ -1024,11 +1030,3 @@ void debug_print_shape(OpKernelContext* context, tensorflow::Tensor* tensor, con
}
#endif
#ifdef isinf
#undef isinf
#endif
#ifdef isnan
#undef isnan
#endif
......@@ -3504,7 +3504,7 @@ common_fast_bw_kernels = {
std::ofstream output(path.c_str(), std::ios::trunc | std::ios::out);
for (size_t i1 = 0ul; i1 < n_d1; i1++) {
T val = buffer[i1];
if (!std::numeric_limits<T>::has_infinity or !std::isinf(val)) {
if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {
output << i1 << ' ' << val << '\\n';
}
}
......@@ -3519,7 +3519,7 @@ common_fast_bw_kernels = {
for (size_t i1 = 0ul; i1 < n_d1; i1++) {
for (size_t i2 = 0ul; i2 < n_d2; i2++) {
T val = buffer[i1 * n_d2 + i2];
if (!std::numeric_limits<T>::has_infinity or !std::isinf(val)) {
if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {
output << i1 << ' ' << i2 << ' ' << val << '\\n';
}
}
......@@ -3536,7 +3536,7 @@ common_fast_bw_kernels = {
for (size_t i2 = 0ul; i2 < n_d2; i2++) {
for (size_t i3 = 0ul; i3 < n_d3; i3++) {
T val = buffer[i1 * n_d2 * n_d3 + i2 * n_d3 + i3];
if (!std::numeric_limits<T>::has_infinity or !std::isinf(val)) {
if (!std::numeric_limits<T>::has_infinity or !isinf(val)) {
output << i1 << ' ' << i2 << ' ' << i3 << ' ' << val << '\\n';
}
}
......@@ -3691,7 +3691,7 @@ class FastBaumWelchOp(NativeOpGenBase):
for (unsigned s = start_states[seq]; s <= end_states[seq]; s++) {
const float val = state_buffer[t * n_states + s];
float diff = val - sum;
if (!std::isnan(diff)) {
if (!isnan(diff)) {
sum = -log1p(exp(-abs(diff))) + fminf(sum, val);
}
}
......
......@@ -585,8 +585,8 @@ class Layer(Container):
# In some cases, e.g. forwarding, the target index (for "classes") might have shape[0]==0.
# Or shape[0]==1 with index[0]==0. See Dataset.shapes_for_batches().
# Use source index in that case.
have_zero = ifelse(T.lt(index.shape[0], 1), 1, T.cast(T.le(index.shape[0], 1) * T.eq(T.sum(index[0]), 0), 'int8'))
index = ifelse(have_zero, self.sources[0].index, index)
have_zero = T.le(index.shape[0], 1) * T.eq(T.sum(index[0]), 0)
index = ifelse(have_zero, T.cast(self.sources[0].index,'int8'), T.cast(index,'int8'))
return index
def find_data_layer(self):
......
This diff is collapsed.
......@@ -426,9 +426,9 @@ class Pretrain:
# -------------- Public interface
def __str__(self):
return ("Pretrain construction algo %r, "
"number of pretrain epochs: %i (repetitions: %r)") % (
self._construction_algo, self.get_train_num_epochs(), self.repetitions)
return ("Default layerwise construction+pretraining, starting with input+hidden+output. " +
"Number of pretrain epochs: %i (repetitions: %r)") % (
self.get_train_num_epochs(), self.repetitions)
def get_train_num_epochs(self):
"""
......
......@@ -496,45 +496,21 @@ class FileArchiveBundle:
File archive bundle.
"""
def __init__(self, filename=None):
def __init__(self, filename):
"""
:param str|None filename: .bundle file
:param str filename: .bundle file
"""
# filename -> FileArchive
self.archives = {} # type: typing.Dict[str,FileArchive]
# archive content file -> FileArchive
self.files = {} # type: typing.Dict[str,FileArchive]
self._short_seg_names = {}
if filename is not None:
self.add_bundle(filename=filename)
def add_bundle(self, filename):
"""
:param str filename: bundle
"""
for line in open(filename).read().splitlines():
self.add_archive(filename=line)
def add_archive(self, filename):
"""
:param str filename: single archive
"""
if filename in self.archives:
return
self.archives[filename] = a = FileArchive(filename, must_exists=True)
for f in a.ft.keys():
self.files[f] = a
# noinspection PyProtectedMember
self._short_seg_names.update(a._short_seg_names)
def add_bundle_or_archive(self, filename):
"""
:param str filename:
"""
if filename.endswith(".bundle"):
self.add_bundle(filename)
else:
self.add_archive(filename)
self.archives[line] = a = FileArchive(line, must_exists=True)
for f in a.ft.keys():
self.files[f] = a
# noinspection PyProtectedMember
self._short_seg_names.update(a._short_seg_names)
def file_list(self):
"""
......
......@@ -58,7 +58,7 @@ class SprintDatasetBase(Dataset):
:param dict[str,str|dict] target_maps: e.g. {"speaker": "speaker_map.txt"}
:param bool str_add_final_zero: adds e.g. "orth0" with '\0'-ending
:param float input_stddev: if != 1, will divide the input "data" by that
:param str|list[str]|((str)->str)|None orth_post_process: :func:`get_post_processor_function`, applied on orth
:param str|list[str]|None orth_post_process: :func:`get_post_processor_function`, applied on orth
:param None|dict[str] bpe: if given, will be opts for :class:`BytePairEncoding`
:param None|dict[str] orth_vocab: if given, orth_vocab is applied to orth and orth_classes is an available target`
:param bool suppress_load_seqs_print: less verbose
......@@ -76,19 +76,11 @@ class SprintDatasetBase(Dataset):
self.target_maps = target_maps
self.str_add_final_zero = str_add_final_zero
self.input_stddev = input_stddev
# Note: "orth" is actually the raw bytes of the utf8 string,
# so it does not make quite sense to associate a single str to each byte.
# However, some other code might expect that the labels are all strings, not bytes,
# and the API requires the labels to be strings.
# The code in Dataset.serialize_data tries to decode this case as utf8 (if possible).
self.labels["orth"] = [chr(i) for i in range(255)]
self.orth_post_process = None # type: typing.Optional[typing.Callable[[str],str]]
self.orth_post_process = None
if orth_post_process:
if callable(orth_post_process):
self.orth_post_process = orth_post_process
else:
from LmDataset import get_post_processor_function
self.orth_post_process = get_post_processor_function(orth_post_process)
from LmDataset import get_post_processor_function
self.orth_post_process = get_post_processor_function(orth_post_process)
self.bpe = None
if bpe:
from GeneratingDataset import BytePairEncoding
......@@ -628,37 +620,16 @@ class ExternSprintDataset(SprintDatasetBase):
# This is our workaround. We check for it in self.run_inner().
self.python_exit = False
atexit.register(self._exit_handler)
# We don't know about num_outputs yet, but we should.
# Thus we call Sprint and immediately exit it.
self._start_child(epoch=None, get_dim_only=True)
def finish_epoch(self):
"""
Called at the end of the epoch.
"""
with self.lock:
# Reset epoch such that exiting the child will go smoothly.
super(ExternSprintDataset, self).init_seq_order(epoch=None, seq_list=None)
# Exit child, before we overwrite anything, such as new epoch or seq_list.
self._exit_child(wait_thread=True)
super(ExternSprintDataset, self).finish_epoch()
def _exit_handler(self):
"""
Called at exit.
"""
assert os.getpid() == self.parent_pid
self.python_exit = True
self._exit_child(wait_thread=False)
self.init_seq_order()
def _exit_child(self, wait_thread=True):
"""
:param bool wait_thread:
"""
if self.child_pid:
expected_exit_status = 0 if wait_thread and not self.python_exit else None
expected_exit_status = 0 if not self.python_exit else None
if self._join_child(wait=False, expected_exit_status=expected_exit_status) is False: # Not yet terminated.
interrupt = not self.reached_final_seq_seen_all or not wait_thread
interrupt = not self.reached_final_seq_seen_all
if interrupt:
print("%s: interrupt child proc %s" % (self, self.child_pid), file=log.v5)
os.kill(self.child_pid, signal.SIGKILL)
......@@ -667,7 +638,7 @@ class ExternSprintDataset(SprintDatasetBase):
self.child_pid = None
else: # child process terminated
self.child_pid = None
if wait_thread and self.reader_thread:
if wait_thread:
# Load all remaining data so that the reader thread is not waiting in self.add_new_data().
while self.is_less_than_num_seqs(self.expected_load_seq_start + 1):
if self.reached_final_seq: # this is set by the reader thread
......@@ -684,13 +655,13 @@ class ExternSprintDataset(SprintDatasetBase):
except IOError:
pass
if self.child_pid:
self._join_child(wait=True, expected_exit_status=expected_exit_status)
self._join_child(wait=True, expected_exit_status=0)
self.child_pid = None
def _start_child(self, epoch, get_dim_only=False):
def _start_child(self, epoch):
"""
:param int|None epoch:
:param bool get_dim_only:
:param epoch:
:return:
"""
assert self.child_pid is None
assert self.reader_thread is None
......@@ -740,14 +711,10 @@ class ExternSprintDataset(SprintDatasetBase):
self._exit_child(wait_thread=False)
raise Exception("%s Sprint init failed" % self)
if get_dim_only:
self._exit_child(wait_thread=False)
else:
self.reader_thread = Thread(target=self._reader_thread_proc, args=(pid, epoch),
name="%s reader thread" % self)
self.reader_thread.daemon = True
self.reader_thread.start()
self.reader_thread = Thread(target=self._reader_thread_proc, args=(pid, epoch,),
name="%s reader thread" % self)
self.reader_thread.daemon = True
self.reader_thread.start()
# noinspection PyMethodMayBeStatic
def _pipe_open(self):
......@@ -952,6 +919,14 @@ class ExternSprintDataset(SprintDatasetBase):
# Exceptions are fatal. If we can recover, we should handle it in run_inner().
interrupt_main()
def _exit_handler(self):
"""
Called at exit.
"""
assert os.getpid() == self.parent_pid
self.python_exit = True
self._exit_child(wait_thread=False)
def init_seq_order(self, epoch=None, seq_list=None):
"""
:param int epoch:
......
......@@ -958,6 +958,7 @@ def _forward(segment_name, features):
else:
raise NotImplementedError("unknown backend engine")
# If we have a sequence training criterion, posteriors might be in format (time,seq|batch,emission).
if posteriors.ndim == 3:
assert posteriors.shape == (num_time, 1, OutputDim * MaxSegmentLength)
......
......@@ -871,6 +871,7 @@ class Engine(EngineBase):
self.learning_rate = self.learning_rate_control.default_learning_rate
self.initial_learning_rate = self.learning_rate
self.pretrain_learning_rate = config.float('pretrain_learning_rate', self.learning_rate)
self.curriculum_learning = config.typed_value('curriculum_learning', {'use_curriculum_learning': False})
self.final_epoch = self.config_get_final_epoch(config) # Inclusive.
self.max_seqs = config.int('max_seqs', -1)
self.ctc_prior_file = config.value('ctc_prior_file', None)
......@@ -886,6 +887,7 @@ class Engine(EngineBase):
if isinstance(self.max_seq_length, dict):
self.max_seq_length = NumbersDict(self.max_seq_length)
assert isinstance(self.max_seq_length, (int, float, NumbersDict))
self.max_pad_size = config.typed_value("max_pad_size", None)
# And also initialize the network. That depends on some vars here such as pretrain.
self.init_network_from_config(config)
......@@ -938,6 +940,8 @@ class Engine(EngineBase):
# - SubnetworkLayer also has a load_on_init option.
# - LayerBase has custom_param_importer which is quite flexible.
print("Start pre-loading weights...", file=log.v2)
# model_name will not be used directly, but it defines the order in which we apply the preloading.
# Variables are initialized by the first preload.
for model_name, opts in sorted(self.preload_from_files.items()):
assert isinstance(opts, dict)
if opts.get("init_for_train", False):
......@@ -956,8 +960,10 @@ class Engine(EngineBase):
saveable_params=self.network.get_params_list(),
params_prefix=self_prefix, load_if_prefix=load_if_prefix,
ignore_missing=opts.get("ignore_missing", False))
# `set_as_custom_init` is also a marker for the vars, that they are preloaded,
# such that further checkpoint loaders will not load them again.
loader.set_as_custom_init()
self.network.initialize_params(session=self.tf_session)
loader.load_now(session=self.tf_session)
if model_epoch_filename:
print("loading weights from", model_epoch_filename, file=log.v2)
......@@ -1178,6 +1184,10 @@ class Engine(EngineBase):
print("using batch size: %r, max seqs: %i" % (self.batch_size, self.max_seqs), file=log.v4)
print("learning rate control:", self.learning_rate_control, file=log.v4)
print("pretrain:", self.pretrain, file=log.v4)
if self.curriculum_learning['use_curriculum_learning']:
self.train_data.get_seq_difficulty(self.curriculum_learning)
self.dataset_batches.clear()
assert self.start_epoch >= 1, "Epochs start at 1."
......@@ -1205,6 +1215,12 @@ class Engine(EngineBase):
# In case of random seq ordering, we want to reorder each epoch.
if self.train_data.init_seq_order(epoch=self.epoch):
self.dataset_batches.pop("train", None)
if self.curriculum_learning['use_curriculum_learning']:
self.dataset_batches.pop("train", None)
self.curriculum_learning['use_curriculum_learning'] = self.train_data.make_cur_slice(self.curriculum_learning)
for dataset_name, dataset in self.get_eval_datasets().items():
if dataset.init_seq_order(epoch=self.epoch):
self.dataset_batches.pop(dataset_name, None)
......@@ -1308,6 +1324,7 @@ class Engine(EngineBase):
batch_size=self.batch_size,
max_seqs=self.max_seqs,
max_seq_length=self.max_seq_length,
max_pad_size=self.max_pad_size,
seq_drop=self.seq_drop,
shuffle_batches=self.shuffle_batches,
used_data_keys=self.network.get_used_data_keys())
......@@ -1317,7 +1334,10 @@ class Engine(EngineBase):
train_batches = self.dataset_batches['train']
self.updater.set_learning_rate(self.learning_rate, session=self.tf_session)
trainer = Runner(engine=self, dataset=self.train_data, batches=train_batches, train=True)
trainer = Runner(
engine=self,
dataset=self.train_data, batches=train_batches,
train=self.network.layers_desc.get("#trainable", True))
trainer.run(report_prefix=("pre" if self.is_pretrain_epoch() else "") + "train epoch %s" % self.epoch)
if not trainer.finalized:
......
......@@ -11,6 +11,13 @@ import tensorflow as tf
returnn_dir = os.path.dirname(os.path.abspath(__file__))
kenlm_dir = returnn_dir + "/extern/kenlm"
def kenlm_checked_out():
"""
:rtype: bool
"""
return os.path.exists("%s/lm/test.arpa" % kenlm_dir)
# https://www.tensorflow.org/guide/extend/op
# Also see TFUitl.TFArrayContainer for TF resources.
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor.h
......@@ -351,6 +358,7 @@ def get_tf_mod(verbose=False):
# https://github.com/kpu/kenlm/blob/master/compile_query_only.sh
# Collect files.
assert kenlm_checked_out(), "submodule in %r not checked out?" % kenlm_dir
files = glob('%s/util/*.cc' % kenlm_dir)
files += glob('%s/lm/*.cc' % kenlm_dir)
files += glob('%s/util/double-conversion/*.cc' % kenlm_dir)
......
......@@ -272,7 +272,8 @@ class TFNetwork(object):
:param bool search_flag: whether we perform a beam-search. see usage
:param TFNetworkLayer.LayerBase|None parent_layer:
:param TFNetwork|None parent_net:
:param TFNetwork|None extra_parent_net:
:param TFNetwork|None extra_parent_net: we are on the same level (not really a child),
but an "extra" net of extra_parent_net
:param bool is_inside_rec_layer: at template construction, use this
:param str|None absolute_name_prefix:
:param str name: only for debugging
......@@ -342,7 +343,7 @@ class TFNetwork(object):
self.recurrent = False
self._assigner_cache = {} # type: typing.Dict[tf.Variable,VariableAssigner]
self.concat_sources_dropout_cache = {} # type: typing.Dict[typing.Tuple[typing.Tuple[LayerBase,...],float,typing.Optional[typing.Tuple[typing.Optional[int],...]]],Data] # nopep8
self._batch_dim = None # see get_batch_dim
self._batch_dim = None # see get_data_batch_dim
self._merge_all_summaries = None # type: typing.Optional[tf.Tensor]
self._graph_reset_callbacks = [] # type: typing.List[typing.Callable]
......@@ -397,7 +398,9 @@ class TFNetwork(object):
if self.parent_net:
return self.parent_net.get_absolute_name_prefix()
if self.extra_parent_net:
return self.extra_parent_net.get_absolute_name_prefix()
prefixes = {net: prefix for (prefix, net) in self.extra_parent_net.extra_nets.items()}
my_prefix = ("%s:" % prefixes[self]) if self in prefixes else ""
return self.extra_parent_net.get_absolute_name_prefix() + my_prefix
return ""
def construct_from(self, list_or_dict):
......@@ -1209,24 +1212,30 @@ class TFNetwork(object):
ls.append(param)
return ls
def declare_train_params(self, hidden_layer_selection=None, with_output=None):
def declare_train_params(self, hidden_layer_selection=None, with_output=None, global_trainable=None):
"""
:param list[str]|None hidden_layer_selection:
:param bool|None with_output:
:param bool|None global_trainable:
"""
if hidden_layer_selection is None:
hidden_layer_selection = [name for (name, layer) in self.layers.items() if not layer.is_output_layer()]
if global_trainable is None:
global_trainable = self.layers_desc.get("#trainable", True)
if global_trainable:
if hidden_layer_selection is None:
hidden_layer_selection = [name for (name, layer) in self.layers.items() if not layer.is_output_layer()]
else:
hidden_layer_selection = list(hidden_layer_selection)
if with_output is None:
with_output = True
if with_output:
hidden_layer_selection += [name for (name, layer) in self.layers.items() if layer.is_output_layer()]
hidden_layer_selection = set(hidden_layer_selection)
else:
hidden_layer_selection = list(hidden_layer_selection)
if with_output is None:
with_output = True
if with_output:
hidden_layer_selection += [name for (name, layer) in self.layers.items() if layer.is_output_layer()]
hidden_layer_selection = set(hidden_layer_selection)
hidden_layer_selection = set()
self._selected_train_layers = sorted(hidden_layer_selection)
if self.extra_nets:
for _, extra_net in self.extra_nets.items():
extra_net.declare_train_params() # select all, currently...
extra_net.declare_train_params(global_trainable=global_trainable) # select all, currently...
def get_num_params(self):
"""
......@@ -1607,6 +1616,9 @@ class TFNetwork(object):
normalized_src = src.get_normalized_layer()
if normalized_src != src:
assert _normalized_to_layer.setdefault(normalized_src, src) == src # Currently expecting that this is unique.
if src.search_choices:
assert normalized_src.search_choices, "normalized %s vs %s (choices %s)" % (
normalized_src, src, src.search_choices)
if src.search_choices:
if src.search_choices.is_decided:
return []
......@@ -1635,13 +1647,16 @@ class TFNetwork(object):
if layer not in layers:
layers.append(layer)
if not layers:
if self.parent_layer:
# Use parent layer if available.
# Note that we should not mix layers from different context frames,
# e.g. inside and outside a rec loop, as the search choices cannot be compared.
if self.parent_layer and not self.is_inside_rec_layer():
# noinspection PyProtectedMember
return self.parent_layer.network._get_all_search_choices(sources=self.parent_layer.get_dep_layers())
return []
if base_search_choice is not None:
normalized_base = base_search_choice.get_normalized_layer()
if normalized_base != base_search_choice:
if normalized_base != base_search_choice: # from prev frame or so
# Just make sure we visit these as well.
normalized_choices = self._get_all_search_choices(
base_search_choice=normalized_base,
......@@ -1653,9 +1668,10 @@ class TFNetwork(object):
# Get corresponding "prev:..." layers.
from pprint import pformat
assert all([l in _normalized_to_layer for l in normalized_choices]), "\n".join([
"No cur -> prev mapping for some layers.", "Base: %s" % base_search_choice,
"Prev choices:", pformat(layers),
"Cur choices:", pformat(normalized_choices), "Mapping:", pformat(_normalized_to_layer)])
"No cur -> prev mapping for some layers.", "",
"Base: %s" % base_search_choice, "", "Cur (normalized) base: %s" % normalized_base, "",
"Prev choices:", pformat(layers), "", "Cur (normalized) choices:", pformat(normalized_choices), "",
"Mapping:", pformat(_normalized_to_layer), ""])
layers = [_normalized_to_layer[l] for l in normalized_choices]
_layer_to_search_choices[base_search_choice] = layers
return layers
......@@ -1710,13 +1726,14 @@ class TFNetwork(object):
:rtype: int|tf.Tensor
"""
from TFUtil import get_shape_dim, reuse_name_scope_of_tensor
if self._batch_dim is not None:
return self._batch_dim
# First check parent because there we might get the true batch dim.
# (Do not check parent_layer, because that potentially includes a beam.)
if self.parent_net:
return self.parent_net.get_data_batch_dim()
if self.extra_parent_net:
return self.extra_parent_net.get_data_batch_dim()
if self._batch_dim is not None:
return self._batch_dim
for key, data in self.extern_data.get_sorted_data_items():
assert isinstance(data, Data)
if data.available_for_inference:
......@@ -2250,14 +2267,19 @@ class CannotHandleUndefinedSourcesException(Exception):
Raised when some layer gets None (undefined) source(s) (because e.g. in RecLayer template construction),
and cannot handle it (e.g. cannot infer the out_type in that case).
"""
def __init__(self, layer_name, layer_desc):
def __init__(self, layer_name, layer_desc, extended_info_str=None):
"""
:param str layer_name:
:param dict[str] layer_desc:
:param str|None extended_info_str:
"""
from pprint import pformat
super(CannotHandleUndefinedSourcesException, self).__init__(
"%r: cannot handle undefined sources without defined out_type.\n%s" % (layer_name, pformat(layer_desc)))
info_strs = [
"%r: cannot handle undefined sources without defined out_type." % layer_name,
pformat(layer_desc)]
if extended_info_str:
info_strs.append(extended_info_str)
super(CannotHandleUndefinedSourcesException, self).__init__("\n".join(info_strs))
self.layer_name = layer_name
self.layer_desc = layer_desc
......@@ -2521,15 +2543,17 @@ class CustomCheckpointLoader:
self.params_prefix = params_prefix
self.load_if_prefix = load_if_prefix
self.saveable_params = []
count = 0
for param in saveable_params:
if load_if_prefix and self._get_param_name(param, assert_load_if_prefix_match=False) is None:
continue
count += 1
custom_post_init = getattr(param, "custom_post_init", None)
if custom_post_init:
print("Not loading pre-initialized variables %s" % param, file=log.v2)
continue
if load_if_prefix and self._get_param_name(param, assert_load_if_prefix_match=False) is None:
print("%s: Not loading pre-initialized variable %s" % (self, param), file=log.v2)
continue
self.saveable_params.append(param)
assert self.saveable_params, "no saveable vars"
assert count > 0, "%s: no saveable vars" % self
self.reader = tf.train.NewCheckpointReader(filename)
self.net_vars = [v for v in self.saveable_params if isinstance(v, tf.Variable)]
self.net_saveables = [v for v in self.saveable_params if not isinstance(v, tf.Variable)]
......
This diff is collapsed.
This diff is collapsed.
......@@ -147,7 +147,7 @@ class ComplexToAlternatingRealLayer(_ConcatInputLayer):
self.output.placeholder = _interleaveVectors(real_value, imag_value)
self.output.size_placeholder = {0: self.input_data.size_placeholder[self.input_data.time_dim_axis_excluding_batch]}
class MaskBasedGevBeamformingLayer(LayerBase):
"""
This layer applies GEV beamforming to a multichannel signal. The different
......@@ -425,7 +425,7 @@ class MultiChannelMultiResolutionStftLayer(_ConcatInputLayer):
return tf.concat([input_signal, tf.ones([get_shape(input_signal)[0], frame_size-self._reference_frame_size, get_shape(input_signal)[2]])*1e-7], axis=1)
else:
return input_placeholder
input_signal = _padTimeSignal(input_placeholder, frame_size)
if self._use_rfft:
channel_wise_stft = tf.contrib.signal.stft(
......@@ -461,9 +461,7 @@ class MultiChannelMultiResolutionStftLayer(_ConcatInputLayer):
n_out = np.sum([cls._get_n_out_by_fft_config(fft_size, use_rfft, nr_of_channels) for fft_size in fft_sizes])
if 'n_out' not in kwargs:
kwargs['n_out'] = n_out
return (super(MultiChannelMultiResolutionStftLayer, cls)
.get_out_data_from_opts(**kwargs)
.copy_template(dtype="complex64"))
return super(MultiChannelMultiResolutionStftLayer, cls).get_out_data_from_opts(**kwargs)
class MultiChannelStftLayer(MultiChannelMultiResolutionStftLayer):
......@@ -475,7 +473,7 @@ class MultiChannelStftLayer(MultiChannelMultiResolutionStftLayer):
def __init__(self, frame_shift, frame_size, fft_size, window="hanning", use_rfft=True, nr_of_channels=1, pad_last_frame=False, **kwargs):
kwargs['frame_shift'] = frame_shift
kwargs['window'] = window
kwargs['window'] = window
kwargs['use_rfft'] = use_rfft
kwargs['nr_of_channels'] = nr_of_channels
kwargs['pad_last_frame'] = pad_last_frame
......@@ -483,9 +481,7 @@ class MultiChannelStftLayer(MultiChannelMultiResolutionStftLayer):
@classmethod
def get_out_data_from_opts(cls, fft_size, use_rfft=True, nr_of_channels=1, **kwargs):
return (super(MultiChannelStftLayer, cls)
.get_out_data_from_opts(fft_sizes=[fft_size], use_rfft=use_rfft, nr_of_channels=nr_of_channels, **kwargs)
.copy_template(dtype="complex64"))
return super(MultiChannelStftLayer, cls).get_out_data_from_opts(fft_sizes=[fft_size], use_rfft=use_rfft, nr_of_channels=nr_of_channels, **kwargs)
class NoiseEstimationByFirstTFramesLayer(_ConcatInputLayer):
......
This diff is collapsed.
......@@ -16,59 +16,33 @@ from TFNetwork import TFNetwork
from TFUtil import tf_version_tuple, assert_min_tf_version, CustomUpdate, add_check_numerics_ops, \
get_non_deterministic_ops_from_graph
_OptimizerClassesDictInitialized = False
_OptimizerClassesDict = {} # type: typing.Dict[str,typing.Callable[[],Optimizer]]
def _init_optimizer_classes_dict():
global _OptimizerClassesDictInitialized
if _OptimizerClassesDictInitialized:
return
_OptimizerClassesDictInitialized = True
potential_list = list(vars(tf.train).items())
if tf_version_tuple() >= (1, 2, 0):
from tensorflow.contrib import opt
potential_list += list(vars(opt).items())
potential_list += list(globals().items())
for name, v in potential_list:
assert isinstance(name, str)
if v is Optimizer:
continue
if not isinstance(v, type) or not issubclass(v, Optimizer):
continue
register_optimizer_class(v, name=name)
def register_optimizer_class(cls, name=None):