Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2Gluon
Commits
1802f7d0
Commit
1802f7d0
authored
Aug 31, 2018
by
Svetlana Pavlitskaya
Browse files
A new generator class for CNNTrain language impelemnting corresponding interface
parent
cfb99d88
Changes
15
Hide whitespace changes
Inline
Side-by-side
pom.xml
View file @
1802f7d0
...
...
@@ -8,15 +8,15 @@
<groupId>
de.monticore.lang.monticar
</groupId>
<artifactId>
cnnarch-mxnet-generator
</artifactId>
<version>
0.2.
4
-SNAPSHOT
</version>
<version>
0.2.
5
-SNAPSHOT
</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>
0.2.
5
-SNAPSHOT
</CNNArch.version>
<CNNTrain.version>
0.2.
4
-SNAPSHOT
</CNNTrain.version>
<CNNArch.version>
0.2.
6
-SNAPSHOT
</CNNArch.version>
<CNNTrain.version>
0.2.
5
-SNAPSHOT
</CNNTrain.version>
<embedded-montiarc-math-generator>
0.0.25-SNAPSHOT
</embedded-montiarc-math-generator>
<!-- .. Libraries .................................................. -->
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNArch2MxNet.java
View file @
1802f7d0
...
...
@@ -26,19 +26,19 @@ import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.lang.monticar.generator.FileContent
;
import
de.monticore.lang.monticar.generator.cmake.CMakeConfig
;
import
de.monticore.lang.monticar.generator.cmake.CMakeFindModule
;
import
de.monticore.lang.monticar.generator.cpp.GeneratorCPP
;
import
de.monticore.symboltable.GlobalScope
;
import
de.monticore.symboltable.Scope
;
import
de.se_rwth.commons.logging.Log
;
import
java.io.File
;
import
java.io.FileWriter
;
import
java.io.IOException
;
import
java.nio.file.Path
;
import
java.util.*
;
import
java.util.HashMap
;
import
java.util.Map
;
import
java.util.Optional
;
public
class
CNNArch2MxNet
implements
CNNArchGenerator
{
...
...
@@ -87,24 +87,6 @@ public class CNNArch2MxNet implements CNNArchGenerator {
}
}
@Override
public
Map
<
String
,
String
>
generateTrainer
(
List
<
ConfigurationSymbol
>
configurations
,
List
<
String
>
instanceNames
,
String
mainComponentName
)
{
int
numberOfNetworks
=
configurations
.
size
();
if
(
configurations
.
size
()
!=
instanceNames
.
size
()){
throw
new
IllegalStateException
(
"The number of configurations and the number of instances for generation of the CNNTrainer is not equal. "
+
"This should have been checked previously."
);
}
List
<
ConfigurationData
>
configDataList
=
new
ArrayList
<>();
for
(
int
i
=
0
;
i
<
numberOfNetworks
;
i
++){
configDataList
.
add
(
new
ConfigurationData
(
configurations
.
get
(
i
),
instanceNames
.
get
(
i
)));
}
Map
<
String
,
Object
>
ftlContext
=
Collections
.
singletonMap
(
"configurations"
,
configDataList
);
return
Collections
.
singletonMap
(
"CNNTrainer_"
+
mainComponentName
+
".py"
,
TemplateConfiguration
.
processTemplate
(
ftlContext
,
"CNNTrainer.ftl"
));
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public
Map
<
String
,
String
>
generateStrings
(
ArchitectureSymbol
architecture
){
Map
<
String
,
String
>
fileContentMap
=
new
HashMap
<>();
...
...
@@ -163,19 +145,10 @@ public class CNNArch2MxNet implements CNNArchGenerator {
}
private
void
generateFromFilecontentsMap
(
Map
<
String
,
String
>
fileContentMap
)
throws
IOException
{
GeneratorCPP
genCPP
=
new
GeneratorCPP
();
genCPP
.
setGenerationTargetPath
(
getGenerationTargetPath
());
for
(
String
fileName
:
fileContentMap
.
keySet
()){
File
f
=
new
File
(
getGenerationTargetPath
()
+
fileName
);
Log
.
info
(
f
.
getName
(),
"FileCreation:"
);
if
(!
f
.
exists
())
{
f
.
getParentFile
().
mkdirs
();
if
(!
f
.
createNewFile
())
{
Log
.
error
(
"File could not be created"
);
}
}
FileWriter
writer
=
new
FileWriter
(
f
);
writer
.
write
(
fileContentMap
.
get
(
fileName
));
writer
.
close
();
genCPP
.
generateFile
(
new
FileContent
(
fileContentMap
.
get
(
fileName
),
fileName
));
}
}
...
...
@@ -186,7 +159,7 @@ public class CNNArch2MxNet implements CNNArchGenerator {
CMakeConfig
cMakeConfig
=
new
CMakeConfig
(
rootModelName
);
cMakeConfig
.
addModuleDependency
(
new
CMakeFindModule
(
"Armadillo"
,
true
));
cMakeConfig
.
addCMakeCommand
End
(
"set(LIBS ${LIBS} mxnet)"
);
cMakeConfig
.
addCMakeCommand
(
"set(LIBS ${LIBS} mxnet)"
);
Map
<
String
,
String
>
fileContentMap
=
new
HashMap
<>();
for
(
FileContent
fileContent
:
cMakeConfig
.
generateCMakeFiles
()){
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/CNNTrain2MxNet.java
0 → 100644
View file @
1802f7d0
package
de.monticore.lang.monticar.cnnarch.mxnetgenerator
;
import
de.monticore.io.paths.ModelPath
;
import
de.monticore.lang.monticar.cnntrain.CNNTrainGenerator
;
import
de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos
;
import
de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.lang.monticar.generator.FileContent
;
import
de.monticore.lang.monticar.generator.cpp.GeneratorCPP
;
import
de.monticore.symboltable.GlobalScope
;
import
de.se_rwth.commons.logging.Log
;
import
java.io.IOException
;
import
java.nio.file.Path
;
import
java.util.*
;
public
class
CNNTrain2MxNet
implements
CNNTrainGenerator
{
private
String
generationTargetPath
;
private
String
instanceName
;
public
CNNTrain2MxNet
()
{
setGenerationTargetPath
(
"./target/generated-sources-cnnarch/"
);
}
public
String
getInstanceName
()
{
String
parsedInstanceName
=
this
.
instanceName
.
replace
(
'.'
,
'_'
).
replace
(
'['
,
'_'
).
replace
(
']'
,
'_'
);
parsedInstanceName
=
parsedInstanceName
.
substring
(
0
,
1
).
toLowerCase
()
+
parsedInstanceName
.
substring
(
1
);
return
parsedInstanceName
;
}
public
void
setInstanceName
(
String
instanceName
)
{
this
.
instanceName
=
instanceName
;
}
public
String
getGenerationTargetPath
()
{
if
(
generationTargetPath
.
charAt
(
generationTargetPath
.
length
()
-
1
)
!=
'/'
)
{
this
.
generationTargetPath
=
generationTargetPath
+
"/"
;
}
return
generationTargetPath
;
}
public
void
setGenerationTargetPath
(
String
generationTargetPath
)
{
this
.
generationTargetPath
=
generationTargetPath
;
}
public
ConfigurationSymbol
getConfigurationSymbol
(
Path
modelsDirPath
,
String
rootModelName
)
{
final
ModelPath
mp
=
new
ModelPath
(
modelsDirPath
);
GlobalScope
scope
=
new
GlobalScope
(
mp
,
new
CNNTrainLanguage
());
Optional
<
CNNTrainCompilationUnitSymbol
>
compilationUnit
=
scope
.
resolve
(
rootModelName
,
CNNTrainCompilationUnitSymbol
.
KIND
);
if
(!
compilationUnit
.
isPresent
())
{
Log
.
error
(
"could not resolve training configuration "
+
rootModelName
);
System
.
exit
(
1
);
}
setInstanceName
(
compilationUnit
.
get
().
getFullName
());
CNNTrainCocos
.
checkAll
(
compilationUnit
.
get
());
return
compilationUnit
.
get
().
getConfiguration
();
}
public
void
generate
(
Path
modelsDirPath
,
String
rootModelName
)
{
ConfigurationSymbol
configuration
=
getConfigurationSymbol
(
modelsDirPath
,
rootModelName
);
Map
<
String
,
String
>
fileContents
=
generateStrings
(
configuration
);
GeneratorCPP
genCPP
=
new
GeneratorCPP
();
genCPP
.
setGenerationTargetPath
(
getGenerationTargetPath
());
try
{
for
(
String
fileName
:
fileContents
.
keySet
()){
genCPP
.
generateFile
(
new
FileContent
(
fileContents
.
get
(
fileName
),
fileName
));
}
}
catch
(
IOException
e
)
{
e
.
printStackTrace
();
}
}
public
Map
<
String
,
String
>
generateStrings
(
ConfigurationSymbol
configuration
)
{
ConfigurationData
configData
=
new
ConfigurationData
(
configuration
,
getInstanceName
());
List
<
ConfigurationData
>
configDataList
=
new
ArrayList
<>();
configDataList
.
add
(
configData
);
Map
<
String
,
Object
>
ftlContext
=
Collections
.
singletonMap
(
"configurations"
,
configDataList
);
String
templateContent
=
TemplateConfiguration
.
processTemplate
(
ftlContext
,
"CNNTrainer.ftl"
);
return
Collections
.
singletonMap
(
"CNNTrainer_"
+
getInstanceName
()
+
".py"
,
templateContent
);
}
}
\ No newline at end of file
src/test/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/AbstractSymtabTest.java
View file @
1802f7d0
...
...
@@ -60,15 +60,6 @@ public class AbstractSymtabTest {
return
scope
;
}
/* protected static ASTCNNArchCompilationUnit getAstNode(String modelPath, String model) {
Scope symTab = createSymTab(MODEL_PATH + modelPath);
CNNArchCompilationUnitSymbol comp = symTab.<CNNArchCompilationUnitSymbol> resolve(
model, CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull("Could not resolve model " + model, comp);
return (ASTCNNArchCompilationUnit) comp.getAstNode().get();
}*/
protected
static
CNNArchCompilationUnitSymbol
getCompilationUnitSymbol
(
String
modelPath
,
String
model
)
{
Scope
symTab
=
createSymTab
(
MODEL_PATH
+
modelPath
);
CNNArchCompilationUnitSymbol
comp
=
symTab
.<
CNNArchCompilationUnitSymbol
>
resolve
(
...
...
src/test/java/de/monticore/lang/monticar/cnnarch/mxnetgenerator/GenerationTest.java
View file @
1802f7d0
...
...
@@ -20,19 +20,13 @@
*/
package
de.monticore.lang.monticar.cnnarch.mxnetgenerator
;
import
de.monticore.io.paths.ModelPath
;
import
de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos
;
import
de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.symboltable.GlobalScope
;
import
de.se_rwth.commons.logging.Log
;
import
freemarker.template.TemplateException
;
import
org.junit.Before
;
import
org.junit.Test
;
import
java.io.FileWriter
;
import
java.io.IOException
;
import
java.nio.file.Path
;
import
java.nio.file.Paths
;
import
java.util.*
;
...
...
@@ -121,143 +115,50 @@ public class GenerationTest extends AbstractSymtabTest{
assertTrue
(
Log
.
getFindings
().
size
()
==
3
);
}
@Test
public
void
testCNNTrainerGeneration
()
throws
IOException
,
TemplateException
{
Log
.
getFindings
().
clear
();
List
<
ConfigurationSymbol
>
configurations
=
new
ArrayList
<>();
List
<
String
>
instanceNames
=
Arrays
.
asList
(
"main_net1"
,
"main_net2"
);
final
ModelPath
mp
=
new
ModelPath
(
Paths
.
get
(
"src/test/resources/valid_tests"
));
GlobalScope
scope
=
new
GlobalScope
(
mp
,
new
CNNTrainLanguage
());
CNNTrainCompilationUnitSymbol
compilationUnit
=
scope
.<
CNNTrainCompilationUnitSymbol
>
resolve
(
"Network1"
,
CNNTrainCompilationUnitSymbol
.
KIND
).
get
();
CNNTrainCocos
.
checkAll
(
compilationUnit
);
configurations
.
add
(
compilationUnit
.
getConfiguration
());
compilationUnit
=
scope
.<
CNNTrainCompilationUnitSymbol
>
resolve
(
"Network2"
,
CNNTrainCompilationUnitSymbol
.
KIND
).
get
();
CNNTrainCocos
.
checkAll
(
compilationUnit
);
configurations
.
add
(
compilationUnit
.
getConfiguration
());
CNNArch2MxNet
generator
=
new
CNNArch2MxNet
();
Map
<
String
,
String
>
trainerMap
=
generator
.
generateTrainer
(
configurations
,
instanceNames
,
"main"
);
for
(
String
fileName
:
trainerMap
.
keySet
()){
FileWriter
writer
=
new
FileWriter
(
generator
.
getGenerationTargetPath
()
+
fileName
);
writer
.
write
(
trainerMap
.
get
(
fileName
));
writer
.
close
();
}
assertTrue
(
Log
.
getFindings
().
isEmpty
());
checkFilesAreEqual
(
Paths
.
get
(
"./target/generated-sources-cnnarch"
),
Paths
.
get
(
"./src/test/resources/target_code"
),
Arrays
.
asList
(
"CNNTrainer_main.py"
));
}
@Test
public
void
testFullCfgGeneration
()
throws
IOException
,
TemplateException
{
Log
.
getFindings
().
clear
();
List
<
ConfigurationSymbol
>
configurations
=
new
ArrayList
<>();
List
<
String
>
instanceName
=
Arrays
.
asList
(
"main_net1"
,
"main_net2"
);
final
ModelPath
mp
=
new
ModelPath
(
Paths
.
get
(
"src/test/resources/valid_tests"
));
GlobalScope
scope
=
new
GlobalScope
(
mp
,
new
CNNTrainLanguage
());
CNNTrainCompilationUnitSymbol
compilationUnit
=
scope
.<
CNNTrainCompilationUnitSymbol
>
resolve
(
"FullConfig"
,
CNNTrainCompilationUnitSymbol
.
KIND
).
get
();
CNNTrainCocos
.
checkAll
(
compilationUnit
);
configurations
.
add
(
compilationUnit
.
getConfiguration
());
compilationUnit
=
scope
.<
CNNTrainCompilationUnitSymbol
>
resolve
(
"FullConfig2"
,
CNNTrainCompilationUnitSymbol
.
KIND
).
get
();
CNNTrainCocos
.
checkAll
(
compilationUnit
);
configurations
.
add
(
compilationUnit
.
getConfiguration
());
CNNArch2MxNet
generator
=
new
CNNArch2MxNet
();
Map
<
String
,
String
>
trainerMap
=
generator
.
generateTrainer
(
configurations
,
instanceName
,
"mainFull"
);
for
(
String
fileName
:
trainerMap
.
keySet
()){
FileWriter
writer
=
new
FileWriter
(
generator
.
getGenerationTargetPath
()
+
fileName
);
writer
.
write
(
trainerMap
.
get
(
fileName
));
writer
.
close
();
}
String
sourcePath
=
"src/test/resources/valid_tests"
;
CNNTrain2MxNet
trainGenerator
=
new
CNNTrain2MxNet
();
trainGenerator
.
generate
(
Paths
.
get
(
sourcePath
),
"FullConfig"
);
assertTrue
(
Log
.
getFindings
().
isEmpty
());
checkFilesAreEqual
(
Paths
.
get
(
"./target/generated-sources-cnnarch"
),
Paths
.
get
(
"./src/test/resources/target_code"
),
Arrays
.
asList
(
"CNNTrainer_
mainFull
.py"
));
"CNNTrainer_
fullConfig
.py"
));
}
@Test
public
void
testSimpleCfgGeneration
()
throws
IOException
{
Log
.
getFindings
().
clear
();
List
<
ConfigurationSymbol
>
configurations
=
new
ArrayList
<>();
List
<
String
>
instanceName
=
Arrays
.
asList
(
"main_net1"
,
"main_net2"
);
final
ModelPath
mp
=
new
ModelPath
(
Paths
.
get
(
"src/test/resources/valid_tests"
));
GlobalScope
scope
=
new
GlobalScope
(
mp
,
new
CNNTrainLanguage
());
CNNTrainCompilationUnitSymbol
compilationUnit
=
scope
.<
CNNTrainCompilationUnitSymbol
>
resolve
(
"SimpleConfig1"
,
CNNTrainCompilationUnitSymbol
.
KIND
).
get
();
CNNTrainCocos
.
checkAll
(
compilationUnit
);
configurations
.
add
(
compilationUnit
.
getConfiguration
());
Path
modelPath
=
Paths
.
get
(
"src/test/resources/valid_tests"
);
CNNTrain2MxNet
trainGenerator
=
new
CNNTrain2MxNet
();
compilationUnit
=
scope
.<
CNNTrainCompilationUnitSymbol
>
resolve
(
"SimpleConfig2"
,
CNNTrainCompilationUnitSymbol
.
KIND
).
get
();
CNNTrainCocos
.
checkAll
(
compilationUnit
);
configurations
.
add
(
compilationUnit
.
getConfiguration
());
CNNArch2MxNet
generator
=
new
CNNArch2MxNet
();
Map
<
String
,
String
>
trainerMap
=
generator
.
generateTrainer
(
configurations
,
instanceName
,
"mainSimple"
);
for
(
String
fileName
:
trainerMap
.
keySet
()){
FileWriter
writer
=
new
FileWriter
(
generator
.
getGenerationTargetPath
()
+
fileName
);
writer
.
write
(
trainerMap
.
get
(
fileName
));
writer
.
close
();
}
trainGenerator
.
generate
(
modelPath
,
"SimpleConfig"
);
assertTrue
(
Log
.
getFindings
().
isEmpty
());
checkFilesAreEqual
(
Paths
.
get
(
"./target/generated-sources-cnnarch"
),
Paths
.
get
(
"./src/test/resources/target_code"
),
Arrays
.
asList
(
"CNNTrainer_
mainSimple
.py"
));
"CNNTrainer_
simpleConfig
.py"
));
}
@Test
public
void
testEmptyCfgGeneration
()
throws
IOException
{
Log
.
getFindings
().
clear
();
List
<
ConfigurationSymbol
>
configurations
=
new
ArrayList
<>();
List
<
String
>
instanceName
=
Arrays
.
asList
(
"main_net1"
);
final
ModelPath
mp
=
new
ModelPath
(
Paths
.
get
(
"src/test/resources/valid_tests"
));
GlobalScope
scope
=
new
GlobalScope
(
mp
,
new
CNNTrainLanguage
());
CNNTrainCompilationUnitSymbol
compilationUnit
=
scope
.<
CNNTrainCompilationUnitSymbol
>
resolve
(
"EmptyConfig"
,
CNNTrainCompilationUnitSymbol
.
KIND
).
get
();
CNNTrainCocos
.
checkAll
(
compilationUnit
);
configurations
.
add
(
compilationUnit
.
getConfiguration
());
CNNArch2MxNet
generator
=
new
CNNArch2MxNet
();
Map
<
String
,
String
>
trainerMap
=
generator
.
generateTrainer
(
configurations
,
instanceName
,
"mainEmpty"
);
for
(
String
fileName
:
trainerMap
.
keySet
()){
FileWriter
writer
=
new
FileWriter
(
generator
.
getGenerationTargetPath
()
+
fileName
);
writer
.
write
(
trainerMap
.
get
(
fileName
));
writer
.
close
();
}
Path
modelPath
=
Paths
.
get
(
"src/test/resources/valid_tests"
);
CNNTrain2MxNet
trainGenerator
=
new
CNNTrain2MxNet
();
trainGenerator
.
generate
(
modelPath
,
"EmptyConfig"
);
assertTrue
(
Log
.
getFindings
().
isEmpty
());
checkFilesAreEqual
(
Paths
.
get
(
"./target/generated-sources-cnnarch"
),
Paths
.
get
(
"./src/test/resources/target_code"
),
Arrays
.
asList
(
"CNNTrainer_
mainEmpty
.py"
));
"CNNTrainer_
emptyConfig
.py"
));
}
...
...
src/test/resources/target_code/CMakeLists.txt
View file @
1802f7d0
...
...
@@ -12,6 +12,7 @@ set(INCLUDE_DIRS ${INCLUDE_DIRS} ${Armadillo_INCLUDE_DIRS})
set
(
LIBS
${
LIBS
}
${
Armadillo_LIBRARIES
}
)
# additional commands
set
(
LIBS
${
LIBS
}
mxnet
)
# create static library
include_directories
(
${
INCLUDE_DIRS
}
)
...
...
@@ -24,4 +25,3 @@ set_target_properties(alexnet PROPERTIES LINKER_LANGUAGE CXX)
export
(
TARGETS alexnet FILE alexnet.cmake
)
# additional commands end
set
(
LIBS
${
LIBS
}
mxnet
)
src/test/resources/target_code/CNNTrainer_
mainEmpty
.py
→
src/test/resources/target_code/CNNTrainer_
emptyConfig
.py
View file @
1802f7d0
import
logging
import
mxnet
as
mx
import
CNNCreator_
main_net1
import
CNNCreator_
emptyConfig
if
__name__
==
"__main__"
:
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
...
...
@@ -8,6 +8,6 @@ if __name__ == "__main__":
handler
=
logging
.
FileHandler
(
"train.log"
,
"w"
,
encoding
=
None
,
delay
=
"true"
)
logger
.
addHandler
(
handler
)
main_net1
=
CNNCreator_
main_net1
.
CNNCreator_
main_net1
()
main_net1
.
train
(
emptyConfig
=
CNNCreator_
emptyConfig
.
CNNCreator_
emptyConfig
()
emptyConfig
.
train
(
)
src/test/resources/target_code/CNNTrainer_
mainFull
.py
→
src/test/resources/target_code/CNNTrainer_
fullConfig
.py
View file @
1802f7d0
import
logging
import
mxnet
as
mx
import
CNNCreator_main_net1
import
CNNCreator_main_net2
import
CNNCreator_fullConfig
if
__name__
==
"__main__"
:
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
...
...
@@ -9,8 +8,8 @@ if __name__ == "__main__":
handler
=
logging
.
FileHandler
(
"train.log"
,
"w"
,
encoding
=
None
,
delay
=
"true"
)
logger
.
addHandler
(
handler
)
main_net1
=
CNNCreator_
main_net1
.
CNNCreator_
main_net1
()
main_net1
.
train
(
fullConfig
=
CNNCreator_
fullConfig
.
CNNCreator_
fullConfig
()
fullConfig
.
train
(
batch_size
=
100
,
num_epoch
=
5
,
load_checkpoint
=
True
,
...
...
@@ -33,25 +32,3 @@ if __name__ == "__main__":
'learning_rate'
:
0.001
,
'step_size'
:
1000
}
)
main_net2
=
CNNCreator_main_net2
.
CNNCreator_main_net2
()
main_net2
.
train
(
batch_size
=
100
,
num_epoch
=
10
,
load_checkpoint
=
False
,
context
=
'gpu'
,
normalize
=
False
,
eval_metric
=
'topKAccuracy'
,
optimizer
=
'adam'
,
optimizer_params
=
{
'epsilon'
:
1.0E-6
,
'weight_decay'
:
0.01
,
'rescale_grad'
:
1.1
,
'beta1'
:
0.9
,
'clip_gradient'
:
10.0
,
'beta2'
:
0.9
,
'learning_rate_minimum'
:
0.001
,
'learning_rate_policy'
:
'exp'
,
'learning_rate'
:
0.001
,
'learning_rate_decay'
:
0.9
,
'step_size'
:
1000
}
)
src/test/resources/target_code/CNNTrainer_main.py
deleted
100644 → 0
View file @
cfb99d88
import
logging
import
mxnet
as
mx
import
CNNCreator_main_net1
import
CNNCreator_main_net2
if
__name__
==
"__main__"
:
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
logger
=
logging
.
getLogger
()
handler
=
logging
.
FileHandler
(
"train.log"
,
"w"
,
encoding
=
None
,
delay
=
"true"
)
logger
.
addHandler
(
handler
)
main_net1
=
CNNCreator_main_net1
.
CNNCreator_main_net1
()
main_net1
.
train
(
batch_size
=
64
,
num_epoch
=
10
,
load_checkpoint
=
False
,
context
=
'gpu'
,
normalize
=
True
,
optimizer
=
'adam'
,
optimizer_params
=
{
'weight_decay'
:
1.0E-4
,
'learning_rate'
:
0.01
,
'learning_rate_decay'
:
0.8
,
'step_size'
:
1000
}
)
main_net2
=
CNNCreator_main_net2
.
CNNCreator_main_net2
()
main_net2
.
train
(
batch_size
=
32
,
num_epoch
=
10
,
load_checkpoint
=
False
,
context
=
'gpu'
,
normalize
=
True
,
optimizer
=
'adam'
,
optimizer_params
=
{
'weight_decay'
:
1.0E-4
,
'learning_rate'
:
0.01
,
'learning_rate_decay'
:
0.8
,
'step_size'
:
1000
}
)
src/test/resources/target_code/CNNTrainer_
mainSimple
.py
→
src/test/resources/target_code/CNNTrainer_
simpleConfig
.py
View file @
1802f7d0
import
logging
import
mxnet
as
mx
import
CNNCreator_main_net1
import
CNNCreator_main_net2
import
CNNCreator_simpleConfig
if
__name__
==
"__main__"
:
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
...
...
@@ -9,19 +8,11 @@ if __name__ == "__main__":
handler
=
logging
.
FileHandler
(
"train.log"
,
"w"
,
encoding
=
None
,
delay
=
"true"
)
logger
.
addHandler
(
handler
)
main_net1
=
CNNCreator_
main_net1
.
CNNCreator_main_net1
()
main_net1
.
train
(
simpleConfig
=
CNNCreator_
simpleConfig
.
CNNCreator_simpleConfig
()
simpleConfig
.
train
(
batch_size
=
100
,
num_epoch
=
50
,
optimizer
=
'adam'
,
optimizer_params
=
{
'learning_rate'
:
0.001
}
)
main_net2
=
CNNCreator_main_net2
.
CNNCreator_main_net2
()
main_net2
.
train
(
batch_size
=
100
,
num_epoch
=
5
,
optimizer
=
'sgd'
,
optimizer_params
=
{
'learning_rate'
:
0.1
}
)
src/test/resources/valid_tests/FullConfig2.cnnt
deleted
100644 → 0
View file @
cfb99d88
configuration FullConfig2{
num_epoch : 10
batch_size : 100
load_checkpoint : false
context : gpu
eval_metric : top_k_accuracy
normalize : false
optimizer : adam{
learning_rate : 0.001
learning_rate_minimum : 0.001
weight_decay : 0.01
learning_rate_decay : 0.9
learning_rate_policy : exp
step_size : 1000
rescale_grad : 1.1
clip_gradient : 10
beta1 : 0.9
beta2 : 0.9
epsilon : 0.000001
}
}
src/test/resources/valid_tests/Network1.cnnt
deleted
100644 → 0
View file @
cfb99d88
configuration Network1{
num_epoch:10
batch_size:64
normalize:true
context:gpu
load_checkpoint:false
optimizer:adam{
learning_rate:0.01
learning_rate_decay:0.8
step_size:1000
weight_decay:0.0001
}
}
src/test/resources/valid_tests/Network2.cnnt
deleted
100644 → 0
View file @
cfb99d88
configuration Network2{
num_epoch:10
batch_size:32
normalize:true
context:gpu
load_checkpoint:false