Commit aff7ca07 authored by Sebastian N.'s avatar Sebastian N.

Changed TrainParamSupportChecker

parent d8b24a45
......@@ -3,7 +3,6 @@ package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnnarch.generator.ConfigurationData;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnnarch.generator.TrainParamSupportChecker;
import de.monticore.lang.monticar.cnntrain._ast.ASTCNNTrainNode;
import de.monticore.lang.monticar.cnntrain._ast.ASTOptimizerEntry;
import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
......@@ -20,6 +19,9 @@ import java.nio.file.Path;
import java.util.*;
public abstract class CNNTrainGenerator {
protected TrainParamSupportChecker trainParamSupportChecker;
private String generationTargetPath;
private String instanceName;
......@@ -33,34 +35,32 @@ public abstract class CNNTrainGenerator {
}
private void checkEntryParams(ConfigurationSymbol configuration){
TrainParamSupportChecker funcChecker = new TrainParamSupportChecker();
Iterator it = configuration.getEntryMap().keySet().iterator();
while (it.hasNext()) {
String key = it.next().toString();
ASTCNNTrainNode astTrainEntryNode = (ASTCNNTrainNode) configuration.getEntryMap().get(key).getAstNode().get();
astTrainEntryNode.accept(funcChecker);
astTrainEntryNode.accept(trainParamSupportChecker);
}
it = configuration.getEntryMap().keySet().iterator();
while (it.hasNext()) {
String key = it.next().toString();
if (funcChecker.getUnsupportedElemList().contains(key)) {
if (trainParamSupportChecker.getUnsupportedElemList().contains(key)) {
it.remove();
}
}
}
private void checkOptimizerParams(ConfigurationSymbol configuration){
TrainParamSupportChecker funcChecker = new TrainParamSupportChecker();
if (configuration.getOptimizer() != null) {
ASTOptimizerEntry astOptimizer = (ASTOptimizerEntry) configuration.getOptimizer().getAstNode().get();
astOptimizer.accept(funcChecker);
if (funcChecker.getUnsupportedElemList().contains(funcChecker.unsupportedOptFlag)) {
astOptimizer.accept(trainParamSupportChecker);
if (trainParamSupportChecker.getUnsupportedElemList().contains(trainParamSupportChecker.unsupportedOptFlag)) {
configuration.setOptimizer(null);
}else {
Iterator it = configuration.getOptimizer().getOptimizerParamMap().keySet().iterator();
while (it.hasNext()) {
String key = it.next().toString();
if (funcChecker.getUnsupportedElemList().contains(key)) {
if (trainParamSupportChecker.getUnsupportedElemList().contains(key)) {
it.remove();
}
}
......
......@@ -11,15 +11,15 @@ public class TrainParamSupportChecker implements CNNTrainVisitor {
private List<String> unsupportedElemList = new ArrayList<>();
private void printUnsupportedEntryParam(String nodeName){
Log.warn("Unsupported training parameter " + "'" + nodeName + "'" + " for the backend MXNet. It will be ignored.");
Log.warn("Unsupported training parameter " + "'" + nodeName + "'" + " for the current backend. It will be ignored.");
}
private void printUnsupportedOptimizer(String nodeName){
Log.warn("Unsupported optimizer parameter " + "'" + nodeName + "'" + " for the backend MXNet. It will be ignored.");
Log.warn("Unsupported optimizer parameter " + "'" + nodeName + "'" + " for the current backend. It will be ignored.");
}
private void printUnsupportedOptimizerParam(String nodeName){
Log.warn("Unsupported training optimizer parameter " + "'" + nodeName + "'" + " for the backend MXNet. It will be ignored.");
Log.warn("Unsupported training optimizer parameter " + "'" + nodeName + "'" + " for the current backend. It will be ignored.");
}
public TrainParamSupportChecker() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment