Commit e680d6bb authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'develop' into 'master'

Shared code, updated for CNNArchLang, etc.

See merge request !17
parents ff353c4f 7184ebe9
Pipeline #158838 passed with stages
in 6 minutes and 11 seconds
......@@ -15,9 +15,9 @@
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.0-SNAPSHOT</CNNArch.version>
<CNNArch.version>0.3.1-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.4-SNAPSHOT</CNNTrain.version>
<CNNArch2MXNet.version>0.2.16-SNAPSHOT</CNNArch2MXNet.version>
<CNNArch2X.version>0.0.2-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.1</EMADL2PythonWrapper.version>
......@@ -63,8 +63,8 @@
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-mxnet-generator</artifactId>
<version>${CNNArch2MXNet.version}</version>
<artifactId>cnnarch-generator</artifactId>
<version>${CNNArch2X.version}</version>
</dependency>
<dependency>
......@@ -116,6 +116,12 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.github.stefanbirkner</groupId>
<artifactId>system-rules</artifactId>
<version>1.3.0</version>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
......
......@@ -21,23 +21,27 @@
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.IOSymbol;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNet;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.Target;
import de.monticore.lang.monticar.cnnarch.generator.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.generator.Target;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration;
import de.se_rwth.commons.logging.Log;
import java.util.HashMap;
import java.util.Map;
public class CNNArch2Gluon extends CNNArch2MxNet {
public class CNNArch2Gluon extends CNNArchGenerator {
public CNNArch2Gluon() {
architectureSupportChecker = new CNNArch2GluonArchitectureSupportChecker();
layerSupportChecker = new CNNArch2GluonLayerSupportChecker();
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
@Override
public Map<String, String> generateStrings(ArchitectureSymbol architecture){
Map<String, String> fileContentMap = compileFileContentMap(architecture);
checkValidGeneration(architecture);
return fileContentMap;
}
......@@ -85,6 +89,9 @@ public class CNNArch2Gluon extends CNNArch2MxNet {
temp = controller.process("CNNPredictor", Target.CPP);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = controller.process("CNNSupervisedTrainer", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = controller.process("execute", Target.CPP);
fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue());
......
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.generator.ArchitectureSupportChecker;
public class CNNArch2GluonArchitectureSupportChecker extends ArchitectureSupportChecker {
public CNNArch2GluonArchitectureSupportChecker() {}
@Override
protected boolean checkMultipleStreams(ArchitectureSymbol architecture) {
return true;
}
@Override
protected boolean checkMultipleInputs(ArchitectureSymbol architecture) {
return true;
}
@Override
protected boolean checkMultipleOutputs(ArchitectureSymbol architecture) {
return true;
}
@Override
protected boolean checkConstants(ArchitectureSymbol architecture) {
return true;
}
}
......@@ -20,8 +20,8 @@
*/
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.GenericCNNArchCli;
import de.monticore.lang.monticar.cnnarch.generator.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.generator.GenericCNNArchCli;
public class CNNArch2GluonCli {
public static void main(String[] args) {
......
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.monticore.lang.monticar.cnnarch.generator.LayerSupportChecker;
public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
public CNNArch2GluonLayerSupportChecker() {
supportedLayerList.add(AllPredefinedLayers.FULLY_CONNECTED_NAME);
supportedLayerList.add(AllPredefinedLayers.CONVOLUTION_NAME);
supportedLayerList.add(AllPredefinedLayers.SOFTMAX_NAME);
supportedLayerList.add(AllPredefinedLayers.SIGMOID_NAME);
supportedLayerList.add(AllPredefinedLayers.TANH_NAME);
supportedLayerList.add(AllPredefinedLayers.RELU_NAME);
supportedLayerList.add(AllPredefinedLayers.DROPOUT_NAME);
supportedLayerList.add(AllPredefinedLayers.POOLING_NAME);
supportedLayerList.add(AllPredefinedLayers.GLOBAL_POOLING_NAME);
supportedLayerList.add(AllPredefinedLayers.LRN_NAME);
supportedLayerList.add(AllPredefinedLayers.BATCHNORM_NAME);
supportedLayerList.add(AllPredefinedLayers.SPLIT_NAME);
supportedLayerList.add(AllPredefinedLayers.GET_NAME);
supportedLayerList.add(AllPredefinedLayers.ADD_NAME);
supportedLayerList.add(AllPredefinedLayers.CONCATENATE_NAME);
supportedLayerList.add(AllPredefinedLayers.FLATTEN_NAME);
supportedLayerList.add(AllPredefinedLayers.ONE_HOT_NAME);
}
}
......@@ -20,17 +20,17 @@
*/
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.ArchitectureElementData;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArchTemplateController;
import de.monticore.lang.monticar.cnnarch.generator.ArchitectureElementData;
import de.monticore.lang.monticar.cnnarch.generator.CNNArchTemplateController;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import java.io.Writer;
import java.util.*;
public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
public static final String NET_DEFINITION_MODE_KEY = "definition_mode";
public static final String NET_DEFINITION_MODE_KEY = "mode";
public CNNArch2GluonTemplateController(ArchitectureSymbol architecture,
TemplateConfiguration templateConfiguration) {
......@@ -42,7 +42,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
Map<String, Object> ftlContext = new HashMap<>();
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement());
ftlContext.put(NET_DEFINITION_MODE_KEY, netDefinitionMode);
ftlContext.put(NET_DEFINITION_MODE_KEY, netDefinitionMode.toString());
getTemplateConfiguration().processTemplate(ftlContext, templatePath, writer);
}
......@@ -65,6 +65,20 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
setCurrentElement(previousElement);
}
public void include(ConstantSymbol constant, Writer writer, NetDefinitionMode netDefinitionMode) {
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(constant);
if (constant.isAtomic()) {
include(TEMPLATE_ELEMENTS_DIR_PATH, "Const", writer, netDefinitionMode);
}
else {
include(constant.getResolvedThis().get(), writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(LayerSymbol layer, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(layer);
......@@ -99,6 +113,9 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
else if (architectureElement instanceof LayerSymbol){
include((LayerSymbol) architectureElement, writer, netDefinitionMode);
}
else if (architectureElement instanceof ConstantSymbol) {
include((ConstantSymbol) architectureElement, writer, netDefinitionMode);
}
else {
include((IOSymbol) architectureElement, writer, netDefinitionMode);
}
......@@ -114,4 +131,28 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
include(architectureElement, getWriter(), netDefinitionMode);
}
public String ioNameToCpp(String ioName) {
return ioName.replaceAll("_([0-9]+)_", "[$1]");
}
public List<String> getStreamInputNames(SerialCompositeElementSymbol stream) {
List<String> names = new ArrayList<>();
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
names.add(getName(element));
}
return names;
}
public List<String> getStreamOutputNames(SerialCompositeElementSymbol stream) {
List<String> names = new ArrayList<>();
for (ArchitectureElementSymbol element : stream.getLastAtomicElements()) {
names.add(getName(element));
}
return names;
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch.generator.TrainParamSupportChecker;
public class CNNArch2GluonTrainParamSupportChecker extends TrainParamSupportChecker {
}
\ No newline at end of file
......@@ -6,10 +6,10 @@ import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.Cr
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.ConfigurationData;
import de.monticore.lang.monticar.cnnarch.generator.ConfigurationData;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNTrain2MxNet;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration;
import de.monticore.lang.monticar.cnnarch.generator.CNNTrainGenerator;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod;
import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm;
......@@ -31,7 +31,7 @@ import java.nio.file.Paths;
import java.util.*;
import java.util.stream.Collectors;
public class CNNTrain2Gluon extends CNNTrain2MxNet {
public class CNNTrain2Gluon extends CNNTrainGenerator {
private static final String REINFORCEMENT_LEARNING_FRAMEWORK_MODULE = "reinforcement_learning";
private final RewardFunctionSourceGenerator rewardFunctionSourceGenerator;
......@@ -46,7 +46,8 @@ public class CNNTrain2Gluon extends CNNTrain2MxNet {
}
public CNNTrain2Gluon(RewardFunctionSourceGenerator rewardFunctionSourceGenerator) {
super();
trainParamSupportChecker = new CNNArch2GluonTrainParamSupportChecker();
this.rewardFunctionSourceGenerator = rewardFunctionSourceGenerator;
}
......@@ -114,9 +115,6 @@ public class CNNTrain2Gluon extends CNNTrain2MxNet {
if (configData.isSupervisedLearning()) {
String cnnTrainTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTemplateContent);
String cnnSupervisedTrainerContent = templateConfiguration.processTemplate(ftlContext, "CNNSupervisedTrainer.ftl");
fileContentMap.put("supervised_trainer.py", cnnSupervisedTrainerContent);
} else if (configData.isReinforcementLearning()) {
final String trainerName = "CNNTrainer_" + getInstanceName();
final RLAlgorithm rlAlgorithm = configData.getRlAlgorithm();
......
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import freemarker.template.Configuration;
/**
*
*/
public class GluonTemplateConfiguration extends TemplateConfiguration {
private static Configuration configuration;
......
......@@ -5,7 +5,9 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator;
*/
public enum NetDefinitionMode {
ARCHITECTURE_DEFINITION,
FORWARD_FUNCTION;
FORWARD_FUNCTION,
PYTHON_INLINE,
CPP_INLINE;
public static NetDefinitionMode fromString(final String netDefinitionMode) {
switch(netDefinitionMode) {
......@@ -13,6 +15,10 @@ public enum NetDefinitionMode {
return ARCHITECTURE_DEFINITION;
case "FORWARD_FUNCTION":
return FORWARD_FUNCTION;
case "PYTHON_INLINE":
return PYTHON_INLINE;
case "CPP_INLINE":
return CPP_INLINE;
default:
throw new IllegalArgumentException("Unknown Net Definition Mode");
}
......
......@@ -2,16 +2,13 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.ConfigurationData;
import de.monticore.lang.monticar.cnnarch.generator.ConfigurationData;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import java.util.*;
/**
*
*/
public class ReinforcementConfigurationData extends ConfigurationData {
private static final String AST_ENTRY_LEARNING_METHOD = "learning_method";
private static final String AST_ENTRY_NUM_EPISODES = "num_episodes";
......
......@@ -3,9 +3,10 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic;
import com.google.common.collect.Lists;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2Gluon;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArchSymbolCompiler;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.checker.AllowAllLayerSupportChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2GluonArchitectureSupportChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2GluonLayerSupportChecker;
import de.monticore.lang.monticar.cnnarch.generator.CNNArchSymbolCompiler;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
......@@ -81,7 +82,8 @@ public class CriticNetworkGenerator {
gluonGenerator.setGenerationTargetPath(this.getGenerationTargetPath());
Map<String, String> fileContentMap = new HashMap<>();
CNNArchSymbolCompiler symbolCompiler = new CNNArchSymbolCompiler(new AllowAllLayerSupportChecker());
CNNArchSymbolCompiler symbolCompiler = new CNNArchSymbolCompiler(new CNNArch2GluonArchitectureSupportChecker(),
new CNNArch2GluonLayerSupportChecker());
ArchitectureSymbol architectureSymbol = symbolCompiler.compileArchitectureSymbolFromModelsDir(directoryOfCnnArchFile, criticNetworkName);
architectureSymbol.setComponentName(criticNetworkName);
fileContentMap.putAll(gluonGenerator.generateStringsAllowMultipleIO(architectureSymbol, true));
......
import mxnet as mx
import logging
import os
from CNNNet_${tc.fullArchitectureName} import Net
<#list tc.architecture.streams as stream>
<#if stream.isNetwork()>
from CNNNet_${tc.fullArchitectureName} import Net_${stream?index}
</#if>
</#list>
class ${tc.fileNameWithoutEnding}:
_model_dir_ = "model/${tc.componentName}/"
_model_prefix_ = "model"
_input_shapes_ = [<#list tc.architecture.inputs as input>(${tc.join(input.definition.type.dimensions, ",")},)<#if input?has_next>,</#if></#list>]
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.net = None
def get_input_shapes(self):
return self._input_shapes_
self.networks = {}
def load(self, context):
lastEpoch = 0
param_file = None
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_newest-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_newest-symbol.json")
except OSError:
pass
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
if ".params" in file and self._model_prefix_ in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
if param_file is None:
return 0
else:
logging.info("Loading checkpoint: " + param_file)
self.net.load_parameters(self._model_dir_ + param_file)
return lastEpoch
earliestLastEpoch = None
for i, network in self.networks.items():
lastEpoch = 0
param_file = None
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-symbol.json")
except OSError:
pass
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
if param_file is None:
earliestLastEpoch = 0
else:
logging.info("Loading checkpoint: " + param_file)
network.load_parameters(self._model_dir_ + param_file)
if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
earliestLastEpoch = lastEpoch
return earliestLastEpoch
def construct(self, context, data_mean=None, data_std=None):
self.net = Net(data_mean=data_mean, data_std=data_std)
self.net.collect_params().initialize(self.weight_initializer, ctx=context)
self.net.hybridize()
self.net(<#list tc.architecture.inputs as input>mx.nd.zeros((1,)+self._input_shapes_[${input?index}], ctx=context)<#if input?has_next>,</#if></#list>)
<#list tc.architecture.streams as stream>
<#if stream.isNetwork()>
self.networks[${stream?index}] = Net_${stream?index}(data_mean=data_mean, data_std=data_std)
self.networks[${stream?index}].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[${stream?index}].hybridize()
self.networks[${stream?index}](<#list stream.getFirstAtomicElements() as input>mx.nd.zeros((1, ${tc.join(input.definition.type.dimensions, ",")},), ctx=context)<#sep>, </#list>)
</#if>
</#list>
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
self.net.export(self._model_dir_ + self._model_prefix_, epoch=0)
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
......@@ -3,8 +3,9 @@ import h5py
import mxnet as mx
import logging
import sys
from mxnet import nd
class ${tc.fullArchitectureName}DataLoader:
class ${tc.fileNameWithoutEnding}:
_input_names_ = [${tc.join(tc.architectureInputs, ",", "'", "'")}]
_output_names_ = [${tc.join(tc.architectureOutputs, ",", "'", "_label'")}]
......@@ -14,21 +15,38 @@ class ${tc.fullArchitectureName}DataLoader:
def load_data(self, batch_size):
train_h5, test_h5 = self.load_h5_files()
data_mean = train_h5[self._input_names_[0]][:].mean(axis=0)
data_std = train_h5[self._input_names_[0]][:].std(axis=0) + 1e-5
train_data = {}
data_mean = {}
data_std = {}
for input_name in self._input_names_:
train_data[input_name] = train_h5[input_name]
data_mean[input_name] = nd.array(train_h5[input_name][:].mean(axis=0))
data_std[input_name] = nd.array(train_h5[input_name][:].std(axis=0) + 1e-5)
train_label = {}
for output_name in self._output_names_:
train_label[output_name] = train_h5[output_name]
train_iter = mx.io.NDArrayIter(data=train_data,
label=train_label,
batch_size=batch_size)
train_iter = mx.io.NDArrayIter(train_h5[self._input_names_[0]],
train_h5[self._output_names_[0]],
batch_size=batch_size,
data_name=self._input_names_[0],
label_name=self._output_names_[0])
test_iter = None
if test_h5 != None:
test_iter = mx.io.NDArrayIter(test_h5[self._input_names_[0]],
test_h5[self._output_names_[0]],
batch_size=batch_size,
data_name=self._input_names_[0],
label_name=self._output_names_[0])
test_data = {}
for input_name in self._input_names_:
test_data[input_name] = test_h5[input_name]
test_label = {}
for output_name in self._output_names_:
test_label[output_name] = test_h5[output_name]
test_iter = mx.io.NDArrayIter(data=test_data,
label=test_label,
batch_size=batch_size)
return train_iter, test_iter, data_mean, data_std
def load_h5_files(self):
......@@ -36,21 +54,39 @@ class ${tc.fullArchitectureName}DataLoader:
test_h5 = None
train_path = self._data_dir + "train.h5"
test_path = self._data_dir + "test.h5"
if os.path.isfile(train_path):
train_h5 = h5py.File(train_path, 'r')
if not (self._input_names_[0] in train_h5 and self._output_names_[0] in train_h5):
logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the datasets: "
+ "'" + self._input_names_[0] + "', '" + self._output_names_[0] + "'")
sys.exit(1)
test_iter = None
for input_name in self._input_names_:
if not input_name in train_h5:
logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
+ "'" + input_name + "'")
sys.exit(1)
for output_name in self._output_names_:
if not output_name in train_h5:
logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
+ "'" + output_name + "'")
sys.exit(1)
if os.path.isfile(test_path):
test_h5 = h5py.File(test_path, 'r')
if not (self._input_names_[0] in test_h5 and self._output_names_[0] in test_h5):
logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the datasets: "
+ "'" + self._input_names_[0] + "', '" + self._output_names_[0] + "'")
sys.exit(1)
for input_name in self._input_names_:
if not input_name in test_h5:
logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
+ "'" + input_name + "'")
sys.exit(1)
for output_name in self._output_names_:
if not output_name in test_h5:
logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
+ "'" + output_name + "'")
sys.exit(1)
else:
logging.warning("Couldn't load test set. File '" + os.path.abspath(test_path) + "' does not exist.")
return train_h5, test_h5
else:
logging.error("Data loading failure. File '" + os.path.abspath(train_path) + "' does not exist.")
......
......@@ -2,6 +2,16 @@ import mxnet as mx
import numpy as np
from mxnet import gluon
class OneHot(gluon.HybridBlock):
def __init__(self, size, **kwargs):
super(OneHot, self).__init__(**kwargs)
with self.name_scope():
self.size = size
def hybrid_forward(self, F, x):
return F.one_hot(indices=F.argmax(data=x, axis=1), depth=self.size)
class Softmax(gluon.HybridBlock):