NeRF-Pytorch源码学习

源码仓库:yenchenlin/nerf-pytorch: A PyTorch implementation of NeRF (Neural Radiance Fields) that reproduces the results. (github.com)
https://github.com/yenchenlin/nerf-pytorch

run_nerf_helpers.py

# img to mse 计算两张图像的均方误差MSE
img2mse = lambda x, y : torch.mean((x - y) ** 2)
# mse to psnr 将MSE转换为PSNR(峰值信噪比)—— PSNR 越高,表示图像质量越好
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))

Embedder类

位置编码器,用于将输入坐标(如 3D 点坐标,视角方向)映射到高维空间,增强神经网络对高频细节的表达能力

class Embedder:
# **kwargs 关键字参数,kwargs是字典形式存储关键字参数
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()

# 构建嵌入函数
def create_embedding_fn(self):
embed_fns = [] # 存储所有的嵌入函数
d = self.kwargs['input_dims']
out_dim = 0
# 如果包含原始输入
if self.kwargs['include_input']:
embed_fns.append(lambda x : x)
out_dim += d # 输出维度也要增加

max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']

# 频率按对数增长
if self.kwargs['log_sampling']:
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
# 频率按线性增长
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)

# 对每个频率 freq,构建正弦函数和余弦函数
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
out_dim += d # 增加输出维度

self.embed_fns = embed_fns
self.out_dim = out_dim

# 依次调用所有嵌入函数,将结果按照最后一维拼接到一起返回
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)

get_embedder()

构造一个嵌入函数和其对应的输出维度

def get_embedder(multires, i=0):
if i == -1:
return nn.Identity(), 3

embed_kwargs = {
'include_input' : True,
'input_dims' : 3,
'max_freq_log2' : multires-1,
'num_freqs' : multires,
'log_sampling' : True,
'periodic_fns' : [torch.sin, torch.cos],
}

embedder_obj = Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj : eo.embed(x)
return embed, embedder_obj.out_dim

如果 i = -1,直接返回原始输入(不使用嵌入)

multires:位置编码的最大分辨率,对应 num_freqs

i:控制是否应用嵌入(i = -1 时跳过)

输出一个嵌入函数和其输出张量的维度

NeRF类

定义了模型的网络结构和前向传播逻辑

# Model
class NeRF(nn.Module):
def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
super(NeRF, self).__init__()
self.D = D # MLP的层数
self.W = W # 每层的隐藏神经元数量(宽度)
self.input_ch = input_ch # 空间点坐标输入的维度——三维 xyz
self.input_ch_views = input_ch_views # 视角方向的输入维度
self.skips = skips # 条约连接层的索引(4表示是第4层)
self.use_viewdirs = use_viewdirs # 是否使用视角方向信息(引入光照着色的效果)

# 构建空间点的MLP
self.pts_linears = nn.ModuleList(
[nn.Linear(input_ch, W)] + # 第一层线性变换,输入维度为input_ch,输出维度为 W
[nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)] # 跳跃连接:在指定层拼接输入
)

### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
# 构建视角方向的MLP
self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])

### Implementation according to the paper
# self.views_linears = nn.ModuleList(
# [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
# 构建输出层
if use_viewdirs:
self.feature_linear = nn.Linear(W, W) # 提取特征
self.alpha_linear = nn.Linear(W, 1) # 输出alpha值
self.rgb_linear = nn.Linear(W//2, 3) # 输出rgb值
else:
# 不使用视角方向时,直接输出 RGB 和密度
self.output_linear = nn.Linear(W, output_ch)

# 前向传播
def forward(self, x):
# 分离输入为空间点和视角方向
input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
h = input_pts
# 空间点坐标进行全连接层的变换
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h) # 线性变换
h = F.relu(h) # 非线性激活函数
if i in self.skips: # 跳跃连接层
h = torch.cat([input_pts, h], -1) # 把空间点坐标位置这个原输入添加到输入张量中

# 如果使用视角方向
if self.use_viewdirs:
alpha = self.alpha_linear(h) # 输出alpha(密度)用于体渲染
feature = self.feature_linear(h) # 提取特征
h = torch.cat([feature, input_views], -1) # 拼接特征,视角方向

for i, l in enumerate(self.views_linears):
h = self.views_linears[i](h)
h = F.relu(h)

rgb = self.rgb_linear(h) # 输出颜色
outputs = torch.cat([rgb, alpha], -1) # 拼接RGB和密度
else:
outputs = self.output_linear(h) # 不使用视角方向直接输出RGB+密度

return outputs

# 从 Keras 预训练模型的权重中加载参数到 PyTorch 模型
def load_weights_from_keras(self, weights):
assert self.use_viewdirs, "Not implemented if use_viewdirs=False"

# Load pts_linears
for i in range(self.D):
idx_pts_linears = 2 * i
self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears]))
self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1]))

# Load feature_linear
idx_feature_linear = 2 * self.D
self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear]))
self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1]))

# Load views_linears
idx_views_linears = 2 * self.D + 2
self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears]))
self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1]))

# Load rgb_linear
idx_rbg_linear = 2 * self.D + 4
self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear]))
self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1]))

# Load alpha_linear
idx_alpha_linear = 2 * self.D + 6
self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear]))
self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1]))

get_rays()和get_rays_np()

根据相机内外参数,生成光线的起点坐标和方向,用于确定从相机到场景的每条光线信息

# Pytorch实现,支持GPU加速
def get_rays(H, W, K, c2w):
# 生成图像的像素网格,i 是横坐标,j 是纵坐标
i, j = torch.meshgrid(
torch.linspace(0, W-1, W),
torch.linspace(0, H-1, H)
)
i = i.t()
j = j.t()
# 将像素坐标转为归一化相机坐标
dirs = torch.stack(
[(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)],
-1
)
# 转换到世界坐标系
rays_d = torch.sum(dirs[..., None, :] * c2w[:3, :3], -1)
# 光线的起点是相机在世界坐标系中的位置
rays_o = c2w[:3, -1].expand(rays_d.shape)
return rays_o, rays_d

H: 图像高度(像素)。

W: 图像宽度(像素)。

K: 相机内参矩阵,形状为 [3, 3]
包含焦距和主点坐标:

K=[fx0cx0fycy001]K=\begin{bmatrix} f_x & 0 & c_x \\ 0 & f_y & c_y \\ 0 & 0 & 1 \end{bmatrix}

f_xf_y:水平和垂直方向的焦距。
c_xc_y:图像中心点的坐标。

c2w: 相机外参矩阵(相机到世界的变换矩阵),形状为 [3, 4]
包含旋转矩阵 R 和平移向量 t:$$ c2w = \begin{bmatrix} R & t \end{bmatrix} $$

# Numpy实现
def get_rays_np(H, W, K, c2w):
i, j = np.meshgrid(
np.arange(W, dtype=np.float32),
np.arange(H, dtype=np.float32),
indexing='xy'
)
dirs = np.stack(
[(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)],
-1
)
rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1)
rays_o = np.broadcast_to(c2w[:3, -1], rays_d.shape)
return rays_o, rays_d

ndc_rays()

将光线的起点和方向从世界坐标系转换到 归一化设备坐标系 (Normalized Device Coordinates, NDC)。在 NDC 中,3D 空间被投影到标准化范围(通常是 [−1,1]),以便于图像平面上的处理和渲染。

def ndc_rays(H, W, focal, near, rays_o, rays_d):
# 将光线的起点移动到最近平面
# [..., 2] 表示省略其余的维度(这里对应 N_rays),只保留最后一个维度的第 2 个分量(z 分量)
t = -(near + rays_o[...,2]) / rays_d[...,2]
# [..., None] 在最后一维增加一个新维度,默认值为1
# 广播规则:[N_rays, 1] 可以自动扩展为 [N_rays, 3],即:每个光线的缩放因子 t 被复制到 x, y, z 方向
rays_o = rays_o + t[...,None] * rays_d

# 光线起点被透视投影转换为 NDC 范围
# 在 NDC 中,通常会翻转 X 轴和 Y 轴的方向,使得符合图像空间的坐标标准(左上角为原点),所以有负号
# z要映射到[0,1]所以是加1
o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2]
o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2]
o2 = 1. + 2. * near / rays_o[...,2]
# 光线方向的投影
d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2])
d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2])
d2 = -2. * near / rays_o[...,2]
# 重新组合起点和方向
rays_o = torch.stack([o0,o1,o2], -1)
rays_d = torch.stack([d0,d1,d2], -1)

return rays_o, rays_d

H: 图像高度(像素)。

W: 图像宽度(像素)。

focal: 相机的焦距(单位:像素)。

near: 最近平面(Near Plane)距离,用于确定光线的起点。

rays_o: 光线的起点,形状为 [N_rays, 3]

rays_d: 光线的方向,形状为 [N_rays, 3]

sample_pdf()——不是很明白

实现从给定的概率密度函数 (PDF) 和累积分布函数 (CDF) 中采样点的功能
用于在较粗的采样基础上进行精细采样,提升渲染质量

# Hierarchical sampling (section 5.2)
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
# 构造 PDF 和 CDF
weights = weights + 1e-5 # 防止数值问题
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins))

# Take uniform samples——平均采样
# 生成采样点
if det:
u = torch.linspace(0., 1., steps=N_samples)
u = u.expand(list(cdf.shape[:-1]) + [N_samples])
# 随机采样
else:
u = torch.rand(list(cdf.shape[:-1]) + [N_samples])

# Pytest, overwrite u with numpy's fixed random numbers
if pytest:
np.random.seed(0)
new_shape = list(cdf.shape[:-1]) + [N_samples]
if det:
u = np.linspace(0., 1., N_samples)
u = np.broadcast_to(u, new_shape)
else:
u = np.random.rand(*new_shape)
u = torch.Tensor(u)

# Invert CDF
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds-1), inds-1)
above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)

# cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
# bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

denom = (cdf_g[...,1]-cdf_g[...,0])
denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
t = (u-cdf_g[...,0])/denom
samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])

return samples

bins:输入的分层边界,形状为 (batch_size, num_bins),每个光线沿采样方向的分层区间。

weights:每个区间的权重,形状为 (batch_size, num_bins),通常与区间的体密度相关,用于构造 PDF。

N_samples:希望从 PDF 中采样的样本数。

det (可选):是否使用均匀采样(确定性采样)。如果为 False,则使用随机采样。

pytest (可选):用于测试目的。允许固定随机数生成器的种子值,使结果可重复。

run_nerf.py

概览

是训练和渲染NeRF模型的主入口文件,主要包括
train()——训练,数据加载,模型实例化
render()——根据一组rays和相机参数计算ray的RGB,Z,Alpha值
create_nerf()——创建NeRF的MLP模型,包括粗网络和细网络(如果使用分层采样)
config_parser()——解析命令行参数和配置文件

batchify()

def batchify(fn, chunk):
if chunk is None:
return fn
def ret(inputs):
return torch.cat(
[fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
return ret

将输入的函数fn,切分为多块以避免内存溢出
fn: 需要被分块应用的函数。
chunk: 每次处理的最大输入大小(如果为 None,直接返回原函数)

torch.cat 是 PyTorch 的一个函数,用于将张量按指定维度拼接

torch.cat(tensors, dim=0)

tensors: 一个包含多个张量的列表或元组,这些张量的形状在指定维度以外需要一致。
dim: 指定拼接的维度。

注意该函数实际上的返回值是函数体内定义的一个新的函数ret,所以调用时是

outputs_flat = batchify(fn, netchunk)(embedded)

因为batchtify返回一个ret(inputs)函数,所以还需要一个括号传递参数

run_network()

完成预处理——点坐标/视角方向位置编码为高频信息
调用NeRF网络执行一次前向传播,并且网络的调用分块处理
把输出还原为与原输入相同的格式维度

def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
"""Prepares inputs and applies network 'fn'."""
# 1. 展平输入
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
# 2. 对输入进行嵌入编码
embedded = embed_fn(inputs_flat)

# 3. 如果有视角方向 viewdirs
if viewdirs is not None:
# 展开视角方向,使其形状与输入一致
input_dirs = viewdirs[:, None].expand(inputs.shape)
input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
# 对视角方向进行嵌入编码
embedded_dirs = embeddirs_fn(input_dirs_flat)
# 将视角方向嵌入拼接到输入嵌入上
embedded = torch.cat([embedded, embedded_dirs], -1)

# 4. 调用 batchify 分块处理
outputs_flat = batchify(fn, netchunk)(embedded)
# 5. 将输出的形状还原为与输入相匹配
outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
return outputs

inputs:输入点(通常是光线在场景中的采样点),形状为 [N_rays, N_samples, D],其中 D 是输入点的维度(通常为 3)。

viewdirs:观察视角方向,形状为 [N_rays, 3],描述每条光线的方向。

fn:模型网络(通常是 NeRF 类的实例)。

embed_fn:用于对输入点进行位置编码的函数。

embeddirs_fn:用于对视角方向进行位置编码的函数。

netchunk:用于控制批处理大小,防止内存溢出。

batchify_rays()

对输入光线数据(rays_flat)进行分块处理,从而避免显存溢出 (Out of Memory, OOM)

def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
"""Render rays in smaller minibatches to avoid OOM."""
all_ret = {} # 创建了一个Dict字典,用于存储所有光线渲染的结果
for i in range(0, rays_flat.shape[0], chunk): # 每次处理chunk条光线
ret = render_rays(rays_flat[i:i+chunk], **kwargs) # chunk条光线调用渲染函数渲染
for k in ret:
if k not in all_ret:
all_ret[k] = []
all_ret[k].append(ret[k])

all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret} # 将所有块的结果按第 0 维拼接,得到完整的渲染结果
return all_ret

rays_flat:输入的光线数据,形状通常是 [N_rays, D],每行代表一条光线的参数(如起点、方向等)。

chunk:每次处理的光线数量(批大小)。默认值为 1024 * 32,这是一个经验值,具体取决于设备的显存容量和任务复杂度。

**kwargs:额外的参数,用于传递给 render_rays 函数。

render()

def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
near=0., far=1.,
use_viewdirs=False, c2w_staticcam=None,
**kwargs):
"""Render rays
Args:
H: int. Height of image in pixels. 图像的高度(像素单位)
W: int. Width of image in pixels. 图像的宽度(像素单位)
focal: float. Focal length of pinhole camera. 相机的焦距
chunk: int. Maximum number of rays to process simultaneously. Used to
control maximum memory usage. Does not affect final results. 每次最大同时处理的光线数量
rays: array of shape [2, batch_size, 3]. Ray origin and direction for
each example in batch. 光线的数据,包含起点和方向
c2w: array of shape [3, 4]. Camera-to-world transformation matrix. 相机到世界变换矩阵
ndc: bool. If True, represent ray origin, direction in NDC coordinates. 是否以NDC坐标表示光线
near: float or array of shape [batch_size]. Nearest distance for a ray. 光线的近平面距离
far: float or array of shape [batch_size]. Farthest distance for a ray. 光线的远平面距离
use_viewdirs: bool. If True, use viewing direction of a point in space in model. 是否使用视角方向
c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
camera while using other c2w argument for viewing directions. 特殊的相机-世界变换矩阵,用于可视化固定相机的视角方向
Returns:
rgb_map: [batch_size, 3]. Predicted RGB values for rays. 颜色
disp_map: [batch_size]. Disparity map. Inverse of depth. 视差
acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. 密度
extras: dict with everything returned by render_rays(). 额外信息
"""
if c2w is not None:
# special case to render full image
# 生成整个图像的所有光线
rays_o, rays_d = get_rays(H, W, K, c2w)
else:
# use provided ray batch
# 光线通过给定参数获取
rays_o, rays_d = rays

if use_viewdirs:
# provide ray directions as input
viewdirs = rays_d
if c2w_staticcam is not None:
# special case to visualize effect of viewdirs
rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) # 将光线方向 rays_d 归一化
viewdirs = torch.reshape(viewdirs, [-1,3]).float() # reshape为[N_rays, 3]维度

sh = rays_d.shape # [..., 3]
if ndc:
# for forward facing scenes
# 转换为NDC坐标
rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)

# Create ray batch
rays_o = torch.reshape(rays_o, [-1,3]).float()
rays_d = torch.reshape(rays_d, [-1,3]).float()

# 为每条光线添加近远平面
# rays_d 是 [N_rays, 3]的,则 rays_d[..., :1] 是 [N_rays, 1]的,最后三维中只取第一维
near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
# 拼接光线信息 变为 [N_rays, 8]
rays = torch.cat([rays_o, rays_d, near, far], -1)
if use_viewdirs:
# 如果启用了 use_viewdirs,将视角方向拼接到光线数据中,形状变为 [N_rays, 11]
rays = torch.cat([rays, viewdirs], -1)

# Render and reshape
# 分块渲染光线
all_ret = batchify_rays(rays, chunk, **kwargs)
for k in all_ret:
# 确保渲染结果符合格式要求
k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
all_ret[k] = torch.reshape(all_ret[k], k_sh)

# 返回渲染结果
k_extract = ['rgb_map', 'disp_map', 'acc_map']
ret_list = [all_ret[k] for k in k_extract]
ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
return ret_list + [ret_dict]

render_path()

用于沿着一系列相机姿态(render_poses)渲染图像和视差图
该函数适合于生成动态场景视频或从不同视角渲染场景

def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):

H, W, focal = hwf

if render_factor!=0:
# Render downsampled for speed
# H,W和焦距按照缩放因子render_factor进行缩小,即降采样,加快渲染速度
H = H//render_factor
W = W//render_factor
focal = focal/render_factor

rgbs = [] # 存储渲染路径上所有颜色图像
disps = [] # 存储渲染路径上所有视差图

t = time.time()
# 渲染每个相机姿态
for i, c2w in enumerate(tqdm(render_poses)):
print(i, time.time() - t)
t = time.time()
rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
# 使用 .cpu().numpy() 将张量结果转为 NumPy 格式
rgbs.append(rgb.cpu().numpy())
disps.append(disp.cpu().numpy())
if i==0:
print(rgb.shape, disp.shape)

"""
此处注释的代码是进行 PSNR 评估
if gt_imgs is not None and render_factor==0:
p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
print(p)
"""
# 保存图像到文件夹中
if savedir is not None:
rgb8 = to8b(rgbs[-1]) # to8b 将图像从浮点数(范围 [0, 1])转换为 8 位整数(范围 [0, 255])
filename = os.path.join(savedir, '{:03d}.png'.format(i)) # 创建文件名
imageio.imwrite(filename, rgb8) # rgb写入图像文件中

# np堆叠数组到一个新维度上,也就是增加了一个维度,表示第几帧
rgbs = np.stack(rgbs, 0)
disps = np.stack(disps, 0)

return rgbs, disps

render_poses:渲染路径上的相机姿态,形状为 [N_poses, 4, 4]。每一帧的相机位姿(4x4 矩阵),定义了相机在世界坐标系中的位置和方向。

hwf:包含图像的高度(H)、宽度(W)和焦距(focal)。
例如:hwf = (H, W, focal)

K:相机的内参矩阵,形状为 [3, 3]

chunk:每次渲染的光线数量,用于控制显存占用。

render_kwargs:额外的渲染参数,直接传递给 render 函数。

gt_imgs(可选):渲染目标的 ground truth 图像,用于计算 PSNR(峰值信噪比)作为评估指标。

savedir(可选):如果指定,渲染的 RGB 图像会保存到 savedir 目录下。

render_factor:下采样因子。通过降低分辨率加快渲染速度(牺牲画质)。

create_nerf()

def create_nerf(args):
"""Instantiate NeRF's MLP model."""
# args.multires:位置编码的分辨率参数
# args.i_embed:是否使用位置编码
# embed_fn:编码函数,用于将输入点映射到高维空间
# input_ch:编码后输入点的维度
embed_fn, input_ch = get_embedder(args.multires, args.i_embed)

# 视角方向的位置编码
input_ch_views = 0
embeddirs_fn = None
if args.use_viewdirs:
embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)

# 初始化神经网络
# 如果有精细网络(args.N_importance > 0),输出包含 5 个通道(RGB + 密度 + 权重)。否则,输出 4 个通道(RGB + 密度)
output_ch = 5 if args.N_importance > 0 else 4
skips = [4]
# to(device)把网络转移到对应设备上,可以是显卡/cpu
model = NeRF(D=args.netdepth, W=args.netwidth,
input_ch=input_ch, output_ch=output_ch, skips=skips,
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
# 将网络的参数都加入到列表中,后续提供给优化器进行优化
grad_vars = list(model.parameters())

# 如果启用了精细采样,则创建一个细网络,类似于主网络,但深度和宽度可以不同
model_fine = None
if args.N_importance > 0:
model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
input_ch=input_ch, output_ch=output_ch, skips=skips,
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
grad_vars += list(model_fine.parameters()) # 细网络的参数加入优化参数列表

# 定义网络查询lambda函数
# 内部调用了run_network()完成预处理,跑一次前向传播
network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
embed_fn=embed_fn,
embeddirs_fn=embeddirs_fn,
netchunk=args.netchunk)

# Create optimizer
# 创建优化器
optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))

start = 0 # 初始化训练的全局步数
basedir = args.basedir # 实验目录路径 存放训练过程中生成的各种文件
expname = args.expname # 实验名称 区分不同的训练任务

##########################

# Load checkpoints
# 加载检查点,用于恢复训练
# 如果指定了路径 args.ft_path,直接加载该检查点
if args.ft_path is not None and args.ft_path!='None':
ckpts = [args.ft_path]
# 否则,加载basedir/expname目录下的最新检查点
else:
ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]

print('Found ckpts', ckpts)
if len(ckpts) > 0 and not args.no_reload:
ckpt_path = ckpts[-1]
print('Reloading from', ckpt_path)
ckpt = torch.load(ckpt_path)

start = ckpt['global_step']
optimizer.load_state_dict(ckpt['optimizer_state_dict'])

# Load model
model.load_state_dict(ckpt['network_fn_state_dict'])
if model_fine is not None:
model_fine.load_state_dict(ckpt['network_fine_state_dict'])

##########################

# 配置训练中的参数
render_kwargs_train = {
'network_query_fn' : network_query_fn,
'perturb' : args.perturb,
'N_importance' : args.N_importance,
'network_fine' : model_fine,
'N_samples' : args.N_samples,
'network_fn' : model,
'use_viewdirs' : args.use_viewdirs,
'white_bkgd' : args.white_bkgd,
'raw_noise_std' : args.raw_noise_std,
}

# NDC only good for LLFF-style forward facing data
if args.dataset_type != 'llff' or args.no_ndc:
print('Not ndc!')
render_kwargs_train['ndc'] = False
render_kwargs_train['lindisp'] = args.lindisp

# 配置测试参数
render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
render_kwargs_test['perturb'] = False
render_kwargs_test['raw_noise_std'] = 0.

return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer

raw2outputs()

def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
"""Transforms model's predictions to semantically meaningful values.
Args:
raw: [num_rays, num_samples along ray, 4]. Prediction from model. 模型的预测输出
z_vals: [num_rays, num_samples along ray]. Integration time. 每条光线上采样点的深度值
rays_d: [num_rays, 3]. Direction of each ray. 每条光线的方向向量
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 颜色图
disp_map: [num_rays]. Disparity map. Inverse of depth map. 视差图
acc_map: [num_rays]. Sum of weights along each ray. 每条光线的累积密度(透明度)
weights: [num_rays, num_samples]. Weights assigned to each sampled color. 每条光线每个采样点的权重值
depth_map: [num_rays]. Estimated distance to object. 每条光线的深度图
"""
# raw是模型输出的密度,dists表示采样点间距,act_fn是要传入的非线性激活函数,可以确保密度值非负,根据公式计算透明度
raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)

# 计算相邻采样点之间的距离
dists = z_vals[...,1:] - z_vals[...,:-1]
# 在最后添加一个大值 1e10,表示光线的“无限远”距离,确保每个采样点都有对应的距离
dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples]
# 根据光线方向向量,沿着方向把相邻采样点距离转为真实3D距离
dists = dists * torch.norm(rays_d[...,None,:], dim=-1)

# 使用sigmoid激活函数将模型预测输出的前三维rgb值转为0~1的rgb值
rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3]
noise = 0.
# 如果设置了 raw_noise_std,生成随机噪声以增强训练的鲁棒性
if raw_noise_std > 0.:
noise = torch.randn(raw[...,3].shape) * raw_noise_std

# Overwrite randomly sampled data if pytest
# 如果启用了 pytest,生成固定的随机噪声以保持结果一致
if pytest:
np.random.seed(0)
noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
noise = torch.Tensor(noise)

# 计算透明度
alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples]
# weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
# 计算每个采样点的权重
weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]

# 渲染RGB图像——采样点的权重和颜色值相乘再相加得到一个光线(像素)的颜色,依次类推
rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3]

# 计算(渲染)深度图和视差图
depth_map = torch.sum(weights * z_vals, -1)
# 计算深度的倒数(1 / depth_map),确保视差值不会接近无穷大(通过 torch.max)
disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
# 计算累积透明度
# 累积透明度是权重的和,表示光线穿过物体的总不透明度
acc_map = torch.sum(weights, -1)

# 处理白色背景——将未被物体占据的部分填充为白色
if white_bkgd:
rgb_map = rgb_map + (1.-acc_map[...,None])

return rgb_map, disp_map, acc_map, weights, depth_map

render_rays()

def render_rays(ray_batch,
network_fn,
network_query_fn,
N_samples,
retraw=False,
lindisp=False,
perturb=0.,
N_importance=0,
network_fine=None,
white_bkgd=False,
raw_noise_std=0.,
verbose=False,
pytest=False):
"""Volumetric rendering.
Args:
ray_batch: array of shape [batch_size, ...]. All information necessary
for sampling along a ray, including: ray origin, ray direction, min
dist, max dist, and unit-magnitude viewing direction.
network_fn: function. Model for predicting RGB and density at each point
in space.
network_query_fn: function used for passing queries to network_fn.
N_samples: int. Number of different times to sample along each ray.
retraw: bool. If True, include model's raw, unprocessed predictions.
lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
random points in time.
N_importance: int. Number of additional times to sample along each ray. 重要性采样的额外采样点数(仅在精细模型中使用)
These samples are only passed to network_fine.
network_fine: "fine" network with same spec as network_fn. 细网络,用于接收重要性采样点的输入
white_bkgd: bool. If True, assume a white background. 是否假设背景为白色
raw_noise_std: ...
verbose: bool. If True, print more debugging info.
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
disp_map: [num_rays]. Disparity map. 1 / depth.
acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
raw: [num_rays, num_samples, 4]. Raw predictions from model.
rgb0: See rgb_map. Output for coarse model.
disp0: See disp_map. Output for coarse model.
acc0: See acc_map. Output for coarse model.
z_std: [num_rays]. Standard deviation of distances along ray for each
sample.
"""
# 提取光线信息
N_rays = ray_batch.shape[0]
rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2]) # 近平面远平面
near, far = bounds[...,0], bounds[...,1] # [-1,1]

# 初始化采样点位置
t_vals = torch.linspace(0., 1., steps=N_samples)
if not lindisp:
z_vals = near * (1.-t_vals) + far * (t_vals)
# 如果 lindisp=True,采样点在逆深度空间中线性分布
else:
z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))

z_vals = z_vals.expand([N_rays, N_samples])

# 随机扰动采样点
if perturb > 0.:
# get intervals between samples
mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
upper = torch.cat([mids, z_vals[...,-1:]], -1)
lower = torch.cat([z_vals[...,:1], mids], -1)
# stratified samples in those intervals
t_rand = torch.rand(z_vals.shape)

# Pytest, overwrite u with numpy's fixed random numbers
if pytest:
np.random.seed(0)
t_rand = np.random.rand(*list(z_vals.shape))
t_rand = torch.Tensor(t_rand)

z_vals = lower + (upper - lower) * t_rand

# 计算采样点的坐标位置 pts = ray_o + t * ray_d
pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]


# raw = run_network(pts)
# 使用粗网络预测采样点的颜色+密度
raw = network_query_fn(pts, viewdirs, network_fn)
# 通过预测值计算颜色图,深度图,视差图,透明度累积图
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)

# 细网络采样部分
if N_importance > 0:

rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map

z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
z_samples = z_samples.detach()

z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]

run_fn = network_fn if network_fine is None else network_fine
# raw = run_network(pts, fn=run_fn)
raw = network_query_fn(pts, viewdirs, run_fn)

rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)

# 返回结果
ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map}
if retraw:
# 模型的原始预测输出rgb+密度
ret['raw'] = raw
if N_importance > 0:
ret['rgb0'] = rgb_map_0
ret['disp0'] = disp_map_0
ret['acc0'] = acc_map_0
ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays]

for k in ret:
if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
print(f"! [Numerical Error] {k} contains nan or inf.")

return ret

train()

调用前面的各种函数,完成整个框架流程

def train():

# 调用 config_parser 函数解析命令行参数,返回 args,包含实验和训练所需的所有配置
parser = config_parser()
args = parser.parse_args()

# Load data
"""
llff 数据集:调用 load_llff_data 函数,加载图像、相机姿态、边界信息等。
blender 数据集:调用 load_blender_data,加载图像、相机姿态以及训练/验证/测试集分割信息。
LINEMOD 数据集:加载图像、姿态、相机内参以及最近和最远边界值。
deepvoxels 数据集:加载图像和姿态,计算球形半径的近远边界值。
未知数据类型直接退出。
"""
K = None
if args.dataset_type == 'llff':
images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
recenter=True, bd_factor=.75,
spherify=args.spherify)
hwf = poses[0,:3,-1]
poses = poses[:,:3,:4]
print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
if not isinstance(i_test, list):
i_test = [i_test]

if args.llffhold > 0:
print('Auto LLFF holdout,', args.llffhold)
i_test = np.arange(images.shape[0])[::args.llffhold]

i_val = i_test
i_train = np.array([i for i in np.arange(int(images.shape[0])) if
(i not in i_test and i not in i_val)])

print('DEFINING BOUNDS')
if args.no_ndc:
near = np.ndarray.min(bds) * .9
far = np.ndarray.max(bds) * 1.

else:
near = 0.
far = 1.
print('NEAR FAR', near, far)

elif args.dataset_type == 'blender':
images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip)
print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
i_train, i_val, i_test = i_split

near = 2.
far = 6.

if args.white_bkgd:
images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
else:
images = images[...,:3]

elif args.dataset_type == 'LINEMOD':
images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, args.testskip)
print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}')
print(f'[CHECK HERE] near: {near}, far: {far}.')
i_train, i_val, i_test = i_split

if args.white_bkgd:
images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
else:
images = images[...,:3]

elif args.dataset_type == 'deepvoxels':

images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
basedir=args.datadir,
testskip=args.testskip)

print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir)
i_train, i_val, i_test = i_split

hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1))
near = hemi_R-1.
far = hemi_R+1.

else:
print('Unknown dataset type', args.dataset_type, 'exiting')
return

# Cast intrinsics to right types
# 获取图像分辨率,相机内参
H, W, focal = hwf
H, W = int(H), int(W)
hwf = [H, W, focal]

# 如果相机内参 K 为空,则根据焦距 focal 和图像中心坐标生成内参矩阵
if K is None:
K = np.array([
[focal, 0, 0.5*W],
[0, focal, 0.5*H],
[0, 0, 1]
])

if args.render_test:
render_poses = np.array(poses[i_test])

# Create log dir and copy the config file
# 日志目录和配置存储文件等
basedir = args.basedir
expname = args.expname
os.makedirs(os.path.join(basedir, expname), exist_ok=True)
f = os.path.join(basedir, expname, 'args.txt')
with open(f, 'w') as file:
for arg in sorted(vars(args)):
attr = getattr(args, arg)
file.write('{} = {}\n'.format(arg, attr))
if args.config is not None:
f = os.path.join(basedir, expname, 'config.txt')
with open(f, 'w') as file:
file.write(open(args.config, 'r').read())

# Create nerf model
# 实例化NeRF网络——调用create_nerf()
render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
global_step = start

# 配置近远平面值,用于光线采样点的范围确定
bds_dict = {
'near' : near,
'far' : far,
}
render_kwargs_train.update(bds_dict)
render_kwargs_test.update(bds_dict)

# Move testing data to GPU
render_poses = torch.Tensor(render_poses).to(device)

# Short circuit if only rendering out from trained model
# 仅渲染模式——跳过训练,直接渲染测试路径或图像——调用render_path生成一系列图片
# 渲染结果保存为.mp4
if args.render_only:
print('RENDER ONLY')
with torch.no_grad():
if args.render_test:
# render_test switches to test poses
images = images[i_test]
else:
# Default is smoother render_poses path
images = None

testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start))
os.makedirs(testsavedir, exist_ok=True)
print('test poses shape', render_poses.shape)

rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
print('Done rendering', testsavedir)
imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)

return

