Browse Source

pygrass.GridModule: make tile size optional (#2294)

When width or height or both are not provided, compute the size based on number of process.
Use 1 column x N rows tile schema to make patching faster.
Anna Petrasova 2 years ago
parent
commit
4def144dbf

+ 27 - 1
python/grass/pygrass/modules/grid/grid.py

@@ -12,6 +12,7 @@ import sys
 import multiprocessing as mltp
 import subprocess as sub
 import shutil as sht
+from math import ceil
 
 from grass.script.setup import write_gisrc
 from grass.script import append_node_pid, legalize_vector_name
@@ -497,6 +498,7 @@ class GridModule(object):
             self.patch_backend = patch_backend
         self.gisrc_src = os.environ["GISRC"]
         self.n_mset, self.gisrc_dst = None, None
+        self.estimate_tile_size()
         if self.move:
             self.n_mset = copy_mapset(self.mset, self.move)
             self.gisrc_dst = write_gisrc(
@@ -514,7 +516,7 @@ class GridModule(object):
             if groups:
                 copy_groups(groups, self.gisrc_src, self.gisrc_dst, region=self.region)
         self.bboxes = split_region_tiles(
-            region=region, width=width, height=height, overlap=overlap
+            region=region, width=self.width, height=self.height, overlap=overlap
         )
         if mapset_prefix:
             self.mapset_prefix = mapset_prefix
@@ -564,6 +566,30 @@ class GridModule(object):
             inlist[inm.value] = sorted(self.mset.glist(type="raster", pattern=patt))
         self.inlist = inlist
 
+    def estimate_tile_size(self):
+        """Estimates tile width and height based on number of processes.
+
+        Keeps width and height if provided by user. If one dimension
+        is provided the other is computed as the minimum number of tiles
+        to keep all requested processes working (initially).
+        If no dimensions are provided, tiling is 1 column x N rows
+        which speeds up patching.
+        """
+        region = Region()
+        if self.width and self.height:
+            return
+        if self.width:
+            n_tiles_x = ceil(region.cols / self.width)
+            n_tiles_y = ceil(self.processes / n_tiles_x)
+            self.height = ceil(region.rows / n_tiles_y)
+        elif self.height:
+            n_tiles_y = ceil(region.rows / self.height)
+            n_tiles_x = ceil(self.processes / n_tiles_y)
+            self.width = ceil(region.cols / n_tiles_x)
+        else:
+            self.width = region.cols
+            self.height = ceil(region.rows / self.processes)
+
     def get_works(self):
         """Return a list of tuble with the parameters for cmd_exe function"""
         works = []

+ 41 - 0
python/grass/pygrass/modules/tests/grass_pygrass_grid_test.py

@@ -212,3 +212,44 @@ def test_patching_backend(tmp_path, patch_backend):
         mean_ref = float(gs.parse_command("r.univar", map=reference, flags="g")["mean"])
         mean = float(gs.parse_command("r.univar", map="output", flags="g")["mean"])
         assert abs(mean - mean_ref) < 0.0001
+
+
+@pytest.mark.parametrize(
+    "width, height, processes",
+    [
+        (None, None, max_processes()),
+        (10, None, max_processes()),
+        (None, 5, max_processes()),
+    ],
+)
+def test_tiling(tmp_path, width, height, processes):
+    """Check auto adjusted tile size based on processes"""
+    location = "test"
+    gs.core._create_location_xy(tmp_path, location)  # pylint: disable=protected-access
+    with grass_setup.init(tmp_path / location):
+        gs.run_command("g.region", s=0, n=50, w=0, e=50, res=1)
+
+        surface = "surface"
+        gs.run_command("r.surf.fractal", output=surface)
+
+        def run_grid_module():
+            # modules/shortcuts calls get_commands which requires GISBASE.
+            # pylint: disable=import-outside-toplevel
+            from grass.pygrass.modules.grid import GridModule
+
+            grid = GridModule(
+                "r.slope.aspect",
+                width=width,
+                height=height,
+                overlap=2,
+                processes=processes,
+                elevation=surface,
+                slope="slope",
+                aspect="aspect",
+            )
+            grid.run()
+
+        run_in_subprocess(run_grid_module)
+
+        info = gs.raster_info("slope")
+        assert info["min"] > 0