Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
What's new
7
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2Gluon
Commits
fba1a685
Commit
fba1a685
authored
Aug 21, 2019
by
Christian Fuß
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
changed .ftl templates to work with Unrolls
parent
c68409e0
Pipeline
#174526
failed with stages
in 2 minutes and 4 seconds
Changes
25
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
25 changed files
with
205 additions
and
94 deletions
+205
-94
_gitignore
_gitignore
+0
-10
_gitlab-ci.yml
_gitlab-ci.yml
+0
-50
_travis.yml
_travis.yml
+0
-5
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java
...narch/gluongenerator/CNNArch2GluonTemplateController.java
+0
-1
src/main/resources/templates/gluon/CNNCreator.ftl
src/main/resources/templates/gluon/CNNCreator.ftl
+16
-0
src/main/resources/templates/gluon/CNNNet.ftl
src/main/resources/templates/gluon/CNNNet.ftl
+16
-0
src/main/resources/templates/gluon/CNNPredictor.ftl
src/main/resources/templates/gluon/CNNPredictor.ftl
+108
-0
src/main/resources/templates/gluon/execute.ftl
src/main/resources/templates/gluon/execute.ftl
+8
-0
src/main/resources/templates/gluon/pythonExecute.ftl
src/main/resources/templates/gluon/pythonExecute.ftl
+8
-0
src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java
.../lang/monticar/cnnarch/gluongenerator/GenerationTest.java
+10
-22
src/test/resources/target_code/CNNCreator_Alexnet.py
src/test/resources/target_code/CNNCreator_Alexnet.py
+3
-0
src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py
...esources/target_code/CNNCreator_CifarClassifierNetwork.py
+3
-0
src/test/resources/target_code/CNNCreator_VGG16.py
src/test/resources/target_code/CNNCreator_VGG16.py
+3
-0
src/test/resources/target_code/CNNNet_Alexnet.py
src/test/resources/target_code/CNNNet_Alexnet.py
+2
-0
src/test/resources/target_code/CNNNet_CifarClassifierNetwork.py
...st/resources/target_code/CNNNet_CifarClassifierNetwork.py
+2
-0
src/test/resources/target_code/CNNNet_VGG16.py
src/test/resources/target_code/CNNNet_VGG16.py
+2
-0
src/test/resources/target_code/CNNPredictor_Alexnet.h
src/test/resources/target_code/CNNPredictor_Alexnet.h
+2
-0
src/test/resources/target_code/CNNPredictor_CifarClassifierNetwork.h
...sources/target_code/CNNPredictor_CifarClassifierNetwork.h
+2
-0
src/test/resources/target_code/CNNPredictor_VGG16.h
src/test/resources/target_code/CNNPredictor_VGG16.h
+2
-0
src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py
...est/resources/target_code/CNNSupervisedTrainer_Alexnet.py
+5
-2
src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
...arget_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
+5
-2
src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
+5
-2
src/test/resources/target_code/execute_Alexnet
src/test/resources/target_code/execute_Alexnet
+1
-0
src/test/resources/target_code/execute_CifarClassifierNetwork
...test/resources/target_code/execute_CifarClassifierNetwork
+1
-0
src/test/resources/target_code/execute_VGG16
src/test/resources/target_code/execute_VGG16
+1
-0
No files found.
_gitignore
deleted
100644 → 0
View file @
c68409e0
target
nppBackup
.project
.settings
.classpath
.idea
.git
*.iml
_gitlab-ci.yml
deleted
100644 → 0
View file @
c68409e0
#
#
# ******************************************************************************
# MontiCAR Modeling Family, www.se-rwth.de
# Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
# All rights reserved.
#
# This project is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3.0 of the License, or (at your option) any later version.
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this project. If not, see <http://www.gnu.org/licenses/>.
# *******************************************************************************
#
stages
:
-
windows
-
linux
masterJobLinux
:
stage
:
linux
image
:
maven:3-jdk-8
script
:
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml
-
cat target/site/jacoco/index.html
-
mvn package sonar:sonar -s settings.xml
only
:
-
master
masterJobWindows
:
stage
:
windows
script
:
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
tags
:
-
Windows10
BranchJobLinux
:
stage
:
linux
image
:
maven:3-jdk-8
script
:
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
-
cat target/site/jacoco/index.html
except
:
-
master
_travis.yml
deleted
100644 → 0
View file @
c68409e0
script
:
-
git checkout ${TRAVIS_BRANCH}
-
mvn clean install cobertura:cobertura org.eluder.coveralls:coveralls-maven-plugin:report --settings "settings.xml"
after_success
:
-
if [ "${TRAVIS_BRANCH}" == "master" ]; then mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B deploy --debug --settings "./settings.xml"; fi
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java
View file @
fba1a685
...
...
@@ -40,7 +40,6 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
public
void
include
(
String
relativePath
,
String
templateWithoutFileEnding
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
){
System
.
err
.
println
(
"include called. templateName: "
+
templateWithoutFileEnding
);
String
templatePath
=
relativePath
+
templateWithoutFileEnding
+
FTL_FILE_ENDING
;
Map
<
String
,
Object
>
ftlContext
=
new
HashMap
<>();
ftlContext
.
put
(
TEMPLATE_CONTROLLER_KEY
,
this
);
...
...
src/main/resources/templates/gluon/CNNCreator.ftl
View file @
fba1a685
import mxnet as mx
import logging
import os
<#list tc.architecture.streams as stream>
<#if stream.isTrainable()>
from CNNNet_${tc.fullArchitectureName} import Net_${stream?index}
</#if>
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.isTrainable()>
from CNNNet_${tc.fullArchitectureName} import Net_${unroll?index}
</#if>
</#list>
class ${tc.fileNameWithoutEnding}:
_model_dir_ = "model/${tc.componentName}/"
_model_prefix_ = "model"
...
...
@@ -58,6 +65,15 @@ class ${tc.fileNameWithoutEnding}:
self.networks[${stream?index}].hybridize()
self.networks[${stream?index}](<#list tc.getStreamInputDimensions(stream) as dimensions>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>)
</#if>
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.isTrainable()>
self.networks[${unroll?index}] = Net_${unroll?index}(data_mean=data_mean, data_std=data_std)
self.networks[${unroll?index}].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[${unroll?index}].hybridize()
self.networks[${unroll?index}](<#list tc.getUnrollInputDimensions(unroll) as dimensions>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>)
</#if>
</#list>
if not os.path.exists(self._model_dir_):
...
...
src/main/resources/templates/gluon/CNNNet.ftl
View file @
fba1a685
...
...
@@ -94,6 +94,22 @@ ${tc.include(stream, "FORWARD_FUNCTION")}
</#if>
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.isTrainable()>
class Net_${unroll?index}(gluon.HybridBlock):
def __init__(self, data_mean=None, data_std=None, **kwargs):
super(Net_${unroll?index}, self).__init__(**kwargs)
self.last_layers = {}
with self.name_scope():
${tc.include(unroll, "ARCHITECTURE_DEFINITION")}
def hybrid_forward(self, F, ${tc.join(tc.getUnrollInputNames(unroll), ", ")}):
${tc.include(unroll, "FORWARD_FUNCTION")}
return ${tc.join(tc.getUnrollOutputNames(unroll), ", ")}
</#if>
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.isTrainable()>
class Net_${unroll?index}(gluon.HybridBlock):
...
...
src/main/resources/templates/gluon/CNNPredictor.ftl
View file @
fba1a685
...
...
@@ -116,4 +116,112 @@ public:
</#if>
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.isTrainable()>
class ${tc.fileNameWithoutEnding}_${unroll?index}{
public:
const std::string json_file = "model/${tc.componentName}/model_${unroll?index}_newest-symbol.json";
const std::string param_file = "model/${tc.componentName}/model_${unroll?index}_newest-0000.params";
const std::vector<std::string> input_keys = {
<#if tc.getUnrollInputNames(unroll)?size == 1>
"data"
<#else>
<#list tc.getUnrollInputNames(unroll) as variable>"data${variable?index}"<#sep>, </#list>
</#if>
};
const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getUnrollInputDimensions(unroll) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>};
const bool use_gpu = false;
PredictorHandle handle;
explicit ${tc.fileNameWithoutEnding}_${unroll?index}(){
init(json_file, param_file, input_keys, input_shapes, use_gpu);
}
~${tc.fileNameWithoutEnding}_${unroll?index}(){
if(handle) MXPredFree(handle);
}
void predict(${tc.join(tc.getUnrollInputNames(unroll), ", ", "const std::vector<float> &in_", "")},
${tc.join(tc.getUnrollOutputNames(unroll), ", ", "std::vector<float> &out_", "")}){
<#list tc.getUnrollInputNames(unroll) as variable>
MXPredSetInput(handle, input_keys[${variable?index}].c_str(), in_${variable}.data(), static_cast<mx_uint>(in_${variable}.size()));
</#list>
MXPredForward(handle);
mx_uint output_index;
mx_uint *shape = 0;
mx_uint shape_len;
size_t size;
<#list tc.getUnrollOutputNames(unroll) as variable>
output_index = ${variable?index?c};
MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
size = 1;
for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
assert(size == out_${variable}.size());
MXPredGetOutput(handle, ${variable?index?c}, &(out_${variable}[0]), out_${variable}.size());
</#list>
}
void init(const std::string &json_file,
const std::string ¶m_file,
const std::vector<std::string> &input_keys,
const std::vector<std::vector<mx_uint>> &input_shapes,
const bool &use_gpu){
BufferFile json_data(json_file);
BufferFile param_data(param_file);
int dev_type = use_gpu ? 2 : 1;
int dev_id = 0;
if (json_data.GetLength() == 0 ||
param_data.GetLength() == 0) {
std::exit(-1);
}
const mx_uint num_input_nodes = input_keys.size();
const char* input_keys_ptr[num_input_nodes];
for(mx_uint i = 0; i < num_input_nodes; i++){
input_keys_ptr[i] = input_keys[i].c_str();
}
mx_uint shape_data_size = 0;
mx_uint input_shape_indptr[input_shapes.size() + 1];
input_shape_indptr[0] = 0;
for(mx_uint i = 0; i < input_shapes.size(); i++){
shape_data_size += input_shapes[i].size();
input_shape_indptr[i+1] = shape_data_size;
}
mx_uint input_shape_data[shape_data_size];
mx_uint index = 0;
for(mx_uint i = 0; i < input_shapes.size(); i++){
for(mx_uint j = 0; j < input_shapes[i].size(); j++){
input_shape_data[index] = input_shapes[i][j];
index++;
}
}
MXPredCreate(static_cast<const char*>(json_data.GetBuffer()),
static_cast<const char*>(param_data.GetBuffer()),
static_cast<size_t>(param_data.GetLength()),
dev_type,
dev_id,
num_input_nodes,
input_keys_ptr,
input_shape_indptr,
input_shape_data,
&handle);
assert(handle);
}
};
</#if>
</#list>
#endif // ${tc.fileNameWithoutEnding?upper_case}
src/main/resources/templates/gluon/execute.ftl
View file @
fba1a685
...
...
@@ -16,6 +16,14 @@ ${tc.include(stream, "CPP_INLINE")}
</#if>
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.isTrainable()>
_predictor_${unroll?index}_.predict(${tc.join(tc.getUnrollInputNames(unroll), ", ")}, ${tc.join(tc.getUnrollOutputNames(unroll), ", ")});
<#else>
${tc.include(unroll, "CPP_INLINE")}
</#if>
</#list>
<#list tc.architecture.outputs as output>
<#assign shape = output.ioDeclaration.type.dimensions>
<#if shape?size == 1>
...
...
src/main/resources/templates/gluon/pythonExecute.ftl
View file @
fba1a685
...
...
@@ -11,4 +11,12 @@
<#else>
${tc.include(stream, "PYTHON_INLINE")}
</#if>
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.isTrainable()>
${tc.join(tc.getUnrollOutputNames(unroll), ", ")} = self._networks[${unroll?index}](${tc.join(tc.getUnrollInputNames(unroll), ", ")})
<#else>
${tc.include(unroll, "PYTHON_INLINE")}
</#if>
</#list>
\ No newline at end of file
src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java
View file @
fba1a685
...
...
@@ -80,27 +80,7 @@ public class GenerationTest extends AbstractSymtabTest {
CNNArch2GluonCli
.
main
(
args
);
assertTrue
(
Log
.
getFindings
().
isEmpty
());
/*checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNCreator_Alexnet.py",
"CNNNet_Alexnet.py",
"CNNDataLoader_Alexnet.py",
"CNNSupervisedTrainer_Alexnet.py",
"CNNPredictor_Alexnet.h",
"execute_Alexnet"));*/
}
@Test
public
void
testRNNencdecGeneration
()
throws
IOException
,
TemplateException
{
Log
.
getFindings
().
clear
();
String
[]
args
=
{
"-m"
,
"src/test/resources/valid_tests"
,
"-r"
,
"RNNencdec"
,
"-o"
,
"./target/generated-sources-cnnarch/"
};
CNNArch2GluonCli
.
main
(
args
);
// assertTrue(Log.getFindings().isEmpty());
/*checkFilesAreEqual(
checkFilesAreEqual
(
Paths
.
get
(
"./target/generated-sources-cnnarch"
),
Paths
.
get
(
"./src/test/resources/target_code"
),
Arrays
.
asList
(
...
...
@@ -109,7 +89,7 @@ public class GenerationTest extends AbstractSymtabTest {
"CNNDataLoader_Alexnet.py"
,
"CNNSupervisedTrainer_Alexnet.py"
,
"CNNPredictor_Alexnet.h"
,
"execute_Alexnet"));
*/
"execute_Alexnet"
));
}
@Test
...
...
@@ -171,6 +151,14 @@ public class GenerationTest extends AbstractSymtabTest {
assertTrue
(
Log
.
getFindings
().
isEmpty
());
}
@Test
public
void
testRNNencdec
()
throws
IOException
,
TemplateException
{
Log
.
getFindings
().
clear
();
String
[]
args
=
{
"-m"
,
"src/test/resources/valid_tests"
,
"-r"
,
"RNNencdec"
,
"-o"
,
"./target/generated-sources-cnnarch/"
};
CNNArch2GluonCli
.
main
(
args
);
assertTrue
(
Log
.
getFindings
().
isEmpty
());
}
@Test
public
void
testFullCfgGeneration
()
throws
IOException
,
TemplateException
{
Log
.
getFindings
().
clear
();
...
...
src/test/resources/target_code/CNNCreator_Alexnet.py
View file @
fba1a685
import
mxnet
as
mx
import
logging
import
os
from
CNNNet_Alexnet
import
Net_0
class
CNNCreator_Alexnet
:
_model_dir_
=
"model/Alexnet/"
_model_prefix_
=
"model"
...
...
@@ -52,6 +54,7 @@ class CNNCreator_Alexnet:
self
.
networks
[
0
].
hybridize
()
self
.
networks
[
0
](
mx
.
nd
.
zeros
((
1
,
3
,
224
,
224
,),
ctx
=
context
))
if
not
os
.
path
.
exists
(
self
.
_model_dir_
):
os
.
makedirs
(
self
.
_model_dir_
)
...
...
src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py
View file @
fba1a685
import
mxnet
as
mx
import
logging
import
os
from
CNNNet_CifarClassifierNetwork
import
Net_0
class
CNNCreator_CifarClassifierNetwork
:
_model_dir_
=
"model/CifarClassifierNetwork/"
_model_prefix_
=
"model"
...
...
@@ -52,6 +54,7 @@ class CNNCreator_CifarClassifierNetwork:
self
.
networks
[
0
].
hybridize
()
self
.
networks
[
0
](
mx
.
nd
.
zeros
((
1
,
3
,
32
,
32
,),
ctx
=
context
))
if
not
os
.
path
.
exists
(
self
.
_model_dir_
):
os
.
makedirs
(
self
.
_model_dir_
)
...
...
src/test/resources/target_code/CNNCreator_VGG16.py
View file @
fba1a685
import
mxnet
as
mx
import
logging
import
os
from
CNNNet_VGG16
import
Net_0
class
CNNCreator_VGG16
:
_model_dir_
=
"model/VGG16/"
_model_prefix_
=
"model"
...
...
@@ -52,6 +54,7 @@ class CNNCreator_VGG16:
self
.
networks
[
0
].
hybridize
()
self
.
networks
[
0
](
mx
.
nd
.
zeros
((
1
,
3
,
224
,
224
,),
ctx
=
context
))
if
not
os
.
path
.
exists
(
self
.
_model_dir_
):
os
.
makedirs
(
self
.
_model_dir_
)
...
...
src/test/resources/target_code/CNNNet_Alexnet.py
View file @
fba1a685
...
...
@@ -275,3 +275,5 @@ class Net_0(gluon.HybridBlock):
return
predictions_
src/test/resources/target_code/CNNNet_CifarClassifierNetwork.py
View file @
fba1a685
...
...
@@ -468,3 +468,5 @@ class Net_0(gluon.HybridBlock):
return
softmax_
src/test/resources/target_code/CNNNet_VGG16.py
View file @
fba1a685
...
...
@@ -295,3 +295,5 @@ class Net_0(gluon.HybridBlock):
return
predictions_
src/test/resources/target_code/CNNPredictor_Alexnet.h
View file @
fba1a685
...
...
@@ -104,4 +104,6 @@ public:
}
};
#endif // CNNPREDICTOR_ALEXNET
src/test/resources/target_code/CNNPredictor_CifarClassifierNetwork.h
View file @
fba1a685
...
...
@@ -104,4 +104,6 @@ public:
}
};
#endif // CNNPREDICTOR_CIFARCLASSIFIERNETWORK
src/test/resources/target_code/CNNPredictor_VGG16.h
View file @
fba1a685
...
...
@@ -104,4 +104,6 @@ public:
}
};
#endif // CNNPREDICTOR_VGG16
src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py
View file @
fba1a685
...
...
@@ -140,6 +140,7 @@ class CNNSupervisedTrainer_Alexnet:
predictions_
=
self
.
_networks
[
0
](
data_
)
loss
=
\
loss_function
(
predictions_
,
predictions_label
)
...
...
@@ -172,11 +173,12 @@ class CNNSupervisedTrainer_Alexnet:
batch
.
label
[
0
].
as_in_context
(
mx_context
)
]
if
True
:
if
True
:
predictions_
=
mx
.
nd
.
zeros
((
10
,),
ctx
=
mx_context
)
predictions_
=
self
.
_networks
[
0
](
data_
)
predictions
=
[
mx
.
nd
.
argmax
(
predictions_
,
axis
=
1
)
]
...
...
@@ -193,11 +195,12 @@ class CNNSupervisedTrainer_Alexnet:
batch
.
label
[
0
].
as_in_context
(
mx_context
)
]
if
True
:
if
True
:
predictions_
=
mx
.
nd
.
zeros
((
10
,),
ctx
=
mx_context
)
predictions_
=
self
.
_networks
[
0
](
data_
)
predictions
=
[
mx
.
nd
.
argmax
(
predictions_
,
axis
=
1
)
]
...
...
src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
View file @
fba1a685
...
...
@@ -140,6 +140,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
softmax_
=
self
.
_networks
[
0
](
data_
)
loss
=
\
loss_function
(
softmax_
,
softmax_label
)
...
...
@@ -172,11 +173,12 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
batch
.
label
[
0
].
as_in_context
(
mx_context
)
]
if
True
:
if
True
:
softmax_
=
mx
.
nd
.
zeros
((
10
,),
ctx
=
mx_context
)
softmax_
=
self
.
_networks
[
0
](
data_
)
predictions
=
[
mx
.
nd
.
argmax
(
softmax_
,
axis
=
1
)
]
...
...
@@ -193,11 +195,12 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
batch
.
label
[
0
].
as_in_context
(
mx_context
)
]
if
True
:
if
True
:
softmax_
=
mx
.
nd
.
zeros
((
10
,),
ctx
=
mx_context
)
softmax_
=
self
.
_networks
[
0
](
data_
)
predictions
=
[
mx
.
nd
.
argmax
(
softmax_
,
axis
=
1
)
]
...
...
src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
View file @
fba1a685
...
...
@@ -140,6 +140,7 @@ class CNNSupervisedTrainer_VGG16:
predictions_
=
self
.
_networks
[
0
](
data_
)
loss
=
\
loss_function
(
predictions_
,
predictions_label
)
...
...
@@ -172,11 +173,12 @@ class CNNSupervisedTrainer_VGG16:
batch
.
label
[
0
].
as_in_context
(
mx_context
)
]
if
True
:
if
True
:
predictions_
=
mx
.
nd
.
zeros
((
1000
,),
ctx
=
mx_context
)
predictions_
=
self
.
_networks
[
0
](
data_
)
predictions
=
[
mx
.
nd
.
argmax
(
predictions_
,
axis
=
1
)
]
...
...
@@ -193,11 +195,12 @@ class CNNSupervisedTrainer_VGG16:
batch
.
label
[
0
].
as_in_context
(
mx_context
)
]
if
True
:
if
True
:
predictions_
=
mx
.
nd
.
zeros
((
1000
,),
ctx
=
mx_context
)
predictions_
=
self
.
_networks
[
0
](
data_
)
predictions
=
[
mx
.
nd
.
argmax
(
predictions_
,
axis
=
1
)
]
...
...
src/test/resources/target_code/execute_Alexnet
View file @
fba1a685
...
...
@@ -3,4 +3,5 @@
_predictor_0_.predict(data_, predictions_);
predictions = CNNTranslator::translateToCol(predictions_, std::vector<size_t> {10});
src/test/resources/target_code/execute_CifarClassifierNetwork
View file @
fba1a685
...
...
@@ -3,4 +3,5 @@
_predictor_0_.predict(data_, softmax_);
softmax = CNNTranslator::translateToCol(softmax_, std::vector<size_t> {10});
src/test/resources/target_code/execute_VGG16
View file @
fba1a685
...
...
@@ -3,4 +3,5 @@
_predictor_0_.predict(data_, predictions_);
predictions = CNNTranslator::translateToCol(predictions_, std::vector<size_t> {1000});
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment