-
Notifications
You must be signed in to change notification settings - Fork 20
/
val.py
76 lines (58 loc) · 2.22 KB
/
val.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import time
import argparse
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision.utils as vutils
from torch.autograd import Variable
from data import *
from networks import *
from misc import *
parser = argparse.ArgumentParser()
parser.add_argument('--gpu_ids', default='0,1', type=str)
parser.add_argument('--hdim', default=128, type=int)
parser.add_argument('--pre_model', default='./model/netG_model_epoch_50_iter_0.pth', type=str)
parser.add_argument('--out_path_1', default='fake_images/nir_noise/', type=str)
parser.add_argument('--out_path_2', default='fake_images/vis_noise/', type=str)
def main():
global opt, model
args = parser.parse_args()
print(args)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids
cudnn.benchmark = True
if not os.path.exists(args.out_path_1):
os.makedirs(args.out_path_1)
if not os.path.exists(args.out_path_2):
os.makedirs(args.out_path_2)
# generator
netE_nir, netE_vis, netG = define_G(hdim=args.hdim)
load_model(netG, args.pre_model)
netG.eval()
num = 0
for n in range(1000):
noise = torch.zeros(100, args.hdim).normal_(0, 1)
noise = torch.cat((noise, noise), dim=1)
noise = Variable(noise).cuda()
fake = netG(noise)
nir = fake[:, 0:3, :, :].data.cpu().numpy()
vis = fake[:, 3:6, :, :].data.cpu().numpy()
for i in range(nir.shape[0]):
num = num + 1
save_img = nir[i, :, :, :]
save_img = np.transpose((255 * save_img).astype('uint8'), (1, 2, 0))
output = Image.fromarray(save_img)
save_name = str(num) + '.jpg'
output.save(os.path.join(args.out_path_1, save_name))
save_img = vis[i, :, :, :]
save_img = np.transpose((255 * save_img).astype('uint8'), (1, 2, 0))
output = Image.fromarray(save_img)
save_name = str(num) + '.jpg'
output.save(os.path.join(args.out_path_2, save_name))
print(num)
if __name__ == "__main__":
main()