Added parameter 'context' (gpu or cpu for training)

parent de4e527b
Pipeline #52300 passed with stage
in 4 minutes and 33 seconds
#
#
# ******************************************************************************
# MontiCAR Modeling Family, www.se-rwth.de
# Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
# All rights reserved.
#
# This project is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3.0 of the License, or (at your option) any later version.
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this project. If not, see <http://www.gnu.org/licenses/>.
# *******************************************************************************
#
image: maven:3-jdk-8
build:
......
......@@ -30,7 +30,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>0.2.0-SNAPSHOT</version>
<version>0.2.1-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -41,7 +41,7 @@
<se-commons.version>1.7.7</se-commons.version>
<mc.grammars.assembly.version>0.0.6</mc.grammars.assembly.version>
<SIUnit.version>0.0.10-SNAPSHOT</SIUnit.version>
<Common-MontiCar.version>0.0.10-SNAPSHOT</Common-MontiCar.version>
<Common-MontiCar.version>0.0.11-SNAPSHOT</Common-MontiCar.version>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
......
......@@ -21,6 +21,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.lang.N
LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue;
NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue;
OptimizerEntry implements ConfigEntry = name:"optimizer" ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step"
......@@ -28,6 +29,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.lang.N
| inv:"inv"
| poly:"poly"
| sigmoid:"sigmoid");
TrainContextValue implements ConfigValue = (cpu:"cpu" | gpu:"gpu");
DataVariable implements VariableReference = Name&;
......
......@@ -143,6 +143,17 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.setNormalize(symbol);
}
@Override
public void visit(ASTTrainContextEntry node) {
TrainContextSymbol symbol = new TrainContextSymbol();
if (node.getValue().cpuIsPresent()){
symbol.setValue(node.getValue().getCpu().get());
}
else {
symbol.setValue(node.getValue().getGpu().get());
}
configuration.setTrainContext(symbol);
}
@Override
public void endVisit(ASTLRPolicyValue node) {
......
......@@ -31,6 +31,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private LoadCheckpointSymbol loadCheckpoint;
private NormalizeSymbol normalize;
private OptimizerSymbol optimizer;
private TrainContextSymbol trainContext;
public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind();
......@@ -78,5 +79,12 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
this.normalize = normalize;
}
public TrainContextSymbol getTrainContext() {
return trainContext;
}
public void setTrainContext(TrainContextSymbol trainContext) {
this.trainContext = trainContext;
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class TrainContextKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.TrainContextKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
public class TrainContextSymbol extends CommonSymbol {
public static final TrainContextKind KIND = new TrainContextKind();
private String value;
public TrainContextSymbol() {
super("", KIND);
}
public TrainContextSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public String getValue() {
return value;
}
public void setValue(String value) {
this.value = value;
}
}
......@@ -104,6 +104,13 @@ public class CNNTrainTemplateController {
return getConfiguration().getNormalize();
}
public TrainContextSymbol getContext() {
if (getConfiguration().getTrainContext() == null) {
return null;
}
return getConfiguration().getTrainContext();
}
public String getOptimizerName() {
if (getConfiguration().getOptimizer() == null) {
return null;
......
......@@ -7,6 +7,9 @@ num_epoch = ${tc.numEpoch},
<#if (tc.loadCheckpoint)??>
load_checkpoint = ${tc.loadCheckpoint.value?string("True","False")},
</#if>
<#if (tc.context)??>
context = '${tc.context.value}',
</#if>
<#if (tc.normalize)??>
normalize = ${tc.normalize.value?string("True","False")},
</#if>
......
......@@ -2,6 +2,7 @@ configuration FullConfig{
num_epoch : 5
batch_size : 100
load_checkpoint: true
context:cpu
normalize: true
optimizer:rmsprop{
learning_rate:0.001
......
......@@ -2,6 +2,7 @@ configuration FullConfig2{
num_epoch:10
batch_size:100
load_checkpoint: false
context:gpu
normalize: false
optimizer:adam{
learning_rate:0.001
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment