Commit 7fb61c34 authored by Thomas Michael Timmermanns's avatar Thomas Michael Timmermanns Committed by Thomas Michael Timmermanns
Browse files

Fixed code generation errors

parent 6704cd5e
......@@ -52,6 +52,8 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class Generator {
......@@ -106,6 +108,8 @@ public class Generator {
fileContents.addAll(SimulatorIntegrationHelper.getSimulatorIntegrationHelperFileContent());
}
fixArmadilloImports(fileContents);
return fileContents;
}
......@@ -133,13 +137,21 @@ public class Generator {
}
}
private void fixArmadilloImports(List<FileContent> fileContents){
for (FileContent fileContent : fileContents){
fileContent.setFileContent(fileContent.getFileContent()
.replaceFirst("#include \"armadillo.h\"",
"#include \"armadillo\""));
}
}
public void generateCNN(List<FileContent> fileContents, TaggingResolver taggingResolver, ExpandedComponentInstanceSymbol instance, ArchitectureSymbol architecture){
CNNArchGenerator cnnArchGenerator = new CNNArchGenerator();
Map<String,String> contentMap = cnnArchGenerator.generateStrings(architecture);
String fullName = instance.getFullName();
String fullName = instance.getFullName().replaceAll("\\.", "_");
//get the components execute method
String executeKey = "execute_" + fullName.replaceAll("\\.", "_");
String executeKey = "execute_" + fullName;
String executeMethod = contentMap.get(executeKey);
if (executeMethod == null){
throw new IllegalStateException("execute method of " + fullName + " not found");
......@@ -159,7 +171,8 @@ public class Generator {
}
protected String transformComponent(String component, String predictorClassName, String executeMethod){
String networkVariableName = "cnn_";
String networkVariableName = "_cnn_";
//insert includes
component = component.replaceFirst("using namespace",
"#include \"" + predictorClassName + ".h" + "\"\n" +
......@@ -170,9 +183,15 @@ public class Generator {
component = component.replaceFirst("public:",
"public:\n" + predictorClassName + " " + networkVariableName + ";");
/*
Pattern initPattern = Pattern.compile("void init\\(.*\\)\n\\{");
Matcher matcher = initPattern.matcher(component);
matcher.find();
String initMethodString = matcher.group(0);
//insert attribute initialization
component = component.replaceFirst("void init\\(\\)\\s\\{",
"void init()\n{\n" + networkVariableName + "=" + predictorClassName + "();");
component = component.replaceFirst("\\Q" + initMethodString,
initMethodString + "\n" + networkVariableName + " = " + predictorClassName + "();");*/
//insert execute method
component = component.replaceFirst("void execute\\(\\)\\s\\{\\s\\}",
......
import logging
<#list componentNames as name>
import CNNCreator_${name?replace(".", "_")}
import mxnet as mx
<#list instances as instance>
import CNNCreator_${instance.fullName?replace(".", "_")}
</#list>
if __name__ == "__main__":
......@@ -11,7 +11,7 @@ if __name__ == "__main__":
logger.addHandler(handler)
<#list instances as instance>
${instance.fullName?replace(".", "_")} = CNNCreator_${instance.componentType.fullName?replace(".", "_")}.CNNCreator_${instance.componentType.fullName?replace(".", "_")}()
${instance.fullName?replace(".", "_")} = CNNCreator_${instance.fullName?replace(".", "_")}.CNNCreator_${instance.fullName?replace(".", "_")}()
${instance.fullName?replace(".", "_")}.train(
<#if (trainParams[instance_index])??>
${trainParams[instance_index]}
......
......@@ -51,7 +51,14 @@ public class GenerationTest {
@Test
public void testMnistGeneration() throws IOException, TemplateException {
generate("mnist.Main");
generate("mnist.MnistClassifier");
assertTrue(Log.getFindings().isEmpty());
}
@Test
@Ignore
public void testCifar10Generation() throws IOException, TemplateException {
generate("cifar10.Cifar10Classifier");
assertTrue(Log.getFindings().isEmpty());
}
......
......@@ -40,7 +40,7 @@ public class AllCoCoTest extends AbstractCoCoTest {
@Test
public void testCoCosSimulator() throws IOException {
checkValid("", "mnist.Main");
checkValid("", "mnist.MnistClassifier");
checkValid("", "Alexnet");
checkValid("", "VGG16");
checkValid("", "ThreeInputCNN_M14");
......
component Alexnet{
ports in Z(0:255)^{3, 224, 224} image,
out Q(0:1)^{10,1,1} predictions;
out Q(0:1)^{1000} predictions;
implementation CNN {
......@@ -39,7 +39,7 @@ component Alexnet{
split2(i=[0|1]) ->
Concatenate() ->
fc(->=2) ->
FullyConnected(units=10) ->
FullyConnected(units=1000) ->
Softmax() ->
predictions
......
component MultipleOutputs{
ports in Q(-oo:+oo)^{10,1,1} data,
out Q(0:1)^{4,1,1} pred[2];
out Q(0:1)^{4} pred[2];
implementation CNN {
......
component ResNeXt50{
ports in Z(0:255)^{3, 224, 224} image,
out Q(0:1)^{1000,1,1} predictions;
out Q(0:1)^{1000} predictions;
implementation CNN {
def conv(kernel, channels, stride=1, act=true){
......
component ResNet152{
ports in Z(0:255)^{3, 224, 224} image,
out Q(0:1)^{1000,1,1} predictions;
out Q(0:1)^{1000} predictions;
implementation CNN {
def conv(kernel, channels, stride=1, act=true){
......
component ResNet34{
ports in Z(0:255)^{3, 224, 224} image,
out Q(0:1)^{1000,1,1} predictions;
out Q(0:1)^{1000} predictions;
implementation CNN {
def conv(kernel, channels, stride=1, act=true){
......
component ThreeInputCNN_M14{
ports in Z(0:255)^{3, 224, 224} image[3],
out Q(0:1)^{10,1,1} predictions;
out Q(0:1)^{10} predictions;
implementation CNN {
......
component VGG16{
ports in Z(0:255)^{3, 224, 224} image,
out Q(0:1)^{1000,1,1} predictions;
out Q(0:1)^{1000} predictions;
implementation CNN {
......
package mnist;
component ArgMax(Z(1:oo) n){
ports in Q^{n} inputVector,
out Z(0:oo) maxIndex;
implementation Math{
maxIndex = 0;
Q maxValue = inputVector(0);
for i = 1:(n - 1)
if inputVector(i) > maxValue
maxIndex = i;
maxValue = inputVector(i);
end
end
}
}
\ No newline at end of file
package mnist;
component CalculateClass{
ports in Q(0:1)^{1,10} probabilities,
out Z(0:9) digit;
implementation Math{
Q(0:1:9) max = 0;
Q maxValue = 0.0;
for i = 1:10
Q prob = probabilities(1, i);
if prob > maxValue
max = i - 1;
maxValue = prob;
end
end
digit = max;
}
}
\ No newline at end of file
package mnist;
component Network(Z classes){
ports in Z(0:255)^{1,28,28} data,
out Q(0:1)^{classes,1,1} predictions;
component LeNet(Z(1:oo) channels, Z(1:oo) height, Z(1:oo) width, Z(2:oo) classes){
ports in Z(0:255)^{channels, height, width} data,
out Q(0:1)^{classes} predictions;
implementation CNN {
......
configuration LeNetConfig{
num_epoch:12
batch_size:100
load_checkpoint: false
optimizer:adam{
learning_rate:0.002
learning_rate_decay:0.9
step_size:500
}
}
......@@ -2,17 +2,17 @@ package mnist;
import Network;
import CalculateClass;
component Main{
component MnistClassifier{
ports in Z(0:255)^{1, 28, 28} image,
out Z(0:9) digit;
instance Network(10) net;
instance LeNet(1,28,28,10) net;
instance CalculateClass outCalc;
instance ArgMax(10) calculateClass;
connect image -> net.data;
connect net.predictions -> outCalc.probabilities;
connect outCalc.digit -> digit;
connect net.predictions -> calculateClass.inputVector;
connect calculateClass.maxIndex -> digit;
}
}
\ No newline at end of file
configuration NetworkConfig{
num_epoch:10
batch_size:100
load_checkpoint: false
optimizer:adam{
learning_rate:0.001
weight_decay:0.01
learning_rate_decay:0.9
learning_rate_policy:exp
step_size:1000
rescale_grad:1.1
clip_gradient:10
beta1:0.9
beta2:0.9
epsilon:0.000001
}
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment