Commit 45b946ba authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'feature/loss_function' into 'master'

Feature/loss function

See merge request !18
parents 8247510f fd3b2605
Pipeline #152959 passed with stages
in 7 minutes and 51 seconds
......@@ -75,7 +75,7 @@ configuration ReinforcementConfig {
target_score : 1000
training_interval : 10
loss : huber_loss
loss : huber
use_fix_target_network : true
target_network_update_interval : 100
......@@ -126,7 +126,7 @@ configuration ReinforcementConfig {
|discount_factor | Float | 0.9 | No | All | Discount factor |
| target_score | Float | None | No | All | If set, the agent stops the training when the average score of the last 100 episodes is greater than the target score. |
| training_interval | Integer | 1 | No | All | Number of steps between two trainings |
| loss | euclidean, l1, softmax_cross_entropy, sigmoid_cross_entropy, huber_loss | euclidean | No | DQN | Selects the loss function
| loss | l2, l1, softmax_cross_entropy, sigmoid_cross_entropy, huber | l2 | No | DQN | Selects the loss function
| use_fix_target_network | bool | false | No | DQN | If set, an extra network with fixed parameters is used to estimate the Q values |
| target_network_update_interval | Integer | / | DQN | Yes, if fixed target network is true | If *use_fix_target_network* is set, it determines the number of steps after the target network is updated (Minh et. al. "Human Level Control through Deep Reinforcement Learning")|
| use_double_dqn | bool | false | No | If set, two value functions are used to determine the action values (Hasselt et. al. "Deep Reinforcement Learning with Double Q Learning") |
......
......@@ -30,7 +30,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>0.3.2-SNAPSHOT</version>
<version>0.3.4-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......
......@@ -41,11 +41,6 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
| rmse:"rmse"
| topKAccuracy:"top_k_accuracy");
LossValue implements ConfigValue =(euclidean:"euclidean"
| l1: "l1"
| crossEntropy:"cross_entropy"
| huberLoss: "huber_loss");
LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step"
| exp:"exp"
......@@ -55,6 +50,43 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
TrainContextValue implements ConfigValue = (cpu:"cpu" | gpu:"gpu");
interface OptimizerParamEntry extends Entry;
interface LossValue extends ConfigValue;
L1Loss implements LossValue = name:"l1" ("{" params:Entry* "}")?;
L2Loss implements LossValue = name:"l2" ("{" params:Entry* "}")?;
LogCoshLoss implements LossValue = name:"log_cosh" ("{" params:Entry* "}")?;
interface HuberEntry extends Entry;
HuberLoss implements LossValue = name:"huber" ("{" params:HuberEntry* "}")?;
interface CrossEntropyEntry extends Entry;
CrossEntropyLoss implements LossValue = name:"cross_entropy" ("{" params:CrossEntropyEntry* "}")?;
interface SoftmaxCrossEntropyEntry extends Entry;
SoftmaxCrossEntropyLoss implements LossValue = name:"softmax_cross_entropy" ("{" params:SoftmaxCrossEntropyEntry* "}")?;
SigmoidBinaryCrossEntropyLoss implements LossValue = name:"sigmoid_binary_cross_entropy" ("{" params:Entry* "}")?;
interface HingeEntry extends Entry;
HingeLoss implements LossValue = name:"hinge" ("{" params:HingeEntry* "}")?;
interface SquaredHingeEntry extends Entry;
SquaredHingeLoss implements LossValue = name:"squared_hinge" ("{" params:SquaredHingeEntry* "}")?;
interface LogisticEntry extends Entry;
LogisticLoss implements LossValue = name:"logistic" ("{" params:LogisticEntry* "}")?;
interface KullbackLeiblerEntry extends Entry;
KullbackLeiblerLoss implements LossValue = name:"kullback_leibler" ("{" params:KullbackLeiblerEntry* "}")?;
SparseLabelEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry = name:"sparse_label" ":" value:BooleanValue;
FromLogitsEntry implements SoftmaxCrossEntropyEntry, KullbackLeiblerEntry = name:"from_logits" ":" value:BooleanValue;
MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue;
LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue;
interface OptimizerValue extends ConfigValue;
interface SGDEntry extends OptimizerParamEntry;
SGDOptimizer implements OptimizerValue = name:"sgd" ("{" params:SGDEntry* "}")?;
......@@ -93,7 +125,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
Gamma2Entry implements RmsPropEntry = name:"gamma2" ":" value:NumberValue;
CenteredEntry implements RmsPropEntry = name:"centered" ":" value:BooleanValue;
ClipWeightsEntry implements RmsPropEntry = name:"clip_weights" ":" value:NumberValue;
RhoEntry implements AdaDeltaEntry = name:"rho" ":" value:NumberValue;
RhoEntry implements AdaDeltaEntry, HuberEntry = name:"rho" ":" value:NumberValue;
// Reinforcement Extensions
interface MultiParamValue extends ConfigValue;
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._ast;
import java.util.List;
public interface ASTLossValue extends ASTLossValueTOP{
String getName();
List<? extends ASTEntry> getParamsList();
}
......@@ -53,7 +53,12 @@ class ParameterAlgorithmMapping {
ASTEvalMetricEntry.class,
ASTNormalizeEntry.class,
ASTNumEpochEntry.class,
ASTLossEntry.class
ASTLossEntry.class,
ASTSparseLabelEntry.class,
ASTFromLogitsEntry.class,
ASTMarginEntry.class,
ASTLabelFormatEntry.class,
ASTRhoEntry.class
);
private static final List<Class> GENERAL_REINFORCEMENT_PARAMETERS = Lists.newArrayList(
......
......@@ -199,23 +199,25 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
@Override
public void visit(ASTLossEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
ValueSymbol value = new ValueSymbol();
if (node.getValue().isPresentEuclidean()){
value.setValue(Loss.EUCLIDEAN);
}
else if (node.getValue().isPresentCrossEntropy()){
value.setValue(Loss.CROSS_ENTROPY);
} else if (node.getValue().isPresentHuberLoss()) {
value.setValue(Loss.HUBER_LOSS);
} else if (node.getValue().isPresentL1()) {
value.setValue(Loss.L1);
LossSymbol loss = new LossSymbol(node.getValue().getName());
configuration.setLoss(loss);
addToScopeAndLinkWithNode(loss, node);
}
@Override
public void endVisit(ASTLossEntry node) {
for (ASTEntry nodeParam : node.getValue().getParamsList()) {
LossParamSymbol param = new LossParamSymbol();
OptimizerParamValueSymbol valueSymbol = (OptimizerParamValueSymbol) nodeParam.getValue().getSymbolOpt().get();
LossParamValueSymbol lossParamValue = new LossParamValueSymbol();
lossParamValue.setValue(valueSymbol.getValue());
param.setValue(lossParamValue);
configuration.getLoss().getLossParamMap().put(nodeParam.getName(), param);
}
entry.setValue(value);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTLRPolicyValue node) {
OptimizerParamValueSymbol value = new OptimizerParamValueSymbol();
......
......@@ -32,6 +32,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private Map<String, EntrySymbol> entryMap = new HashMap<>();
private OptimizerSymbol optimizer;
private OptimizerSymbol criticOptimizer;
private LossSymbol loss;
private RewardFunctionSymbol rlRewardFunctionSymbol;
private TrainedArchitecture trainedArchitecture;
......@@ -59,6 +60,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return Optional.ofNullable(criticOptimizer);
}
public LossSymbol getLoss() {
return loss;
}
public void setLoss(LossSymbol loss) {
this.loss = loss;
}
protected void setRlRewardFunction(RewardFunctionSymbol rlRewardFunctionSymbol) {
this.rlRewardFunctionSymbol = rlRewardFunctionSymbol;
}
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
public class LossParamSymbol extends CommonSymbol {
public static final EntryKind KIND = new EntryKind();
private LossParamValueSymbol value;
public LossParamSymbol() {
super("", KIND);
}
public LossParamSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public LossParamValueSymbol getValue() {
return value;
}
public void setValue(LossParamValueSymbol value) {
this.value = value;
}
public String toString(){
return getValue().toString();
}
}
......@@ -20,25 +20,19 @@
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
public enum Loss {
EUCLIDEAN{
@Override
public String toString() {
return "euclidean";
}
},
CROSS_ENTROPY{
@Override
public String toString() {
return "cross_entropy";
}
},
L1 {
@Override
public String toString() { return "l1";}
},
HUBER_LOSS {
@Override
public String toString() { return "huber_loss";}
import de.monticore.symboltable.SymbolKind;
public class LossParamSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.LossParamSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
\ No newline at end of file
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
public class LossParamValueSymbol extends CommonSymbol {
public static final LossParamValueSymbolKind KIND = new LossParamValueSymbolKind();
private Object value;
public LossParamValueSymbol() {
super("", KIND);
}
public LossParamValueSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public Object getValue() {
return value;
}
public void setValue(Object value) {
this.value = value;
}
public String toString(){
return getValue().toString();
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class LossParamValueSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.LossParamValueSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import java.util.HashMap;
import java.util.Map;
public class LossSymbol extends de.monticore.symboltable.CommonSymbol {
private Map<String, LossParamSymbol> lossParamMap = new HashMap<>();
public static final LossSymbolKind KIND = LossSymbolKind.INSTANCE;
public LossSymbol(String name) {
super(name, KIND);
}
public Map<String, LossParamSymbol> getLossParamMap() {
return lossParamMap;
}
public void setLossParamMap(Map<String, LossParamSymbol> lossParamMap) {
this.lossParamMap = lossParamMap;
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class LossSymbolKind implements SymbolKind {
public static final LossSymbolKind INSTANCE = new LossSymbolKind();
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.OptimizerSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
......@@ -13,7 +13,7 @@ configuration CheckLearningParameterCombination1 {
target_score : 1000
training_interval : 10
loss : huber_loss
loss : huber
use_fix_target_network : true
target_network_update_interval : 100
......
......@@ -11,7 +11,7 @@ configuration CheckReinforcementRequiresEnvironment {
target_score : 1000
training_interval : 10
loss : huber_loss
loss : huber
use_fix_target_network : true
target_network_update_interval : 100
......
......@@ -21,7 +21,7 @@ configuration CheckRosEnvironmentHasOnlyOneRewardSpecification {
target_score : 1000
training_interval : 10
loss : huber_loss
loss : huber
use_fix_target_network : true
target_network_update_interval : 100
......
......@@ -18,7 +18,7 @@ configuration CheckRosEnvironmentRequiresRewardFunction {
target_score : 1000
training_interval : 10
loss : huber_loss
loss : huber
use_fix_target_network : true
target_network_update_interval : 100
......
......@@ -7,7 +7,7 @@ configuration FixTargetNetworkRequiresInterval1 {
context : cpu
loss : huber_loss
loss : huber
use_fix_target_network : true
}
\ No newline at end of file
......@@ -7,7 +7,7 @@ configuration FixTargetNetworkRequiresInterval2 {
context : cpu
loss : huber_loss
loss : huber
target_network_update_interval : 100
}
\ No newline at end of file
......@@ -3,7 +3,10 @@ configuration FullConfig{
batch_size : 100
load_checkpoint : true
eval_metric : mse
loss: cross_entropy
loss: softmax_cross_entropy{
sparse_label: true
from_logits: true
}
context : gpu
normalize : true
optimizer : rmsprop{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment