Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2Gluon
Commits
604bf5b8
Commit
604bf5b8
authored
Sep 05, 2019
by
Sebastian Nickels
Browse files
Temporarily removed some code
parent
15a612d3
Pipeline
#180355
failed with stages
in 60 minutes
Changes
10
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java
View file @
604bf5b8
...
...
@@ -28,6 +28,7 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import
de.monticore.lang.monticar.generator.FileContent
;
import
de.monticore.lang.monticar.generator.cmake.CMakeConfig
;
import
de.monticore.lang.monticar.generator.cmake.CMakeFindModule
;
import
de.se_rwth.commons.logging.Log
;
import
java.util.*
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonLayerSupportChecker.java
View file @
604bf5b8
...
...
@@ -23,7 +23,6 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
supportedLayerList
.
add
(
AllPredefinedLayers
.
CONCATENATE_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
FLATTEN_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
ONE_HOT_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
BEAMSEARCH_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
RNN_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
LSTM_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
GRU_NAME
);
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java
View file @
604bf5b8
...
...
@@ -20,13 +20,11 @@
*/
package
de.monticore.lang.monticar.cnnarch.gluongenerator
;
import
de.monticore.lang.monticar.cnnarch._ast.ASTStream
;
import
de.monticore.lang.monticar.cnnarch.generator.ArchitectureElementData
;
import
de.monticore.lang.monticar.cnnarch.generator.CNNArchTemplateController
;
import
de.monticore.lang.monticar.cnnarch._symboltable.*
;
import
de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration
;
import
de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers
;
import
java.io.Writer
;
import
java.util.*
;
...
...
@@ -48,7 +46,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
getTemplateConfiguration
().
processTemplate
(
ftlContext
,
templatePath
,
writer
);
}
public
void
include
(
VariableSymbol
element
,
boolean
partOfUnroll
,
int
unrollIndex
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
){
public
void
include
(
VariableSymbol
element
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
){
ArchitectureElementData
previousElement
=
getCurrentElement
();
setCurrentElement
(
element
);
...
...
@@ -66,13 +64,13 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
}
else
{
include
((
ArchitectureElementSymbol
)
element
.
getResolvedThis
().
get
(),
partOfUnroll
,
unrollIndex
,
writer
,
netDefinitionMode
);
include
((
ArchitectureElementSymbol
)
element
.
getResolvedThis
().
get
(),
writer
,
netDefinitionMode
);
}
setCurrentElement
(
previousElement
);
}
public
void
include
(
ConstantSymbol
constant
,
boolean
partOfUnroll
,
int
unrollIndex
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
)
{
public
void
include
(
ConstantSymbol
constant
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
)
{
ArchitectureElementData
previousElement
=
getCurrentElement
();
setCurrentElement
(
constant
);
...
...
@@ -80,106 +78,72 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include
(
TEMPLATE_ELEMENTS_DIR_PATH
,
"Const"
,
writer
,
netDefinitionMode
);
}
else
{
include
((
ArchitectureElementSymbol
)
constant
.
getResolvedThis
().
get
(),
partOfUnroll
,
unrollIndex
,
writer
,
netDefinitionMode
);
include
((
ArchitectureElementSymbol
)
constant
.
getResolvedThis
().
get
(),
writer
,
netDefinitionMode
);
}
setCurrentElement
(
previousElement
);
}
public
void
include
(
LayerSymbol
layer
,
boolean
partOfUnroll
,
int
unrollIndex
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
){
public
void
include
(
LayerSymbol
layer
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
){
ArchitectureElementData
previousElement
=
getCurrentElement
();
setCurrentElement
(
layer
);
getCurrentElement
().
setPartOfUnroll
(
partOfUnroll
);
getCurrentElement
().
setUnrollIndex
(
unrollIndex
);
if
(
layer
.
isAtomic
()){
String
templateName
=
layer
.
getDeclaration
().
getName
();
include
(
TEMPLATE_ELEMENTS_DIR_PATH
,
templateName
,
writer
,
netDefinitionMode
);
}
else
{
include
((
ArchitectureElementSymbol
)
layer
.
getResolvedThis
().
get
(),
partOfUnroll
,
unrollIndex
,
writer
,
netDefinitionMode
);
include
((
ArchitectureElementSymbol
)
layer
.
getResolvedThis
().
get
(),
writer
,
netDefinitionMode
);
}
setCurrentElement
(
previousElement
);
}
public
void
include
(
UnrollSymbol
unrollElement
,
boolean
partOfUnroll
,
int
unrollIndex
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
){
include
(
unrollElement
.
getBody
(),
partOfUnroll
,
unrollIndex
,
writer
,
netDefinitionMode
);
String
templateName
=
unrollElement
.
getDeclaration
().
getName
();
include
(
TEMPLATE_ELEMENTS_DIR_PATH
,
templateName
,
writer
,
netDefinitionMode
);
}
public
void
include
(
CompositeElementSymbol
compositeElement
,
boolean
partOfUnroll
,
int
unrollIndex
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
){
public
void
include
(
CompositeElementSymbol
compositeElement
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
){
ArchitectureElementData
previousElement
=
getCurrentElement
();
setCurrentElement
(
compositeElement
);
for
(
ArchitectureElementSymbol
element
:
compositeElement
.
getElements
()){
include
(
element
,
partOfUnroll
,
unrollIndex
,
writer
,
netDefinitionMode
);
include
(
element
,
writer
,
netDefinitionMode
);
}
setCurrentElement
(
previousElement
);
}
public
void
include
(
ArchitectureElementSymbol
architectureElement
,
boolean
partOfUnroll
,
int
unrollIndex
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
){
public
void
include
(
ArchitectureElementSymbol
architectureElement
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
){
if
(
architectureElement
instanceof
CompositeElementSymbol
){
include
((
CompositeElementSymbol
)
architectureElement
,
partOfUnroll
,
unrollIndex
,
writer
,
netDefinitionMode
);
include
((
CompositeElementSymbol
)
architectureElement
,
writer
,
netDefinitionMode
);
}
else
if
(
architectureElement
instanceof
LayerSymbol
){
include
((
LayerSymbol
)
architectureElement
,
partOfUnroll
,
unrollIndex
,
writer
,
netDefinitionMode
);
include
((
LayerSymbol
)
architectureElement
,
writer
,
netDefinitionMode
);
}
else
if
(
architectureElement
instanceof
ConstantSymbol
)
{
include
((
ConstantSymbol
)
architectureElement
,
partOfUnroll
,
unrollIndex
,
writer
,
netDefinitionMode
);
include
((
ConstantSymbol
)
architectureElement
,
writer
,
netDefinitionMode
);
}
else
{
include
((
VariableSymbol
)
architectureElement
,
partOfUnroll
,
unrollIndex
,
writer
,
netDefinitionMode
);
include
((
VariableSymbol
)
architectureElement
,
writer
,
netDefinitionMode
);
}
}
public
void
include
(
ArchitectureElementSymbol
architectureElementSymbol
,
String
netDefinitionMode
)
{
include
(
architectureElementSymbol
,
false
,
-
1
,
NetDefinitionMode
.
fromString
(
netDefinitionMode
));
}
public
void
include
(
ArchitectureElementSymbol
architectureElementSymbol
,
boolean
partOfUnroll
,
int
unrollIndex
,
String
netDefinitionMode
)
{
int
layerIndex
=
-
1
;
include
(
architectureElementSymbol
,
partOfUnroll
,
unrollIndex
,
NetDefinitionMode
.
fromString
(
netDefinitionMode
));
}
public
void
include
(
UnrollSymbol
unrollSymbol
,
boolean
partOfUnroll
,
int
unrollIndex
,
String
netDefinitionMode
)
{
include
(
unrollSymbol
,
partOfUnroll
,
unrollIndex
,
NetDefinitionMode
.
fromString
(
netDefinitionMode
));
}
public
void
include
(
ArchitectureElementSymbol
architectureElement
,
boolean
partOfUnroll
,
int
unrollIndex
,
NetDefinitionMode
netDefinitionMode
){
if
(
getWriter
()
==
null
){
throw
new
IllegalStateException
(
"missing writer"
);
}
include
(
architectureElement
,
partOfUnroll
,
unrollIndex
,
getWriter
(),
netDefinitionMode
);
include
(
architectureElementSymbol
,
NetDefinitionMode
.
fromString
(
netDefinitionMode
));
}
public
void
include
(
UnrollSymbol
unroll
,
boolean
partOfUnroll
,
int
unrollIndex
,
NetDefinitionMode
netDefinitionMode
){
public
void
include
(
ArchitectureElementSymbol
architectureElement
,
NetDefinitionMode
netDefinitionMode
){
if
(
getWriter
()
==
null
){
throw
new
IllegalStateException
(
"missing writer"
);
}
include
(
unroll
,
partOfUnroll
,
unrollIndex
,
getWriter
(),
netDefinitionMode
);
include
(
architectureElement
,
getWriter
(),
netDefinitionMode
);
}
public
Set
<
String
>
getStreamInputNames
(
SerialCompositeElementSymbol
stream
)
{
return
getStreamInputs
(
stream
).
keySet
();
}
public
Set
<
String
>
getUnrollInputNames
(
UnrollSymbol
unroll
)
{
return
getUnrollInputs
(
unroll
).
keySet
();
}
public
Collection
<
List
<
String
>>
getStreamInputDimensions
(
SerialCompositeElementSymbol
stream
)
{
return
getStreamInputs
(
stream
).
values
();
}
public
Collection
<
List
<
String
>>
getUnrollInputDimensions
(
UnrollSymbol
unroll
)
{
return
getUnrollInputs
(
unroll
).
values
();
}
public
Set
<
String
>
getStreamOutputNames
(
SerialCompositeElementSymbol
stream
)
{
Set
<
String
>
outputNames
=
new
LinkedHashSet
<>();
...
...
@@ -194,20 +158,6 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return
outputNames
;
}
public
Set
<
String
>
getUnrollOutputNames
(
UnrollSymbol
unroll
)
{
Set
<
String
>
outputNames
=
new
LinkedHashSet
<>();
for
(
ArchitectureElementSymbol
element
:
unroll
.
getBody
().
getElements
())
{
if
(
element
.
isOutput
())
{
outputNames
.
add
(
getName
(
element
));
}
}
outputNames
.
addAll
(
getStreamLayerVariableMembers
(
unroll
.
getBody
(),
"1"
,
true
).
keySet
());
return
outputNames
;
}
// Used to initialize all layer variable members which are passed through the networks
public
Map
<
String
,
List
<
String
>>
getLayerVariableMembers
(
String
batchSize
)
{
Map
<
String
,
List
<
String
>>
members
=
new
LinkedHashMap
<>();
...
...
@@ -243,30 +193,6 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return
inputs
;
}
private
Map
<
String
,
List
<
String
>>
getUnrollInputs
(
UnrollSymbol
unroll
)
{
Map
<
String
,
List
<
String
>>
inputs
=
new
LinkedHashMap
<>();
for
(
ArchitectureElementSymbol
element
:
unroll
.
getBody
().
getFirstAtomicElements
())
{
if
(
element
.
isInput
()
||
element
.
isOutput
())
{
List
<
Integer
>
intDimensions
=
element
.
getOutputTypes
().
get
(
0
).
getDimensions
();
List
<
String
>
dimensions
=
new
ArrayList
<>();
for
(
Integer
intDimension
:
intDimensions
)
{
dimensions
.
add
(
intDimension
.
toString
());
}
// Add batch size dimension
dimensions
.
add
(
0
,
"1"
);
inputs
.
put
(
getName
(
element
),
dimensions
);
}
}
inputs
.
putAll
(
getStreamLayerVariableMembers
(
unroll
.
getBody
(),
"1"
,
false
));
return
inputs
;
}
private
Map
<
String
,
List
<
String
>>
getStreamLayerVariableMembers
(
SerialCompositeElementSymbol
stream
,
String
batchSize
,
boolean
includeOutput
)
{
Map
<
String
,
List
<
String
>>
members
=
new
HashMap
<>();
...
...
src/main/resources/templates/gluon/CNNNet.ftl
View file @
604bf5b8
...
...
@@ -94,24 +94,19 @@ ${tc.include(stream, "FORWARD_FUNCTION")}
</#
if
>
</#
list
>
<#
list
tc.architecture.unrolls as unroll>
<#
list
unroll.getBodiesForAllTimesteps() as body>
<#
if
body.isTrainable()>
<#
if
body?index == 0>
<#
assign
partOfUnroll = false>
<#
else
>
<#
assign
partOfUnroll = true>
</#
if
>
class
Net_$
{
tc
.architecture.streams
?
size
+
body
?
index
}
(gluon.HybridBlock):
<#
if
unroll.body.isTrainable()>
class
Net_$
{
unroll
?
index
}
(gluon.HybridBlock):
def __init__(self, data_mean=None, data_std=None,
**
kwargs):
super(Net_$
{
tc
.architecture.streams
?
size
+
body
?
index
}
, self).__init__(
**
kwargs)
super(Net_$
{
unroll
?
index
}
, self).__init__(
**
kwargs)
self.last_layers =
{}
with self.name_scope():
${
tc
.include
(
body
,
partOfUnroll, unroll?index, "ARCHITECTURE_DEFINITION")
}
${
tc
.include
(
unroll
.body
,
"ARCHITECTURE_DEFINITION")
}
def hybrid_forward(self, F, $
{
tc
.join
(
tc
.getStreamInputNames
(
unroll
.body
),
", "
)}
):
${
tc
.include
(
unroll
.body
,
"FORWARD_FUNCTION")
}
return $
{
tc
.join
(
tc
.getStreamOutputNames
(
unroll
.body
),
", "
)}
def hybrid_forward(self, F, $
{
tc
.join
(
tc
.getStreamInputNames
(
body
),
", "
)}
):
${
tc
.include
(
body
,
partOfUnroll, unroll?index, "FORWARD_FUNCTION")
}
return $
{
tc
.join
(
tc
.getStreamOutputNames
(
body
),
", "
)}
</#
if
>
</#
list
>
</#
list
>
src/main/resources/templates/gluon/CNNPredictor.ftl
View file @
604bf5b8
...
...
@@ -116,114 +116,4 @@ public:
</#
if
>
</#
list
>
<#
list
tc.architecture.unrolls as unroll>
<#
list
unroll.getBodiesForAllTimesteps() as body>
<#
if
body.isTrainable()>
class
$
{
tc
.fileNameWithoutEnding
}
_$
{
tc
.architecture.streams
?
size
+
body
?
index
}{
public
:
const
std
::
string
json_file
=
"model/${tc.componentName}/model_${tc.architecture.streams?size + body?index}_newest-symbol.json"
;
const
std
::
string
param_file
=
"model/${tc.componentName}/model_${tc.architecture.streams?size + body?index}_newest-0000.params"
;
const
std
::
vector
<
std
::
string
>
input_keys
=
{
<#
if
tc
.getStreamInputNames
(
body
)
?
size
==
1
>
"data"
<#
else
>
<#
list
tc
.getStreamInputNames
(
body
)
as
variable
>
"data${variable?index}"
<#
sep
>
,
</#
list
>
</#
if
>
}
;
const
std
::
vector
<
std
::
vector
<
mx_uint
>>
input_shapes
=
{
<#
list
tc
.getStreamInputDimensions
(
body
)
as
dimensions
>
{
$
{
tc
.join
(
dimensions
,
", "
)}}
<#
sep
>
,
</#
list
>
}
;
const
bool
use_gpu
=
false
;
P
redictorHandle
handle
;
explicit
$
{
tc
.fileNameWithoutEnding
}
_$
{
tc
.architecture.streams
?
size
+
body
?
index
}(){
init
(
json_file
,
param_file
,
input_keys
,
input_shapes
,
use_gpu
)
;
}
~$
{
tc
.fileNameWithoutEnding
}
_$
{
tc
.architecture.streams
?
size
+
body
?
index
}(){
if
(
handle
)
MXP
redFree
(
handle
)
;
}
void
predict
(
$
{
tc
.join
(
tc
.getStreamInputNames
(
body
),
", "
,
"const std::vector<float> &in_"
,
""
)},
$
{
tc
.join
(
tc
.getStreamOutputNames
(
body
),
", "
,
"std::vector<float> &out_"
,
""
)}){
<#
list
tc
.getStreamInputNames
(
body
)
as
variable
>
MXP
redSetInput
(
handle
,
input_keys
[$
{
variable
?
index
}
]
.c_str
(),
in_
$
{
variable
}
.data
(),
static_cast
<
mx_uint
>
(
in_
$
{
variable
}
.size
()))
;
</#
list
>
MXP
redForward
(
handle
)
;
mx_uint
output_index
;
mx_uint
*
shape
=
0
;
mx_uint
shape_len
;
size_t
size
;
<#
list
tc
.getStreamOutputNames
(
body
)
as
variable
>
output_index
=
$
{
variable
?
index
?
c
}
;
MXP
redGetOutputShape
(
handle
,
output_index
,
&
shape
,
&
shape_len
)
;
size
=
1
;
for
(
mx_uint
i
=
0
;
i
<
shape_len
;
++
i
)
size
*
=
shape
[
i
];
assert
(
size
==
out_
$
{
variable
}
.size
())
;
MXP
redGetOutput
(
handle
,
$
{
variable
?
index
?
c
},
&
(
out_
$
{
variable
}[
0
]
), out_$
{
variable
}
.size());
</#
list
>
}
void
init
(
const
std
::
string
&
json_file
,
const
std
::
string
&
param_file
,
const
std
::
vector
<
std
::
string
>
&
input_keys
,
const
std
::
vector
<
std
::
vector
<
mx_uint
>>
&
input_shapes
,
const
bool
&
use_gpu
){
B
ufferFile
json_data
(
json_file
)
;
B
ufferFile
param_data
(
param_file
)
;
int
dev_type
=
use_gpu
?
2
:
1
;
int
dev_id
=
0
;
if
(
json_data
.GetLength
()
==
0
||
param_data
.GetLength
()
==
0
)
{
std
::
exit
(
-1
)
;
}
const
mx_uint
num_input_nodes
=
input_keys
.size
()
;
const
char
*
input_keys_ptr
[
num_input_nodes
]
;
for
(
mx_uint
i
=
0
;
i
<
num_input_nodes
;
i
++
){
input_keys_ptr
[
i
]
=
input_keys
[
i
]
.c_str
()
;
}
mx_uint
shape_data_size
=
0
;
mx_uint
input_shape_indptr
[
input_shapes
.size
()
+
1
];
input_shape_indptr
[
0
]
= 0;
for
(
mx_uint
i
=
0
;
i
<
input_shapes
.size
()
;
i
++
){
shape_data_size
+=
input_shapes
[
i
]
.size
()
;
input_shape_indptr
[
i
+
1
]
=
shape_data_size
;
}
mx_uint
input_shape_data
[
shape_data_size
]
;
mx_uint
index
=
0
;
for
(
mx_uint
i
=
0
;
i
<
input_shapes
.size
()
;
i
++
){
for
(
mx_uint
j
=
0
;
j
<
input_shapes
[
i
]
.size
()
;
j
++
){
input_shape_data
[
index
]
= input_shapes[i][j];
index
++;
}
}
MXP
redCreate
(
static_cast
<
const
char
*
>
(
json_data
.GetBuffer
()),
static_cast
<
const
char
*
>
(
param_data
.GetBuffer
()),
static_cast
<
size_t
>
(
param_data
.GetLength
()),
dev_type
,
dev_id
,
num_input_nodes
,
input_keys_ptr
,
input_shape_indptr
,
input_shape_data
,
&
handle
)
;
assert
(
handle
)
;
}
}
;
</#
if
>
</#
list
>
</#
list
>
#
endif
// $
{
tc
.fileNameWithoutEnding
?
upper_case
}
src/main/resources/templates/gluon/elements/BeamSearch
Start
.ftl
→
src/main/resources/templates/gluon/elements/BeamSearch.ftl
View file @
604bf5b8
File moved
src/main/resources/templates/gluon/execute.ftl
View file @
604bf5b8
...
...
@@ -21,16 +21,6 @@ ${tc.include(stream, "CPP_INLINE")}
</#
if
>
</#
list
>
<#
list
tc.architecture.unrolls as unroll>
<#
list
unroll.getBodiesForAllTimesteps() as body>
<#
if
body.isTrainable()>
_predictor_$
{
tc
.architecture.streams
?
size
+
body
?
index
}
_.predict($
{
tc
.join
(
tc
.getStreamInputNames
(
body
),
", "
)}
, $
{
tc
.join
(
tc
.getStreamOutputNames
(
body
),
", "
)}
);
<#
else
>
${
tc
.include
(
unroll
,
true, "CPP_INLINE")
}
</#
if
>
</#
list
>
</#
list
>
<#
list
tc.architecture.outputs as output>
<#
if
tc.getName(output)??>
<#
assign
shape = output.ioDeclaration.type.dimensions>
...
...
src/main/resources/templates/gluon/pythonExecute.ftl
View file @
604bf5b8
...
...
@@ -13,14 +13,4 @@
<#
else
>
${
tc
.include
(
stream
,
"PYTHON_INLINE")
}
</#
if
>
</#
list
>
<#
list
tc.architecture.unrolls as unroll>
<#
list
unroll.getBodiesForAllTimesteps() as body>
<#
if
body.isTrainable()>
$
{
tc
.join
(
tc
.getStreamOutputNames
(
body
),
", "
)}
= self._networks[$
{
tc
.architecture.streams
?
size
+
body
?
index
}
]($
{
tc
.join
(
tc
.getStreamInputNames
(
body
),
", "
)}
)
<#
else
>
${
tc
.include
(
unroll
,
true, "PYTHON_INLINE")
}
</#
if
>
</#
list
>
</#
list
>
\ No newline at end of file
src/test/resources/valid_tests/RNNencdec.cnna
View file @
604bf5b8
...
...
@@ -14,7 +14,7 @@ architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){
encoder.state -> decoder.state;
timed<t>
Greedy
Search(max_length=50) {
timed<t>
Beam
Search(max_length=50) {
target[t-1] ->
Embedding(output_dim=hidden_size) ->
decoder ->
...
...
src/test/resources/valid_tests/RNNtest.cnna
View file @
604bf5b8
...
...
@@ -2,13 +2,13 @@ architecture RNNtest(max_length=50, vocabulary_size=30000, hidden_size=1000){
def input Q(0:1)^{vocabulary_size} source
def output Q(0:1)^{vocabulary_size} target[5]
source -> Softmax() -> target[0];
source -> Softmax() -> target[0];
timed
<t>
Beam
Search(max_length=5){
(target[0] | target[t-1]) ->
Concatenate() ->
FullyConnected(units=30000) ->
Softmax() ->
target[t]
};
timed<t>
Greedy
Search(max_length=5)
{
(target[0] | target[t-1]) ->
Concatenate() ->
FullyConnected(units=30000) ->
Softmax() ->
target[t]
};
}
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment