Commit 3cde5fb4 authored by nilsfreyer's avatar nilsfreyer

Merge branch 'master' into oneclick_nn_training

parents cddf24ac 980d2c19
Pipeline #106734 passed with stages
in 2 minutes and 57 seconds
......@@ -16,7 +16,7 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.2.8</CNNArch.version>
<CNNTrain.version>0.2.5</CNNTrain.version>
<CNNTrain.version>0.2.6</CNNTrain.version>
<embedded-montiarc-math-generator>0.1.2-SNAPSHOT</embedded-montiarc-math-generator>
<!-- .. Libraries .................................................. -->
......
......@@ -20,14 +20,12 @@
*/
package de.monticore.lang.monticar.cnnarch.caffe2generator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.se_rwth.commons.logging.Log;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.List;
public class ArchitectureElementData {
......@@ -165,31 +163,21 @@ public class ArchitectureElementData {
}
@Nullable
public List<Integer> getPadding(){
public Integer getPadding(){
return getPadding((LayerSymbol) getElement());
}
@Nullable
public List<Integer> getPadding(LayerSymbol layer){
List<Integer> kernel = layer.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get();
List<Integer> stride = layer.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get();
ArchTypeSymbol inputType = layer.getInputTypes().get(0);
ArchTypeSymbol outputType = layer.getOutputTypes().get(0);
int heightWithPad = kernel.get(0) + stride.get(0)*(outputType.getHeight() - 1);
int widthWithPad = kernel.get(1) + stride.get(1)*(outputType.getWidth() - 1);
int heightPad = Math.max(0, heightWithPad - inputType.getHeight());
int widthPad = Math.max(0, widthWithPad - inputType.getWidth());
int topPad = (int)Math.ceil(heightPad / 2.0);
int bottomPad = (int)Math.floor(heightPad / 2.0);
int leftPad = (int)Math.ceil(widthPad / 2.0);
int rightPad = (int)Math.floor(widthPad / 2.0);
if (topPad == 0 && bottomPad == 0 && leftPad == 0 && rightPad == 0){
return null;
public Integer getPadding(LayerSymbol layer){
String padding_type = ((LayerSymbol) getElement()).getStringValue(AllPredefinedLayers.PADDING_NAME).get();
Integer pad=0;
if (padding_type.equals(AllPredefinedLayers.PADDING_VALID)){
pad = 0;
} else if (padding_type.equals(AllPredefinedLayers.PADDING_SAME)){
pad = 1;
}
return Arrays.asList(0,0,0,0,topPad,bottomPad,leftPad,rightPad);
return pad;
}
}
......@@ -26,7 +26,6 @@ import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.IOSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.lang.monticar.cnnarch.DataPathConfigParser;
......@@ -53,18 +52,18 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
private boolean isSupportedLayer(ArchitectureElementSymbol element, LayerSupportChecker layerChecker){
List<ArchitectureElementSymbol> constructLayerElemList;
if (!(element instanceof IOSymbol) && (element.getResolvedThis().get() instanceof CompositeElementSymbol))
{
if (element.getResolvedThis().get() instanceof CompositeElementSymbol) {
constructLayerElemList = ((CompositeElementSymbol)element.getResolvedThis().get()).getElements();
for (ArchitectureElementSymbol constructedLayerElement : constructLayerElemList) {
if (!isSupportedLayer(constructedLayerElement, layerChecker)) return false;
if (!isSupportedLayer(constructedLayerElement, layerChecker)) {
return false;
}
}
}
if (!layerChecker.isSupported(element.toString())) {
Log.error("Unsupported layer " + "'" + element.getName() + "'" + " for the backend CAFFE2.");
return false;
}
else {
} else {
return true;
}
}
......@@ -72,7 +71,9 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
private boolean supportCheck(ArchitectureSymbol architecture){
LayerSupportChecker layerChecker = new LayerSupportChecker();
for (ArchitectureElementSymbol element : ((CompositeElementSymbol)architecture.getBody()).getElements()){
if(!isSupportedLayer(element, layerChecker)) return false;
if(!isSupportedLayer(element, layerChecker)) {
return false;
}
}
return true;
}
......@@ -84,6 +85,11 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
public void setModelPath(Path modelPath){
this.modelPath = modelPath.toString();
}
private static void quitGeneration(){
Log.error("Code generation is aborted");
System.exit(1);
}
public CNNArch2Caffe2() {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
......@@ -116,13 +122,12 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
Optional<CNNArchCompilationUnitSymbol> compilationUnit = scope.resolve(rootModelName, CNNArchCompilationUnitSymbol.KIND);
if (!compilationUnit.isPresent()){
Log.error("could not resolve architecture " + rootModelName);
System.exit(1);
quitGeneration();
}
CNNArchCocos.checkAll(compilationUnit.get());
if (!supportCheck(compilationUnit.get().getArchitecture())){
Log.error("Code generation aborted.");
System.exit(1);
quitGeneration();
}
try{
......@@ -132,8 +137,7 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
compilationUnit.get().getArchitecture().setDataPath(dataPath);
compilationUnit.get().getArchitecture().setComponentName(rootModelName);
generateFiles(compilationUnit.get().getArchitecture());
}
catch (IOException e){
} catch (IOException e){
Log.error(e.toString());
}
}
......@@ -188,7 +192,7 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
try {
generateFromFilecontentsMap(fileContentMap);
} catch (IOException e) {
e.printStackTrace();
Log.error("CMake file could not be generated" + e.getMessage());
}
}
......@@ -211,8 +215,8 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
cMakeConfig.addCMakeCommand("set(LIBS ${LIBS} -lprotobuf -lglog -lgflags)");
cMakeConfig.addCMakeCommand("find_package(CUDA)" + "\n"
+ "set(INCLUDE_DIRS ${INCLUDE_DIRS} ${CUDA_INCLUDE_DIRS})" + "\n"
+ "set(LIBS ${LIBS} ${CUDA_LIBRARIES} ${CUDA_curand_LIBRARY})" + "\n"); //Needed since CUDA cannot be found correctly (including CUDA_curand_LIBRARY) and as optional using CMakeFindModule
+ "set(LIBS ${LIBS} ${CUDA_LIBRARIES} ${CUDA_curand_LIBRARY})" + "\n");
//Needed since CUDA cannot be found correctly (including CUDA_curand_LIBRARY)
cMakeConfig.addCMakeCommand("if(CUDA_FOUND)" + "\n" + " set(LIBS ${LIBS} caffe2 caffe2_gpu)"
+ "\n" + "else()" + "\n" + " set(LIBS ${LIBS} caffe2)" + "\n" + "endif()");
......
......@@ -19,6 +19,7 @@
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.caffe2generator;
import de.se_rwth.commons.logging.Log;
import org.apache.commons.cli.*;
......@@ -73,13 +74,18 @@ public class CNNArch2Caffe2Cli {
try {
cliArgs = parser.parse(options, args);
} catch (ParseException e) {
System.err.println("argument parsing exception: " + e.getMessage());
System.exit(1);
Log.error("argument parsing exception: " + e.getMessage());
quitGeneration();
return null;
}
return cliArgs;
}
private static void quitGeneration(){
Log.error("Code generation is aborted");
System.exit(1);
}
private static void runGenerator(CommandLine cliArgs) {
Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()));
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
......
......@@ -37,6 +37,7 @@ public class CNNArchTemplateController {
private LayerNameCreator nameManager;
private ArchitectureSymbol architecture;
private String loss;
//temporary attributes. They are set after calling process()
private Writer writer;
......@@ -44,6 +45,8 @@ public class CNNArchTemplateController {
private Target targetLanguage;
private ArchitectureElementData dataElement;
public static final String CROSS_ENTROPY = "cross_entropy";
public static final String EUCLIDEAN = "euclidean";
public CNNArchTemplateController(ArchitectureSymbol architecture) {
setArchitecture(architecture);
......@@ -96,8 +99,7 @@ public class CNNArchTemplateController {
if (isSoftmaxOutput(layer) || isLogisticRegressionOutput(layer)){
inputNames = getLayerInputs(layer.getInputElement().get());
}
else {
} else {
for (ArchitectureElementSymbol input : layer.getPrevious()) {
if (input.getOutputTypes().size() == 1) {
inputNames.add(getName(input));
......@@ -132,6 +134,9 @@ public class CNNArchTemplateController {
return getArchitecture().getComponentName();
}
public String getArchitectureLoss(){
return this.loss;
}
public void include(String relativePath, String templateWithoutFileEnding, Writer writer){
String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING;
......@@ -148,12 +153,10 @@ public class CNNArchTemplateController {
if (ioElement.isAtomic()){
if (ioElement.isInput()){
include(TEMPLATE_ELEMENTS_DIR_PATH, "Input", writer);
}
else {
} else {
include(TEMPLATE_ELEMENTS_DIR_PATH, "Output", writer);
}
}
else {
} else {
include(ioElement.getResolvedThis().get(), writer);
}
......@@ -170,8 +173,7 @@ public class CNNArchTemplateController {
String templateName = layer.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer);
}
}
else {
} else {
include(layer.getResolvedThis().get(), writer);
}
......@@ -192,11 +194,9 @@ public class CNNArchTemplateController {
public void include(ArchitectureElementSymbol architectureElement, Writer writer){
if (architectureElement instanceof CompositeElementSymbol){
include((CompositeElementSymbol) architectureElement, writer);
}
else if (architectureElement instanceof LayerSymbol){
} else if (architectureElement instanceof LayerSymbol){
include((LayerSymbol) architectureElement, writer);
}
else {
} else {
include((IOSymbol) architectureElement, writer);
}
}
......@@ -209,15 +209,15 @@ public class CNNArchTemplateController {
}
public Map.Entry<String,String> process(String templateNameWithoutEnding, Target targetLanguage){
StringWriter writer = new StringWriter();
StringWriter newWriter = new StringWriter();
this.mainTemplateNameWithoutEnding = templateNameWithoutEnding;
this.targetLanguage = targetLanguage;
this.writer = writer;
this.writer = newWriter;
include("", templateNameWithoutEnding, writer);
include("", templateNameWithoutEnding, newWriter);
String fileEnding = targetLanguage.toString();
String fileName = getFileNameWithoutEnding() + fileEnding;
Map.Entry<String,String> fileContent = new AbstractMap.SimpleEntry<>(fileName, writer.toString());
Map.Entry<String,String> fileContent = new AbstractMap.SimpleEntry<>(fileName, newWriter.toString());
this.mainTemplateNameWithoutEnding = null;
this.targetLanguage = null;
......@@ -246,27 +246,39 @@ public class CNNArchTemplateController {
public boolean isLogisticRegressionOutput(ArchitectureElementSymbol architectureElement){
return isTOutput(Sigmoid.class, architectureElement);
if (isTOutput(Sigmoid.class, architectureElement)){
this.loss = CROSS_ENTROPY;
return true;
}
return false;
}
public boolean isLinearRegressionOutput(ArchitectureElementSymbol architectureElement){
return architectureElement.isOutput()
if (architectureElement.isOutput()
&& !isLogisticRegressionOutput(architectureElement)
&& !isSoftmaxOutput(architectureElement);
&& !isSoftmaxOutput(architectureElement)){
this.loss = EUCLIDEAN;
return true;
}
return false;
}
public boolean isSoftmaxOutput(ArchitectureElementSymbol architectureElement){
return isTOutput(Softmax.class, architectureElement);
if (isTOutput(Softmax.class, architectureElement)){
this.loss = CROSS_ENTROPY;
return true;
}
return false;
}
private boolean isTOutput(Class inputPredefinedLayerClass, ArchitectureElementSymbol architectureElement){
if (architectureElement.isOutput()){
if (architectureElement.getInputElement().isPresent() && architectureElement.getInputElement().get() instanceof LayerSymbol){
LayerSymbol inputLayer = (LayerSymbol) architectureElement.getInputElement().get();
if (inputPredefinedLayerClass.isInstance(inputLayer.getDeclaration())){
return true;
}
if (architectureElement.isOutput()
&& architectureElement.getInputElement().isPresent()
&& architectureElement.getInputElement().get() instanceof LayerSymbol){
LayerSymbol inputLayer = (LayerSymbol) architectureElement.getInputElement().get();
if (inputPredefinedLayerClass.isInstance(inputLayer.getDeclaration())){
return true;
}
}
return false;
......
......@@ -8,6 +8,7 @@ import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.OptimizerSymbol;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.symboltable.GlobalScope;
......@@ -37,7 +38,9 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
it = configuration.getEntryMap().keySet().iterator();
while (it.hasNext()) {
String key = it.next().toString();
if (funcChecker.getUnsupportedElemList().contains(key)) it.remove();
if (funcChecker.getUnsupportedElemList().contains(key)) {
it.remove();
}
}
}
......@@ -47,17 +50,25 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
ASTOptimizerEntry astOptimizer = (ASTOptimizerEntry) configuration.getOptimizer().getAstNode().get();
astOptimizer.accept(funcChecker);
if (funcChecker.getUnsupportedElemList().contains(funcChecker.unsupportedOptFlag)) {
configuration.setOptimizer(null);
OptimizerSymbol adamOptimizer = new OptimizerSymbol("adam");
configuration.setOptimizer(adamOptimizer); //Set default as adam optimizer
}else {
Iterator it = configuration.getOptimizer().getOptimizerParamMap().keySet().iterator();
while (it.hasNext()) {
String key = it.next().toString();
if (funcChecker.getUnsupportedElemList().contains(key)) it.remove();
if (funcChecker.getUnsupportedElemList().contains(key)) {
it.remove();
}
}
}
}
}
private static void quitGeneration(){
Log.error("Code generation is aborted");
System.exit(1);
}
public CNNTrain2Caffe2() {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
......@@ -89,7 +100,7 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
Optional<CNNTrainCompilationUnitSymbol> compilationUnit = scope.resolve(rootModelName, CNNTrainCompilationUnitSymbol.KIND);
if (!compilationUnit.isPresent()) {
Log.error("could not resolve training configuration " + rootModelName);
System.exit(1);
quitGeneration();
}
setInstanceName(compilationUnit.get().getFullName());
CNNTrainCocos.checkAll(compilationUnit.get());
......@@ -107,7 +118,7 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
genCPP.generateFile(new FileContent(fileContents.get(fileName), fileName));
}
} catch (IOException e) {
e.printStackTrace();
Log.error("CNNTrainer file could not be generated" + e.getMessage());
}
}
......
......@@ -67,6 +67,13 @@ public class ConfigurationData {
return getConfiguration().getEntry("eval_metric").getValue().toString();
}
public String getLoss() {
if (!getConfiguration().getEntryMap().containsKey("loss")) {
return null;
}
return getConfiguration().getEntry("loss").getValue().toString();
}
public String getOptimizerName() {
if (getConfiguration().getOptimizer() == null) {
return null;
......@@ -89,8 +96,7 @@ public class ConfigurationData {
Class realClass = entry.getValue().getValue().getValue().getClass();
if (realClass == Boolean.class) {
valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False";
}
else if (lrPolicyClasses.contains(realClass)) {
} else if (lrPolicyClasses.contains(realClass)) {
valueAsString = "'" + valueAsString + "'";
}
mapToStrings.put(paramName, valueAsString);
......
......@@ -47,17 +47,14 @@ public class LayerNameCreator {
protected int name(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
if (architectureElement instanceof CompositeElementSymbol){
return nameComposite((CompositeElementSymbol) architectureElement, stage, streamIndices);
}
else{
} else{
if (architectureElement.isAtomic()){
if (architectureElement.getMaxSerialLength().get() > 0){
return add(architectureElement, stage, streamIndices);
}
else {
} else {
return stage;
}
}
else {
} else {
ArchitectureElementSymbol resolvedElement = architectureElement.getResolvedThis().get();
return name(resolvedElement, stage, streamIndices);
}
......@@ -78,8 +75,7 @@ public class LayerNameCreator {
streamIndices.remove(lastIndex);
return Collections.max(endStages) + 1;
}
else {
} else {
int endStage = stage;
for (ArchitectureElementSymbol subElement : compositeElement.getElements()){
endStage = name(subElement, endStage, streamIndices);
......@@ -113,8 +109,7 @@ public class LayerNameCreator {
name = name + "_" + arrayAccess + "_";
}
return name;
}
else {
} else {
return createBaseName(architectureElement) + stage + createStreamPostfix(streamIndices) + "_";
}
}
......@@ -132,11 +127,9 @@ public class LayerNameCreator {
} else {
return layerDeclaration.getName().toLowerCase();
}
}
else if (architectureElement instanceof CompositeElementSymbol){
} else if (architectureElement instanceof CompositeElementSymbol){
return "group";
}
else {
} else {
return architectureElement.getName();
}
}
......
......@@ -43,6 +43,11 @@ public class TemplateConfiguration {
configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
}
private static void quitGeneration(){
Log.error("Code generation is aborted");
System.exit(1);
}
public Configuration getConfiguration() {
return configuration;
}
......@@ -58,14 +63,12 @@ public class TemplateConfiguration {
try{
Template template = TemplateConfiguration.get().getTemplate(templatePath);
template.process(ftlContext, writer);
}
catch (IOException e) {
} catch (IOException e) {
Log.error("Freemarker could not find template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
catch (TemplateException e){
quitGeneration();
} catch (TemplateException e){
Log.error("An exception occured in template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
quitGeneration();
}
}
......
......@@ -25,12 +25,14 @@ public class TrainParamSupportChecker implements CNNTrainVisitor {
public TrainParamSupportChecker() {
}
public String unsupportedOptFlag = "unsupported_optimizer";
public static final String unsupportedOptFlag = "unsupported_optimizer";
public List getUnsupportedElemList(){
return this.unsupportedElemList;
}
//Empty visit method denotes that the corresponding training parameter is supported.
//To set a training parameter as unsupported, add the corresponding node to the unsupportedElemList
public void visit(ASTNumEpochEntry node){}
public void visit(ASTBatchSizeEntry node){}
......@@ -76,10 +78,7 @@ public class TrainParamSupportChecker implements CNNTrainVisitor {
public void visit(ASTWeightDecayEntry node){}
public void visit(ASTLRDecayEntry node){
printUnsupportedOptimizerParam(node.getName());
this.unsupportedElemList.add(node.getName());
}
public void visit(ASTLRDecayEntry node){}
public void visit(ASTLRPolicyEntry node){}
......
......@@ -3,10 +3,12 @@ from caffe2.python.predictor import mobile_exporter
from caffe2.proto import caffe2_pb2
import numpy as np
import math
import datetime
import logging
import os
import sys
import lmdb
class ${tc.fileNameWithoutEnding}:
module = None
......@@ -27,6 +29,15 @@ class ${tc.fileNameWithoutEnding}:
return iterations_int
def get_epoch_as_iter(self, num_epoch, batch_size, dataset_size): #To print metric durint training process
#Force floating point calculation
batch_size_float = float(batch_size)
dataset_size_float = float(dataset_size)
epoch_float = math.ceil(dataset_size_float/batch_size_float)
epoch_int = int(epoch_float)
return epoch_int
def add_input(self, model, batch_size, db, db_type, device_opts):
with core.DeviceScope(device_opts):
......@@ -58,16 +69,20 @@ class ${tc.fileNameWithoutEnding}:
return data, label, dataset_size
def create_model(self, model, data, device_opts):
def create_model(self, model, data, device_opts, is_test):
with core.DeviceScope(device_opts):
${tc.include(tc.architecture.body)}
# this adds the loss and optimizer
def add_training_operators(self, model, output, label, device_opts, opt_type, base_learning_rate, policy, stepsize, epsilon, beta1, beta2, gamma, momentum) :
def add_training_operators(self, model, output, label, device_opts, loss, opt_type, base_learning_rate, policy, stepsize, epsilon, beta1, beta2, gamma, momentum) :
with core.DeviceScope(device_opts):
xent = model.LabelCrossEntropy([output, label], 'xent')
loss = model.AveragedLoss(xent, "loss")
if loss == 'cross_entropy':
xent = model.LabelCrossEntropy([output, label], 'xent')
loss = model.AveragedLoss(xent, "loss")
elif loss == 'euclidean':
dist = model.net.SquaredL2Distance([label, output], 'dist')
loss = dist.AveragedLoss([], ['loss'])
model.AddGradientOperators([loss])
......@@ -104,7 +119,7 @@ ${tc.include(tc.architecture.body)}
accuracy = brew.accuracy(model, [output, label], "accuracy", top_k=3)
return accuracy
def train(self, num_epoch=1000, batch_size=64, context='gpu', eval_metric='accuracy', opt_type='adam', base_learning_rate=0.001, weight_decay=0.001, policy='fixed', stepsize=1, epsilon=1E-8, beta1=0.9, beta2=0.999, gamma=0.999, momentum=0.9) :
def train(self, num_epoch=1000, batch_size=64, context='gpu', eval_metric='accuracy', loss='${tc.architectureLoss}', opt_type='adam', base_learning_rate=0.001, weight_decay=0.001, policy='fixed', stepsize=1, epsilon=1E-8, beta1=0.9, beta2=0.999, gamma=0.999, momentum=0.9) :
if context == 'cpu':
device_opts = core.DeviceOption(caffe2_pb2.CPU, 0)
print("CPU mode selected")
......@@ -118,9 +133,10 @@ ${tc.include(tc.architecture.body)}
# == Training model ==
train_model= model_helper.ModelHelper(name="train_net", arg_scope=arg_scope)
data, label, train_dataset_size = self.add_input(train_model, batch_size=batch_size, db=os.path.join(self._data_dir_, 'train_lmdb'), db_type='lmdb', device_opts=device_opts)
${tc.join(tc.architectureOutputs, ",", "","")} = self.create_model(train_model, data, device_opts=device_opts)
self.add_training_operators(train_model, ${tc.join(tc.architectureOutputs, ",", "","")}, label, device_opts, opt_type, base_learning_rate, policy, stepsize, epsilon, beta1, beta2, gamma, momentum)
self.add_accuracy(train_model, ${tc.join(tc.architectureOutputs, ",", "","")}, label, device_opts, eval_metric)
${tc.join(tc.architectureOutputs, ",", "","")} = self.create_model(train_model, data, device_opts=device_opts, is_test=False)
self.add_training_operators(train_model, ${tc.join(tc.architectureOutputs, ",", "","")}, label, device_opts, loss, opt_type, base_learning_rate, policy, stepsize, epsilon, beta1, beta2, gamma, momentum)
if not loss == 'euclidean':
self.add_accuracy(train_model, ${tc.join(tc.architectureOutputs, ",", "","")}, label, device_opts, eval_metric)
with core.DeviceScope(device_opts):
brew.add_weight_decay(train_model, weight_decay)
......@@ -130,38 +146,62 @@ ${tc.include(tc.architecture.body)}
# Main Training Loop
iterations = self.get_total_num_iter(num_epoch, batch_size, train_dataset_size)
print("** Starting Training for " + str(num_epoch) + " epochs = " + str(iterations) + " iterations **")
epoch_as_iter = self.get_epoch_as_iter(num_epoch, batch_size, train_dataset_size)
print("\n*** Starting Training for " + str(num_epoch) + " epochs = " + str(iterations) + " iterations ***")
start_date = datetime.datetime.now()
for i in range(iterations):
workspace.RunNet(train_model.net)
if i % 50 == 0:
print 'Iter ' + str(i) + ': ' + 'Loss ' + str(workspace.FetchBlob("loss")) + ' - ' + 'Accuracy ' + str(workspace.FetchBlob('accuracy'))
print("Training done")
print("== Running Test model ==")
if i % 50 == 0 or i % epoch_as_iter == 0:
if not loss == 'euclidean':
print 'Iter ' + str(i) + ': ' + 'Loss ' + str(workspace.FetchBlob("loss")) + ' - ' + 'Accuracy ' + str(workspace.FetchBlob('accuracy'))
else:
print 'Iter ' + str(i) + ': ' + 'Loss ' + str(workspace.FetchBlob("loss"))
current_time = datetime.datetime.now()
elapsed_time = current_time - start_date
print 'Progress: ' + str(i) + '/' + str(iterations) + ', ' +'Current time spent: ' + str(elapsed_time)
current_time = datetime.datetime.now()
elapsed_time = current_time - start_date
print 'Progress: ' + str(iterations) + '/' + str(iterations) + ' Training done' + ', ' + 'Total time spent: ' + str(elapsed_time)
print("\n*** Running Test model ***")
# == Testing model. ==
test_model= model_helper.ModelHelper(name="test_net", arg_scope=arg_scope, init_params=False)
data, label, test_dataset_size = self.add_input(test_model, batch_size=batch_size, db=os.path.join(self._data_dir_, 'test_lmdb'), db_type='lmdb', device_opts=device_opts)
${tc.join(tc.architectureOutputs, ",", "","")} = self.create_model(test_model, data, device_opts=device_opts)
self.add_accuracy(test_model, predictions, label, device_opts, eval_metric)
${tc.join(tc.architectureOutputs, ",", "","")} = self.create_model(test_model, data, device_opts=device_opts, is_test=True)
if not loss == 'euclidean':
self.add_accuracy(test_model, ${tc.join(tc.architectureOutputs, ",", "","")}, label, device_opts, eval_metric)
workspace.RunNetOnce(test_model.param_init_net)
workspace.CreateNet(test_model.net, overwrite=True)
# Main Testing Loop
test_accuracy = np.zeros(test_dataset_size/batch_size)
start_date = datetime.datetime.now()
for i in range(test_dataset_size/batch_size):
# Run a forward pass of the net on the current batch
workspace.RunNet(test_model.net)
# Collect the batch accuracy from the workspace
test_accuracy[i] = workspace.FetchBlob('accuracy')