Commit 3471e711 authored by Carlos Alfredo Yeverino Rodriguez's avatar Carlos Alfredo Yeverino Rodriguez
Browse files

Added new loss parameter

parent b69e7693
Pipeline #103578 failed with stages
in 7 minutes and 30 seconds
......@@ -30,7 +30,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>0.2.5</version>
<version>0.2.6-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......
......@@ -23,6 +23,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
OptimizerEntry implements ConfigEntry = name:"optimizer" ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
EvalMetricEntry implements ConfigEntry = name:"eval_metric" ":" value:EvalMetricValue;
LossEntry implements ConfigEntry = name:"loss" ":" value:LossValue;
EvalMetricValue implements ConfigValue =(accuracy:"accuracy"
| crossEntropy:"cross_entropy"
......@@ -32,6 +33,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
| rmse:"rmse"
| topKAccuracy:"top_k_accuracy");
LossValue implements ConfigValue =(euclidean:"euclidean" | crossEntropy:"cross_entropy");
LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step"
| exp:"exp"
......
......@@ -178,6 +178,21 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTLossEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
ValueSymbol value = new ValueSymbol();
if (node.getValue().isPresentEuclidean()){
value.setValue(Loss.EUCLIDEAN);
}
else if (node.getValue().isPresentCrossEntropy()){
value.setValue(Loss.CROSS_ENTROPY);
}
entry.setValue(value);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTLRPolicyValue node) {
OptimizerParamValueSymbol value = new OptimizerParamValueSymbol();
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
public enum Loss {
EUCLIDEAN{
@Override
public String toString() {
return "euclidean";
}
},
CROSS_ENTROPY{
@Override
public String toString() {
return "cross_entropy";
}
}
}
......@@ -3,6 +3,7 @@ configuration FullConfig{
batch_size : 100
load_checkpoint : true
eval_metric : mse
loss: cross_entropy
context : gpu
normalize : true
optimizer : rmsprop{
......
......@@ -4,6 +4,7 @@ configuration FullConfig2{
load_checkpoint : false
context : gpu
eval_metric : top_k_accuracy
loss: euclidean
normalize : false
optimizer : adam{
learning_rate : 0.001
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment