Aufgrund einer Störung des s3 Storage, könnten in nächster Zeit folgende GitLab Funktionen nicht zur Verfügung stehen: LFS, Container Registry, Job Artifacs, Uploads (Wiki, Bilder, Projekt-Exporte). Wir bitten um Verständnis. Es wird mit Hochdruck an der Behebung des Problems gearbeitet. Weitere Informationen zur Störung des Object Storage finden Sie hier: https://maintenance.itc.rwth-aachen.de/ticket/status/messages/59-object-storage-pilot

Commit 21239311 authored by Christopher Jan-Steffen Brix's avatar Christopher Jan-Steffen Brix
Browse files

Merge branch 'master' into oneclick_nn_training

parents 905b93d0 abbd6de0
......@@ -8,19 +8,19 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-emadl-generator</artifactId>
<version>0.2.5-SNAPSHOT</version>
<version>0.2.9</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<emadl.version>0.2.3</emadl.version>
<CNNTrain.version>0.2.5</CNNTrain.version>
<cnnarch-mxnet-generator.version>0.2.5-SNAPSHOT</cnnarch-mxnet-generator.version>
<cnnarch-caffe2-generator.version>0.2.7-SNAPSHOT</cnnarch-caffe2-generator.version>
<embedded-montiarc-math-opt-generator>0.1.0</embedded-montiarc-math-opt-generator>
<emadl.version>0.2.6</emadl.version>
<CNNTrain.version>0.2.6</CNNTrain.version>
<cnnarch-mxnet-generator.version>0.2.14-SNAPSHOT</cnnarch-mxnet-generator.version>
<cnnarch-caffe2-generator.version>0.2.9</cnnarch-caffe2-generator.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
<junit.version>4.12</junit.version>
......
......@@ -23,7 +23,7 @@ package de.monticore.lang.monticar.emadl.generator;
import de.monticore.ModelingLanguageFamily;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.embeddedmontiarc.LogConfig;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ConstantPortSymbol;
import de.monticore.lang.embeddedmontiarc.helper.ConstantPortHelper;
import de.monticore.lang.monticar.emadl._symboltable.EMADLLanguage;
import de.monticore.lang.monticar.enumlang._symboltable.EnumLangLanguage;
import de.monticore.lang.monticar.generator.cpp.converter.MathConverter;
......@@ -62,7 +62,7 @@ public class EMADLAbstractSymtab {
}
public static Scope createSymTab(String... modelPath) {
ConstantPortSymbol.resetLastID();
ConstantPortHelper.resetLastID();
MathConverter.resetIDs();
ThreadingOptimizer.resetID();
ModelingLanguageFamily fam = new ModelingLanguageFamily();
......
......@@ -24,8 +24,8 @@ import com.google.common.base.Charsets;
import com.google.common.base.Joiner;
import com.google.common.base.Splitter;
import com.google.common.io.Resources;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ComponentSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ExpandedComponentInstanceSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.cncModel.EMAComponentSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.math._symboltable.MathStatementsSymbol;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.DataPathConfigParser;
......@@ -111,7 +111,7 @@ public class EMADLGenerator {
public void generate(String modelPath, String qualifiedName, String pythonPath, String forced) throws IOException, TemplateException {
setModelsPath( modelPath );
TaggingResolver symtab = EMADLAbstractSymtab.createSymTabAndTaggingResolver(getModelsPath());
ComponentSymbol component = symtab.<ComponentSymbol>resolve(qualifiedName, ComponentSymbol.KIND).orElse(null);
EMAComponentSymbol component = symtab.<EMAComponentSymbol>resolve(qualifiedName, EMAComponentSymbol.KIND).orElse(null);
List<String> splitName = Splitters.DOT.splitToList(qualifiedName);
String componentName = splitName.get(splitName.size() - 1);
......@@ -122,7 +122,7 @@ public class EMADLGenerator {
System.exit(1);
}
ExpandedComponentInstanceSymbol instance = component.getEnclosingScope().<ExpandedComponentInstanceSymbol>resolve(instanceName, ExpandedComponentInstanceSymbol.KIND).get();
EMAComponentInstanceSymbol instance = component.getEnclosingScope().<EMAComponentInstanceSymbol>resolve(instanceName, EMAComponentInstanceSymbol.KIND).get();
generateFiles(symtab, instance, symtab, pythonPath, forced);
......@@ -187,9 +187,9 @@ public class EMADLGenerator {
}
}
public void generateFiles(TaggingResolver taggingResolver, ExpandedComponentInstanceSymbol componentSymbol, Scope symtab, String pythonPath, String forced) throws IOException {
Set<ExpandedComponentInstanceSymbol> allInstances = new HashSet<>();
List<FileContent> fileContents = generateStrings(taggingResolver, componentSymbol, symtab, allInstances, forced);
public void generateFiles(TaggingResolver taggingResolver, EMAComponentInstanceSymbol EMAComponentSymbol, Scope symtab, String pythonPath, String forced) throws IOException {
Set<EMAComponentInstanceSymbol> allInstances = new HashSet<>();
List<FileContent> fileContents = generateStrings(taggingResolver, EMAComponentSymbol, symtab, allInstances, forced);
for (FileContent fileContent : fileContents) {
emamGen.generateFile(fileContent);
......@@ -203,7 +203,7 @@ public class EMADLGenerator {
List<FileContent> fileContentsTrainingHashes = new ArrayList<>();
List<String> newHashes = new ArrayList<>();
for (ExpandedComponentInstanceSymbol componentInstance : allInstances) {
for (EMAComponentInstanceSymbol componentInstance : allInstances) {
Optional<ArchitectureSymbol> architecture = componentInstance.getSpannedScope().resolve("", ArchitectureSymbol.KIND);
if(!architecture.isPresent()) {
......@@ -288,9 +288,9 @@ public class EMADLGenerator {
return stringBuffer.toString();
}
private boolean isAlreadyTrained(String trainingHash, ExpandedComponentInstanceSymbol componentInstance) {
private boolean isAlreadyTrained(String trainingHash, EMAComponentInstanceSymbol componentInstance) {
try {
ComponentSymbol component = componentInstance.getComponentType().getReferencedSymbol();
EMAComponentSymbol component = componentInstance.getComponentType().getReferencedSymbol();
String componentConfigFilename = component.getFullName().replaceAll("\\.", "/");
String checkFilePathString = getGenerationTargetPath() + componentConfigFilename + ".training_hash";
......@@ -311,7 +311,7 @@ public class EMADLGenerator {
}
}
public List<FileContent> generateStrings(TaggingResolver taggingResolver, ExpandedComponentInstanceSymbol componentInstanceSymbol, Scope symtab, Set<ExpandedComponentInstanceSymbol> allInstances, String forced){
public List<FileContent> generateStrings(TaggingResolver taggingResolver, EMAComponentInstanceSymbol componentInstanceSymbol, Scope symtab, Set<EMAComponentInstanceSymbol> allInstances, String forced){
List<FileContent> fileContents = new ArrayList<>();
generateComponent(fileContents, allInstances, taggingResolver, componentInstanceSymbol, symtab);
......@@ -343,30 +343,30 @@ public class EMADLGenerator {
}
protected void generateComponent(List<FileContent> fileContents,
Set<ExpandedComponentInstanceSymbol> allInstances,
Set<EMAComponentInstanceSymbol> allInstances,
TaggingResolver taggingResolver,
ExpandedComponentInstanceSymbol componentInstanceSymbol,
EMAComponentInstanceSymbol componentInstanceSymbol,
Scope symtab){
allInstances.add(componentInstanceSymbol);
ComponentSymbol componentSymbol = componentInstanceSymbol.getComponentType().getReferencedSymbol();
EMAComponentSymbol EMAComponentSymbol = componentInstanceSymbol.getComponentType().getReferencedSymbol();
/* remove the following two lines if the component symbol full name bug with generic variables is fixed */
componentSymbol.setFullName(null);
componentSymbol.getFullName();
EMAComponentSymbol.setFullName(null);
EMAComponentSymbol.getFullName();
/* */
Optional<ArchitectureSymbol> architecture = componentInstanceSymbol.getSpannedScope().resolve("", ArchitectureSymbol.KIND);
Optional<MathStatementsSymbol> mathStatements = componentSymbol.getSpannedScope().resolve("MathStatements", MathStatementsSymbol.KIND);
Optional<MathStatementsSymbol> mathStatements = EMAComponentSymbol.getSpannedScope().resolve("MathStatements", MathStatementsSymbol.KIND);
EMADLCocos.checkAll(componentInstanceSymbol);
if (architecture.isPresent()){
DataPathConfigParser newParserConfig = new DataPathConfigParser(getModelsPath() + "data_paths.txt");
String dPath = newParserConfig.getDataPath(componentSymbol.getFullName());
String dPath = newParserConfig.getDataPath(EMAComponentSymbol.getFullName());
/*String dPath = DataPathConfigParser.getDataPath(getModelsPath() + "data_paths.txt", componentSymbol.getFullName());*/
architecture.get().setDataPath(dPath);
architecture.get().setComponentName(componentSymbol.getFullName());
architecture.get().setComponentName(EMAComponentSymbol.getFullName());
generateCNN(fileContents, taggingResolver, componentInstanceSymbol, architecture.get());
}
else if (mathStatements.isPresent()){
......@@ -385,7 +385,7 @@ public class EMADLGenerator {
}
}
public void generateCNN(List<FileContent> fileContents, TaggingResolver taggingResolver, ExpandedComponentInstanceSymbol instance, ArchitectureSymbol architecture){
public void generateCNN(List<FileContent> fileContents, TaggingResolver taggingResolver, EMAComponentInstanceSymbol instance, ArchitectureSymbol architecture){
Map<String,String> contentMap = cnnArchGenerator.generateStrings(architecture);
String fullName = instance.getFullName().replaceAll("\\.", "_");
......@@ -428,16 +428,16 @@ public class EMADLGenerator {
return component;
}
public void generateMathComponent(List<FileContent> fileContents, TaggingResolver taggingResolver, ExpandedComponentInstanceSymbol componentSymbol, MathStatementsSymbol mathStatementsSymbol){
public void generateMathComponent(List<FileContent> fileContents, TaggingResolver taggingResolver, EMAComponentInstanceSymbol EMAComponentSymbol, MathStatementsSymbol mathStatementsSymbol){
fileContents.add(new FileContent(
emamGen.generateString(taggingResolver, componentSymbol, mathStatementsSymbol),
componentSymbol));
emamGen.generateString(taggingResolver, EMAComponentSymbol, mathStatementsSymbol),
EMAComponentSymbol));
}
public void generateSubComponents(List<FileContent> fileContents, Set<ExpandedComponentInstanceSymbol> allInstances, TaggingResolver taggingResolver, ExpandedComponentInstanceSymbol componentInstanceSymbol, Scope symtab){
public void generateSubComponents(List<FileContent> fileContents, Set<EMAComponentInstanceSymbol> allInstances, TaggingResolver taggingResolver, EMAComponentInstanceSymbol componentInstanceSymbol, Scope symtab){
fileContents.add(new FileContent(emamGen.generateString(taggingResolver, componentInstanceSymbol, (MathStatementsSymbol) null), componentInstanceSymbol));
String lastNameWithoutArrayPart = "";
for (ExpandedComponentInstanceSymbol instanceSymbol : componentInstanceSymbol.getSubComponents()) {
for (EMAComponentInstanceSymbol instanceSymbol : componentInstanceSymbol.getSubComponents()) {
int arrayBracketIndex = instanceSymbol.getName().indexOf("[");
boolean generateComponentInstance = true;
if (arrayBracketIndex != -1) {
......@@ -478,10 +478,10 @@ public class EMADLGenerator {
return trainConfigFilename;
}
public List<FileContent> generateCNNTrainer(Set<ExpandedComponentInstanceSymbol> allInstances, String mainComponentName) {
public List<FileContent> generateCNNTrainer(Set<EMAComponentInstanceSymbol> allInstances, String mainComponentName) {
List<FileContent> fileContents = new ArrayList<>();
for (ExpandedComponentInstanceSymbol componentInstance : allInstances) {
ComponentSymbol component = componentInstance.getComponentType().getReferencedSymbol();
for (EMAComponentInstanceSymbol componentInstance : allInstances) {
EMAComponentSymbol component = componentInstance.getComponentType().getReferencedSymbol();
Optional<ArchitectureSymbol> architecture = component.getSpannedScope().resolve("", ArchitectureSymbol.KIND);
if (architecture.isPresent()) {
......
......@@ -81,6 +81,7 @@ public class GenerationTest extends AbstractSymtabTest {
Paths.get("./target/generated-sources-emadl"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"cifar10_cifar10Classifier.cpp",
"cifar10_cifar10Classifier.h",
"CNNCreator_cifar10_cifar10Classifier_net.py",
"CNNBufferFile.h",
......@@ -160,6 +161,27 @@ public class GenerationTest extends AbstractSymtabTest {
}
}
@Test
public void testMnistClassifier() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "mnist.MnistClassifier", "-b", "CAFFE2", "-f", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-emadl"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"mnist_mnistClassifier.cpp",
"mnist_mnistClassifier.h",
"CNNCreator_mnist_mnistClassifier_net.py",
"CNNPredictor_mnist_mnistClassifier_net.h",
"mnist_mnistClassifier_net.h",
"CNNTranslator.h",
"mnist_mnistClassifier_calculateClass.h",
"CNNTrainer_mnist_mnistClassifier_net.py"));
}
@Test
public void testHashFunction() {
EMADLGenerator tester = new EMADLGenerator(Backend.MXNET);
......
......@@ -20,8 +20,8 @@
*/
package de.monticore.lang.monticar.emadl;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ComponentSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ExpandedComponentInstanceSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.cncModel.EMAComponentSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.monticar.emadl._parser.EMADLParser;
import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
......@@ -49,8 +49,8 @@ public class SymtabTest extends AbstractSymtabTest {
@Test
public void testAlexnet(){
Scope symTab = createSymTab("src/test/resources/models");
ComponentSymbol a = symTab.<ComponentSymbol>resolve("ResNet34", ComponentSymbol.KIND).orElse(null);
ExpandedComponentInstanceSymbol c = symTab.<ExpandedComponentInstanceSymbol>resolve("resNet34", ExpandedComponentInstanceSymbol.KIND).orElse(null);
EMAComponentSymbol a = symTab.<EMAComponentSymbol>resolve("ResNet34", EMAComponentSymbol.KIND).orElse(null);
EMAComponentInstanceSymbol c = symTab.<EMAComponentInstanceSymbol>resolve("resNet34", EMAComponentInstanceSymbol.KIND).orElse(null);
assertNotNull(a);
}
......
cifar10.CifarNetwork src/test/resources/training_data
MultipleOutputs data/MultipleOutputs
instanceTest.NetworkB data/InstanceTest.NetworkB
InstanceTest.NetworkB data/InstanceTest.NetworkB
Alexnet data/Alexnet
ThreeInputCNN_M14 data/ThreeInputCNN_M14
VGG16 data/VGG16
ResNeXt50 data/ResNeXt50
instanceTestCifar.CifarNetwork src/test/resources/training_data
\ No newline at end of file
instanceTestCifar.CifarNetwork src/test/resources/training_data
mnist.LeNetNetwork data/mnist.LeNetNetwork
\ No newline at end of file
package mnist;
import Network;
import CalculateClass;
component MnistClassifier{
ports in Z(0:255)^{1, 28, 28} image,
......
from caffe2.python import workspace, core, model_helper, brew, optimizer
from caffe2.python.predictor import mobile_exporter
from caffe2.proto import caffe2_pb2
import numpy as np
import math
import datetime
import logging
import os
import sys
import lmdb
class CNNCreator_mnist_mnistClassifier_net:
module = None
_current_dir_ = os.path.join('./')
_data_dir_ = os.path.join(_current_dir_, 'data', 'mnist_mnistClassifier_net')
_model_dir_ = os.path.join(_current_dir_, 'model', 'mnist_mnistClassifier_net')
_init_net_ = os.path.join(_model_dir_, 'init_net.pb')
_predict_net_ = os.path.join(_model_dir_, 'predict_net.pb')
def get_total_num_iter(self, num_epoch, batch_size, dataset_size):
#Force floating point calculation
batch_size_float = float(batch_size)
dataset_size_float = float(dataset_size)
iterations_float = math.ceil(num_epoch*(dataset_size_float/batch_size_float))
iterations_int = int(iterations_float)
return iterations_int
def get_epoch_as_iter(self, num_epoch, batch_size, dataset_size): #To print metric durint training process
#Force floating point calculation
batch_size_float = float(batch_size)
dataset_size_float = float(dataset_size)
epoch_float = math.ceil(dataset_size_float/batch_size_float)
epoch_int = int(epoch_float)
return epoch_int
def add_input(self, model, batch_size, db, db_type, device_opts):
with core.DeviceScope(device_opts):
if not os.path.isdir(db):
logging.error("Data loading failure. Directory '" + os.path.abspath(db) + "' does not exist.")
sys.exit(1)
elif not (os.path.isfile(os.path.join(db, 'data.mdb')) and os.path.isfile(os.path.join(db, 'lock.mdb'))):
logging.error("Data loading failure. Directory '" + os.path.abspath(db) + "' does not contain lmdb files.")
sys.exit(1)
# load the data
data_uint8, label = brew.db_input(
model,
blobs_out=["data_uint8", "label"],
batch_size=batch_size,
db=db,
db_type=db_type,
)
# cast the data to float
data = model.Cast(data_uint8, "data", to=core.DataType.FLOAT)
# scale data from [0,255] down to [0,1]
data = model.Scale(data, data, scale=float(1./256))
# don't need the gradient for the backward pass
data = model.StopGradient(data, data)
dataset_size = int (lmdb.open(db).stat()['entries'])
return data, label, dataset_size
def create_model(self, model, data, device_opts, is_test):
with core.DeviceScope(device_opts):
image = data
# image, output shape: {[1,28,28]}
conv1_ = brew.conv(model, image, 'conv1_', dim_in=1, dim_out=20, kernel=5, stride=1)
# conv1_, output shape: {[20,24,24]}
pool1_ = brew.max_pool(model, conv1_, 'pool1_', kernel=2, stride=2)
# pool1_, output shape: {[20,12,12]}
conv2_ = brew.conv(model, pool1_, 'conv2_', dim_in=20, dim_out=50, kernel=5, stride=1)
# conv2_, output shape: {[50,8,8]}
pool2_ = brew.max_pool(model, conv2_, 'pool2_', kernel=2, stride=2)
# pool2_, output shape: {[50,4,4]}
fc2_ = brew.fc(model, pool2_, 'fc2_', dim_in=50 * 4 * 4, dim_out=500)
# fc2_, output shape: {[500,1,1]}
relu2_ = brew.relu(model, fc2_, fc2_)
fc3_ = brew.fc(model, relu2_, 'fc3_', dim_in=500, dim_out=10)
# fc3_, output shape: {[10,1,1]}
predictions = brew.softmax(model, fc3_, 'predictions')
return predictions
# this adds the loss and optimizer
def add_training_operators(self, model, output, label, device_opts, loss, opt_type, base_learning_rate, policy, stepsize, epsilon, beta1, beta2, gamma, momentum) :
with core.DeviceScope(device_opts):
if loss == 'cross_entropy':
xent = model.LabelCrossEntropy([output, label], 'xent')
loss = model.AveragedLoss(xent, "loss")
elif loss == 'euclidean':
dist = model.net.SquaredL2Distance([label, output], 'dist')
loss = dist.AveragedLoss([], ['loss'])
model.AddGradientOperators([loss])
if opt_type == 'adam':
if policy == 'step':
opt = optimizer.build_adam(model, base_learning_rate=base_learning_rate, policy=policy, stepsize=stepsize, beta1=beta1, beta2=beta2, epsilon=epsilon)
elif policy == 'fixed' or policy == 'inv':
opt = optimizer.build_adam(model, base_learning_rate=base_learning_rate, policy=policy, beta1=beta1, beta2=beta2, epsilon=epsilon)
print("adam optimizer selected")
elif opt_type == 'sgd':
if policy == 'step':
opt = optimizer.build_sgd(model, base_learning_rate=base_learning_rate, policy=policy, stepsize=stepsize, gamma=gamma, momentum=momentum)
elif policy == 'fixed' or policy == 'inv':
opt = optimizer.build_sgd(model, base_learning_rate=base_learning_rate, policy=policy, gamma=gamma, momentum=momentum)
print("sgd optimizer selected")
elif opt_type == 'rmsprop':
if policy == 'step':
opt = optimizer.build_rms_prop(model, base_learning_rate=base_learning_rate, policy=policy, stepsize=stepsize, decay=gamma, momentum=momentum, epsilon=epsilon)
elif policy == 'fixed' or policy == 'inv':
opt = optimizer.build_rms_prop(model, base_learning_rate=base_learning_rate, policy=policy, decay=gamma, momentum=momentum, epsilon=epsilon)
print("rmsprop optimizer selected")
elif opt_type == 'adagrad':
if policy == 'step':
opt = optimizer.build_adagrad(model, base_learning_rate=base_learning_rate, policy=policy, stepsize=stepsize, decay=gamma, epsilon=epsilon)
elif policy == 'fixed' or policy == 'inv':
opt = optimizer.build_adagrad(model, base_learning_rate=base_learning_rate, policy=policy, decay=gamma, epsilon=epsilon)
print("adagrad optimizer selected")
def add_accuracy(self, model, output, label, device_opts, eval_metric):
with core.DeviceScope(device_opts):
if eval_metric == 'accuracy':
accuracy = brew.accuracy(model, [output, label], "accuracy")
elif eval_metric == 'top_k_accuracy':
accuracy = brew.accuracy(model, [output, label], "accuracy", top_k=3)
return accuracy
def train(self, num_epoch=1000, batch_size=64, context='gpu', eval_metric='accuracy', loss='cross_entropy', opt_type='adam', base_learning_rate=0.001, weight_decay=0.001, policy='fixed', stepsize=1, epsilon=1E-8, beta1=0.9, beta2=0.999, gamma=0.999, momentum=0.9) :
if context == 'cpu':
device_opts = core.DeviceOption(caffe2_pb2.CPU, 0)
print("CPU mode selected")
elif context == 'gpu':
device_opts = core.DeviceOption(caffe2_pb2.CUDA, 0)
print("GPU mode selected")
workspace.ResetWorkspace(self._model_dir_)
arg_scope = {"order": "NCHW"}
# == Training model ==
train_model= model_helper.ModelHelper(name="train_net", arg_scope=arg_scope)
data, label, train_dataset_size = self.add_input(train_model, batch_size=batch_size, db=os.path.join(self._data_dir_, 'train_lmdb'), db_type='lmdb', device_opts=device_opts)
predictions = self.create_model(train_model, data, device_opts=device_opts, is_test=False)
self.add_training_operators(train_model, predictions, label, device_opts, loss, opt_type, base_learning_rate, policy, stepsize, epsilon, beta1, beta2, gamma, momentum)
if not loss == 'euclidean':
self.add_accuracy(train_model, predictions, label, device_opts, eval_metric)
with core.DeviceScope(device_opts):
brew.add_weight_decay(train_model, weight_decay)
# Initialize and create the training network
workspace.RunNetOnce(train_model.param_init_net)
workspace.CreateNet(train_model.net, overwrite=True)
# Main Training Loop
iterations = self.get_total_num_iter(num_epoch, batch_size, train_dataset_size)
epoch_as_iter = self.get_epoch_as_iter(num_epoch, batch_size, train_dataset_size)
print("\n*** Starting Training for " + str(num_epoch) + " epochs = " + str(iterations) + " iterations ***")
start_date = datetime.datetime.now()
for i in range(iterations):
workspace.RunNet(train_model.net)
if i % 50 == 0 or i % epoch_as_iter == 0:
if not loss == 'euclidean':
print 'Iter ' + str(i) + ': ' + 'Loss ' + str(workspace.FetchBlob("loss")) + ' - ' + 'Accuracy ' + str(workspace.FetchBlob('accuracy'))
else:
print 'Iter ' + str(i) + ': ' + 'Loss ' + str(workspace.FetchBlob("loss"))
current_time = datetime.datetime.now()
elapsed_time = current_time - start_date
print 'Progress: ' + str(i) + '/' + str(iterations) + ', ' +'Current time spent: ' + str(elapsed_time)
current_time = datetime.datetime.now()
elapsed_time = current_time - start_date
print 'Progress: ' + str(iterations) + '/' + str(iterations) + ' Training done' + ', ' + 'Total time spent: ' + str(elapsed_time)
print("\n*** Running Test model ***")
# == Testing model. ==
test_model= model_helper.ModelHelper(name="test_net", arg_scope=arg_scope, init_params=False)
data, label, test_dataset_size = self.add_input(test_model, batch_size=batch_size, db=os.path.join(self._data_dir_, 'test_lmdb'), db_type='lmdb', device_opts=device_opts)
predictions = self.create_model(test_model, data, device_opts=device_opts, is_test=True)
if not loss == 'euclidean':
self.add_accuracy(test_model, predictions, label, device_opts, eval_metric)
workspace.RunNetOnce(test_model.param_init_net)
workspace.CreateNet(test_model.net, overwrite=True)
# Main Testing Loop
test_accuracy = np.zeros(test_dataset_size/batch_size)
start_date = datetime.datetime.now()
for i in range(test_dataset_size/batch_size):
# Run a forward pass of the net on the current batch
workspace.RunNet(test_model.net)
# Collect the batch accuracy from the workspace
if not loss == 'euclidean':
test_accuracy[i] = workspace.FetchBlob('accuracy')
print 'Iter ' + str(i) + ': ' + 'Accuracy ' + str(workspace.FetchBlob("accuracy"))
else:
test_accuracy[i] = workspace.FetchBlob("loss")
print 'Iter ' + str(i) + ': ' + 'Loss ' + str(workspace.FetchBlob("loss"))
current_time = datetime.datetime.now()
elapsed_time = current_time - start_date
print 'Progress: ' + str(i) + '/' + str(test_dataset_size/batch_size) + ', ' +'Current time spent: ' + str(elapsed_time)
current_time = datetime.datetime.now()
elapsed_time = current_time - start_date
print 'Progress: ' + str(test_dataset_size/batch_size) + '/' + str(test_dataset_size/batch_size) + ' Testing done' + ', ' + 'Total time spent: ' + str(elapsed_time)
print('Test accuracy mean: {:.9f}'.format(test_accuracy.mean()))
# == Deployment model. ==
# We simply need the main AddModel part.
deploy_model = model_helper.ModelHelper(name="deploy_net", arg_scope=arg_scope, init_params=False)
self.create_model(deploy_model, "data", device_opts, is_test=True)
print("\n*** Saving deploy model ***")
self.save_net(self._init_net_, self._predict_net_, deploy_model)
def save_net(self, init_net_path, predict_net_path, model):
init_net, predict_net = mobile_exporter.Export(
workspace,
model.net,
model.params
)
try:
os.makedirs(self._model_dir_)
except OSError:
if not os.path.isdir(self._model_dir_):
raise
print("Save the model to init_net.pb and predict_net.pb")
with open(predict_net_path, 'wb') as f:
f.write(model.net._net.SerializeToString())
with open(init_net_path, 'wb') as f:
f.write(init_net.SerializeToString())
print("Save the model to init_net.pbtxt and predict_net.pbtxt as additional information")
with open(init_net_path.replace('.pb','.pbtxt'), 'w') as f:
f.write(str(init_net))
with open(predict_net_path.replace('.pb','.pbtxt'), 'w') as f:
f.write(str(predict_net))
print("== Saved init_net and predict_net ==")
def load_net(self, init_net_path, predict_net_path, device_opts):
if not os.path.isfile(init_net_path):
logging.error("Network loading failure. File '" + os.path.abspath(init_net_path) + "' does not exist.")
sys.exit(1)
elif not os.path.isfile(predict_net_path):
logging.error("Network loading failure. File '" + os.path.abspath(predict_net_path) + "' does not exist.")
sys.exit(1)
init_def = caffe2_pb2.NetDef()
with open(init_net_path, 'rb') as f:
init_def.ParseFromString(f.read())