Commit fd32e1c9 authored by eyuhar's avatar eyuhar

Merge branch 'feature/loss_function' of /home/eyuhar/Dokumente/CNNTrainLang with conflicts.

parent 8247510f
......@@ -41,11 +41,6 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
| rmse:"rmse"
| topKAccuracy:"top_k_accuracy");
LossValue implements ConfigValue =(euclidean:"euclidean"
| l1: "l1"
| crossEntropy:"cross_entropy"
| huberLoss: "huber_loss");
LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step"
| exp:"exp"
......@@ -55,6 +50,47 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
TrainContextValue implements ConfigValue = (cpu:"cpu" | gpu:"gpu");
interface OptimizerParamEntry extends Entry;
DataVariable implements VariableReference = Name&;
IntegerValue implements ConfigValue = NumberWithUnit;
NumberValue implements ConfigValue = NumberWithUnit;
StringValue implements ConfigValue = StringLiteral;
BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false");
interface LossValue extends ConfigValue;
L1Loss implements LossValue = name:"l1" ("{" params:Entry* "}")?;
L2Loss implements LossValue = name:"l2" ("{" params:Entry* "}")?;
interface HuberEntry extends Entry;
HuberLoss implements LossValue = name:"huber" ("{" params:HuberEntry* "}")?;
interface CrossEntropyEntry extends Entry;
CrossEntropyLoss implements LossValue = name:"cross_entropy" ("{" params:CrossEntropyEntry* "}")?;
interface SoftmaxCrossEntropyEntry extends Entry;
SoftmaxCrossEntropyLoss implements LossValue = name:"softmax_cross_entropy" ("{" params:SoftmaxCrossEntropyEntry* "}")?;
SigmoidBinaryCrossEntropyLoss implements LossValue = name:"sigmoid_binary_cross_entropy" ("{" params:Entry* "}")?;
interface HingeEntry extends Entry;
HingeLoss implements LossValue = name:"hinge" ("{" params:HingeEntry* "}")?;
interface SquaredHingeEntry extends Entry;
SquaredHingeLoss implements LossValue = name:"squared_hinge" ("{" params:SquaredHingeEntry* "}")?;
interface LogisticEntry extends Entry;
LogisticLoss implements LossValue = name:"logistic" ("{" params:LogisticEntry* "}")?;
interface KullbackLeiblerEntry extends Entry;
KullbackLeiblerLoss implements LossValue = name:"kullback_leibler" ("{" params:KullbackLeiblerEntry* "}")?;
SparseLabelEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry = name:"sparse_label" ":" value:BooleanValue;
FromLogitsEntry implements SoftmaxCrossEntropyEntry, KullbackLeiblerEntry = name:"from_logits" ":" value:BooleanValue;
MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue;
LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue;
interface OptimizerValue extends ConfigValue;
interface SGDEntry extends OptimizerParamEntry;
SGDOptimizer implements OptimizerValue = name:"sgd" ("{" params:SGDEntry* "}")?;
......@@ -93,7 +129,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
Gamma2Entry implements RmsPropEntry = name:"gamma2" ":" value:NumberValue;
CenteredEntry implements RmsPropEntry = name:"centered" ":" value:BooleanValue;
ClipWeightsEntry implements RmsPropEntry = name:"clip_weights" ":" value:NumberValue;
RhoEntry implements AdaDeltaEntry = name:"rho" ":" value:NumberValue;
RhoEntry implements AdaDeltaEntry, HuberEntry = name:"rho" ":" value:NumberValue;
// Reinforcement Extensions
interface MultiParamValue extends ConfigValue;
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._ast;
import java.util.List;
public interface ASTLossValue extends ASTLossValueTOP{
String getName();
List<? extends ASTEntry> getParamsList();
}
......@@ -34,6 +34,88 @@ import java.util.Set;
*
*/
public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
private final static List<Class> ALLOWED_SUPERVISED_LEARNING = Lists.newArrayList(
ASTTrainContextEntry.class,
ASTBatchSizeEntry.class,
ASTOptimizerEntry.class,
ASTLearningRateEntry.class,
ASTLoadCheckpointEntry.class,
ASTEvalMetricEntry.class,
ASTLossEntry.class,
ASTSparseLabelEntry.class,
ASTFromLogitsEntry.class,
ASTMarginEntry.class,
ASTLabelFormatEntry.class,
ASTRhoEntry.class,
ASTNormalizeEntry.class,
ASTMinimumLearningRateEntry.class,
ASTLRDecayEntry.class,
ASTWeightDecayEntry.class,
ASTLRPolicyEntry.class,
ASTStepSizeEntry.class,
ASTRescaleGradEntry.class,
ASTClipGradEntry.class,
ASTGamma1Entry.class,
ASTGamma2Entry.class,
ASTEpsilonEntry.class,
ASTCenteredEntry.class,
ASTClipWeightsEntry.class,
ASTBeta1Entry.class,
ASTBeta2Entry.class,
ASTNumEpochEntry.class
);
private final static List<Class> ALLOWED_REINFORCEMENT_LEARNING = Lists.newArrayList(
ASTTrainContextEntry.class,
ASTRLAlgorithmEntry.class,
ASTCriticNetworkEntry.class,
ASTOptimizerEntry.class,
ASTRewardFunctionEntry.class,
ASTMinimumLearningRateEntry.class,
ASTLRDecayEntry.class,
ASTWeightDecayEntry.class,
ASTLRPolicyEntry.class,
ASTGamma1Entry.class,
ASTGamma2Entry.class,
ASTEpsilonEntry.class,
ASTClipGradEntry.class,
ASTRescaleGradEntry.class,
ASTStepSizeEntry.class,
ASTCenteredEntry.class,
ASTClipWeightsEntry.class,
ASTLearningRateEntry.class,
ASTDiscountFactorEntry.class,
ASTNumMaxStepsEntry.class,
ASTTargetScoreEntry.class,
ASTTrainingIntervalEntry.class,
ASTUseFixTargetNetworkEntry.class,
ASTTargetNetworkUpdateIntervalEntry.class,
ASTSnapshotIntervalEntry.class,
ASTAgentNameEntry.class,
ASTGymEnvironmentNameEntry.class,
ASTEnvironmentEntry.class,
ASTUseDoubleDQNEntry.class,
ASTLossEntry.class,
ASTSparseLabelEntry.class,
ASTFromLogitsEntry.class,
ASTMarginEntry.class,
ASTLabelFormatEntry.class,
ASTRhoEntry.class,
ASTReplayMemoryEntry.class,
ASTMemorySizeEntry.class,
ASTSampleSizeEntry.class,
ASTActionSelectionEntry.class,
ASTGreedyEpsilonEntry.class,
ASTMinEpsilonEntry.class,
ASTEpsilonDecayEntry.class,
ASTEpsilonDecayMethodEntry.class,
ASTNumEpisodesEntry.class,
ASTRosEnvironmentActionTopicEntry.class,
ASTRosEnvironmentStateTopicEntry.class,
ASTRosEnvironmentMetaTopicEntry.class,
ASTRosEnvironmentResetTopicEntry.class,
ASTRosEnvironmentTerminalStateTopicEntry.class,
ASTRosEnvironmentGreetingTopicEntry.class
);
private final ParameterAlgorithmMapping parameterAlgorithmMapping;
private Set<ASTEntry> allEntries;
......
......@@ -199,23 +199,25 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
@Override
public void visit(ASTLossEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
ValueSymbol value = new ValueSymbol();
if (node.getValue().isPresentEuclidean()){
value.setValue(Loss.EUCLIDEAN);
}
else if (node.getValue().isPresentCrossEntropy()){
value.setValue(Loss.CROSS_ENTROPY);
} else if (node.getValue().isPresentHuberLoss()) {
value.setValue(Loss.HUBER_LOSS);
} else if (node.getValue().isPresentL1()) {
value.setValue(Loss.L1);
LossSymbol loss = new LossSymbol(node.getValue().getName());
configuration.setLoss(loss);
addToScopeAndLinkWithNode(loss, node);
}
@Override
public void endVisit(ASTLossEntry node) {
for (ASTEntry nodeParam : node.getValue().getParamsList()) {
LossParamSymbol param = new LossParamSymbol();
OptimizerParamValueSymbol valueSymbol = (OptimizerParamValueSymbol) nodeParam.getValue().getSymbolOpt().get();
LossParamValueSymbol lossParamValue = new LossParamValueSymbol();
lossParamValue.setValue(valueSymbol.getValue());
param.setValue(lossParamValue);
configuration.getLoss().getLossParamMap().put(nodeParam.getName(), param);
}
entry.setValue(value);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTLRPolicyValue node) {
OptimizerParamValueSymbol value = new OptimizerParamValueSymbol();
......
......@@ -32,6 +32,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private Map<String, EntrySymbol> entryMap = new HashMap<>();
private OptimizerSymbol optimizer;
private OptimizerSymbol criticOptimizer;
private LossSymbol loss;
private RewardFunctionSymbol rlRewardFunctionSymbol;
private TrainedArchitecture trainedArchitecture;
......@@ -59,6 +60,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return Optional.ofNullable(criticOptimizer);
}
public LossSymbol getLoss() {
return loss;
}
public void setLoss(LossSymbol loss) {
this.loss = loss;
}
protected void setRlRewardFunction(RewardFunctionSymbol rlRewardFunctionSymbol) {
this.rlRewardFunctionSymbol = rlRewardFunctionSymbol;
}
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
public class LossParamSymbol extends CommonSymbol {
public static final EntryKind KIND = new EntryKind();
private LossParamValueSymbol value;
public LossParamSymbol() {
super("", KIND);
}
public LossParamSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public LossParamValueSymbol getValue() {
return value;
}
public void setValue(LossParamValueSymbol value) {
this.value = value;
}
public String toString(){
return getValue().toString();
}
}
......@@ -20,25 +20,19 @@
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
public enum Loss {
EUCLIDEAN{
@Override
public String toString() {
return "euclidean";
}
},
CROSS_ENTROPY{
@Override
public String toString() {
return "cross_entropy";
}
},
L1 {
@Override
public String toString() { return "l1";}
},
HUBER_LOSS {
@Override
public String toString() { return "huber_loss";}
import de.monticore.symboltable.SymbolKind;
public class LossParamSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.LossParamSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
\ No newline at end of file
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
public class LossParamValueSymbol extends CommonSymbol {
public static final LossParamValueSymbolKind KIND = new LossParamValueSymbolKind();
private Object value;
public LossParamValueSymbol() {
super("", KIND);
}
public LossParamValueSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public Object getValue() {
return value;
}
public void setValue(Object value) {
this.value = value;
}
public String toString(){
return getValue().toString();
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class LossParamValueSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.LossParamValueSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import java.util.HashMap;
import java.util.Map;
public class LossSymbol extends de.monticore.symboltable.CommonSymbol {
private Map<String, LossParamSymbol> lossParamMap = new HashMap<>();
public static final LossSymbolKind KIND = LossSymbolKind.INSTANCE;
public LossSymbol(String name) {
super(name, KIND);
}
public Map<String, LossParamSymbol> getLossParamMap() {
return lossParamMap;
}
public void setLossParamMap(Map<String, LossParamSymbol> lossParamMap) {
this.lossParamMap = lossParamMap;
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class LossSymbolKind implements SymbolKind {
public static final LossSymbolKind INSTANCE = new LossSymbolKind();
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.OptimizerSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
......@@ -13,7 +13,7 @@ configuration CheckLearningParameterCombination1 {
target_score : 1000
training_interval : 10
loss : huber_loss
loss : huber
use_fix_target_network : true
target_network_update_interval : 100
......
......@@ -11,7 +11,7 @@ configuration CheckReinforcementRequiresEnvironment {
target_score : 1000
training_interval : 10
loss : huber_loss
loss : huber
use_fix_target_network : true
target_network_update_interval : 100
......
......@@ -18,7 +18,7 @@ configuration CheckRosEnvironmentRequiresRewardFunction {
target_score : 1000
training_interval : 10
loss : huber_loss
loss : huber
use_fix_target_network : true
target_network_update_interval : 100
......
......@@ -7,7 +7,7 @@ configuration FixTargetNetworkRequiresInterval1 {
context : cpu
loss : huber_loss
loss : huber
use_fix_target_network : true
}
\ No newline at end of file
......@@ -7,7 +7,7 @@ configuration FixTargetNetworkRequiresInterval2 {
context : cpu
loss : huber_loss