simple_viewer.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. """A simple example to render a (large-scale) Gaussian Splats
  2. ```bash
  3. python examples/simple_viewer.py --scene_grid 13
  4. ```
  5. """
  6. import argparse
  7. import math
  8. import os
  9. import time
  10. from typing import Tuple
  11. import imageio
  12. import nerfview
  13. import numpy as np
  14. import torch
  15. import torch.nn.functional as F
  16. import tqdm
  17. import viser
  18. from gsplat._helper import load_test_data
  19. from gsplat.distributed import cli
  20. from gsplat.rendering import rasterization
  21. def main(local_rank: int, world_rank, world_size: int, args):
  22. torch.manual_seed(42)
  23. device = torch.device("cuda", local_rank)
  24. if args.ckpt is None:
  25. (
  26. means,
  27. quats,
  28. scales,
  29. opacities,
  30. colors,
  31. viewmats,
  32. Ks,
  33. width,
  34. height,
  35. ) = load_test_data(device=device, scene_grid=args.scene_grid)
  36. assert world_size <= 2
  37. means = means[world_rank::world_size].contiguous()
  38. means.requires_grad = True
  39. quats = quats[world_rank::world_size].contiguous()
  40. quats.requires_grad = True
  41. scales = scales[world_rank::world_size].contiguous()
  42. scales.requires_grad = True
  43. opacities = opacities[world_rank::world_size].contiguous()
  44. opacities.requires_grad = True
  45. colors = colors[world_rank::world_size].contiguous()
  46. colors.requires_grad = True
  47. viewmats = viewmats[world_rank::world_size][:1].contiguous()
  48. Ks = Ks[world_rank::world_size][:1].contiguous()
  49. sh_degree = None
  50. C = len(viewmats)
  51. N = len(means)
  52. print("rank", world_rank, "Number of Gaussians:", N, "Number of Cameras:", C)
  53. # batched render
  54. for _ in tqdm.trange(1):
  55. render_colors, render_alphas, meta = rasterization(
  56. means, # [N, 3]
  57. quats, # [N, 4]
  58. scales, # [N, 3]
  59. opacities, # [N]
  60. colors, # [N, S, 3]
  61. viewmats, # [C, 4, 4]
  62. Ks, # [C, 3, 3]
  63. width,
  64. height,
  65. render_mode="RGB+D",
  66. packed=False,
  67. distributed=world_size > 1,
  68. )
  69. C = render_colors.shape[0]
  70. assert render_colors.shape == (C, height, width, 4)
  71. assert render_alphas.shape == (C, height, width, 1)
  72. render_colors.sum().backward()
  73. render_rgbs = render_colors[..., 0:3]
  74. render_depths = render_colors[..., 3:4]
  75. render_depths = render_depths / render_depths.max()
  76. # dump batch images
  77. os.makedirs(args.output_dir, exist_ok=True)
  78. canvas = (
  79. torch.cat(
  80. [
  81. render_rgbs.reshape(C * height, width, 3),
  82. render_depths.reshape(C * height, width, 1).expand(-1, -1, 3),
  83. render_alphas.reshape(C * height, width, 1).expand(-1, -1, 3),
  84. ],
  85. dim=1,
  86. )
  87. .detach()
  88. .cpu()
  89. .numpy()
  90. )
  91. imageio.imsave(
  92. f"{args.output_dir}/render_rank{world_rank}.png",
  93. (canvas * 255).astype(np.uint8),
  94. )
  95. else:
  96. means, quats, scales, opacities, sh0, shN = [], [], [], [], [], []
  97. for ckpt_path in args.ckpt:
  98. ckpt = torch.load(ckpt_path, map_location=device)["splats"]
  99. means.append(ckpt["means"])
  100. quats.append(F.normalize(ckpt["quats"], p=2, dim=-1))
  101. scales.append(torch.exp(ckpt["scales"]))
  102. opacities.append(torch.sigmoid(ckpt["opacities"]))
  103. sh0.append(ckpt["sh0"])
  104. shN.append(ckpt["shN"])
  105. means = torch.cat(means, dim=0)
  106. quats = torch.cat(quats, dim=0)
  107. scales = torch.cat(scales, dim=0)
  108. opacities = torch.cat(opacities, dim=0)
  109. sh0 = torch.cat(sh0, dim=0)
  110. shN = torch.cat(shN, dim=0)
  111. colors = torch.cat([sh0, shN], dim=-2)
  112. sh_degree = int(math.sqrt(colors.shape[-2]) - 1)
  113. # # crop
  114. # aabb = torch.tensor((-1.0, -1.0, -1.0, 1.0, 1.0, 0.7), device=device)
  115. # edges = aabb[3:] - aabb[:3]
  116. # sel = ((means >= aabb[:3]) & (means <= aabb[3:])).all(dim=-1)
  117. # sel = torch.where(sel)[0]
  118. # means, quats, scales, colors, opacities = (
  119. # means[sel],
  120. # quats[sel],
  121. # scales[sel],
  122. # colors[sel],
  123. # opacities[sel],
  124. # )
  125. # # repeat the scene into a grid (to mimic a large-scale setting)
  126. # repeats = args.scene_grid
  127. # gridx, gridy = torch.meshgrid(
  128. # [
  129. # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
  130. # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
  131. # ],
  132. # indexing="ij",
  133. # )
  134. # grid = torch.stack([gridx, gridy, torch.zeros_like(gridx)], dim=-1).reshape(
  135. # -1, 3
  136. # )
  137. # means = means[None, :, :] + grid[:, None, :] * edges[None, None, :]
  138. # means = means.reshape(-1, 3)
  139. # quats = quats.repeat(repeats**2, 1)
  140. # scales = scales.repeat(repeats**2, 1)
  141. # colors = colors.repeat(repeats**2, 1, 1)
  142. # opacities = opacities.repeat(repeats**2)
  143. print("Number of Gaussians:", len(means))
  144. # register and open viewer
  145. @torch.no_grad()
  146. def viewer_render_fn(camera_state: nerfview.CameraState, img_wh: Tuple[int, int]):
  147. width, height = img_wh
  148. c2w = camera_state.c2w
  149. K = camera_state.get_K(img_wh)
  150. c2w = torch.from_numpy(c2w).float().to(device)
  151. K = torch.from_numpy(K).float().to(device)
  152. viewmat = c2w.inverse()
  153. if args.backend == "gsplat":
  154. rasterization_fn = rasterization
  155. elif args.backend == "inria":
  156. from gsplat import rasterization_inria_wrapper
  157. rasterization_fn = rasterization_inria_wrapper
  158. else:
  159. raise ValueError
  160. render_colors, render_alphas, meta = rasterization_fn(
  161. means, # [N, 3]
  162. quats, # [N, 4]
  163. scales, # [N, 3]
  164. opacities, # [N]
  165. colors, # [N, S, 3]
  166. viewmat[None], # [1, 4, 4]
  167. K[None], # [1, 3, 3]
  168. width,
  169. height,
  170. sh_degree=sh_degree,
  171. render_mode="RGB",
  172. # this is to speedup large-scale rendering by skipping far-away Gaussians.
  173. radius_clip=3,
  174. )
  175. render_rgbs = render_colors[0, ..., 0:3].cpu().numpy()
  176. return render_rgbs
  177. server = viser.ViserServer(port=args.port, verbose=False)
  178. _ = nerfview.Viewer(
  179. server=server,
  180. render_fn=viewer_render_fn,
  181. mode="rendering",
  182. )
  183. print("Viewer running... Ctrl+C to exit.")
  184. time.sleep(100000)
  185. if __name__ == "__main__":
  186. """
  187. # Use single GPU to view the scene
  188. CUDA_VISIBLE_DEVICES=0 python simple_viewer.py \
  189. --ckpt results/garden/ckpts/ckpt_3499_rank0.pt results/garden/ckpts/ckpt_3499_rank1.pt \
  190. --port 8081
  191. """
  192. parser = argparse.ArgumentParser()
  193. parser.add_argument(
  194. "--output_dir", type=str, default="results/", help="where to dump outputs"
  195. )
  196. parser.add_argument(
  197. "--scene_grid", type=int, default=1, help="repeat the scene into a grid of NxN"
  198. )
  199. parser.add_argument(
  200. "--ckpt", type=str, nargs="+", default=None, help="path to the .pt file"
  201. )
  202. parser.add_argument(
  203. "--port", type=int, default=8080, help="port for the viewer server"
  204. )
  205. parser.add_argument("--backend", type=str, default="gsplat", help="gsplat, inria")
  206. args = parser.parse_args()
  207. assert args.scene_grid % 2 == 1, "scene_grid must be odd"
  208. cli(main, args, verbose=True)