Commit 292498dc authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'tensorflow_group' into 'master'

Tensorflow group 2

See merge request !26
parents 6f92839d fb75a171
Pipeline #187341 passed with stages
in 19 minutes and 47 seconds
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch._symboltable;
......@@ -170,6 +177,32 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
}
}
//output type function for upconvolution
protected static List<ArchTypeSymbol> computeUpConvOutputShape(ArchTypeSymbol inputType, LayerSymbol method, int channels) {
String borderModeSetting = method.getStringValue(AllPredefinedLayers.PADDING_NAME).get();
if (borderModeSetting.equals(AllPredefinedLayers.PADDING_SAME)){
return computeOutputUpConvShapeWithSamePadding(inputType, method, channels);
}
else if (borderModeSetting.equals(AllPredefinedLayers.PADDING_VALID)){
return computeOutputUpConvShapeWithValidPadding(inputType, method, channels);
}
else if (borderModeSetting.equals(AllPredefinedLayers.PADDING_NO_LOSS)){
return computeOutputUpConvShapeWithNoLossPadding(inputType, method, channels);
}
else{
throw new IllegalStateException("border_mode is " + borderModeSetting + ". This should never happen.");
}
}
protected static void computeOneHotOutputSize(LayerSymbol layer){
int outputChannels = 0;
if (layer.getOutputElement().get() instanceof VariableSymbol && layer.getOutputElement().get().isOutput()) {
outputChannels = ((VariableSymbol) layer.getOutputElement().get()).getIoDeclaration().getType().getChannels();
}
layer.setIntValue(AllPredefinedLayers.SIZE_NAME, outputChannels);
}
//padding with border_mode=valid, no padding
private static List<ArchTypeSymbol> computeOutputShapeWithValidPadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0);
......@@ -235,7 +268,66 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
.elementType("-oo", "oo")
.build());
}
private static List<ArchTypeSymbol> computeOutputUpConvShapeWithValidPadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0);
int strideWidth = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(1);
int kernelHeight = method.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get().get(0);
int kernelWidth = method.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get().get(1);
int inputHeight = inputType.getHeight();
int inputWidth = inputType.getWidth();
int outputWidth;
int outputHeight;
outputWidth = (inputWidth - 1) * strideWidth + kernelWidth;
outputHeight = (inputHeight - 1) * strideHeight + kernelHeight;
return Collections.singletonList(new ArchTypeSymbol.Builder()
.height(outputHeight)
.width(outputWidth)
.channels(channels)
.elementType("-oo", "oo")
.build());
}
private static List<ArchTypeSymbol> computeOutputUpConvShapeWithNoLossPadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0);
int strideWidth = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(1);
int kernelHeight = method.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get().get(0);
int kernelWidth = method.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get().get(1);
int inputHeight = inputType.getHeight();
int inputWidth = inputType.getWidth();
int outputWidth = Math.max(0, ((inputWidth - 1) * strideWidth - strideWidth + kernelWidth + 1) );
int outputHeight = Math.max(0, ((inputHeight - 1) * strideHeight - strideHeight + kernelHeight + 1));
return Collections.singletonList(new ArchTypeSymbol.Builder()
.height(outputHeight)
.width(outputWidth)
.channels(channels)
.elementType("-oo", "oo")
.build());
}
private static List<ArchTypeSymbol> computeOutputUpConvShapeWithSamePadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0);
int strideWidth = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(1);
int inputHeight = inputType.getHeight();
int inputWidth = inputType.getWidth();
//no -strideWidth+1 at end as sugested by rearanging the corresponding formula for convolution. Tensorflow calculates it like this.
int outputWidth = inputWidth * strideWidth;
int outputHeight = inputHeight * strideHeight;
return Collections.singletonList(new ArchTypeSymbol.Builder()
.height(outputHeight)
.width(outputWidth)
.channels(channels)
.elementType("-oo", "oo")
.build());
}
protected List<String> computeStartAndEndValue(List<ArchTypeSymbol> inputTypes, BinaryOperator<Rational> startValAccumulator, BinaryOperator<Rational> endValAccumulator){
Stream.Builder<Rational> startValues = Stream.builder();
Stream.Builder<Rational> endValues = Stream.builder();
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch._symboltable;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch._symboltable;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch._symboltable;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch._symboltable;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch._symboltable;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.helper;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.helper;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.helper;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
......@@ -11,6 +18,7 @@ public class AllPredefinedLayers {
//predefined layer names
public static final String FULLY_CONNECTED_NAME = "FullyConnected";
public static final String CONVOLUTION_NAME = "Convolution";
public static final String UP_CONVOLUTION_NAME = "UpConvolution";
public static final String SOFTMAX_NAME = "Softmax";
public static final String SIGMOID_NAME = "Sigmoid";
public static final String TANH_NAME = "Tanh";
......@@ -70,6 +78,7 @@ public class AllPredefinedLayers {
return Arrays.asList(
FullyConnected.create(),
Convolution.create(),
UpConvolution.create(),
Softmax.create(),
Sigmoid.create(),
Tanh.create(),
......@@ -88,7 +97,7 @@ public class AllPredefinedLayers {
RNN.create(),
LSTM.create(),
GRU.create(),
Embedding.create());
Embedding.create(),
RNN.create());
}
}
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment