Commit b7c85c9d authored by Christian Fuß's avatar Christian Fuß
Browse files

fixed a problem with inferring size for OneHotLayer, when no 'size' parameter...

fixed a problem with inferring size for OneHotLayer, when no 'size' parameter is given. Progress on Unroll feature
parent 6e5d8673
Pipeline #155401 passed with stages
in 18 minutes and 32 seconds
......@@ -47,7 +47,7 @@ public class CheckArgument implements CNNArchASTArchArgumentCoCo {
}else if(argument.getEnclosingScope().getSpanningSymbol().get() instanceof UnrollSymbol){
UnrollDeclarationSymbol layerDeclaration = argument.getUnroll().getDeclaration();
if (layerDeclaration != null && argument.getParameter() == null){
if (layerDeclaration != null && argument.getUnrollParameter() == null){
Log.error("0"+ ErrorCodes.UNKNOWN_ARGUMENT + " Unknown Argument. " +
"Parameter with name '" + node.getName() + "' does not exist. " +
"Possible arguments are: " + Joiners.COMMA.join(layerDeclaration.getParameters())
......
......@@ -52,12 +52,12 @@ public class CheckUnroll implements CNNArchASTUnrollCoCo{
}
LayerDeclarationSymbol layerDeclaration = ((LayerSymbol) node.getSymbolOpt().get()).getDeclaration();
UnrollDeclarationSymbol layerDeclaration = ((UnrollSymbol) node.getSymbolOpt().get()).getDeclaration();
if (layerDeclaration == null){
ArchitectureSymbol architecture = node.getSymbolOpt().get().getEnclosingScope().<ArchitectureSymbol>resolve("", ArchitectureSymbol.KIND).get();
Log.error("0" + ErrorCodes.UNKNOWN_LAYER + " Unknown layer. " +
"Layer with name '" + node.getName() + "' does not exist. " +
"Existing layers: " + Joiners.COMMA.join(architecture.getLayerDeclarations()) + "."
"Existing layers: " + Joiners.COMMA.join(architecture.getUnrollDeclarations()) + "."
, node.get_SourcePositionStart());
}
else {
......
......@@ -103,6 +103,10 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
return getSpannedScope().resolveLocally(LayerDeclarationSymbol.KIND);
}
public Collection<UnrollDeclarationSymbol> getUnrollDeclarations(){
return getSpannedScope().resolveLocally(UnrollDeclarationSymbol.KIND);
}
public void resolve() {
for (CompositeElementSymbol stream : streams) {
stream.checkIfResolvable();
......@@ -171,12 +175,22 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
copy.getSpannedScope().getAsMutableScope().add(layerDeclaration);
}
for (UnrollDeclarationSymbol layerDeclaration : AllPredefinedLayers.createUnrollList()){
copy.getSpannedScope().getAsMutableScope().add(layerDeclaration);
}
for (LayerDeclarationSymbol layerDeclaration : getSpannedScope().<LayerDeclarationSymbol>resolveLocally(LayerDeclarationSymbol.KIND)){
if (!layerDeclaration.isPredefined()) {
copy.getSpannedScope().getAsMutableScope().add(layerDeclaration.deepCopy());
}
}
for (UnrollDeclarationSymbol layerDeclaration : getSpannedScope().<UnrollDeclarationSymbol>resolveLocally(UnrollDeclarationSymbol.KIND)){
if (!layerDeclaration.isPredefined()) {
copy.getSpannedScope().getAsMutableScope().add(layerDeclaration.deepCopy());
}
}
List<SerialCompositeElementSymbol> copyStreams = new ArrayList<>();
for (SerialCompositeElementSymbol stream : streams) {
SerialCompositeElementSymbol copyStream = stream.preResolveDeepCopy();
......
......@@ -51,6 +51,16 @@ public class ArgumentSymbol extends CommonSymbol {
return parameter;
}
public VariableSymbol getUnrollParameter() {
if (parameter == null){
if (getUnroll().getDeclaration() != null){
Optional<VariableSymbol> optParam = getUnroll().getDeclaration().getParameter(getName());
optParam.ifPresent(this::setParameter);
}
}
return parameter;
}
protected void setParameter(VariableSymbol parameter) {
this.parameter = parameter;
}
......@@ -107,6 +117,14 @@ public class ArgumentSymbol extends CommonSymbol {
}
}
public void resolveUnrollExpression() throws ArchResolveException {
getRhs().resolveOrError();
boolean valid = Constraints.checkUnroll(this);
if (!valid){
throw new ArchResolveException();
}
}
public void checkConstraints(){
Constraints.check(this);
}
......
......@@ -43,6 +43,7 @@ public class CNNArchLanguage extends CNNArchLanguageTOP {
addResolvingFilter(new CNNArchCompilationUnitResolvingFilter());
addResolvingFilter(CommonResolvingFilter.create(ArchitectureSymbol.KIND));
addResolvingFilter(CommonResolvingFilter.create(LayerDeclarationSymbol.KIND));
addResolvingFilter(CommonResolvingFilter.create(UnrollDeclarationSymbol.KIND));
addResolvingFilter(CommonResolvingFilter.create(ArchitectureElementSymbol.KIND));
addResolvingFilter(CommonResolvingFilter.create(VariableSymbol.KIND));
addResolvingFilter(CommonResolvingFilter.create(IODeclarationSymbol.KIND));
......
......@@ -164,6 +164,9 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
for (LayerDeclarationSymbol sym : AllPredefinedLayers.createList()){
addToScope(sym);
}
for (UnrollDeclarationSymbol sym : AllPredefinedLayers.createUnrollList()){
addToScope(sym);
}
}
@Override
......@@ -342,7 +345,7 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
@Override
public void visit(ASTUnroll ast) {
UnrollSymbol layer = new UnrollSymbol("BeamSearchStart");
UnrollSymbol layer = new UnrollSymbol(ast.getName());
addToScopeAndLinkWithNode(layer, ast);
}
......@@ -357,6 +360,13 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
layer.setArguments(arguments);
/*List<ArchitectureElementSymbol> elements = new ArrayList<>();
for (ASTStream astStream : ast.getGroupsList()){
elements.add((SerialCompositeElementSymbol) astStream.getSymbolOpt().get());
}
compositeElement.setElements(elements);
*/
removeCurrentScope();
}
......
......@@ -209,6 +209,16 @@ public enum Constraints {
return valid;
}
public static boolean checkUnroll(ArgumentSymbol argument){
boolean valid = true;
VariableSymbol variable = argument.getUnrollParameter();
for (Constraints constraint : variable.getConstraints()) {
valid = valid &&
constraint.check(argument.getRhs(), argument.getSourcePosition(), variable.getName());
}
return valid;
}
public boolean check(ArchExpressionSymbol exp, SourcePosition sourcePosition, String name){
if (exp instanceof ArchRangeExpressionSymbol){
ArchRangeExpressionSymbol range = (ArchRangeExpressionSymbol)exp;
......
......@@ -22,6 +22,7 @@ package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables;
import de.monticore.symboltable.Scope;
import de.monticore.symboltable.Symbol;
......@@ -251,6 +252,9 @@ public class LayerSymbol extends ArchitectureElementSymbol {
@Override
public List<ArchTypeSymbol> computeOutputTypes() {
if (getResolvedThis().isPresent()) {
if (getResolvedThis().get() == this) {
List<ArchTypeSymbol> inputTypes = getInputTypes();
......
......@@ -21,6 +21,7 @@
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.UnrollDeclarationSymbol;
import java.util.Arrays;
import java.util.List;
......@@ -64,7 +65,7 @@ public class AllPredefinedLayers {
public static final String PADDING_NAME = "padding";
public static final String POOL_TYPE_NAME = "pool_type";
public static final String ONE_HOT_SIZE_NAME = "size";
public static final String BEAMSEARCH_MAX_LENGTH = "max_length";
public static final String BEAMSEARCH_MAX_LENGTH_NAME = "max_length";
public static final String BEAMSEARCH_WIDTH_NAME = "width";
......@@ -98,4 +99,9 @@ public class AllPredefinedLayers {
OneHot.create());
}
public static List<UnrollDeclarationSymbol> createUnrollList(){
return Arrays.asList(
BeamSearchStart.create());
}
}
......@@ -27,7 +27,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class BeamSearchStart extends PredefinedLayerDeclaration {
public class BeamSearchStart extends PredefinedUnrollDeclaration {
private BeamSearchStart() {
super(AllPredefinedLayers.BEAMSEARCH_NAME);
......@@ -35,7 +35,19 @@ public class BeamSearchStart extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer) {
try {
System.err.println("TEST0: ");
//System.err.println("LastElementSize-1: " + layer.computeOutputTypes().toString());
//System.err.println("LastElementSize0: " + layer.getDeclaration().toString());
//System.err.println("LastElementSize0.5: " + layer.getDeclaration().getBody().toString());
//System.err.println("LastElementSize1: " + layer.getDeclaration().getBody().getLastAtomicElements().size());
//System.err.println("LastElementSize2: " + layer.getDeclaration().getBody().computeOutputTypes());
//System.err.println("LastElementSize3: " + layer.getDeclaration().getBody().computeOutputTypes().size());
}catch(Exception e){
e.printStackTrace();
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(100) // TODO
......@@ -46,7 +58,7 @@ public class BeamSearchStart extends PredefinedLayerDeclaration {
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
public void checkInput(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer) {
errorIfInputSizeIsNotOne(inputTypes, layer);
}
......@@ -54,8 +66,9 @@ public class BeamSearchStart extends PredefinedLayerDeclaration {
BeamSearchStart declaration = new BeamSearchStart();
List<VariableSymbol> parameters = new ArrayList<>(Arrays.asList(
new VariableSymbol.Builder()
.name(AllPredefinedLayers.BEAMSEARCH_MAX_LENGTH)
.name(AllPredefinedLayers.BEAMSEARCH_MAX_LENGTH_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(99)
.build(),
new VariableSymbol.Builder()
.name(AllPredefinedLayers.BEAMSEARCH_WIDTH_NAME)
......
......@@ -39,7 +39,7 @@ public class OneHot extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
channels=layer.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get();
channels = inputTypes.get(0).getChannels();
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get())
......@@ -52,7 +52,7 @@ public class OneHot extends PredefinedLayerDeclaration {
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfInputSizeUnequalToOnehotSize(inputTypes, layer);
//errorIfInputSizeUnequalToOnehotSize(inputTypes, layer);
}
public static OneHot create(){
......
......@@ -39,7 +39,7 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
Convolution(kernel=(11,11), channels=96, stride=(4,4), padding="no_loss") ->
Lrn(nsize=5, alpha=0.0001, beta=0.75) ->
Pooling(pool_type="max", kernel=(3,3), stride=(2,2), padding="no_loss") ->
unroll<t> BeamSearchStart (width=5, max_length=50){
unroll<t> BeamSearchStart (width=5, max_length = 50){
FullyConnected(units=4096) ->
Relu() ->
Dropout()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment