Commit d03d1284 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'develop' into 'master'

Develop

See merge request !28
parents 67579b12 06caaddd
# (c) https://github.com/MontiCore/monticore
stages:
- windows
#- windows
- linux
masterJobLinux:
......@@ -20,17 +20,17 @@ masterJobLinux:
- .gitlab-ci.yml
masterJobWindows:
stage: windows
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
tags:
- Windows10
except:
changes:
- README.md
- .gitignore
- .gitlab-ci.yml
#masterJobWindows:
# stage: windows
# script:
# - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
# tags:
# - Windows10
# except:
# changes:
# - README.md
# - .gitignore
# - .gitlab-ci.yml
BranchJobLinux:
......
This diff is collapsed.
......@@ -17,12 +17,12 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.4-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.9-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.5-SNAPSHOT</CNNArch2X.version>
<CNNArch.version>0.3.5-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.10-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.6-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.6</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.2-SNAPSHOT</EMADL2PythonWrapper.version>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
<junit.version>4.12</junit.version>
......@@ -144,7 +144,7 @@
</dependency>
</dependencies>
<!-- == PROJECT BUILD SETTINGS =========================================== -->
......
......@@ -84,8 +84,6 @@ public class CNNArch2Gluon extends CNNArchGenerator {
CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController(
architecture, templateConfiguration);
archTc.getStreamOutputDomains(archTc.getArchitecture().getStreams().get(0));
fileContentMap.putAll(compilePythonFiles(archTc, architecture));
fileContentMap.putAll(compileCppFiles(archTc));
......
......@@ -9,7 +9,7 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
public CNNArch2GluonLayerSupportChecker() {
supportedLayerList.add(AllPredefinedLayers.FULLY_CONNECTED_NAME);
supportedLayerList.add(AllPredefinedLayers.CONVOLUTION_NAME);
supportedLayerList.add(AllPredefinedLayers.TRANS_CONV_NAME);
supportedLayerList.add(AllPredefinedLayers.UP_CONVOLUTION_NAME);
supportedLayerList.add(AllPredefinedLayers.SOFTMAX_NAME);
supportedLayerList.add(AllPredefinedLayers.SIGMOID_NAME);
supportedLayerList.add(AllPredefinedLayers.TANH_NAME);
......@@ -40,6 +40,7 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
supportedLayerList.add(AllPredefinedLayers.REDUCE_SUM_NAME);
supportedLayerList.add(AllPredefinedLayers.BROADCAST_ADD_NAME);
supportedLayerList.add(AllPredefinedLayers.RESHAPE_NAME);
// supportedLayerList.add(AllPredefinedLayers.CROP_NAME);
}
}
......@@ -415,6 +415,9 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
dimensions.add(intDimension.toString());
}
if (dimensions.isEmpty())
dimensions.add("unknown");
String name = getName(element);
if (outputAsArray && element.isOutput() && element instanceof VariableSymbol) {
......
......@@ -2,9 +2,12 @@
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import com.google.common.collect.Maps;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.cncModel.EMAComponentSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing.PreprocessingComponentParameterAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing.PreprocessingPortChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
......@@ -78,9 +81,9 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
try {
Iterator var6 = fileContents.keySet().iterator();
while(var6.hasNext()) {
String fileName = (String)var6.next();
genCPP.generateFile(new FileContent((String)fileContents.get(fileName), fileName));
while (var6.hasNext()) {
String fileName = (String) var6.next();
genCPP.generateFile(new FileContent((String) fileContents.get(fileName), fileName));
}
} catch (IOException var8) {
Log.error("CNNTrainer file could not be generated" + var8.getMessage());
......@@ -98,6 +101,19 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
generateFilesFromConfigurationSymbol(configurationSymbol);
}
public void generate(Path modelsDirPath,
String rootModelName,
NNArchitectureSymbol trainedArchitecture,
NNArchitectureSymbol discriminatorNetwork,
NNArchitectureSymbol qNetwork) {
ConfigurationSymbol configurationSymbol = this.getConfigurationSymbol(modelsDirPath, rootModelName);
configurationSymbol.setTrainedArchitecture(trainedArchitecture);
configurationSymbol.setDiscriminatorNetwork(discriminatorNetwork);
configurationSymbol.setQNetwork(qNetwork);
this.setRootProjectModelsDir(modelsDirPath.toString());
generateFilesFromConfigurationSymbol(configurationSymbol);
}
public void generate(Path modelsDirPath, String rootModelName, NNArchitectureSymbol trainedArchitecture) {
generate(modelsDirPath, rootModelName, trainedArchitecture, null);
}
......@@ -117,16 +133,16 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
if (configData.isSupervisedLearning()) {
String cnnTrainTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTemplateContent);
} else if(configData.isGan()) {
} else if (configData.isGan()) {
final String trainerName = "CNNTrainer_" + getInstanceName();
if(!configuration.getDiscriminatorNetwork().isPresent()) {
if (!configuration.getDiscriminatorNetwork().isPresent()) {
Log.error("No architecture model for discriminator available but is required for chosen " +
"GAN");
}
NNArchitectureSymbol genericDisArchitectureSymbol = configuration.getDiscriminatorNetwork().get();
ArchitectureSymbol disArchitectureSymbol
= ((ArchitectureAdapter)genericDisArchitectureSymbol).getArchitectureSymbol();
= ((ArchitectureAdapter) genericDisArchitectureSymbol).getArchitectureSymbol();
CNNArch2Gluon gluonGenerator = new CNNArch2Gluon();
gluonGenerator.setGenerationTargetPath(
......@@ -147,7 +163,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
if (configuration.hasQNetwork()) {
NNArchitectureSymbol genericQArchitectureSymbol = configuration.getQNetwork().get();
ArchitectureSymbol qArchitectureSymbol
= ((ArchitectureAdapter)genericQArchitectureSymbol).getArchitectureSymbol();
= ((ArchitectureAdapter) genericQArchitectureSymbol).getArchitectureSymbol();
Map<String, String> qArchitectureFileContentMap
= gluonGenerator.generateStringsAllowMultipleIO(qArchitectureSymbol, true);
......@@ -179,7 +195,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
final RLAlgorithm rlAlgorithm = configData.getRlAlgorithm();
if (rlAlgorithm.equals(RLAlgorithm.DDPG)
|| rlAlgorithm.equals(RLAlgorithm.TD3)) {
|| rlAlgorithm.equals(RLAlgorithm.TD3)) {
if (!configuration.getCriticNetwork().isPresent()) {
Log.error("No architecture model for critic available but is required for chosen " +
......@@ -187,7 +203,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
}
NNArchitectureSymbol genericArchitectureSymbol = configuration.getCriticNetwork().get();
ArchitectureSymbol architectureSymbol
= ((ArchitectureAdapter)genericArchitectureSymbol).getArchitectureSymbol();
= ((ArchitectureAdapter) genericArchitectureSymbol).getArchitectureSymbol();
CNNArch2Gluon gluonGenerator = new CNNArch2Gluon();
gluonGenerator.setGenerationTargetPath(
......@@ -201,8 +217,8 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
creatorName.indexOf('_') + 1, creatorName.lastIndexOf(".py"));
fileContentMap.putAll(architectureFileContentMap.entrySet().stream().collect(Collectors.toMap(
k -> REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
k -> REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
);
ftlContext.put("criticInstanceName", criticInstanceName);
......@@ -215,7 +231,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
configuration.getRlRewardFunction().get(), Paths.get(rootProjectModelsDir));
} else {
Log.error("No architecture model for the trained neural network but is required for " +
"reinforcement learning configuration.");
"reinforcement learning configuration.");
}
}
......@@ -234,7 +250,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
}
private void generateRewardFunction(NNArchitectureSymbol trainedArchitecture,
RewardFunctionSymbol rewardFunctionSymbol, Path modelsDirPath) {
RewardFunctionSymbol rewardFunctionSymbol, Path modelsDirPath) {
GeneratorPythonWrapperStandaloneApi pythonWrapperApi = new GeneratorPythonWrapperStandaloneApi();
List<String> fullNameOfComponent = rewardFunctionSymbol.getRewardFunctionComponentName();
......@@ -271,7 +287,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
rewardFunctionSymbol.setRewardFunctionParameter(functionParameter);
}
private void fixArmadilloEmamGenerationOfFile(Path pathToBrokenFile){
private void fixArmadilloEmamGenerationOfFile(Path pathToBrokenFile) {
final File brokenFile = pathToBrokenFile.toFile();
if (brokenFile.exists()) {
try {
......@@ -301,19 +317,19 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/agent.py", reinforcementAgentContent);
final String reinforcementStrategyContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/agent/Strategy.ftl");
ftlContext, "reinforcement/agent/Strategy.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/strategy.py", reinforcementStrategyContent);
final String replayMemoryContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/agent/ReplayMemory.ftl");
ftlContext, "reinforcement/agent/ReplayMemory.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/replay_memory.py", replayMemoryContent);
final String environmentContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/environment/Environment.ftl");
ftlContext, "reinforcement/environment/Environment.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/environment.py", environmentContent);
final String utilContent = templateConfiguration.processTemplate(
ftlContext, "reinforcement/util/Util.ftl");
ftlContext, "reinforcement/util/Util.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/util.py", utilContent);
final String initContent = "";
......@@ -322,3 +338,4 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
return fileContentMap;
}
}
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing;
import de.monticore.lang.monticar.cnntrain.annotations.PreprocessingComponentParameter;
import de.monticore.lang.monticar.cnntrain.annotations.RewardFunctionParameter;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.EmadlType;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.PortVariable;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
/**
*
*/
public class PreprocessingComponentParameterAdapter implements PreprocessingComponentParameter {
private final ComponentPortInformation adaptee;
private String outputParameterName;
private String inputStateParameterName;
private String inputTerminalParameterName;
public PreprocessingComponentParameterAdapter(final ComponentPortInformation componentPortInformation) {
this.adaptee = componentPortInformation;
}
@Override
public List<String> getInputNames() {
return this.adaptee.getAllInputs().stream()
.map(PortVariable::getVariableName)
.collect(Collectors.toList());
}
@Override
public List<String> getOutputNames() {
return this.adaptee.getAllOutputs().stream()
.map(PortVariable::getVariableName)
.collect(Collectors.toList());
}
@Override
public Optional<String> getTypeOfInputPort(String portName) {
return this.adaptee.getAllInputs().stream()
.filter(port -> port.getVariableName().equals(portName))
.map(port -> port.getEmadlType().toString())
.findFirst();
}
@Override
public Optional<String> getTypeOfOutputPort(String portName) {
return this.adaptee.getAllOutputs().stream()
.filter(port -> port.getVariableName().equals(portName))
.map(port -> port.getEmadlType().toString())
.findFirst();
}
@Override
public Optional<List<Integer>> getInputPortDimensionOfPort(String portName) {
return this.adaptee.getAllInputs().stream()
.filter(port -> port.getVariableName().equals(portName))
.map(PortVariable::getDimension)
.findFirst();
}
@Override
public Optional<List<Integer>> getOutputPortDimensionOfPort(String portName) {
return this.adaptee.getAllOutputs().stream()
.filter(port -> port.getVariableName().equals(portName))
.map(PortVariable::getDimension)
.findFirst();
}
public Optional<String> getOutputParameterName() {
if (this.outputParameterName == null) {
if (this.getOutputNames().size() == 1) {
this.outputParameterName = this.getOutputNames().get(0);
} else {
return Optional.empty();
}
}
return Optional.of(this.outputParameterName);
}
private boolean isBooleanScalar(final PortVariable portVariable) {
return portVariable.getEmadlType().equals(EmadlType.B)
&& portVariable.getDimension().size() == 1
&& portVariable.getDimension().get(0) == 1;
}
private boolean determineInputNames() {
if (this.getInputNames().size() != 2) {
return false;
}
Optional<String> terminalInput = this.adaptee.getAllInputs()
.stream()
.filter(this::isBooleanScalar)
.map(PortVariable::getVariableName)
.findFirst();
if (terminalInput.isPresent()) {
this.inputTerminalParameterName = terminalInput.get();
} else {
return false;
}
Optional<String> stateInput = this.adaptee.getAllInputs().stream()
.filter(portVariable -> !portVariable.getVariableName().equals(this.inputTerminalParameterName))
.filter(portVariable -> !isBooleanScalar(portVariable))
.map(PortVariable::getVariableName)
.findFirst();
if (stateInput.isPresent()) {
this.inputStateParameterName = stateInput.get();
} else {
this.inputTerminalParameterName = null;
return false;
}
return true;
}
public Optional<String> getInputStateParameterName() {
if (this.inputStateParameterName == null) {
this.determineInputNames();
}
return Optional.ofNullable(this.inputStateParameterName);
}
public Optional<String> getInputTerminalParameter() {
if (this.inputTerminalParameterName == null) {
this.determineInputNames();
}
return Optional.ofNullable(this.inputTerminalParameterName);
}
}
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing;
import de.se_rwth.commons.logging.Log;
import java.util.HashSet;
import java.util.ListIterator;
import java.util.Set;
/**
*
*/
public class PreprocessingPortChecker {
public PreprocessingPortChecker() { }
static public void check(final PreprocessingComponentParameterAdapter preprocessingComponentParameter) {
assert preprocessingComponentParameter != null;
checkEqualNumberofInAndOutPorts(preprocessingComponentParameter);
checkCorrectPortNames(preprocessingComponentParameter);
}
static private void checkEqualNumberofInAndOutPorts(PreprocessingComponentParameterAdapter preprocessingComponentParameter) {
failIfConditionFails(equalNumberOfInAndOutPorts(preprocessingComponentParameter),
"The number of in- and output ports of the " +
"preprocessing component is not equal");
}
static private boolean equalNumberOfInAndOutPorts(PreprocessingComponentParameterAdapter preprocessingComponentParameter) {
return preprocessingComponentParameter.getInputNames().size()
== preprocessingComponentParameter.getOutputNames().size();
}
static private void checkCorrectPortNames(PreprocessingComponentParameterAdapter preprocessingComponentParameter) {
failIfConditionFails(correctPortNames(preprocessingComponentParameter),
"The output ports are not correctly named with \"_out\" appendix");
}
static private boolean correctPortNames(PreprocessingComponentParameterAdapter preprocessingComponentParameter) {
ListIterator<String> iterator = preprocessingComponentParameter.getInputNames().listIterator();
Set<String> inputs = new HashSet<String>();
while (iterator.hasNext()) {
inputs.add(iterator.next() + "_out");
}
Set<String> outputs = new HashSet<String>(preprocessingComponentParameter.getOutputNames());
return inputs.equals(outputs);
}
static private void failIfConditionFails(final boolean condition, final String message) {
if (!condition) {
fail(message);
}
}
static private void fail(final String message) {
Log.error(message);
//System.exit(-1);
}
}
......@@ -2,6 +2,7 @@
import mxnet as mx
import logging
import os
import shutil
<#list tc.architecture.networkInstructions as networkInstruction>
from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index}
......@@ -14,6 +15,11 @@ class ${tc.fileNameWithoutEnding}:
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.networks = {}
<#if (tc.weightsPath)??>
self._weights_dir_ = "${tc.weightsPath}/"
<#else>
self._weights_dir_ = None
</#if>
def load(self, context):
earliestLastEpoch = None
......@@ -50,6 +56,29 @@ class ${tc.fileNameWithoutEnding}:
return earliestLastEpoch
def load_pretrained_weights(self, context):
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
if self._weights_dir_ is not None:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
<#list tc.architecture.networkInstructions as networkInstruction>
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std)
......@@ -63,3 +92,29 @@ class ${tc.fileNameWithoutEnding}:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def getInputs(self):
inputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamInputs(stream, false))>
<#assign domains = (tc.getStreamInputDomains(stream))>
<#list tc.getStreamInputVariableNames(stream, false) as name>
input_dimensions = (${tc.join(dimensions[name], ",")},)
input_domains = (${tc.join(domains[name], ",")},)
inputs["${name}"] = input_domains + (input_dimensions,)
</#list>
</#list>
return inputs
def getOutputs(self):
outputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamOutputs(stream, false))>
<#assign domains = (tc.getStreamOutputDomains(stream))>
<#list tc.getStreamOutputVariableNames(stream, false) as name>
output_dimensions = (${tc.join(dimensions[name], ",")},)
output_domains = (${tc.join(domains[name], ",")},)
outputs["${name}"] = output_domains + (output_dimensions,)
</#list>
</#list>
return outputs
......@@ -5,7 +5,6 @@ import mxnet as mx
import logging
import sys
import numpy as np
import cv2
import importlib
from mxnet import nd
......@@ -79,6 +78,7 @@ class ${tc.fileNameWithoutEnding}:
train_label = {}
data_mean = {}
data_std = {}
train_images = {}
shape_output = self.preprocess_data(instance, inp, 0, train_h5)
train_len = len(train_h5[self._input_names_[0]])
......@@ -141,6 +141,7 @@ class ${tc.fileNameWithoutEnding}:
for output_name in self._output_names_:
test_label[output_name][i] = getattr(shape_output, output_name + "_out")
test_images = {}
if 'images' in test_h5:
test_images = test_h5['images']
......@@ -152,7 +153,7 @@ class ${tc.fileNameWithoutEnding}:
def preprocess_data(self, instance_wrapper, input_wrapper, index, data_h5):
for input_name in self._input_names_:
data = data_h5[input_name][0]
data = data_h5[input_name][index]
attr = getattr(input_wrapper, input_name)
if (type(data)) == np.ndarray:
data = np.asfortranarray(data).astype(attr.dtype)
......@@ -160,7 +161,7 @@ class ${tc.fileNameWithoutEnding}:
data = type(attr)(data)
setattr(input_wrapper, input_name, data)
for output_name in self._output_names_:
data = data_h5[output_name][0]
data = data_h5[output_name][index]
attr = getattr(input_wrapper, output_name)
if (type(data)) == np.ndarray:
data = np.asfortranarray(data).astype(attr.dtype)
......
......@@ -105,28 +105,3 @@ ${tc.include(networkInstruction.body, "FORWARD_FUNCTION")}
</#if>
</#list>
def getInputs(self):
inputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamInputs(stream, false))>
<#assign domains = (tc.getStreamInputDomains(stream))>
<#list tc.getStreamInputVariableNames(stream, false) as name>
input_dimensions = (${tc.join(dimensions[name], ",")})
input_domains = (${tc.join(domains[name], ",")})
inputs["${name}"] = input_domains + (input_dimensions,)
</#list>
</#list>
return inputs
def getOutputs(self):
outputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamOutputs(stream, false))>
<#assign domains = (tc.getStreamOutputDomains(stream))>
<#list tc.getStreamOutputVariableNames(stream, false) as name>
output_dimensions = (${tc.join(dimensions[name], ",")})
output_domains = (${tc.join(domains[name], ",")})
outputs["${name}"] = output_domains + (output_dimensions,)
</#list>
</#list>
return outputs
......@@ -51,13 +51,89 @@ class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
if self._sparse_label:
loss = -pick(pred, label, axis=self._axis, keepdims=True)
else:
label = _reshape_like(F, label, pred)
label = gluon.loss._reshape_like(F, label, pred)