Commit cd421d5f authored by Sebastian Nickels's avatar Sebastian Nickels

Removed optional onehot for now, as it does not work

parent 5689235f
Pipeline #157630 passed with stages
in 18 minutes and 14 seconds
......@@ -325,6 +325,50 @@ public class LayerSymbol extends ArchitectureElementSymbol {
}
}
public void setIntValue(String parameterName, int value) {
setTValue(parameterName, value, ArchSimpleExpressionSymbol::of);
}
public void setIntTupleValue(String parameterName, List<Object> tupleValues) {
setTValue(parameterName, tupleValues, ArchSimpleExpressionSymbol::of);
}
public void setBooleanValue(String parameterName, boolean value) {
setTValue(parameterName, value, ArchSimpleExpressionSymbol::of);
}
public void setStringValue(String parameterName, String value) {
setTValue(parameterName, value, ArchSimpleExpressionSymbol::of);
}
public void setDoubleValue(String parameterName, double value) {
setTValue(parameterName, value, ArchSimpleExpressionSymbol::of);
}
public void setValue(String parameterName, Object value) {
ArchSimpleExpressionSymbol res = new ArchSimpleExpressionSymbol();
res.setValue(value);
setTValue(parameterName, res, Function.identity());
}
public <T> void setTValue(String parameterName, T value, Function<T, ArchSimpleExpressionSymbol> of) {
Optional<VariableSymbol> param = getDeclaration().getParameter(parameterName);
if (param.isPresent()) {
Optional<ArgumentSymbol> arg = getArgument(parameterName);
ArchSimpleExpressionSymbol expression = of.apply(value);
if (arg.isPresent()) {
arg.get().setRhs(expression);
}
else {
arg = Optional.of(new ArgumentSymbol(parameterName));
arg.get().setRhs(expression);
arguments.add(arg.get());
}
}
}
@Override
public Optional<Integer> getParallelLength(){
int length = -1;
......
......@@ -31,8 +31,6 @@ import java.util.*;
public class OneHot extends PredefinedLayerDeclaration {
private static int channels;
private OneHot() {
super(AllPredefinedLayers.ONE_HOT_NAME);
}
......@@ -40,14 +38,23 @@ public class OneHot extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
if(layer.getOutputElement().get() instanceof IOSymbol && layer.getOutputElement().get().isOutput()) {
channels = ((IOSymbol) layer.getOutputElement().get()).getDefinition().getType().getChannels();
}else{
channels = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
}
// TODO: Execute this code somewhere before checkInput(), for now size parameter is required
/*if(layer.getOutputElement().get() instanceof IOSymbol && layer.getOutputElement().get().isOutput()) {
int outputChannels = ((IOSymbol) layer.getOutputElement().get()).getDefinition().getType().getChannels();
layer.setIntValue(AllPredefinedLayers.SIZE_NAME, outputChannels);
}*/
int size = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
/*if (size == 0) {
Log.error("0" + ErrorCodes.MISSING_ARGUMENT + " Missing argument. The argument 'size' is in this case required. "
, layer.getSourcePosition());
}*/
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(channels)
.channels(size)
.height(1)
.width(1)
.elementType("0", "1")
......@@ -110,7 +117,6 @@ public class OneHot extends PredefinedLayerDeclaration {
new VariableSymbol.Builder()
.name(AllPredefinedLayers.SIZE_NAME)
.constraints(Constraints.POSITIVE, Constraints.INTEGER)
.defaultValue(channels)
.build()));
declaration.setParameters(parameters);
return declaration;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment