Commit 170c13e6 authored by Thomas Michael Timmermanns's avatar Thomas Michael Timmermanns
Browse files

Implemented generics.

parent cfc2ae43
......@@ -30,8 +30,8 @@ component LeNet{
}
```
```
component VGG16(Z(2:oo) classes){
ports in Z(0:255)^{3, 224, 224} image,
component VGG16<Z(1:oo) channels=3, Z(1:oo) height=224, Z(1:oo) width=224, Z(2:oo) classes=1000>{
ports in Z(0:255)^{channels, height, width} image,
out Q(0:1)^{classes} predictions;
implementation CNN {
......@@ -64,8 +64,8 @@ component VGG16(Z(2:oo) classes){
}
```
```
component ResNet34(Z(2:oo) classes){
ports in Z(0:255)^{3, 224, 224} image,
component ResNet34<Z(1:oo) channels=3, Z(1:oo) height=224, Z(1:oo) width=224, Z(2:oo) classes=1000>{
ports in Z(0:255)^{channels, height, width} image,
out Q(0:1)^{classes} predictions;
implementation CNN {
......@@ -107,8 +107,8 @@ component ResNet34(Z(2:oo) classes){
}
```
```
component Alexnet(Z(2:oo) classes){
ports in Z(0:255)^{3, 224, 224} image,
component Alexnet<Z(1:oo) channels=3, Z(1:oo) height=224, Z(1:oo) width=224, Z(2:oo) classes=1000>{
ports in Z(0:255)^{channels, height, width} image,
out Q(0:1)^{classes} predictions;
implementation CNN {
......@@ -153,8 +153,8 @@ component Alexnet(Z(2:oo) classes){
}
```
```
component ResNeXt50(Z(2:oo) classes){
ports in Z(0:255)^{3, 224, 224} image,
component ResNeXt50<Z(1:oo) channels=3, Z(1:oo) height=224, Z(1:oo) width=224, Z(2:oo) classes=1000>{
ports in Z(0:255)^{channels, height, width} image,
out Q(0:1)^{classes} predictions;
implementation CNN {
......
......@@ -41,7 +41,7 @@
<se-commons.version>1.7.7</se-commons.version>
<mc.grammars.assembly.version>0.0.6-SNAPSHOT</mc.grammars.assembly.version>
<SIUnit.version>0.0.10-SNAPSHOT</SIUnit.version>
<Common-MontiCar.version>0.0.10-SNAPSHOT</Common-MontiCar.version>
<Common-MontiCar.version>0.0.11-SNAPSHOT</Common-MontiCar.version>
<Embedded-MontiArc.version>0.0.11-SNAPSHOT</Embedded-MontiArc.version>
<Embedded-MontiArc-Behaviour.version>0.0.11-SNAPSHOT</Embedded-MontiArc-Behaviour.version>
<CNNArch.version>0.2.0-SNAPSHOT</CNNArch.version>
......
......@@ -24,6 +24,7 @@ import de.monticore.EmbeddingModelingLanguage;
import de.monticore.antlr4.MCConcreteParser;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.EmbeddedMontiArcLanguage;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarcmath.adapter.PortArraySymbol2MathVariableDeclarationSymbolTypeFilter;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarcmath.adapter.ResolutionDeclarationSymbol2MathVariableDeclarationTypeFilter;
import de.monticore.lang.math.math._symboltable.MathLanguage;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.lang.monticar.emadl._parser.EMADLParser;
......@@ -58,10 +59,10 @@ public class EMADLLanguage extends EmbeddingModelingLanguage {
List<ResolvingFilter<? extends Symbol>> ret =
new ArrayList<>(super.getResolvingFilters());
ret.add(new PortArraySymbol2IODeclarationSymbolTypeFilter());
ret.add(new ResolutionDeclarationSymbol2VariableSymbolTypeFilter());
//ret.add(new ResolutionDeclarationSymbol2VariableSymbolTypeFilter());
ret.add(new PortArraySymbol2MathVariableDeclarationSymbolTypeFilter());
ret.add(new ResolutionDeclarationSymbol2VariableSymbolTypeFilter());
ret.add(new ResolutionDeclarationSymbol2MathVariableDeclarationTypeFilter());
return ret;
}
......
......@@ -20,14 +20,21 @@
*/
package de.monticore.lang.monticar.emadl._symboltable;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._ast.ASTComponent;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._ast.ASTEMACompilationUnit;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ComponentSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.EmbeddedMontiArcSymbolTableCreator;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.*;
import de.monticore.lang.embeddedmontiarc.helper.ArcTypePrinter;
import de.monticore.lang.monticar.types2._ast.ASTReferenceType;
import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.ResolvingConfiguration;
import de.monticore.symboltable.Scope;
import de.monticore.symboltable.Symbol;
import de.monticore.symboltable.modifiers.BasicAccessModifier;
import de.se_rwth.commons.logging.Log;
import java.util.Collection;
import java.util.Deque;
import java.util.Optional;
public class ModifiedEMASymbolTableCreator extends EmbeddedMontiArcSymbolTableCreator {
......
......@@ -26,7 +26,9 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ArchSimpleExpressionSymbo
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableType;
import de.monticore.lang.monticar.si._symboltable.ResolutionDeclarationSymbol;
import de.monticore.lang.monticar.ts.MCFieldSymbol;
import de.monticore.lang.monticar.types2._ast.ASTUnitNumberResolution;
import de.se_rwth.commons.logging.Log;
import java.util.*;
......@@ -53,23 +55,44 @@ public class ModifiedExpandedComponentInstanceBuilder extends ExpandedComponentI
}
public void addVariableSymbolsToInstance(ExpandedComponentInstanceSymbol instance){
for (int i = 0; i < instance.getArguments().size(); i++){
if (!(instance.getArguments().get(i) instanceof ASTMathNumberExpression)){
//add generics
for (ResolutionDeclarationSymbol sym : instance.getResolutionDeclarationSymbols()){
if (sym.getASTResolution() instanceof ASTUnitNumberResolution){
ASTUnitNumberResolution numberResolution = (ASTUnitNumberResolution) sym.getASTResolution();
VariableSymbol genericsParam = new VariableSymbol.Builder()
.name(sym.getNameToResolve())
.type(VariableType.ARCHITECTURE_PARAMETER)
.build();
genericsParam.setExpression(ArchSimpleExpressionSymbol.of(numberResolution.getNumber().get()));
instance.getSpannedScope().getAsMutableScope().add(genericsParam);
}
else {
Log.error("Argument type error. Arguments of a CNN component " +
"that are not numbers are not supported."
, instance.getArguments().get(i).get_SourcePositionStart());
"that are not numbers are currently not supported."
, sym.getSourcePosition());
}
ASTMathNumberExpression exp = (ASTMathNumberExpression) instance.getArguments().get(i);
}
//add configuration parameters
for (int i = 0; i < instance.getArguments().size(); i++){
if (instance.getArguments().get(i) instanceof ASTMathNumberExpression){
ASTMathNumberExpression exp = (ASTMathNumberExpression) instance.getArguments().get(i);
MCFieldSymbol emaParam = instance.getComponentType().getConfigParameters().get(i);
VariableSymbol archParam = new VariableSymbol.Builder()
.name(emaParam.getName())
.type(VariableType.ARCHITECTURE_PARAMETER)
.build();
archParam.setExpression(ArchSimpleExpressionSymbol.of(
exp.getNumber().getUnitNumber().get().getNumber().get()));
MCFieldSymbol emaParam = instance.getComponentType().getConfigParameters().get(i);
VariableSymbol archParam = new VariableSymbol.Builder()
.name(emaParam.getName())
.type(VariableType.ARCHITECTURE_PARAMETER)
.build();
archParam.setExpression(ArchSimpleExpressionSymbol.of(
exp.getNumber().getUnitNumber().get().getNumber().get()));
instance.getSpannedScope().getAsMutableScope().add(archParam);
instance.getSpannedScope().getAsMutableScope().add(archParam);
}
else {
Log.error("Argument type error. Arguments of a CNN component " +
"that are not numbers are currently not supported."
, instance.getArguments().get(i).get_SourcePositionStart());
}
}
}
......
......@@ -154,10 +154,15 @@ public class EMADLGenerator {
ExpandedComponentInstanceSymbol componentInstanceSymbol,
Scope symtab){
allInstances.add(componentInstanceSymbol);
ASTComponent astComponent = (ASTComponent) componentInstanceSymbol.getComponentType().getReferencedSymbol().getAstNode().get();
ComponentSymbol componentSymbol = componentInstanceSymbol.getComponentType().getReferencedSymbol();
/* remove the following two lines if the component symbol full name bug with generic variables is fixed */
componentSymbol.setFullName(null);
componentSymbol.getFullName();
/* */
Optional<ArchitectureSymbol> architecture = componentInstanceSymbol.getSpannedScope().resolve("", ArchitectureSymbol.KIND);
Optional<MathStatementsSymbol> mathStatements = astComponent.getSpannedScope().get().resolve("MathStatements", MathStatementsSymbol.KIND);
Optional<MathStatementsSymbol> mathStatements = componentSymbol.getSpannedScope().resolve("MathStatements", MathStatementsSymbol.KIND);
EMADLCocos.checkAll(componentInstanceSymbol);
......@@ -218,16 +223,6 @@ public class EMADLGenerator {
component = component.replaceFirst("public:",
"public:\n" + predictorClassName + " " + networkVariableName + ";");
/*
Pattern initPattern = Pattern.compile("void init\\(.*\\)\n\\{");
Matcher matcher = initPattern.matcher(component);
matcher.find();
String initMethodString = matcher.group(0);
//insert attribute initialization
component = component.replaceFirst("\\Q" + initMethodString,
initMethodString + "\n" + networkVariableName + " = " + predictorClassName + "();");*/
//insert execute method
component = component.replaceFirst("void execute\\(\\)\\s\\{\\s\\}",
"void execute(){\n" + executeMethod + "\n}");
......
......@@ -24,6 +24,8 @@ import de.monticore.ModelingLanguageFamily;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.embeddedmontiarc.LogConfig;
import de.monticore.lang.monticar.emadl._symboltable.EMADLLanguageFamily;
import de.monticore.lang.monticar.emadl.generator.AbstractSymtab;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
import de.monticore.symboltable.GlobalScope;
import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
......@@ -40,17 +42,9 @@ import java.util.stream.Collectors;
import static org.junit.Assert.assertTrue;
public class AbstractSymtabTest {
protected static Scope createSymTab(String... modelPath) {
ModelingLanguageFamily fam = new EMADLLanguageFamily();
protected static TaggingResolver createSymTab(String... modelPath) {
final ModelPath mp = new ModelPath();
for (String m : modelPath) {
mp.addEntry(Paths.get(m));
}
GlobalScope scope = new GlobalScope(mp, fam);
LogConfig.init();
return scope;
return AbstractSymtab.createSymTabAndTaggingResolver(modelPath);
}
public static void checkFilesAreEqual(Path generationPath, Path resultsPath, List<String> fileNames) {
......
package InstanceTest;
component CalculateClassB{
ports in Q(0:1)^{1,10} probabilities,
ports in Q(0:1)^{10} probabilities,
out Z(0:9) digit;
implementation Math{
......
......@@ -7,8 +7,8 @@ component MainB{
out Z(0:9) digit1,
out Z(0:9) digit2;
instance NetworkB(10,20) net1;
instance NetworkB(10,40) net2;
instance NetworkB<10> (20) net1;
instance NetworkB<10> (40) net2;
instance CalculateClassB outCalc1;
instance CalculateClassB outCalc2;
......
package InstanceTest;
component NetworkB(Z classes, Z convChannels){
component NetworkB<Z classes = 10> (Z convChannels){
ports in Z(0:255)^{1,28,28} data,
out Q(0:1)^{classes,1,1} predictions;
out Q(0:1)^{classes} predictions;
implementation CNN {
......
package cifar10;
component ArgMax(Z(1:oo) n){
component ArgMax<Z(1:oo) n = 2>{
ports in Q^{n} inputVector,
out Z(0:oo) maxIndex;
......
......@@ -6,9 +6,9 @@ component Cifar10Classifier{
ports in Z(0:255)^{3, 32, 32} image,
out Z(0:9) classIndex;
instance CifarNetwork(10) net;
instance CifarNetwork<10> net;
instance ArgMax(10) calculateClass;
instance ArgMax<10> calculateClass;
connect image -> net.data;
connect net.softmax -> calculateClass.inputVector;
......
package cifar10;
component CifarNetwork(Z(2:oo) classes){
component CifarNetwork<Z(2:oo) classes = 10>{
ports in Z(0:255)^{3, 32, 32} data,
out Q(0:1)^{classes} softmax;
......
......@@ -16,8 +16,8 @@ cifar10_cifar10Classifier_calculateClass calculateClass;
void init()
{
image = cube(3, 32, 32);
net.init(10);
calculateClass.init(10);
net.init();
calculateClass.init();
}
void execute()
{
......
......@@ -6,13 +6,12 @@
#include "armadillo"
using namespace arma;
class cifar10_cifar10Classifier_calculateClass{
const int n = 10;
public:
double n;
colvec inputVector;
double maxIndex;
void init(double n)
void init()
{
this->n = n;
inputVector=colvec(n);
}
void execute()
......
......@@ -8,14 +8,13 @@
#include "CNNTranslator.h"
using namespace arma;
class cifar10_cifar10Classifier_net{
const int classes = 10;
public:
CNNPredictor_cifar10_cifar10Classifier_net _cnn_;
double classes;
cube data;
colvec softmax;
void init(double classes)
void init()
{
this->classes = classes;
data = cube(3, 32, 32);
softmax=colvec(classes);
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment