Commit 79b73c98 authored by Kirhan, Cihad's avatar Kirhan, Cihad
Browse files

ConfLang integration

parent 88eb588d
Pipeline #452317 passed with stage
in 1 minute and 5 seconds
......@@ -10,6 +10,7 @@ import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.se_rwth.commons.logging.Log;
import schemalang._symboltable.SchemaLangDefinitionSymbol;
import java.io.IOException;
import java.nio.file.Path;
......@@ -42,7 +43,14 @@ public class CNNTrain2MxNet extends CNNTrainGenerator {
@Override
public void generate(Path modelsDirPath, String rootModelName) {
ConfLangConfigurationSymbol configuration = getConfigurationSymbol(modelsDirPath, rootModelName);
List<FileContent> fileContents = generateFileContents(configuration, Maps.newHashMap());
SchemaLangDefinitionSymbol schema = getSchemaDefinitionSymbol(modelsDirPath, rootModelName);
schema.validateConfiguration(configuration, modelsDirPath.toString());
if (Log.getErrorCount() > 0) {
throw new RuntimeException("Generation not possible!");
}
List<FileContent> fileContents = generateFileContents(configuration, schema, Maps.newHashMap());
GeneratorCPP genCPP = new GeneratorCPP();
genCPP.setGenerationTargetPath(getGenerationTargetPath());
try {
......@@ -70,9 +78,9 @@ public class CNNTrain2MxNet extends CNNTrainGenerator {
// }
@Override
public List<FileContent> generateFileContents(ConfLangConfigurationSymbol configuration, Map<String, ArchitectureSymbol> architectureAdapterMap) {
public List<FileContent> generateFileContents(ConfLangConfigurationSymbol configuration, SchemaLangDefinitionSymbol schema, Map<String, ArchitectureSymbol> architectureAdapterMap) {
TemplateConfiguration templateConfiguration = new MxNetTemplateConfiguration();
ConfigurationDataConfLang configData = new ConfigurationDataConfLang(configuration, getInstanceName());
ConfigurationDataConfLang configData = new ConfigurationDataConfLang(configuration, schema, getInstanceName());
List<ConfigurationDataConfLang> configDataList = Lists.newArrayList(configData);
Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList);
......
......@@ -14,6 +14,7 @@ import conflangliterals._ast.ASTVectorLiteral;
import de.monticore.literals.literals._ast.*;
import de.monticore.symboltable.Symbol;
import de.monticore.symboltable.SymbolKind;
import schemalang._symboltable.SchemaLangDefinitionSymbol;
import java.util.List;
import java.util.Map;
......@@ -24,10 +25,12 @@ import static de.monticore.lang.monticar.cnnarch.mxnetgenerator.TrainingParamete
public class ConfigurationDataConfLang {
private ConfLangConfigurationSymbol configuration;
private SchemaLangDefinitionSymbol schema;
private String instanceName;
public ConfigurationDataConfLang(ConfLangConfigurationSymbol configuration, String instanceName) {
public ConfigurationDataConfLang(ConfLangConfigurationSymbol configuration, SchemaLangDefinitionSymbol schema, String instanceName) {
this.configuration = configuration;
this.schema = schema;
this.instanceName = instanceName;
}
......@@ -654,21 +657,27 @@ public class ConfigurationDataConfLang {
if (signedLiteral instanceof ASTSignedIntLiteral) {
ASTSignedIntLiteral signedIntLiteral = (ASTSignedIntLiteral) signedLiteral;
return signedIntLiteral.getValue();
} else if (signedLiteral instanceof ASTSignedDoubleLiteral) {
ASTSignedDoubleLiteral signedDoubleLiteral = (ASTSignedDoubleLiteral) signedLiteral;
return signedDoubleLiteral.getValue();
Double doubleValue = (Double) signedDoubleLiteral.getValue();
return doubleValue.toString();
} else if (signedLiteral instanceof ASTBooleanLiteral) {
ASTBooleanLiteral booleanLiteral = (ASTBooleanLiteral) signedLiteral;
if (booleanLiteral.getValue()) {
return "True";
}
return "False";
} else if (signedLiteral instanceof ASTStringLiteral) {
ASTStringLiteral stringLiteral = (ASTStringLiteral) signedLiteral;
return stringLiteral.getValue();
} else if (signedLiteral instanceof ASTTypelessLiteral) {
ASTTypelessLiteral typelessLiteral = (ASTTypelessLiteral) signedLiteral;
return typelessLiteral.getValue();
return "'".concat(typelessLiteral.getValue()).concat("'");
} else if (signedLiteral instanceof ASTVectorLiteral) {
ASTVectorLiteral vectorLiteral = (ASTVectorLiteral) signedLiteral;
List<Object> vectorEntries = Lists.newArrayList();
......
/* (c) https://github.com/MontiCore/monticore */
schema EmptyConfig {
}
\ No newline at end of file
......@@ -19,11 +19,11 @@ configuration FullConfig {
learning_rate_policy = step
step_size = 1000
rescale_grad = 1.1
clip_gradient = 10
clip_gradient = 10.0
gamma1 = 0.9
gamma2 = 0.9
epsilon = 0.000001
centered = true
clip_weights = 10
clip_weights = 10.0
}
}
\ No newline at end of file
/* (c) https://github.com/MontiCore/monticore */
schema FullConfig {
context: enum {
cpu,
gpu;
}
num_epoch: N0
batch_size: N0
load_checkpoint: B
normalize: B
optimizer: complex<optimizer>
eval_metric: complex<eval_metric>
loss: complex<loss>
complex optimizer {
instances:
rmsprop;
define rmsprop {
learning_rate: Q
learning_rate_minimum: Q
rescale_grad: Q
clip_gradient: Q
weight_decay: Q
learning_rate_decay: Q
learning_rate_policy: enum {
step;
}
step_size: N0
gamma1: Q
gamma2: Q
centered: B
epsilon: Q
clip_weights: Q
}
}
complex eval_metric {
instances:
mse;
}
complex loss {
instances:
softmax_cross_entropy;
define softmax_cross_entropy {
loss_axis: Z
sparse_label: B
from_logits: B
}
}
}
\ No newline at end of file
/* (c) https://github.com/MontiCore/monticore */
schema SimpleConfig {
num_epoch: N0
batch_size: N0
loss: complex<loss>
optimizer: complex<optimizer>
complex loss {
instances:
cross_entropy;
}
complex optimizer {
instances:
adam;
define adam {
learning_rate: Q
}
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment