Skip to content
Snippets Groups Projects
Commit aaa745b1 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'timmermanns' into 'master'

Removed github deploy plugin, changed repository to se nexus

See merge request CNNArch2MXNet!2
parents 0e3a557a 07d96962
No related branches found
No related tags found
No related merge requests found
Showing
with 482 additions and 0 deletions
<#assign input = element.inputs[0]>
<#if element.padding??>
<#assign input = element.name>
${element.name} = mx.symbol.pad(data=${element.inputs[0]},
mode='constant',
pad_width=(${tc.join(element.padding, ",")}),
constant_value=0)
</#if>
${element.name} = mx.symbol.Convolution(data=${input},
kernel=(${tc.join(element.kernel, ",")}),
stride=(${tc.join(element.stride, ",")}),
num_filter=${element.channels?c},
no_bias=${element.noBias?string("True","False")},
name="${element.name}")
<#include "OutputShape.ftl">
\ No newline at end of file
${element.name} = mx.symbol.Dropout(data=${element.inputs[0]},
p=${element.p?c},
name="${element.name}")
${element.name} = mx.symbol.Flatten(data=${element.inputs[0]},
name="${element.name}")
\ No newline at end of file
<#assign flatten = element.element.inputTypes[0].height != 1 || element.element.inputTypes[0].width != 1>
<#assign input = element.inputs[0]>
<#if flatten>
${element.name} = mx.symbol.flatten(data=${input})
<#assign input = element.name>
</#if>
${element.name} = mx.symbol.FullyConnected(data=${input},
num_hidden=${element.units?c},
no_bias=${element.noBias?string("True","False")},
name="${element.name}")
${element.name} = ${element.inputs[element.index]}
${element.name} = mx.symbol.Pooling(data=${element.inputs[0]},
global_pool=True,
kernel=(1,1),
pool_type=${element.poolType},
name="${element.name}")
<#include "OutputShape.ftl">
\ No newline at end of file
<#assign channelIndex = element.element.outputTypes[0].channelIndex + 1>
<#assign heightIndex = element.element.outputTypes[0].heightIndex + 1>
<#assign widthIndex = element.element.outputTypes[0].widthIndex + 1>
<#assign indexList = []>
<#if channelIndex != 0><#assign indexList = indexList + [channelIndex]></#if>
<#if heightIndex != 0><#assign indexList = indexList + [heightIndex]></#if>
<#if widthIndex != 0><#assign indexList = indexList + [widthIndex]></#if>
<#assign dimensions = element.element.outputTypes[0].dimensions>
${element.name} = mx.sym.var("${element.name}",
shape=(0,${tc.join(dimensions, ",")}))
<#include "OutputShape.ftl">
<#if heightIndex != channelIndex + 1 || widthIndex != heightIndex + 1>
${element.name} = mx.symbol.transpose(data=${element.name},
axes=(0,${tc.join(indexList, ",")}))
</#if>
<#if indexList?size != 3>
${element.name} = mx.symbol.reshape(data=${element.name},
shape=(0,${element.element.outputTypes[0].channels?c},${element.element.outputTypes[0].height?c},${element.element.outputTypes[0].width?c}))
</#if>
if not data_mean is None:
assert(not data_std is None)
_data_mean_ = mx.sym.Variable("_data_mean_", shape=(${tc.join(dimensions, ",")}), init=MyConstant(value=data_mean.tolist()))
_data_mean_ = mx.sym.BlockGrad(_data_mean_)
_data_std_ = mx.sym.Variable("_data_std_", shape=(${tc.join(dimensions, ",")}), init=MyConstant(value=data_mean.tolist()))
_data_std_ = mx.sym.BlockGrad(_data_std_)
${element.name} = mx.symbol.broadcast_sub(${element.name}, _data_mean_)
${element.name} = mx.symbol.broadcast_div(${element.name}, _data_std_)
${element.name} = mx.symbol.LRN(data=${element.inputs[0]},
alpha=${element.alpha?c},
beta=${element.beta?c},
knorm=${element.knorm?c},
nsize=${element.nsize?c},
name="${element.name}")
<#if element.softmaxOutput>
${element.name} = mx.symbol.SoftmaxOutput(data=${element.inputs[0]},
name="${element.name}")
<#elseif element.logisticRegressionOutput>
${element.name} = mx.symbol.LogisticRegressionOutput(data=${element.inputs[0]},
name="${element.name}")
<#elseif element.linearRegressionOutput>
${element.name} = mx.symbol.LinearRegressionOutput(data=${element.inputs[0]},
name="${element.name}")
</#if>
\ No newline at end of file
# ${element.name}, output shape: {<#list element.element.outputTypes as type>[${tc.join(type.dimensions, ",")}]</#list>}
<#assign input = element.inputs[0]>
<#if element.padding??>
<#assign input = element.name>
${element.name} = mx.symbol.pad(data=${element.inputs[0]},
mode='constant',
pad_width=(${tc.join(element.padding, ",")}),
constant_value=0)
</#if>
${element.name} = mx.symbol.Pooling(data=${input},
kernel=(${tc.join(element.kernel, ",")}),
pool_type=${element.poolType},
stride=(${tc.join(element.stride, ",")}),
name="${element.name}")
<#include "OutputShape.ftl">
\ No newline at end of file
${element.name} = mx.symbol.Activation(data=${element.inputs[0]},
act_type='relu',
name="${element.name}")
${element.name} = mx.symbol.Activation(data=${element.inputs[0]},
act_type='sigmoid',
name="${element.name}")
<#-- This template is not used if the followiing architecture element is an output. See Output.ftl -->
${element.name} = mx.symbol.softmax(data=${element.inputs[0]},
axis=1,
name="${element.name}")
${element.name} = mx.symbol.split(data=${element.inputs[0]},
num_outputs=${element.numOutputs?c},
axis=1,
name="${element.name}")
<#include "OutputShape.ftl">
\ No newline at end of file
${element.name} = mx.symbol.Activation(data=${element.inputs[0]},
act_type='tanh',
name="${element.name}")
<#list tc.architecture.outputs as output>
<#assign shape = output.definition.type.dimensions>
vector<float> CNN_${tc.getName(output)}(<#list shape as dim>${dim?c}<#if dim?has_next>*</#if></#list>);
</#list>
_cnn_.predict(<#list tc.architecture.inputs as input>CNNTranslator::translate(${input.name}<#if input.arrayAccess.isPresent()>[${input.arrayAccess.get().intValue.get()?c}]</#if>),
</#list><#list tc.architecture.outputs as output>CNN_${tc.getName(output)}<#if output?has_next>,
</#if></#list>);
<#list tc.architecture.outputs as output>
<#assign shape = output.definition.type.dimensions>
<#if shape?size == 1>
${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToCol(CNN_${tc.getName(output)}, std::vector<size_t> {${shape[0]?c}});
</#if>
<#if shape?size == 2>
${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToMat(CNN_${tc.getName(output)}, std::vector<size_t> {${shape[0]?c}, ${shape[1]?c}});
</#if>
<#if shape?size == 3>
${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToCube(CNN_${tc.getName(output)}, std::vector<size_t> {${shape[0]?c}, ${shape[1]?c}, ${shape[2]?c}});
</#if>
</#list>
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch;
import de.monticore.ModelingLanguageFamily;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.symboltable.GlobalScope;
import de.monticore.symboltable.Scope;
import org.junit.Assert;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.stream.Collectors;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
/**
* Common methods for symboltable tests
*/
public class AbstractSymtabTest {
private static final String MODEL_PATH = "src/test/resources/";
protected static Scope createSymTab(String... modelPath) {
ModelingLanguageFamily fam = new ModelingLanguageFamily();
fam.addModelingLanguage(new CNNArchLanguage());
final ModelPath mp = new ModelPath();
for (String m : modelPath) {
mp.addEntry(Paths.get(m));
}
GlobalScope scope = new GlobalScope(mp, fam);
return scope;
}
/* protected static ASTCNNArchCompilationUnit getAstNode(String modelPath, String model) {
Scope symTab = createSymTab(MODEL_PATH + modelPath);
CNNArchCompilationUnitSymbol comp = symTab.<CNNArchCompilationUnitSymbol> resolve(
model, CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull("Could not resolve model " + model, comp);
return (ASTCNNArchCompilationUnit) comp.getAstNode().get();
}*/
protected static CNNArchCompilationUnitSymbol getCompilationUnitSymbol(String modelPath, String model) {
Scope symTab = createSymTab(MODEL_PATH + modelPath);
CNNArchCompilationUnitSymbol comp = symTab.<CNNArchCompilationUnitSymbol> resolve(
model, CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull("Could not resolve model " + model, comp);
return comp;
}
public static void checkFilesAreEqual(Path generationPath, Path resultsPath, List<String> fileNames) {
for (String fileName : fileNames){
File genFile = new File(generationPath.toString() + "/" + fileName);
File fileTarget = new File(resultsPath.toString() + "/" + fileName);
assertTrue(areBothFilesEqual(genFile, fileTarget));
}
}
public static boolean areBothFilesEqual(File file1, File file2) {
if (!file1.exists()) {
Assert.fail("file does not exist: " + file1.getAbsolutePath());
return false;
}
if (!file2.exists()) {
Assert.fail("file does not exist: " + file2.getAbsolutePath());
return false;
}
List<String> lines1;
List<String> lines2;
try {
lines1 = Files.readAllLines(file1.toPath());
lines2 = Files.readAllLines(file2.toPath());
} catch (IOException e) {
e.printStackTrace();
Assert.fail("IO error: " + e.getMessage());
return false;
}
lines1 = discardEmptyLines(lines1);
lines2 = discardEmptyLines(lines2);
if (lines1.size() != lines2.size()) {
Assert.fail(
"files have different number of lines: "
+ file1.getAbsolutePath()
+ " has " + lines1
+ " lines and " + file2.getAbsolutePath() + " has " + lines2 + " lines"
);
return false;
}
int len = lines1.size();
for (int i = 0; i < len; i++) {
String l1 = lines1.get(i);
String l2 = lines2.get(i);
Assert.assertEquals("files differ in " + i + " line: "
+ file1.getAbsolutePath()
+ " has " + l1
+ " and " + file2.getAbsolutePath() + " has " + l2,
l1,
l2
);
}
return true;
}
private static List<String> discardEmptyLines(List<String> lines) {
return lines.stream()
.map(String::trim)
.filter(l -> !l.isEmpty())
.collect(Collectors.toList());
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch;
import de.monticore.lang.monticar.cnnarch.generator.CNNArchGeneratorCli;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.Arrays;
import static junit.framework.TestCase.assertTrue;
public class GenerationTest extends AbstractSymtabTest{
@Before
public void setUp() {
// ensure an empty log
Log.getFindings().clear();
Log.enableFailQuick(false);
}
@Test
public void testCifar10Classifier() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/valid_tests", "-r", "CifarClassifierNetwork", "-o", "./target/generated-sources-cnnarch/"};
CNNArchGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNCreator_CifarClassifierNetwork.py",
"CNNPredictor_CifarClassifierNetwork.h",
"execute_CifarClassifierNetwork",
"CNNBufferFile.h"));
}
@Test
public void testAlexnetGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/architectures", "-r", "Alexnet", "-o", "./target/generated-sources-cnnarch/"};
CNNArchGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNCreator_Alexnet.py",
"CNNPredictor_Alexnet.h",
"execute_Alexnet"));
}
@Test
public void testGeneratorVGG16() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/architectures", "-r", "VGG16", "-o", "./target/generated-sources-cnnarch/"};
CNNArchGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNCreator_VGG16.py",
"CNNPredictor_VGG16.h",
"execute_VGG16"));
}
@Test
public void testThreeInputCNNGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/architectures", "-r", "ThreeInputCNN_M14"};
CNNArchGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 1);
}
@Test
public void testResNeXtGeneration() throws IOException, TemplateException {
Log.getFindings().clear();;
String[] args = {"-m", "src/test/resources/architectures", "-r", "ResNeXt50"};
CNNArchGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
@Test
public void testMultipleOutputs() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/valid_tests", "-r", "MultipleOutputs"};
CNNArchGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 3);
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch;
import de.monticore.lang.monticar.cnnarch._parser.CNNArchParser;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
public class SymtabTest extends AbstractSymtabTest {
@Before
public void setUp() {
// ensure an empty log
Log.getFindings().clear();
Log.enableFailQuick(false);
}
@Test
public void testParsing() throws Exception {
CNNArchParser parser = new CNNArchParser();
assertTrue(parser.parse("src/test/resources/architectures/Alexnet.cnna").isPresent());
}
@Ignore
@Test
public void testAlexnet(){
Scope symTab = createSymTab("src/test/resources/architectures");
CNNArchCompilationUnitSymbol a = symTab.<CNNArchCompilationUnitSymbol>resolve(
"Alexnet",
CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull(a);
a.resolve();
a.getArchitecture().getBody().getOutputTypes();
}
@Ignore
@Test
public void testResNeXt(){
Scope symTab = createSymTab("src/test/resources/architectures");
CNNArchCompilationUnitSymbol a = symTab.<CNNArchCompilationUnitSymbol>resolve(
"ResNeXt50",
CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull(a);
a.resolve();
a.getArchitecture().getBody().getOutputTypes();
}
@Ignore
@Test
public void test3(){
Scope symTab = createSymTab("src/test/resources/valid_tests");
CNNArchCompilationUnitSymbol a = symTab.<CNNArchCompilationUnitSymbol>resolve(
"MultipleOutputs",
CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull(a);
a.resolve();
a.getArchitecture().getBody().getOutputTypes();
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment