Commit e111d230 authored by Julian Treiber's avatar Julian Treiber

added load_pretrained flag, removed obsolete dice_weight

parent 5400ae7c
Pipeline #248020 failed with stages
This diff is collapsed.
This diff is collapsed.
......@@ -30,6 +30,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue;
LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue;
CheckpointPeriodEntry implements ConfigEntry = name:"checkpoint_period" ":" value:IntegerValue;
LoadPretrainedEntry implements ConfigEntry = name:"load_pretrained" ":" value:BooleanValue;
LogPeriodEntry implements ConfigEntry = name:"log_period" ":" value:IntegerValue;
NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue;
......@@ -115,7 +116,6 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
IgnoreIndicesEntry implements SoftmaxCrossEntropyIgnoreIndicesEntry = name:"ignore_indices" ":" value:IntegerValue;
MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue;
LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue;
DiceWeightEntry implements DiceEntry = name: "dice_weight" ":" value:DoubleVectorValue;
interface OptimizerValue extends ConfigValue;
interface SGDEntry extends OptimizerParamEntry;
......
......@@ -38,6 +38,7 @@ class ParameterAlgorithmMapping {
ASTBatchSizeEntry.class,
ASTLoadCheckpointEntry.class,
ASTCheckpointPeriodEntry.class,
ASTLoadPretrainedEntry.class,
ASTLogPeriodEntry.class,
ASTEvalMetricEntry.class,
ASTEvalTrainEntry.class,
......@@ -51,7 +52,6 @@ class ParameterAlgorithmMapping {
ASTBatchAxisEntry.class,
ASTFromLogitsEntry.class,
ASTIgnoreIndicesEntry.class,
ASTDiceWeightEntry.class,
ASTMarginEntry.class,
ASTLabelFormatEntry.class,
ASTRhoEntry.class,
......
......@@ -54,7 +54,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
CNNTrainCompilationUnitSymbol compilationUnitSymbol = new CNNTrainCompilationUnitSymbol(compilationUnit.getName());
addToScopeAndLinkWithNode(compilationUnitSymbol, compilationUnit);
}
@Override
public void endVisit(ASTCNNTrainCompilationUnit ast) {
......@@ -142,6 +142,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTLoadPretrainedEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForBoolean(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTLogPeriodEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
......@@ -231,7 +239,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForDoubleVector(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
......@@ -285,8 +293,8 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
}
addToScopeAndLinkWithNode(value, node);
}
private ValueSymbol getValueSymbolForInteger(ASTIntegerValue astIntegerValue) {
ValueSymbol value = new ValueSymbol();
Integer value_as_int = getIntegerFromNumber(astIntegerValue);
......@@ -321,7 +329,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
value.setValue(value_as_double_list);
return value;
}
private ValueSymbol getValueSymbolForComponentName(ASTComponentNameValue astComponentNameValue) {
ValueSymbol value = new ValueSymbol();
List<String> valueAsList = astComponentNameValue.getNameList();
......@@ -381,7 +389,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTLearningMethodEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
......
......@@ -4,12 +4,9 @@ configuration FullConfig{
batch_size : 100
load_checkpoint : true
eval_metric : mse
loss: dice_loss{
loss: softmax_cross_entropy{
sparse_label: true
from_logits: true
loss_axis : -1
batch_axis: 0
dice_weight: (0.2, 0.8)
}
context : gpu
normalize : true
......
/* (c) https://github.com/MontiCore/monticore */
configuration FullConfig{
num_epoch : 5
batch_size : 100
load_pretrained : true
eval_metric : mse
loss: dice_loss{
sparse_label: true
from_logits: true
loss_axis : -1
batch_axis : 0
}
context : gpu
normalize : true
optimizer : rmsprop{
learning_rate : 0.001
learning_rate_minimum : 0.00001
weight_decay : 0.01
learning_rate_decay : 0.9
learning_rate_policy : step
step_size : 1000
rescale_grad : 1.1
clip_gradient : 10
gamma1 : 0.9
gamma2 : 0.9
epsilon : 0.000001
centered : true
clip_weights : 10
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment