-
Notifications
You must be signed in to change notification settings - Fork 85
/
model-demo-ConvLSTM.lua
73 lines (61 loc) · 1.82 KB
/
model-demo-ConvLSTM.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
require 'nn'
require 'rnn'
local backend_name = 'cudnn'
local backend
if backend_name == 'cudnn' then
require 'cudnn'
backend = cudnn
else
backend = nn
end
if opt.untied then
require 'UntiedConvLSTM'
else
require 'ConvLSTM'
end
-- initialization from MSR
local function MSRinit(net)
local function init(name)
for k,v in pairs(net:findModules(name)) do
local n = v.kW*v.kH*v.nOutputPlane
v.weight:normal(0,math.sqrt(2/n))
v.bias:zero()
end
end
-- have to do for both backends
init'cudnn.SpatialConvolution'
init'nn.SpatialConvolution'
end
net = nn.Sequential()
-- Spatial encoder
encoder = nn.Sequential()
encoder:add(backend.SpatialConvolution(opt.nFilters[1], opt.nFilters[2], opt.kernelSize, opt.kernelSize, 1, 1, opt.padding, opt.padding))
encoder:add(backend.Tanh())
encoder:add(nn.SpatialMaxPooling(2,2))
net:add(encoder)
-- Temporal encoder
if opt.untied then
net:add(nn.UntiedConvLSTM(opt.nFiltersMemory[1],opt.nFiltersMemory[2], opt.nSeq, opt.kernelSize, opt.kernelSizeMemory, opt.stride))
else
net:add(nn.UntiedConvLSTM(opt.nFiltersMemory[1],opt.nFiltersMemory[2], opt.nSeq, opt.kernelSize, opt.kernelSizeMemory, opt.stride))
end
-- Spatial decoder
decoder = nn.Sequential()
decoder:add(nn.SpatialUpSamplingNearest(2))
decoder:add(backend.SpatialConvolution(opt.nFiltersMemory[2], opt.nFilters[1], opt.kernelSize, opt.kernelSize, 1, 1, opt.padding, opt.padding))
net:add(decoder)
-- Init model
net:add(nn.Sigmoid())
MSRinit(net)
local lstm_paramsE, lstm_gradsE = net.modules[2]:getParameters()
lstm_paramsE:uniform(-0.08,0.08)
net.modules[2]:initBias(0,0)
-- Unroll over time using sequencer
model = nn.Sequencer(net)
model:remember('both')
model:training()
-- Loss module
criterion = nn.SequencerCriterion(nn.BCECriterion())
-- move everything to gpu
model:cuda()
criterion:cuda()