Commit 20215ef4 authored by lr119628's avatar lr119628
Browse files

[update] restrucktured model design

parent 1cf71de8
......@@ -535,13 +535,21 @@ from mxnet.ndarray import zeros
<#if networkInstruction.body.containsAdaNet()>
${tc.include(networkInstruction.body, "ADANET_CONSTRUCTION")}
#class Model(gluon.HybridBlock): THIS IS THE ORIGINAL NAME, MUST BE RENAMED IN THE OTHER PARTS
class Net_${networkInstruction?index}(gluon.HybridBlock):
def __init__(self,operations:dict,**kwargs):
super(Net_${networkInstruction?index},self).__init__(**kwargs)
self.AdaNet = True
self.op_names = []
self.candidate_complexities = {}
<#assign outblock = networkInstruction.body.getElements()[1].getDeclaration().getBlock("outBlock")>
with self.name_scope():
<#if outblock.isPresent()>
self.fout = ${tc.include(outblock.get(),"ADANET_CONSTRUCTION")}
<#else>
self.fout = None
</#if>
self.finalout = None
#if operations is None:
# operations={'dummy':nn.Dense(units = 10)}
self.data_shape = ${tc.getDefinedOutputDimension()}
......@@ -554,6 +562,8 @@ class Net_${networkInstruction?index}(gluon.HybridBlock):
self.op_names.append(name)
self.candidate_complexities[name] = operation.get_complexity()
self.out = nn.Dense(units=self.classes,activation=None,flatten=False)
if self.fout:
self.finalout = self.fout()
def hybrid_forward(self, F, x):
res_list = []
......@@ -564,6 +574,8 @@ class Net_${networkInstruction?index}(gluon.HybridBlock):
res = tuple(res_list)
y = F.concat(*res, dim=1)
y = self.out(y)
if self.finalout:
y = self.finalout(y)
return y
def get_candidate_complexity(self):
......
......@@ -487,6 +487,8 @@ def fit(loss: gluon.loss.Loss,
c0_model = model_template(operations=c0_work_op)
c0_model.out.initialize(ctx=ctx)
if c0_model.finalout:
c0_model.finalout.initialize(ctx=ctx)
c0_model.hybridize()
# create model with candidate 1 added -> c1_model
......@@ -495,15 +497,21 @@ def fit(loss: gluon.loss.Loss,
c1_model = model_template(operations=c1_work_op)
c1_model.out.initialize(ctx=ctx)
if c1_model.finalout:
c1_model.finalout.initialize(ctx=ctx)
c1_model.hybridize()
# train c0_model
c0_out_trainer = get_trainer(optimizer, c0_model.out.collect_params(), optimizer_params)
params = c0_model.out.collect_params()
params.update(c0_model.finalout.collect_params())
c0_out_trainer = get_trainer(optimizer, params, optimizer_params)
fitComponent(trainIter=train_iter, trainer=c0_out_trainer, epochs=epochs, component=c0_model,
loss_class=AdaLoss, loss_params={'loss': loss, 'model': c0_model})
# train c1_model
c1_out_trainer = get_trainer(optimizer, c1_model.out.collect_params(), optimizer_params)
params = c1_model.out.collect_params()
params.update(c1_model.finalout.collect_params())
c1_out_trainer = get_trainer(optimizer, params, optimizer_params)
fitComponent(trainIter=train_iter, trainer=c1_out_trainer, epochs=epochs, component=c1_model,
loss_class=AdaLoss, loss_params={'loss': loss, 'model': c1_model})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment