Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Go, and more
# B is equal to 1 or input mini_batch
self.C = C
self.B = B
self.weight = self.collect_params().get('weight', shape=(1,self.C,self.C),
init=mx.initializer.Uniform(),
allow_deferred_init=True)
self.gram = self.collect_params().get('gram', shape=(self.B,self.C,self.C),
init=mx.initializer.Uniform(),
allow_deferred_init=True,
lr_mult=0)
self.weight.initialize(ctx=ctx)
self.gram.initialize(ctx=ctx)
def setTarget(self, target):
self.gram.set_data(target)
def hybrid_forward(self, F, X, gram, weight):
P = F.batch_dot(F.broadcast_to(weight, shape=(self.gram.shape)), gram)
if not isinstance(X,symbol.Symbol):
return F.batch_dot(F.SwapAxis(P,1,2).broadcast_to((X.shape[0], self.C, self.C)), X.reshape((0,0,X.shape[2]*X.shape[3]))).reshape(X.shape)
else:
#print "Hooppla", interals
#for i in dir(interals):
# print "kk:", i
in_shapes,out_shapes,arg_shapes= X.infer_shape(self.gram.shape)
#print out_shapes
#raise Exception
#arg_shapes, out_shapes, aux_shapes = interals.infer_shape(self.gram.shape)
#print "A", arg_shapes, "O", out_shapes, "AU", aux_shapes
return F.batch_dot(F.SwapAxis(P,1,2).broadcast_to((in_shapes[0], self.C, self.C)), X.reshape((0,0,in_shapes[2]*in_shapes[3]))).reshape(in_shapes)
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'N x ' + str(self.C) + ')'`