Commit d58bda9c authored by Svetlana Pavlitskaya's avatar Svetlana Pavlitskaya Committed by Thomas Michael Timmermanns

added normalize parameter and learning_rate_minimum in optimizer params

parent 21f68b7d
......@@ -19,9 +19,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.lang.N
NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue;
BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue;
LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue;
NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue;
OptimizerEntry implements ConfigEntry = name:"optimizer" ":" value:OptimizerValue;
LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step"
| exp:"exp"
......@@ -60,6 +60,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.lang.N
interface GeneralOptimizerEntry extends SGDEntry,AdamEntry,RmsPropEntry,AdaGradEntry,AdaDeltaEntry;
MinimumLearningRateEntry implements GeneralOptimizerEntry = name:"learning_rate_minimum" ":" value:NumberValue;
LearningRateEntry implements GeneralOptimizerEntry = name:"learning_rate" ":" value:NumberValue;
WeightDecayEntry implements GeneralOptimizerEntry = name:"weight_decay" ":" value:NumberValue;
LRDecayEntry implements GeneralOptimizerEntry = name:"learning_rate_decay" ":" value:NumberValue;
......
......@@ -131,6 +131,18 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.setLoadCheckpoint(symbol);
}
@Override
public void endVisit(ASTNormalizeEntry node) {
NormalizeSymbol symbol = new NormalizeSymbol();
if (node.getValue().getTRUE().isPresent()){
symbol.setValue(true);
}
else if (node.getValue().getFALSE().isPresent()){
symbol.setValue(false);
}
configuration.setNormalize(symbol);
}
@Override
public void endVisit(ASTLRPolicyValue node) {
......
......@@ -29,6 +29,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private NumEpochSymbol numEpoch;
private BatchSizeSymbol batchSize;
private LoadCheckpointSymbol loadCheckpoint;
private NormalizeSymbol normalize;
private OptimizerSymbol optimizer;
public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind();
......@@ -69,4 +70,13 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
this.loadCheckpoint = loadCheckpoint;
}
public NormalizeSymbol getNormalize() {
return normalize;
}
public void setNormalize(NormalizeSymbol normalize) {
this.normalize = normalize;
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
import java.util.Optional;
public class NormalizeSymbol extends CommonSymbol {
public static final NormalizeSymbolKind KIND = new NormalizeSymbolKind();
private boolean value = false;
public NormalizeSymbol() {
super("", KIND);
}
public NormalizeSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public boolean getValue() {
return value;
}
public void setValue(boolean value) {
this.value = value;
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class NormalizeSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.NormalizeSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
\ No newline at end of file
......@@ -97,6 +97,13 @@ public class CNNTrainTemplateController {
return getConfiguration().getLoadCheckpoint();
}
public NormalizeSymbol getNormalize() {
if (getConfiguration().getNormalize() == null) {
return null;
}
return getConfiguration().getNormalize();
}
public String getOptimizerName() {
if (getConfiguration().getOptimizer() == null) {
return null;
......
......@@ -7,6 +7,9 @@ num_epoch = ${tc.numEpoch},
<#if (tc.loadCheckpoint)??>
load_checkpoint = ${tc.loadCheckpoint.value?string("True","False")},
</#if>
<#if (tc.normalize)??>
normalize = ${tc.normalize.value?string("True","False")},
</#if>
<#if (tc.configuration.optimizer)??>
optimizer = '${tc.optimizerName}',
optimizer_params = {
......
......@@ -2,8 +2,10 @@ configuration FullConfig{
num_epoch : 5
batch_size : 100
load_checkpoint: true
normalize: true
optimizer:rmsprop{
learning_rate:0.001
learning_rate_minimum: 0.00001
weight_decay:0.01
learning_rate_decay:0.9
learning_rate_policy:step
......
......@@ -2,8 +2,10 @@ configuration FullConfig2{
num_epoch:10
batch_size:100
load_checkpoint: false
normalize: false
optimizer:adam{
learning_rate:0.001
learning_rate_minimum: 0.001
weight_decay:0.01
learning_rate_decay:0.9
learning_rate_policy:exp
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment