Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2Gluon
Commits
d3498240
Commit
d3498240
authored
Oct 18, 2018
by
Carlos Alfredo Yeverino Rodriguez
Browse files
Added checker for training parameter support. Visitor pattern used.
parent
2bfbffa9
Changes
2
Hide whitespace changes
Inline
Side-by-side
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNTrain2MxNet.java
View file @
d3498240
...
...
@@ -2,6 +2,8 @@ package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import
de.monticore.io.paths.ModelPath
;
import
de.monticore.lang.monticar.cnntrain.CNNTrainGenerator
;
import
de.monticore.lang.monticar.cnntrain._ast.ASTCNNTrainNode
;
import
de.monticore.lang.monticar.cnntrain._ast.ASTOptimizerEntry
;
import
de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos
;
import
de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage
;
...
...
@@ -19,6 +21,43 @@ public class CNNTrain2MxNet implements CNNTrainGenerator {
private
String
generationTargetPath
;
private
String
instanceName
;
private
void
supportCheck
(
ConfigurationSymbol
configuration
){
checkEntryParams
(
configuration
);
checkOptimizerParams
(
configuration
);
}
private
void
checkEntryParams
(
ConfigurationSymbol
configuration
){
TrainParamSupportChecker
funcChecker
=
new
TrainParamSupportChecker
();
Iterator
it
=
configuration
.
getEntryMap
().
keySet
().
iterator
();
while
(
it
.
hasNext
())
{
String
key
=
it
.
next
().
toString
();
ASTCNNTrainNode
astTrainEntryNode
=
(
ASTCNNTrainNode
)
configuration
.
getEntryMap
().
get
(
key
).
getAstNode
().
get
();
astTrainEntryNode
.
accept
(
funcChecker
);
}
it
=
configuration
.
getEntryMap
().
keySet
().
iterator
();
while
(
it
.
hasNext
())
{
String
key
=
it
.
next
().
toString
();
if
(
funcChecker
.
getUnsupportedElemList
().
contains
(
key
))
it
.
remove
();
}
}
private
void
checkOptimizerParams
(
ConfigurationSymbol
configuration
){
TrainParamSupportChecker
funcChecker
=
new
TrainParamSupportChecker
();
if
(
configuration
.
getOptimizer
()
!=
null
)
{
ASTOptimizerEntry
astOptimizer
=
(
ASTOptimizerEntry
)
configuration
.
getOptimizer
().
getAstNode
().
get
();
astOptimizer
.
accept
(
funcChecker
);
if
(
funcChecker
.
getUnsupportedElemList
().
contains
(
funcChecker
.
unsupportedOptFlag
))
{
configuration
.
setOptimizer
(
null
);
}
else
{
Iterator
it
=
configuration
.
getOptimizer
().
getOptimizerParamMap
().
keySet
().
iterator
();
while
(
it
.
hasNext
())
{
String
key
=
it
.
next
().
toString
();
if
(
funcChecker
.
getUnsupportedElemList
().
contains
(
key
))
it
.
remove
();
}
}
}
}
public
CNNTrain2MxNet
()
{
setGenerationTargetPath
(
"./target/generated-sources-cnnarch/"
);
}
...
...
@@ -54,6 +93,7 @@ public class CNNTrain2MxNet implements CNNTrainGenerator {
}
setInstanceName
(
compilationUnit
.
get
().
getFullName
());
CNNTrainCocos
.
checkAll
(
compilationUnit
.
get
());
supportCheck
(
compilationUnit
.
get
().
getConfiguration
());
return
compilationUnit
.
get
().
getConfiguration
();
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/TrainParamSupportChecker.java
0 → 100644
View file @
d3498240
package
de.monticore.lang.monticar.cnnarch.mxnetgenerator
;
import
de.monticore.lang.monticar.cnntrain._ast.*
;
import
de.monticore.lang.monticar.cnntrain._visitor.CNNTrainVisitor
;
import
de.se_rwth.commons.logging.Log
;
import
java.util.ArrayList
;
import
java.util.List
;
public
class
TrainParamSupportChecker
implements
CNNTrainVisitor
{
private
List
<
String
>
unsupportedElemList
=
new
ArrayList
();
private
void
printUnsupportedEntryParam
(
String
nodeName
){
Log
.
warn
(
"Unsupported training parameter "
+
"'"
+
nodeName
+
"'"
+
" for the backend MXNet. It will be ignored."
);
}
private
void
printUnsupportedOptimizer
(
String
nodeName
){
Log
.
warn
(
"Unsupported optimizer parameter "
+
"'"
+
nodeName
+
"'"
+
" for the backend MXNet. It will be ignored."
);
}
private
void
printUnsupportedOptimizerParam
(
String
nodeName
){
Log
.
warn
(
"Unsupported training optimizer parameter "
+
"'"
+
nodeName
+
"'"
+
" for the backend MXNet. It will be ignored."
);
}
public
TrainParamSupportChecker
()
{
}
public
String
unsupportedOptFlag
=
"unsupported_optimizer"
;
public
List
getUnsupportedElemList
(){
return
this
.
unsupportedElemList
;
}
public
void
visit
(
ASTNumEpochEntry
node
){}
public
void
visit
(
ASTBatchSizeEntry
node
){}
public
void
visit
(
ASTLoadCheckpointEntry
node
){}
public
void
visit
(
ASTNormalizeEntry
node
){}
public
void
visit
(
ASTTrainContextEntry
node
){}
public
void
visit
(
ASTEvalMetricEntry
node
){}
public
void
visit
(
ASTSGDOptimizer
node
){}
public
void
visit
(
ASTAdamOptimizer
node
){}
public
void
visit
(
ASTRmsPropOptimizer
node
){}
public
void
visit
(
ASTAdaGradOptimizer
node
){}
public
void
visit
(
ASTNesterovOptimizer
node
){}
public
void
visit
(
ASTAdaDeltaOptimizer
node
){}
public
void
visit
(
ASTLearningRateEntry
node
){}
public
void
visit
(
ASTMinimumLearningRateEntry
node
){}
public
void
visit
(
ASTWeightDecayEntry
node
){}
public
void
visit
(
ASTLRDecayEntry
node
){}
public
void
visit
(
ASTLRPolicyEntry
node
){}
public
void
visit
(
ASTRescaleGradEntry
node
){}
public
void
visit
(
ASTClipGradEntry
node
){}
public
void
visit
(
ASTStepSizeEntry
node
){}
public
void
visit
(
ASTMomentumEntry
node
){}
public
void
visit
(
ASTBeta1Entry
node
){}
public
void
visit
(
ASTBeta2Entry
node
){}
public
void
visit
(
ASTEpsilonEntry
node
){}
public
void
visit
(
ASTGamma1Entry
node
){}
public
void
visit
(
ASTGamma2Entry
node
){}
public
void
visit
(
ASTCenteredEntry
node
){}
public
void
visit
(
ASTClipWeightsEntry
node
){}
public
void
visit
(
ASTRhoEntry
node
){}
}
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment