Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2Gluon
Commits
a5e9d403
Commit
a5e9d403
authored
Aug 17, 2019
by
Sebastian Nickels
Browse files
Outputs now can be used as inputs
parent
94d3b4a8
Changes
7
Hide whitespace changes
Inline
Side-by-side
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonArchitectureSupportChecker.java
View file @
a5e9d403
...
...
@@ -37,4 +37,9 @@ public class CNNArch2GluonArchitectureSupportChecker extends ArchitectureSupport
return
true
;
}
@Override
protected
boolean
checkOutputAsInput
(
ArchitectureSymbol
architecture
)
{
return
true
;
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java
View file @
a5e9d403
...
...
@@ -173,7 +173,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
Map
<
String
,
List
<
String
>>
inputs
=
new
LinkedHashMap
<>();
for
(
ArchitectureElementSymbol
element
:
stream
.
getFirstAtomicElements
())
{
if
(
element
.
isInput
())
{
if
(
element
.
isInput
()
||
element
.
isOutput
()
)
{
List
<
Integer
>
intDimensions
=
element
.
getOutputTypes
().
get
(
0
).
getDimensions
();
List
<
String
>
dimensions
=
new
ArrayList
<>();
...
...
src/main/resources/templates/gluon/elements/Output.ftl
View file @
a5e9d403
<#
if
element.inputs?size gte 1>
<#
assign
input = element.inputs[0]>
<#
if
mode == "FORWARD_FUNCTION">
$
{
element
.name
}
= $
{
input
}
...
...
@@ -6,3 +7,4 @@
<#
elseif
mode == "CPP_INLINE">
$
{
element
.name
}
= $
{
input
}
;
</#
if
>
</#
if
>
\ No newline at end of file
src/main/resources/templates/gluon/pythonExecute.ftl
View file @
a5e9d403
<#
list
tc.getLayerVariableMembers("batch_size")?keys as member>
$
{
member
}
= mx.nd.zeroes(($
{
tc
.join
(
tc
.getLayerVariableMembers
(
"batch_size"
)[
member
]
, ", ")
}
,), ctx=mx_context)
</#
list
>
<#
list
tc.architecture.outputs as output>
$
{
tc
.getName
(
output
)}
= mx.nd.zeroes((($
{
tc
.join
(
output
.ioDeclaration.type.dimensions
,
", "
)}
,), ctx=mx_context)
</#
list
>
<#
list
tc.architecture.streams as stream>
<#
if
stream.isTrainable()>
...
...
src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py
View file @
a5e9d403
...
...
@@ -136,6 +136,7 @@ class CNNSupervisedTrainer_Alexnet:
predictions_label
=
batch
.
label
[
0
].
as_in_context
(
mx_context
)
with
autograd
.
record
():
predictions_
=
mx
.
nd
.
zeroes
(((
10
,),
ctx
=
mx_context
)
predictions_
=
self
.
_networks
[
0
](
data_
)
...
...
@@ -172,6 +173,7 @@ class CNNSupervisedTrainer_Alexnet:
]
if
True
:
predictions_
=
mx
.
nd
.
zeroes
(((
10
,),
ctx
=
mx_context
)
predictions_
=
self
.
_networks
[
0
](
data_
)
...
...
@@ -192,6 +194,7 @@ class CNNSupervisedTrainer_Alexnet:
]
if
True
:
predictions_
=
mx
.
nd
.
zeroes
(((
10
,),
ctx
=
mx_context
)
predictions_
=
self
.
_networks
[
0
](
data_
)
...
...
src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
View file @
a5e9d403
...
...
@@ -136,6 +136,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
softmax_label
=
batch
.
label
[
0
].
as_in_context
(
mx_context
)
with
autograd
.
record
():
softmax_
=
mx
.
nd
.
zeroes
(((
10
,),
ctx
=
mx_context
)
softmax_
=
self
.
_networks
[
0
](
data_
)
...
...
@@ -172,6 +173,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
]
if
True
:
softmax_
=
mx
.
nd
.
zeroes
(((
10
,),
ctx
=
mx_context
)
softmax_
=
self
.
_networks
[
0
](
data_
)
...
...
@@ -192,6 +194,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
]
if
True
:
softmax_
=
mx
.
nd
.
zeroes
(((
10
,),
ctx
=
mx_context
)
softmax_
=
self
.
_networks
[
0
](
data_
)
...
...
src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
View file @
a5e9d403
...
...
@@ -136,6 +136,7 @@ class CNNSupervisedTrainer_VGG16:
predictions_label
=
batch
.
label
[
0
].
as_in_context
(
mx_context
)
with
autograd
.
record
():
predictions_
=
mx
.
nd
.
zeroes
(((
1000
,),
ctx
=
mx_context
)
predictions_
=
self
.
_networks
[
0
](
data_
)
...
...
@@ -172,6 +173,7 @@ class CNNSupervisedTrainer_VGG16:
]
if
True
:
predictions_
=
mx
.
nd
.
zeroes
(((
1000
,),
ctx
=
mx_context
)
predictions_
=
self
.
_networks
[
0
](
data_
)
...
...
@@ -192,6 +194,7 @@ class CNNSupervisedTrainer_VGG16:
]
if
True
:
predictions_
=
mx
.
nd
.
zeroes
(((
1000
,),
ctx
=
mx_context
)
predictions_
=
self
.
_networks
[
0
](
data_
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment