Commit 396a8a37 authored by Christian Fuß's avatar Christian Fuß
Browse files

added Dot and Repeat Layer; added 'dim' parameter to Concatenate layer

parent 1dd77007
Pipeline #186551 passed with stages
in 19 minutes and 29 seconds
......@@ -53,6 +53,8 @@ public class AllPredefinedLayers {
public static final String GRU_NAME = "GRU";
public static final String EMBEDDING_NAME = "Embedding";
public static final String ARG_MAX_NAME = "ArgMax";
public static final String DOT_NAME = "Dot";
public static final String REPEAT_NAME = "Repeat";
//predefined argument names
public static final String KERNEL_NAME = "kernel";
......@@ -74,10 +76,13 @@ public class AllPredefinedLayers {
public static final String LAYERS_NAME = "layers";
public static final String INPUT_DIM_NAME = "input_dim";
public static final String OUTPUT_DIM_NAME = "output_dim";
public static final String DIM_NAME = "dim";
public static final String BIDIRECTIONAL_NAME = "bidirectional";
public static final String FLATTEN_PARAMETER_NAME = "flatten";
public static final String MAX_LENGTH_NAME = "max_length";
public static final String WIDTH_NAME = "width";
public static final String REPEATS_NAME = "n";
public static final String AXIS_NAME = "axis";
//possible String values
public static final String PADDING_VALID = "valid";
......@@ -111,7 +116,9 @@ public class AllPredefinedLayers {
LSTM.create(),
GRU.create(),
Embedding.create(),
ArgMax.create());
ArgMax.create(),
Dot.create(),
Repeat.create());
}
public static List<UnrollDeclarationSymbol> createUnrollList(){
......
......@@ -20,15 +20,13 @@
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.PredefinedLayerDeclaration;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.Joiners;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
......@@ -83,7 +81,13 @@ public class Concatenate extends PredefinedLayerDeclaration {
public static Concatenate create(){
Concatenate declaration = new Concatenate();
declaration.setParameters(new ArrayList<>());
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.DIM_NAME)
.constraints(Constraints.INTEGER, Constraints.NON_NEGATIVE)
.defaultValue(1)
.build()));
declaration.setParameters(parameters);
return declaration;
}
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class Dot extends PredefinedLayerDeclaration {
private Dot() {
super(AllPredefinedLayers.DOT_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getInputTypes().get(0).getChannels())
.height(1)
.width(1)
.elementType("-oo", "oo")
.build());
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
errorIfInputIsEmpty(inputTypes, layer);
}
public static Dot create(){
Dot declaration = new Dot();
declaration.setParameters(new ArrayList<>());
return declaration;
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class Repeat extends PredefinedLayerDeclaration {
private Repeat() {
super(AllPredefinedLayers.REPEAT_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
int repeats = layer.getIntValue(AllPredefinedLayers.REPEATS_NAME).get();
int axis = layer.getIntValue(AllPredefinedLayers.AXIS_NAME).get();
boolean axisIsNone = layer.getStringValue(AllPredefinedLayers.AXIS_NAME).isPresent();
int channels = layer.getInputTypes().get(0).getChannels();
int height = layer.getInputTypes().get(0).getHeight();
int width = layer.getInputTypes().get(0).getWidth();
if(axisIsNone){
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(channels *= repeats)
.height(1)
.width(1)
.elementType("-oo", "oo")
.build());
}else {
if(axis == 0){
height *= repeats;
}else if(axis == 1){
width *= repeats;
}else{
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid axis in Repeat layer. Axis for " +
getName() + " layer must be None, 0 or 1"
, layer.getSourcePosition());
}
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(channels)
.height(height)
.width(width)
.elementType("-oo", "oo")
.build());
}
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
errorIfInputSizeIsNotOne(inputTypes, layer);
}
public static Repeat create(){
Repeat declaration = new Repeat();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPEATS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
// no constraints in order to allow 'None' value
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.AXIS_NAME)
.defaultValue(1)
.build()));
declaration.setParameters(parameters);
return declaration;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment