mstksg on master
benchmarks updated (compare)
mstksg on master
benchmarks updated (compare)
testMul :: H.L 2 2
testMul = evalBP2 (BP.<>) mk (H.tr mk)
mk :: H.L 2 2
mk = H.build (\a b -> a + b)
getDim :: H.Sized Double s LA.Matrix => s -> (Int, Int)
getDim s =
let m = H.extract s :: LA.Matrix Double
in (LA.rows m, LA.cols m)
main = do
print $ getDim testMul
network :: (Reifies s W) =>
=> BVar s (Acc (Matrix Float)) -- ^ Inputs
-> BVar s [Acc (Matrix Float)] -- ^ Weights
-> BVar s (Acc (Matrix Float)) -- ^ Outputs
Applying something like gradBP network
is expected to produce [Acc (Matrix Float)]
, a list of gradients. Now, in gradient descent-like scenario, one would iteratively subtract that gradients from the initial weights thus obtaining a new list of weights (or equivalent). In terms of Accelerate, a GPU algorithm is constructed. To obtain the result Acc a
a function run :: Arrays a => Acc a -> a
is applied. How do I represent [Acc (Matrix Float)]
as Arrays a => a
so that I can take advantage of the backprop library?
Backprop
instance
foldB :: (Reifies s W) => (BVar s Double -> BVar s Double) -> BVar s Double -> BVar s [Double] -> BVar s Double
foldB f r xs = divide (PB.foldl' (step' f r) (T2 0 0) xs) where
step' f r (T2 s c) a = uncurry T2 ((r*) $ s + f a, (r*) $ c + 1)
divide (T2 s c) = s / c
stdB :: Reifies s W => BVar s Double -> BVar s [Double] -> BVar s Double
stdB r xs = (\ss s -> sqrt (ss - s ** 2)) (foldB id r xs) (foldB (\x -> x * x) r xs)
online :: (Reifies s W, Fractional b) => (BVar s a -> BVar s b) -> (BVar s b -> BVar s b) -> Fold (BVar s a) (BVar s b)
online f g = Fold step begin extract
where
begin = (0, 0)
step (s, c) a = (g $ s + f a, g $ c + 1)
extract (s, c) = s / c
ma' :: (Reifies s W, Fractional b) => BVar s b -> Fold (BVar s b) (BVar s b)
ma' r = online id (*r)
sqma' :: (Reifies s W, Fractional b) => BVar s b -> Fold (BVar s b) (BVar s b)
sqma' r = online (\x -> x * x) (*r)
std' r = (\s ss -> sqrt (ss - s ** 2)) <$> ma' r <*> sqma' r
> backprop2 (\xs r -> L.fold (std' r) (sequenceVar xs)) [1..10::Double] 0.99
(2.8715489528256772,([-0.15247607523147438,-0.12040952611722658,-8.767961073424568e-2,-5.4276199506764516e-2,-2.0189025904650604e-2,1.4592315289823798e-2,5.007836973625318e-2,8.627982531165954e-2,0.1232075139075938,0.16087241324903193],0.14717238787841228))