Commit 7574b6e6 authored by Julian Dierkes's avatar Julian Dierkes

extended CNNTrain with constraint losses

parent 8531caff
Pipeline #226086 passed with stages
in 9 minutes
......@@ -69,7 +69,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface OptimizerParamEntry extends Entry;
interface LossValue extends ConfigValue;
interface LossValue extends MultiParamValue;
L1Loss implements LossValue = name:"l1" ("{" params:Entry* "}")?;
......@@ -281,5 +281,10 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
// Constraint Distributions
ConstraintDistributionEntry implements MultiParamValueMapConfigEntry = name:"constraint_distributions" ":" value:ConstraintDistributionValue;
ConstraintDistributionValue implements MultiParamValueMapParamValue = ("{" params:ConstraintDistributionParam* "}")?;
ConstraintDistributionParam implements MultiParamValueMapTupleValue = name:StringValue ":" distribution:NoiseDistributionValue;
ConstraintDistributionParam implements MultiParamValueMapTupleValue = name:StringValue ":" multiParamValue:NoiseDistributionValue;
// Constraint losses
ConstraintLossEntry implements MultiParamValueMapConfigEntry = name:"constraint_losses" ":" value:ConstraintLossValue;
ConstraintLossValue implements MultiParamValueMapParamValue = ("{" params:ConstraintLossParam* "}")?;
ConstraintLossParam implements MultiParamValueMapTupleValue = name:StringValue ":" multiParamValue:LossValue;
}
\ No newline at end of file
......@@ -16,7 +16,7 @@ import java.util.Map;
*
*/
public interface ASTMultiParamValueMapTupleValue extends ASTMultiParamValueMapTupleValueTOP {
default ASTNoiseDistributionValue getDistribution() {
default ASTMultiParamValue getMultiParamValue() {
return null;
}
default ASTStringValue getName() {
......
......@@ -122,7 +122,8 @@ class ParameterAlgorithmMapping {
ASTDiscriminatorNetworkEntry.class,
ASTQNetworkEntry.class,
ASTNoiseDistributionEntry.class,
ASTConstraintDistributionEntry.class
ASTConstraintDistributionEntry.class,
ASTConstraintLossEntry.class
);
ParameterAlgorithmMapping() {
......
......@@ -573,6 +573,16 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
processMultiParamMapConfigEndVisit(node);
}
@Override
public void visit(ASTConstraintLossEntry node) {
processMultiParamMapConfigVisit(node, node.getName());
}
@Override
public void endVisit(ASTConstraintLossEntry node) {
processMultiParamMapConfigEndVisit(node);
}
@Override
public void visit(ASTNoiseDistributionEntry node) {
NoiseDistribution noiseDistribution;
......@@ -689,14 +699,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
for (ASTConfigValue nodeParam : ((ASTMultiParamValueMapParamValue)node.getValue()).getParamsList()) {
ASTMultiParamValueMapTupleValue tuple = ((ASTMultiParamValueMapTupleValue)nodeParam);
ASTStringValue name = tuple.getName();
ASTMultiParamValue distribution = tuple.getDistribution();
String distrName = distribution.getName();
multiParamValueMapSymbol.addMultiParamValueName(getStringFromStringValue(name), distrName);
ASTMultiParamValue multiValue = tuple.getMultiParamValue();
String valueName = multiValue.getName();
multiParamValueMapSymbol.addMultiParamValueName(getStringFromStringValue(name), valueName);
HashMap<String, Object> mapEntry = new HashMap<>();
for (ASTEntry param : distribution.getParamsList()) {
String distrEntryName = param.getName();
for (ASTEntry param : multiValue.getParamsList()) {
String valueEntryName = param.getName();
Object res = retrievePrimitiveValueByConfigValue(param.getValue());
mapEntry.put(distrEntryName, res);
mapEntry.put(valueEntryName, res);
}
multiParamValueMapSymbol.addParameter(getStringFromStringValue(name), mapEntry);
}
......
......@@ -51,4 +51,5 @@ public class ConfigEntryNameConstants {
public static final String PREPROCESSING_NAME = "preprocessing_name";
public static final String NOISE_DISTRIBUTION = "noise_distribution";
public static final String CONSTRAINT_DISTRIBUTION = "constraint_distributions";
public static final String CONSTRAINT_LOSS = "constraint_losses";
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment