import os from typing import Optional,Tuple,List,Union,Callable import numpy as np import torch from torch import nn import matplotlib.pyplot as plt from mpl\_toolkits.mplot3d import axes3d from tqdm import trange # 设置GPU还是CPU设备 device = torch.device\('cuda' if torch.cuda.is\_available\(\) else 'cpu' \)
1 输入
这项工作中使用的小型乐高数据集由 106 幅乐高推土机的图像组成,并配有位姿数据和常用焦距数值。与其他数据集一样,这里保留前 100 张图像用于训练,并保留一张测试图像用于验证,具体的加载数据操作如下:
data = np.load\('tiny\_nerf\_data.npz' \) # 加载数据集 images = data\['images' \] # 图像数据 poses = data\['poses' \] # 位姿数据 focal = data\['focal' \] # 焦距数值 print \(f'Images shape: \{images.shape\}' \) print \(f'Poses shape: \{poses.shape\}' \) print \(f'Focal length: \{focal\}' \) height, width = images.shape\[1:3\] near, far = 2., 6. n\_training = 100 # 训练数据数量 testimg\_idx = 101 # 测试数据下标 testimg, testpose = images\[testimg\_idx\], poses\[testimg\_idx\] plt.imshow\(testimg\) print \('Pose' \) print \(testpose\)
2 数据处理
回顾NeRF相关论文, 本次代码实现需要的输入是一个单独的5D坐标 (空间位置
, sigma ))。因此, 我们需要针对上面使用的小型乐高数据做一个处理操作。
# 方向数据 dirs = np.stack\(\[np.sum\(\[0, 0, -1\] \* pose\[:3, :3\], axis=-1\) for pose in poses\]\) # 原点数据 origins = poses\[:, :3, -1\] # 绘图的设置 ax = plt.figure\(figsize=\(12, 8\)\).add\_subplot\(projection='3d' \) \_ = ax.quiver\( origins\[..., 0\].flatten\(\), origins\[..., 1\].flatten\(\), origins\[..., 2\].flatten\(\), dirs \[..., 0\].flatten\(\), dirs \[..., 1\].flatten\(\), dirs \[..., 2\].flatten\(\), length=0.5, normalize=True\) ax.set\_xlabel\('X' \) ax.set\_ylabel\('Y' \) ax.set\_zlabel\('z' \) plt.show\(\)
def get\_rays\( height: int, # 图像高度 width: int, # 图像宽带 focal\_length: float , # 焦距 c2w: torch.Tensor \) -> Tuple\[torch.Tensor, torch.Tensor\]: """ 通过每个像素和相机原点,找到射线的原点和方向。 """ # 应用针孔相机模型收集每个像素的方向 i, j = torch.meshgrid\( torch.arange\(width, dtype=torch.float32\).to\(c2w\), torch.arange\(height, dtype=torch.float32\).to\(c2w\), indexing='ij' \) i, j = i.transpose\(-1, -2\), j.transpose\(-1, -2\) # 方向数据 directions = torch.stack\(\[\(i - width \* .5\) / focal\_length, -\(j - height \* .5\) / focal\_length, -torch.ones\_like\(i\) \], dim=-1\) # 用相机位姿求出方向 rays\_d = torch.sum\(directions\[..., None, :\] \* c2w\[:3, :3\], dim=-1\) # 默认所有射线原点相同 rays\_o = c2w\[:3, -1\].expand\(rays\_d.shape\) return rays\_o, rays\_d
# 转为PyTorch的tensor images = torch.from\_numpy\(data\['images' \]\[:n\_training\]\).to\(device\) poses = torch.from\_numpy\(data\['poses' \]\).to\(device\) focal = torch.from\_numpy\(data\['focal' \]\).to\(device\) testimg = torch.from\_numpy\(data\['images' \]\[testimg\_idx\]\).to\(device\) testpose = torch.from\_numpy\(data\['poses' \]\[testimg\_idx\]\).to\(device\) # 针对每个图像获取射线 height, width = images.shape\[1:3\] with torch.no\_grad\(\): ray\_origin, ray\_direction = get\_rays\(height, width, focal, testpose\) print \('Ray Origin' \) print \(ray\_origin.shape\) print \(ray\_origin\[height // 2, width // 2, :\]\) print \('' \) print \('Ray Direction' \) print \(ray\_direction.shape\) print \(ray\_direction\[height // 2, width // 2, :\]\) print \('' \)
2.1 分层采样
# 采样函数定义 def sample\_stratified\( rays\_o: torch.Tensor, # 射线原点 rays\_d: torch.Tensor, # 射线方向 near: float , far: float , n\_samples: int, # 采样数量 perturb: Optional\[bool\] = True, # 扰动设置 inverse\_depth: bool = False # 反向深度 \) -> Tuple\[torch.Tensor, torch.Tensor\]: """ 从规则的bin中沿着射线进行采样。 """ # 沿着射线抓取采样点 t\_vals = torch.linspace\(0., 1., n\_samples, device=rays\_o.device\) if not inverse\_depth: # 由远到近线性采样 z\_vals = near \* \(1.-t\_vals\) + far \* \(t\_vals\) else : # 在反向深度中线性采样 z\_vals = 1./\(1./near \* \(1.-t\_vals\) + 1./far \* \(t\_vals\)\) # 沿着射线从bins中统一采样 if perturb: mids = .5 \* \(z\_vals\[1:\] + z\_vals\[:-1\]\) upper = torch.concat\(\[mids, z\_vals\[-1:\]\], dim=-1\) lower = torch.concat\(\[z\_vals\[:1\], mids\], dim=-1\) t\_rand = torch.rand\(\[n\_samples\], device=z\_vals.device\) z\_vals = lower + \(upper - lower\) \* t\_rand z\_vals = z\_vals.expand\(list\(rays\_o.shape\[:-1\]\) + \[n\_samples\]\) # 应用相应的缩放参数 pts = rays\_o\[..., None, :\] + rays\_d\[..., None, :\] \* z\_vals\[..., :, None\] return pts, z\_vals
接着就到了对这些采样点做可视化分析的步骤。如图5中所述,未受扰动的蓝 色点是bin的“中心“,而红点对应扰动点的采样。请注意,红点与上方的蓝点略有偏移,但所有点都在远近采样设定值之间。具体代码如下:
y\_vals = torch.zeros\_like\(z\_vals\) # 调用采样策略函数 \_, z\_vals\_unperturbed = sample\_stratified\(rays\_o, rays\_d, near, far, n\_samples, perturb=False, inverse\_depth=inverse\_depth\) # 绘图相关
plt.plot\(z\_vals\_unperturbed\[0\].cpu\(\).numpy\(\), 1 + y\_vals\[0\].cpu\(\).numpy\(\), 'b-o' \) plt.plot\(z\_vals\[0\].cpu\(\).numpy\(\), y\_vals\[0\].cpu\(\).numpy\(\), 'r-o' \) plt.ylim\(\[-1, 2\]\) plt.title\('Stratified Sampling \(blue\) with Perturbation \(red\)' \) ax = plt.gca\(\) ax.axes.yaxis.set\_visible\(False\) plt.grid\(True\)
3 位置编码
这一环节将会为位置编码器建立一个简单的 torch.nn.Module 模块,相同的编码器可同时用于对输入样本和视图方向的编码操作。注意,这些输入被指定了不同的参数。代码如下所示:
# 位置编码类 class PositionalEncoder\(nn.Module\): """ 对输入点,做sine或者consine位置编码。 """ def \_\_init\_\_\( self, d\_input: int, n\_freqs: int, log \_space: bool = False \): super\(\).\_\_init\_\_\(\) self.d\_input = d\_input self.n\_freqs = n\_freqs self.log\_space = log \_space self.d\_output = d\_input \* \(1 + 2 \* self.n\_freqs\) self.embed\_fns = \[lambda x: x\] # 定义线性或者log尺度的频率 if self.log\_space: freq\_bands = 2.\*\*torch.linspace\(0., self.n\_freqs - 1, self.n\_freqs\) else : freq\_bands = torch.linspace\(2.\*\*0., 2.\*\*\(self.n\_freqs - 1\), self.n\_freqs\) # 替换sin和cos for freq in freq\_bands: self.embed\_fns.append\(lambda x, freq=freq: torch.sin\(x \* freq\)\) self.embed\_fns.append\(lambda x, freq=freq: torch.cos\(x \* freq\)\) def forward\( self, x \) -> torch.Tensor: """ 实际使用位置编码的函数。 """ return torch.concat\(\[fn\(x\) for fn in self.embed\_fns\], dim=-1\)
4 NeRF模型
在此,定义一个NeRF 模型——主要由线性层模块列表构成,而列表中进一步包含非线性激活函数和残差连接。该模型有一个可选的视图方向输入,如果在实例化时提供具体的方向信息,那么会改变模型结构。
(本实现基于原始论文NeRF:Representing Scenes as Neural Radiance Fields for View Synthesis 的第3节,并使用相同的默认设置)
# 定义NeRF模型 class NeRF\(nn.Module\): """ 神经辐射场模块。 """ def \_\_init\_\_\( self, d\_input: int = 3, n\_layers: int = 8, d\_filter: int = 256, skip: Tuple\[int\] = \(4,\), d\_viewdirs: Optional\[int\] = None \): super\(\).\_\_init\_\_\(\) self.d\_input = d\_input # 输入 self.skip = skip # 残差连接 self.act = nn.functional.relu # 激活函数 self.d\_viewdirs = d\_viewdirs # 视图方向 # 创建模型的层结构 self.layers = nn.ModuleList\( \[nn.Linear\(self.d\_input, d\_filter\)\] + \[nn.Linear\(d\_filter + self.d\_input, d\_filter\) if i in skip \\ else nn.Linear\(d\_filter, d\_filter\) for i in range\(n\_layers - 1\)\] \) # Bottleneck 层 if self.d\_viewdirs is not None: # 如果使用视图方向,分离alpha和RGB self.alpha\_out = nn.Linear\(d\_filter, 1\) self.rgb\_filters = nn.Linear\(d\_filter, d\_filter\) self.branch = nn.Linear\(d\_filter + self.d\_viewdirs, d\_filter // 2\) self.output = nn.Linear\(d\_filter // 2, 3\) else : # 如果不使用试图方向,则简单输出 self.output = nn.Linear\(d\_filter, 4\) def forward\( self, x: torch.Tensor, viewdirs: Optional\[torch.Tensor\] = None \) -> torch.Tensor: r""" 带有视图方向的前向传播 """ # 判断是否设置视图方向 if self.d\_viewdirs is None and viewdirs is not None: raise ValueError\('Cannot input x\_direction if d\_viewdirs was not given.' \) # 运行bottleneck层之前的网络层 x\_input = x for i, layer in enumerate\(self.layers\): x = self.act\(layer\(x\)\) if i in self.skip: x = torch.cat\(\[x, x\_input\], dim=-1\) # 运行 bottleneck if self.d\_viewdirs is not None: # Split alpha from network output alpha = self.alpha\_out\(x\) # 结果传入到rgb过滤器 x = self.rgb\_filters\(x\) x = torch.concat\(\[x, viewdirs\], dim=-1\) x = self.act\(self.branch\(x\)\) x = self.output\(x\) # 拼接alpha一起作为输出 x = torch.concat\(\[x, alpha\], dim=-1\) else : # 不拼接,简单输出 x = self.output\(x\) return x
5 体积渲染
# 体积渲染 def cumprod\_exclusive\( tensor: torch.Tensor \) -> torch.Tensor: """ \(Courtesy of https://github.com/krrish94/nerf-pytorch\) 和tf.math.cumprod\(..., exclusive=True\)功能类似 参数: tensor \(torch.Tensor\): Tensor whose cumprod \(cumulative product, see \`torch.cumprod\`\) along dim=-1 is to be computed. 返回值: cumprod \(torch.Tensor\): cumprod of Tensor along dim=-1, mimiciking the functionality of tf.math.cumprod\(..., exclusive=True\) \(see \`tf.math.cumprod\` for details\). """ # 首先计算规则的cunprod cumprod = torch.cumprod\(tensor, -1\) cumprod = torch.roll\(cumprod, 1, -1\) # 用1替换首个元素 cumprod\[..., 0\] = 1. return cumprod # 输出到图像的函数 def raw2outputs\( raw: torch.Tensor, z\_vals: torch.Tensor, rays\_d: torch.Tensor, raw\_noise\_std: float = 0.0, white\_bkgd: bool = False \) -> Tuple\[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor\]: """ 将NeRF的输出转换为RGB输出。 """ # 沿着\`z\_vals\`轴元素之间的差值. dists = z\_vals\[..., 1:\] - z\_vals\[..., :-1\] dists = torch.cat\(\[dists, 1e10 \* torch.ones\_like\(dists\[..., :1\]\)\], dim=-1\) # 将每个距离乘以相应方向射线的法线,转换为现实世界中的距离(考虑非单位方向)。 dists = dists \* torch.norm\(rays\_d\[..., None, :\], dim=-1\) # 为模型预测密度添加噪音。可用于在训练过程中对网络进行正则化(防止出现浮点伪影)。 noise = 0. if raw\_noise\_std > 0.: noise = torch.randn\(raw\[..., 3\].shape\) \* raw\_noise\_std # Predict density of each sample along each ray. Higher values imply # higher likelihood of being absorbed at this point. \[n\_rays, n\_samples\] alpha = 1.0 - torch.exp\(-nn.functional.relu\(raw\[..., 3\] + noise\) \* dists\) # 预测每条射线上每个样本的密度。数值越大,表示该点被吸收的可能性越大。\[n\_ 射线,n\_样本] weights = alpha \* cumprod\_exclusive\(1. - alpha + 1e-10\) # 计算RGB图的权重。 rgb = torch.sigmoid\(raw\[..., :3\]\) # \[n\_rays, n\_samples, 3\] rgb\_map = torch.sum\(weights\[..., None\] \* rgb, dim=-2\) # \[n\_rays, 3\] # 估计预测距离的深度图。 depth\_map = torch.sum\(weights \* z\_vals, dim=-1\) # 稀疏图 disp\_map = 1. / torch.max\(1e-10 \* torch.ones\_like\(depth\_map\), depth\_map / torch.sum\(weights, -1\)\) # 沿着每条射线加权。 acc\_map = torch.sum\(weights, dim=-1\) # 要合成到白色背景上,请使用累积的 alpha 贴图。 if white\_bkgd: rgb\_map = rgb\_map + \(1. - acc\_map\[..., None\]\) return rgb\_map, depth\_map, acc\_map, weights
6 分层体积采样
# 采样概率密度函数 def sample\_pdf\( bins: torch.Tensor, weights: torch.Tensor, n\_samples: int, perturb: bool = False \) -> torch.Tensor: """ 应用反向转换采样到一组加权点。 """ # 正则化权重得到概率密度函数。 pdf = \(weights + 1e-5\) / torch.sum\(weights + 1e-5, -1, keepdims=True\) # \[n\_rays, weights.shape\[-1\]\] # 将概率密度函数转为累计分布函数。 cdf = torch.cumsum\(pdf, dim=-1\) # \[n\_rays, weights.shape\[-1\]\] cdf = torch.concat\(\[torch.zeros\_like\(cdf\[..., :1\]\), cdf\], dim=-1\) # \[n\_rays, weights.shape\[-1\] + 1\] # 从累计分布函数中提取样本位置。perturb == 0 时为线性。 if not perturb: u = torch.linspace\(0., 1., n\_samples, device=cdf.device\) u = u.expand\(list\(cdf.shape\[:-1\]\) + \[n\_samples\]\) # \[n\_rays, n\_samples\] else : u = torch.rand\(list\(cdf.shape\[:-1\]\) + \[n\_samples\], device=cdf.device\) # \[n\_rays, n\_samples\] # 沿累计分布函数找出 u 值所在的索引。
u = u.contiguous\(\) # 返回具有相同值的连续张量。 inds = torch.searchsorted\(cdf, u, right=True\) # \[n\_rays, n\_samples\] # 夹住超出范围的索引。 below = torch.clamp\(inds - 1, min=0\) above = torch.clamp\(inds, max=cdf.shape\[-1\] - 1\) inds\_g = torch.stack\(\[below, above\], dim=-1\) # \[n\_rays, n\_samples, 2\] # 从累计分布函数和相应的 bin 中心取样。 matched\_shape = list\(inds\_g.shape\[:-1\]\) + \[cdf.shape\[-1\]\] cdf\_g = torch.gather\(cdf.unsqueeze\(-2\).expand\(matched\_shape\), dim=-1, index=inds\_g\) bins\_g = torch.gather\(bins.unsqueeze\(-2\).expand\(matched\_shape\), dim=-1, index=inds\_g\) # 将样本转换为射线长度。 denom = \(cdf\_g\[..., 1\] - cdf\_g\[..., 0\]\) denom = torch.where\(denom \ t = \(u - cdf\_g\[..., 0\]\) / denom samples = bins\_g\[..., 0\] + t \* \(bins\_g\[..., 1\] - bins\_g\[..., 0\]\) return samples # \[n\_rays, n\_samples\]
7 整体的前向传播流程
def get\_chunks\( inputs: torch.Tensor, chunksize: int = 2\*\*15 \) -> List\[torch.Tensor\]: """ 输入分块。 """ return \[inputs\[i:i + chunksize\] for i in range\(0, inputs.shape\[0\], chunksize\)\] def prepare\_chunks\( points: torch.Tensor, encoding\_function: Callable\[\[torch.Tensor\], torch.Tensor\], chunksize: int = 2\*\*15 \) -> List\[torch.Tensor\]: """ 对点进行编码和分块,为 NeRF 模型做好准备。 """ points = points.reshape\(\(-1, 3\)\) points = encoding\_function\(points\) points = get\_chunks\(points, chunksize=chunksize\) return points def prepare\_viewdirs\_chunks\( points: torch.Tensor, rays\_d: torch.Tensor, encoding\_function: Callable\[\[torch.Tensor\], torch.Tensor\], chunksize: int = 2\*\*15 \) -> List\[torch.Tensor\]: r""" 对视图方向进行编码和分块,为 NeRF 模型做好准备。 """ viewdirs = rays\_d / torch.norm\(rays\_d, dim=-1, keepdim=True\) viewdirs = viewdirs\[:, None, ...\].expand\(points.shape\).reshape\(\(-1, 3\)\) viewdirs = encoding\_function\(viewdirs\) viewdirs = get\_chunks\(viewdirs, chunksize=chunksize\) return viewdirs def nerf\_forward\( rays\_o: torch.Tensor, rays\_d: torch.Tensor, near: float , far: float , encoding\_fn: Callable\[\[torch.Tensor\], torch.Tensor\], coarse\_model: nn.Module, kwargs\_sample\_stratified: dict = None, n\_samples\_hierarchical: int = 0, kwargs\_sample\_hierarchical: dict = None, fine\_model = None, viewdirs\_encoding\_fn: Optional\[Callable\[\[torch.Tensor\], torch.Tensor\]\] = None, chunksize: int = 2\*\*15 \) -> Tuple\[torch.Tensor, torch.Tensor, torch.Tensor, dict\]: """ 计算一次前向传播 """ # 设置参数 if kwargs\_sample\_stratified is None: kwargs\_sample\_stratified = \{\} if kwargs\_sample\_hierarchical is None: kwargs\_sample\_hierarchical = \{\} # 沿着每条射线的样本查询点。 query\_points, z\_vals = sample\_stratified\( rays\_o, rays\_d, near, far, \*\*kwargs\_sample\_stratified\) # 准备批次。 batches = prepare\_chunks\(query\_points, encoding\_fn, chunksize=chunksize\) if viewdirs\_encoding\_fn is not None: batches\_viewdirs = prepare\_viewdirs\_chunks\(query\_points, rays\_d, viewdirs\_encoding\_fn, chunksize=chunksize\) else : batches\_viewdirs = \[None\] \* len\(batches\) # 稀疏模型流程。 predictions = \[\] for batch, batch\_viewdirs in zip\(batches, batches\_viewdirs\): predictions.append\(coarse\_model\(batch, viewdirs=batch\_viewdirs\)\) raw = torch.cat\(predictions, dim=0\) raw = raw.reshape\(list\(query\_points.shape\[:2\]\) + \[raw.shape\[-1\]\]\) # 执行可微分体积渲染,重新合成 RGB 图像。 rgb\_map, depth\_map, acc\_map, weights = raw2outputs\(raw, z\_vals, rays\_d\) outputs = \{ 'z\_vals\_stratified' : z\_vals \} if n\_samples\_hierarchical > 0: # Save previous outputs to return. rgb\_map\_0, depth\_map\_0, acc\_map\_0 = rgb\_map, depth\_map, acc\_map # 对精细查询点进行分层抽样。 query\_points, z\_vals\_combined, z\_hierarch = sample\_hierarchical\( rays\_o, rays\_d, z\_vals, weights, n\_samples\_hierarchical, \*\*kwargs\_sample\_hierarchical\) # 像以前一样准备输入。 batches = prepare\_chunks\(query\_points, encoding\_fn, chunksize=chunksize\) if viewdirs\_encoding\_fn is not None: batches\_viewdirs = prepare\_viewdirs\_chunks\(query\_points, rays\_d, viewdirs\_encoding\_fn, chunksize=chunksize\) else : batches\_viewdirs = \[None\] \* len\(batches\) # 通过精细模型向前传递新样本。 fine\_model = fine\_model if fine\_model is not None else coarse\_model predictions = \[\] for batch, batch\_viewdirs in zip\(batches, batches\_viewdirs\): predictions.append\(fine\_model\(batch, viewdirs=batch\_viewdirs\)\) raw = torch.cat\(predictions, dim=0\) raw = raw.reshape\(list\(query\_points.shape\[:2\]\) + \[raw.shape\[-1\]\]\) # 执行可微分体积渲染,重新合成 RGB 图像。 rgb\_map, depth\_map, acc\_map, weights = raw2outputs\(raw, z\_vals\_combined, rays\_d\) # 存储输出 outputs\['z\_vals\_hierarchical' \] = z\_hierarch outputs\['rgb\_map\_0' \] = rgb\_map\_0 outputs\['depth\_map\_0' \] = depth\_map\_0 outputs\['acc\_map\_0' \] = acc\_map\_0 # 存储输出 outputs\['rgb\_map' \] = rgb\_map outputs\['depth\_map' \] = depth\_map outputs\['acc\_map' \] = acc\_map outputs\['weights' \] = weights return outputs
7.1 超参数
# 编码器 d\_input = 3 # 输入维度 n\_freqs = 10 # 输入到编码函数中的样本点数量 log \_space = True # 如果设置,频率按对数空间缩放 use\_viewdirs = True # 如果设置,则使用视图方向作为输入 n\_freqs\_views = 4 # 视图编码功能的数量 # 采样策略 n\_samples = 64 # 每条射线的空间样本数 perturb = True # 如果设置,则对采样位置应用噪声 inverse\_depth = False # 如果设置,则按反深度线性采样点 # 模型 d\_filter = 128 # 线性层滤波器的尺寸 n\_layers = 2 # bottleneck层数量 skip = \[\] # 应用输入残差的层级 use\_fine\_model = True # 如果设置,则创建一个精细模型 d\_filter\_fine = 128 # 精细网络线性层滤波器的尺寸 n\_layers\_fine = 6 # 精细网络瓶颈层数 # 分层采样 n\_samples\_hierarchical = 64 # 每条射线的样本数 perturb\_hierarchical = False # 如果设置,则对采样位置应用噪声 # 优化器 lr = 5e-4 # 学习率 # 训练 n\_iters = 10000 batch\_size = 2\*\*14 # 每个梯度步长的射线数量(2 的幂次) one\_image\_per\_step = True # 每个梯度步骤一个图像(禁用批处理) chunksize = 2\*\*14 # 根据需要进行修改,以适应 GPU 内存 center\_crop = True # 裁剪图像的中心部分(每幅图像裁剪一次) center\_crop\_iters = 50 # 经过这么多epoch后,停止裁剪中心 display\_rate = 25 # 每 X 个epoch显示一次测试输出 # 早停 warmup\_iters = 100 # 热身阶段的迭代次数 warmup\_min\_fitness = 10.0 # 在热身\_iters 处继续训练的最小 PSNR 值 n\_restarts = 10 # 训练停滞时重新开始的次数 # 捆绑了各种函数的参数,以便一次性传递。 kwargs\_sample\_stratified = \{ 'n\_samples' : n\_samples, 'perturb' : perturb, 'inverse\_depth' : inverse\_depth \} kwargs\_sample\_hierarchical = \{ 'perturb' : perturb \}
7.2 训练类和函数
# 绘制采样函数 def plot\_samples\( z\_vals: torch.Tensor, z\_hierarch: Optional\[torch.Tensor\] = None, ax: Optional\[np.ndarray\] = None\): r""" 绘制分层样本和(可选)分级样本。 """ y\_vals = 1 + np.zeros\_like\(z\_vals\) if ax is None: ax = plt.subplot\(\) ax.plot\(z\_vals, y\_vals, 'b-o' \) if z\_hierarch is not None: y\_hierarch = np.zeros\_like\(z\_hierarch\) ax.plot\(z\_hierarch, y\_hierarch, 'r-o' \) ax.set\_ylim\(\[-1, 2\]\) ax.set\_title\('Stratified Samples \(blue\) and Hierarchical Samples \(red\)' \) ax.axes.yaxis.set\_visible\(False\) ax.grid\(True\) return ax def crop\_center\( img: torch.Tensor, frac: float = 0.5 \) -> torch.Tensor: r""" 从图像中裁剪中心方形。 """ h\_offset = round\(img.shape\[0\] \* \(frac / 2\)\) w\_offset = round\(img.shape\[1\] \* \(frac / 2\)\) return img\[h\_offset:-h\_offset, w\_offset:-w\_offset\] class EarlyStopping: r""" 基于适配标准的早期停止辅助器 """ def \_\_init\_\_\( self, patience: int = 30, margin: float = 1e-4 \): self.best\_fitness = 0.0 self.best\_iter = 0 self.margin = margin self.patience = patience or float \('inf' \) # 在epoch停止提高后等待的停止时间 def \_\_call\_\_\( self, iter: int, fitness: float \): r""" 检查是否符合停止标准。 """ if \(fitness - self.best\_fitness\) > self.margin: self.best\_iter = iter self.best\_fitness = fitness delta = iter - self.best\_iter stop = delta >= self.patience # 超过耐性则停止训练 return stop def init\_models\(\): r""" 为 NeRF 训练初始化模型、编码器和优化器。 """ # 编码器 encoder = PositionalEncoder\(d\_input, n\_freqs, log \_space=log \_space\) encode = lambda x: encoder\(x\) # 视图方向编码 if use\_viewdirs: encoder\_viewdirs = PositionalEncoder\(d\_input, n\_freqs\_views, log \_space=log \_space\) encode\_viewdirs = lambda x: encoder\_viewdirs\(x\) d\_viewdirs = encoder\_viewdirs.d\_output else : encode\_viewdirs = None d\_viewdirs = None # 模型 model = NeRF\(encoder.d\_output, n\_layers=n\_layers, d\_filter=d\_filter, skip=skip, d\_viewdirs=d\_viewdirs\) model.to\(device\) model\_params = list\(model.parameters\(\)\) if use\_fine\_model: fine\_model = NeRF\(encoder.d\_output, n\_layers=n\_layers, d\_filter=d\_filter, skip=skip, d\_viewdirs=d\_viewdirs\) fine\_model.to\(device\) model\_params = model\_params + list\(fine\_model.parameters\(\)\) else : fine\_model = None # 优化器 optimizer = torch.optim.Adam\(model\_params, lr=lr\) # 早停