# Prepare raybatch tensor if batching random rays
# 准备光线的batch批处理
N_rand = args.N_rand # 每个批次中随机光线的数量(例如:1024条光线)
use_batching = not args.no_batching
if use_batching:
# For random ray batching
print('get rays')
# 将所有相机的光线堆叠在一起,生成形状为 [N, ro+rd, H, W, 3] 的数组
rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
print('done, concats')
rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only
rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]
rays_rgb = rays_rgb.astype(np.float32)
print('shuffle rays')
np.random.shuffle(rays_rgb)

print('done')
i_batch = 0

# Move training data to GPU
if use_batching:
images = torch.Tensor(images).to(device)
poses = torch.Tensor(poses).to(device)
if use_batching:
rays_rgb = torch.Tensor(rays_rgb).to(device)


N_iters = 200000 + 1
print('Begin')
print('TRAIN views are', i_train)
print('TEST views are', i_test)
print('VAL views are', i_val)

# Summary writers
# writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))

start = start + 1
# 核心训练循环
for i in trange(start, N_iters):
time0 = time.time()

# Sample random ray batch
if use_batching:
# Random over all images
batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
batch = torch.transpose(batch, 0, 1) # 交换0和1维度
batch_rays, target_s = batch[:2], batch[2]

i_batch += N_rand
if i_batch >= rays_rgb.shape[0]:
print("Shuffle data after an epoch!")
rand_idx = torch.randperm(rays_rgb.shape[0])
rays_rgb = rays_rgb[rand_idx]
i_batch = 0

else:
# Random from one image
img_i = np.random.choice(i_train)
target = images[img_i]
target = torch.Tensor(target).to(device)
pose = poses[img_i, :3,:4]

if N_rand is not None:
rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)

if i < args.precrop_iters:
dH = int(H//2 * args.precrop_frac)
dW = int(W//2 * args.precrop_frac)
coords = torch.stack(
torch.meshgrid(
torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH),
torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)
), -1)
if i == start:
print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")
else:
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)

coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
select_coords = coords[select_inds].long() # (N_rand, 2)
rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
batch_rays = torch.stack([rays_o, rays_d], 0)
target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)

##### Core optimization loop #####
# 调用render函数渲染光线通过场景后的颜色,视差图,透明度
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
verbose=i < 10, retraw=True,
**render_kwargs_train)

# 反向传播部分——优化器进行参数优化
optimizer.zero_grad()
# 使用MSE均方误差进一步计算PSNR作为损失函数计算方法
img_loss = img2mse(rgb, target_s)
trans = extras['raw'][...,-1]
loss = img_loss
psnr = mse2psnr(img_loss)

if 'rgb0' in extras:
img_loss0 = img2mse(extras['rgb0'], target_s)
loss = loss + img_loss0
psnr0 = mse2psnr(img_loss0)

loss.backward() # 反向传播
optimizer.step() # 优化器

# NOTE: IMPORTANT!
### update learning rate ###
# 指数衰减策略动态调整学习率
decay_rate = 0.1
decay_steps = args.lrate_decay * 1000
new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
for param_group in optimizer.param_groups:
param_group['lr'] = new_lrate
################################

dt = time.time()-time0
# print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
##### end #####

# Rest is logging
if i%args.i_weights==0:
path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
torch.save({
'global_step': global_step,
'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, path)
print('Saved checkpoints at', path)

if i%args.i_video==0 and i > 0:
# Turn on testing mode
with torch.no_grad():
rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test)
print('Done, saving', rgbs.shape, disps.shape)
moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)

# if args.use_viewdirs:
# render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4]
# with torch.no_grad():
# rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
# render_kwargs_test['c2w_staticcam'] = None
# imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)

if i%args.i_testset==0 and i > 0:
testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
os.makedirs(testsavedir, exist_ok=True)
print('test poses shape', poses[i_test].shape)
with torch.no_grad():
render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)
print('Saved test set')



if i%args.i_print==0:
tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}")
"""
print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
print('iter time {:.05f}'.format(dt))

with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
tf.contrib.summary.scalar('loss', loss)
tf.contrib.summary.scalar('psnr', psnr)
tf.contrib.summary.histogram('tran', trans)
if args.N_importance > 0:
tf.contrib.summary.scalar('psnr0', psnr0)


if i%args.i_img==0:

# Log a rendered validation view to Tensorboard
img_i=np.random.choice(i_val)
target = images[img_i]
pose = poses[img_i, :3,:4]
with torch.no_grad():
rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose,
**render_kwargs_test)

psnr = mse2psnr(img2mse(rgb, target))

with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):

tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis])
tf.contrib.summary.image('disp', disp[tf.newaxis,...,tf.newaxis])
tf.contrib.summary.image('acc', acc[tf.newaxis,...,tf.newaxis])

tf.contrib.summary.scalar('psnr_holdout', psnr)
tf.contrib.summary.image('rgb_holdout', target[tf.newaxis])


if args.N_importance > 0:

with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis])
tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis])
tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis])
"""

global_step += 1


if __name__=='__main__':
torch.set_default_tensor_type('torch.cuda.FloatTensor')

train()