瀏覽代碼

GridModule: use parallel r.patch as alternative backend when overlap=0 to potentially speed up patching tiles (#2249)

Anna Petrasova 3 年之前
父節點
當前提交
1f48487068

+ 44 - 11
python/grass/pygrass/modules/grid/grid.py

@@ -22,7 +22,7 @@ from grass.pygrass.modules import Module
 from grass.pygrass.utils import get_mapset_raster, findmaps
 
 from grass.pygrass.modules.grid.split import split_region_tiles
-from grass.pygrass.modules.grid.patch import rpatch_map
+from grass.pygrass.modules.grid.patch import rpatch_map, rpatch_map_r_patch_backend
 
 
 def select(parms, ptype):
@@ -424,11 +424,17 @@ class GridModule(object):
     :type split: bool
     :param mapset_prefix: if specified created mapsets start with this prefix
     :type mapset_prefix: str
+    :param patch_backend: "r.patch", "RasterRow", or None for for default
+    :type patch_backend: None or str
     :param run_: if False only instantiate the object
     :type run_: bool
     :param args: give all the parameters to the command
     :param kargs: give all the parameters to the command
 
+    When patch_backend is None, the RasterRow method is used for patching the result.
+    When patch_backend is "r.patch", r.patch is used with nprocs=processes.
+    r.patch can only be used when overlap is 0.
+
     >>> grd = GridModule('r.slope.aspect',
     ...                  width=500, height=500, overlap=2,
     ...                  processes=None, split=False,
@@ -458,6 +464,7 @@ class GridModule(object):
         start_col=0,
         out_prefix="",
         mapset_prefix=None,
+        patch_backend=None,
         *args,
         **kargs,
     ):
@@ -474,6 +481,20 @@ class GridModule(object):
         self.out_prefix = out_prefix
         self.log = log
         self.move = move
+        # by default RasterRow is used as previously
+        # if overlap > 0, r.patch won't work properly
+        if not patch_backend:
+            self.patch_backend = "RasterRow"
+        elif patch_backend not in ("r.patch", "RasterRow"):
+            raise RuntimeError(
+                _("Parameter patch_backend must be 'r.patch' or 'RasterRow'")
+            )
+        elif patch_backend == "r.patch" and self.overlap:
+            raise RuntimeError(
+                _("Patching backend 'r.patch' doesn't work for overlap > 0")
+            )
+        else:
+            self.patch_backend = patch_backend
         self.gisrc_src = os.environ["GISRC"]
         self.n_mset, self.gisrc_dst = None, None
         if self.move:
@@ -665,16 +686,28 @@ class GridModule(object):
         for otmap in self.module.outputs:
             otm = self.module.outputs[otmap]
             if otm.typedesc == "raster" and otm.value:
-                rpatch_map(
-                    otm.value,
-                    self.mset.name,
-                    self.msetstr,
-                    bboxes,
-                    self.module.flags.overwrite,
-                    self.start_row,
-                    self.start_col,
-                    self.out_prefix,
-                )
+                if self.patch_backend == "RasterRow":
+                    rpatch_map(
+                        raster=otm.value,
+                        mapset=self.mset.name,
+                        mset_str=self.msetstr,
+                        bbox_list=bboxes,
+                        overwrite=self.module.flags.overwrite,
+                        start_row=self.start_row,
+                        start_col=self.start_col,
+                        prefix=self.out_prefix,
+                    )
+                else:
+                    rpatch_map_r_patch_backend(
+                        raster=otm.value,
+                        mset_str=self.msetstr,
+                        bbox_list=bboxes,
+                        overwrite=self.module.flags.overwrite,
+                        start_row=self.start_row,
+                        start_col=self.start_col,
+                        prefix=self.out_prefix,
+                        processes=self.processes,
+                    )
                 noutputs += 1
         if noutputs < 1:
             msg = "No raster output option defined for <{}>".format(self.module.name)

+ 45 - 0
python/grass/pygrass/modules/grid/patch.py

@@ -15,6 +15,7 @@ from __future__ import (
 from grass.pygrass.gis.region import Region
 from grass.pygrass.raster import RasterRow
 from grass.pygrass.utils import coor2pixel
+from grass.pygrass.modules import Module
 
 
 def get_start_end_index(bbox_list):
@@ -111,3 +112,47 @@ def rpatch_map(
             del rst
 
     rast.close()
+
+
+def rpatch_map_r_patch_backend(
+    raster,
+    mset_str,
+    bbox_list,
+    overwrite=False,
+    start_row=0,
+    start_col=0,
+    prefix="",
+    processes=1,
+):
+    """Patch raster using a r.patch. Only use with overlap=0.
+    Will be faster than rpatch_map, since r.patch is parallelized.
+
+    :param raster: the name of output raster
+    :type raster: str
+    :param mset_str:
+    :type mset_str: str
+    :param bbox_list: a list of BBox object to convert
+    :type bbox_list: list of BBox object
+    :param overwrite: overwrite existing raster
+    :type overwrite: bool
+    :param start_row: the starting row of original raster
+    :type start_row: int
+    :param start_col: the starting column of original raster
+    :type start_col: int
+    :param prefix: the prefix of output raster
+    :type prefix: str
+    :param processes: number of parallel process for r.patch
+    :type processes: int
+    """
+    rasts = []
+    for row, rbbox in enumerate(bbox_list):
+        for col in range(len(rbbox)):
+            mapset = mset_str % (start_row + row, start_col + col)
+            rasts.append(f"{raster}@{mapset}")
+    Module(
+        "r.patch",
+        input=rasts,
+        output=prefix + raster,
+        overwrite=overwrite,
+        nprocs=processes,
+    )

+ 41 - 0
python/grass/pygrass/modules/tests/grass_pygrass_grid_test.py

@@ -171,3 +171,44 @@ def test_cleans(tmp_path, clean):
                     prefixed += int(item.name.startswith(mapset_prefix))
         if not clean:
             assert prefixed, "Not even one prefixed mapset"
+
+
+@pytest.mark.parametrize("patch_backend", [None, "r.patch", "RasterRow"])
+def test_patching_backend(tmp_path, patch_backend):
+    """Check patching backend works"""
+    location = "test"
+    gs.core._create_location_xy(tmp_path, location)  # pylint: disable=protected-access
+    with grass_setup.init(tmp_path / location):
+        gs.run_command("g.region", s=0, n=50, w=0, e=50, res=1)
+
+        points = "points"
+        reference = "reference"
+        gs.run_command("v.random", output=points, npoints=100)
+        gs.run_command(
+            "v.to.rast", input=points, output=reference, type="point", use="cat"
+        )
+
+        def run_grid_module():
+            # modules/shortcuts calls get_commands which requires GISBASE.
+            # pylint: disable=import-outside-toplevel
+            from grass.pygrass.modules.grid import GridModule
+
+            grid = GridModule(
+                "v.to.rast",
+                width=10,
+                height=5,
+                overlap=0,
+                patch_backend=patch_backend,
+                processes=max_processes(),
+                input=points,
+                output="output",
+                type="point",
+                use="cat",
+            )
+            grid.run()
+
+        run_in_subprocess(run_grid_module)
+
+        mean_ref = float(gs.parse_command("r.univar", map=reference, flags="g")["mean"])
+        mean = float(gs.parse_command("r.univar", map="output", flags="g")["mean"])
+        assert abs(mean - mean_ref) < 0.0001