def train():
parser = config_parser() args = parser.parse_args()
""" llff 数据集:调用 load_llff_data 函数,加载图像、相机姿态、边界信息等。 blender 数据集:调用 load_blender_data,加载图像、相机姿态以及训练/验证/测试集分割信息。 LINEMOD 数据集:加载图像、姿态、相机内参以及最近和最远边界值。 deepvoxels 数据集:加载图像和姿态,计算球形半径的近远边界值。 未知数据类型直接退出。 """ K = None if args.dataset_type == 'llff': images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify) hwf = poses[0,:3,-1] poses = poses[:,:3,:4] print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) if not isinstance(i_test, list): i_test = [i_test]
if args.llffhold > 0: print('Auto LLFF holdout,', args.llffhold) i_test = np.arange(images.shape[0])[::args.llffhold]
i_val = i_test i_train = np.array([i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val)])
print('DEFINING BOUNDS') if args.no_ndc: near = np.ndarray.min(bds) * .9 far = np.ndarray.max(bds) * 1. else: near = 0. far = 1. print('NEAR FAR', near, far)
elif args.dataset_type == 'blender': images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip) print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split
near = 2. far = 6.
if args.white_bkgd: images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) else: images = images[...,:3]
elif args.dataset_type == 'LINEMOD': images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, args.testskip) print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}') print(f'[CHECK HERE] near: {near}, far: {far}.') i_train, i_val, i_test = i_split
if args.white_bkgd: images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) else: images = images[...,:3]
elif args.dataset_type == 'deepvoxels':
images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape, basedir=args.datadir, testskip=args.testskip)
print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split
hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1)) near = hemi_R-1. far = hemi_R+1.
else: print('Unknown dataset type', args.dataset_type, 'exiting') return
H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if K is None: K = np.array([ [focal, 0, 0.5*W], [0, focal, 0.5*H], [0, 0, 1] ])
if args.render_test: render_poses = np.array(poses[i_test])
basedir = args.basedir expname = args.expname os.makedirs(os.path.join(basedir, expname), exist_ok=True) f = os.path.join(basedir, expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read())
render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args) global_step = start bds_dict = { 'near' : near, 'far' : far, } render_kwargs_train.update(bds_dict) render_kwargs_test.update(bds_dict)
render_poses = torch.Tensor(render_poses).to(device)
if args.render_only: print('RENDER ONLY') with torch.no_grad(): if args.render_test: images = images[i_test] else: images = None
testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', render_poses.shape)
rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor) print('Done rendering', testsavedir) imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)
return
N_rand = args.N_rand use_batching = not args.no_batching if use_batching: print('get rays') rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) print('done, concats') rays_rgb = np.concatenate([rays, images[:,None]], 1) rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) rays_rgb = np.reshape(rays_rgb, [-1,3,3]) rays_rgb = rays_rgb.astype(np.float32) print('shuffle rays') np.random.shuffle(rays_rgb)
print('done') i_batch = 0
if use_batching: images = torch.Tensor(images).to(device) poses = torch.Tensor(poses).to(device) if use_batching: rays_rgb = torch.Tensor(rays_rgb).to(device)
N_iters = 200000 + 1 print('Begin') print('TRAIN views are', i_train) print('TEST views are', i_test) print('VAL views are', i_val)
start = start + 1 for i in trange(start, N_iters): time0 = time.time()
if use_batching: batch = rays_rgb[i_batch:i_batch+N_rand] batch = torch.transpose(batch, 0, 1) batch_rays, target_s = batch[:2], batch[2]
i_batch += N_rand if i_batch >= rays_rgb.shape[0]: print("Shuffle data after an epoch!") rand_idx = torch.randperm(rays_rgb.shape[0]) rays_rgb = rays_rgb[rand_idx] i_batch = 0
else: img_i = np.random.choice(i_train) target = images[img_i] target = torch.Tensor(target).to(device) pose = poses[img_i, :3,:4]
if N_rand is not None: rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose))
if i < args.precrop_iters: dH = int(H//2 * args.precrop_frac) dW = int(W//2 * args.precrop_frac) coords = torch.stack( torch.meshgrid( torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW) ), -1) if i == start: print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}") else: coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1)
coords = torch.reshape(coords, [-1,2]) select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) select_coords = coords[select_inds].long() rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] batch_rays = torch.stack([rays_o, rays_d], 0) target_s = target[select_coords[:, 0], select_coords[:, 1]]
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays, verbose=i < 10, retraw=True, **render_kwargs_train) optimizer.zero_grad() img_loss = img2mse(rgb, target_s) trans = extras['raw'][...,-1] loss = img_loss psnr = mse2psnr(img_loss)
if 'rgb0' in extras: img_loss0 = img2mse(extras['rgb0'], target_s) loss = loss + img_loss0 psnr0 = mse2psnr(img_loss0)
loss.backward() optimizer.step()
decay_rate = 0.1 decay_steps = args.lrate_decay * 1000 new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps)) for param_group in optimizer.param_groups: param_group['lr'] = new_lrate
dt = time.time()-time0
if i%args.i_weights==0: path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) torch.save({ 'global_step': global_step, 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, path) print('Saved checkpoints at', path)
if i%args.i_video==0 and i > 0: with torch.no_grad(): rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test) print('Done, saving', rgbs.shape, disps.shape) moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)
if i%args.i_testset==0 and i > 0: testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', poses[i_test].shape) with torch.no_grad(): render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir) print('Saved test set')
if i%args.i_print==0: tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") """ print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy()) print('iter time {:.05f}'.format(dt))
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print): tf.contrib.summary.scalar('loss', loss) tf.contrib.summary.scalar('psnr', psnr) tf.contrib.summary.histogram('tran', trans) if args.N_importance > 0: tf.contrib.summary.scalar('psnr0', psnr0)
if i%args.i_img==0:
# Log a rendered validation view to Tensorboard img_i=np.random.choice(i_val) target = images[img_i] pose = poses[img_i, :3,:4] with torch.no_grad(): rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose, **render_kwargs_test)
psnr = mse2psnr(img2mse(rgb, target))
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis]) tf.contrib.summary.image('disp', disp[tf.newaxis,...,tf.newaxis]) tf.contrib.summary.image('acc', acc[tf.newaxis,...,tf.newaxis])
tf.contrib.summary.scalar('psnr_holdout', psnr) tf.contrib.summary.image('rgb_holdout', target[tf.newaxis])
if args.N_importance > 0:
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img): tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis]) tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis]) tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis]) """
global_step += 1
if __name__=='__main__': torch.set_default_tensor_type('torch.cuda.FloatTensor')
train()
|