grid.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Thu Mar 28 11:06:00 2013
  4. @author: pietro
  5. """
  6. import os
  7. import multiprocessing as mltp
  8. import subprocess as sub
  9. from grass.script.setup import write_gisrc
  10. from grass.pygrass.gis import Mapset, Location, make_mapset
  11. from grass.pygrass.modules import Module
  12. from grass.pygrass.functions import get_mapset_raster
  13. from split import split_region_tiles
  14. from patch import patch_map
  15. _GREG = Module('g.region')
  16. def get_cmd(cmdd):
  17. """Transforma a cmd dictionary to a list of parameters"""
  18. cmd = [cmdd['name'], ]
  19. cmd.extend(("%s=%s" % (k, v) for k, v in cmdd['inputs']))
  20. cmd.extend(("%s=%s" % (k, v) for k, v in cmdd['outputs']))
  21. cmd.extend(("%s" % (flg) for flg in cmdd['flags'] if len(flg) == 1))
  22. cmd.extend(("--%s" % (flg[0]) for flg in cmdd['flags'] if len(flg) > 1))
  23. return cmd
  24. def cmd_exe((bbox, mapnames, msetname, cmd)):
  25. """Create a mapset, and execute a cmd inside."""
  26. mset = Mapset()
  27. try:
  28. make_mapset(msetname)
  29. except:
  30. pass
  31. env = os.environ.copy()
  32. env['GISRC'] = write_gisrc(mset.gisdbase, mset.location, msetname)
  33. if mapnames:
  34. inputs = dict(cmd['inputs'])
  35. # reset the inputs to
  36. for key in mapnames:
  37. inputs[key] = mapnames[key]
  38. cmd['inputs'] = inputs.items()
  39. # set the region to the tile
  40. _GREG(env_=env, rast=key)
  41. else:
  42. # set the computational region
  43. _GREG(env_=env, **bbox)
  44. # run the grass command
  45. sub.Popen(get_cmd(cmd), env=env).wait()
  46. class GridModule(object):
  47. """Run GRASS raster commands in a multiproccessing mode.
  48. Parameters
  49. -----------
  50. cmd: raster GRASS command
  51. Only command staring with r.* are valid.
  52. width: integer
  53. Width of the tile, in pixel.
  54. height: integer
  55. Height of the tile, in pixel.
  56. overlap: integer
  57. Overlap between tiles, in pixel.
  58. nthreads: number of threads
  59. Default value is equal to the number of processor available.
  60. split: boolean
  61. If True use r.tile to split all the inputs.
  62. run_: boolean
  63. If False only instantiate the object.
  64. args and kargs: cmd parameters
  65. Give all the parameters to the command.
  66. Examples
  67. --------
  68. ::
  69. >>> grd = GridModule('r.slope.aspect',
  70. ... width=500, height=500, overlap=2,
  71. ... processes=None, split=True,
  72. ... elevation='elevation',
  73. ... slope='slope', aspect='aspect', overwrite=True)
  74. >>> grd.run()
  75. """
  76. def __init__(self, cmd, width=None, height=None, overlap=0, processes=None,
  77. split=False, debug=False, *args, **kargs):
  78. kargs['run_'] = False
  79. self.mset = Mapset()
  80. self.module = Module(cmd, *args, **kargs)
  81. self.width = width
  82. self.height = height
  83. self.overlap = overlap
  84. self.processes = processes
  85. self.bboxes = split_region_tiles(width=width, height=height,
  86. overlap=overlap)
  87. self.msetstr = cmd.replace('.', '') + "_%03d_%03d"
  88. self.inlist = None
  89. if split:
  90. self.split()
  91. self.debug = debug
  92. def clean_location(self):
  93. """Remove all created mapsets."""
  94. mapsets = Location().mapsets(self.msetstr.split('_')[0] + '_*')
  95. for mset in mapsets:
  96. Mapset(mset).delete()
  97. def split(self):
  98. """Split all the raster inputs using r.tile"""
  99. rtile = Module('r.tile')
  100. inlist = {}
  101. #import pdb; pdb.set_trace()
  102. for inmap in self.module.inputs:
  103. inm = self.module.inputs[inmap]
  104. if inm.type == 'raster' and inm.value:
  105. rtile(input=inm.value, output=inm.value,
  106. width=self.width, height=self.height,
  107. overlap=self.overlap)
  108. patt = '%s-*' % inm.value
  109. inlist[inm.value] = sorted(self.mset.glist(type='rast',
  110. pattern=patt))
  111. self.inlist = inlist
  112. def get_works(self):
  113. """Return a list of tuble with the parameters for cmd_exe function"""
  114. works = []
  115. cmd = self.module.get_dict()
  116. for row, box_row in enumerate(self.bboxes):
  117. for col, box in enumerate(box_row):
  118. inms = None
  119. if self.inlist:
  120. inms = {}
  121. cols = len(box_row)
  122. for key in self.inlist:
  123. indx = row * cols + col
  124. inms[key] = "%s@%s" % (self.inlist[key][indx],
  125. self.mset.name)
  126. # set the computational region, prepare the region parameters
  127. bbox = dict([(k[0], str(v)) for k, v in box.items()[:-2]])
  128. works.append((bbox, inms, self.msetstr % (row, col), cmd))
  129. return works
  130. def define_mapset_inputs(self):
  131. for inmap in self.module.inputs:
  132. inm = self.module.inputs[inmap]
  133. if inm.type == 'raster' and inm.value:
  134. if '@' not in inm.value:
  135. mset = get_mapset_raster(inm.value)
  136. inm.value = inm.value + '@%s' % mset
  137. def run(self, patch=True, clean=True):
  138. """Run the GRASS command."""
  139. self.module.flags.overwrite = True
  140. self.define_mapset_inputs()
  141. if self.debug:
  142. for wrk in self.get_works():
  143. cmd_exe(wrk)
  144. else:
  145. pool = mltp.Pool(processes=self.processes)
  146. result = pool.map_async(cmd_exe, self.get_works())
  147. result.wait()
  148. if not result.successful():
  149. raise RuntimeError
  150. if patch:
  151. self.patch()
  152. if clean:
  153. self.clean_location()
  154. self.rm_tiles()
  155. def patch(self):
  156. """Patch the final results."""
  157. # patch all the outputs
  158. for otmap in self.module.outputs:
  159. otm = self.module.outputs[otmap]
  160. if otm.type == 'raster' and otm.value:
  161. patch_map(otm.value, self.mset.name, self.msetstr,
  162. split_region_tiles(width=self.width,
  163. height=self.height),
  164. self.module.flags.overwrite)
  165. def rm_tiles(self):
  166. """Remove all the tiles."""
  167. # if split, remove tiles
  168. if self.inlist:
  169. grm = Module('g.remove')
  170. for key in self.inlist:
  171. grm(rast=self.inlist[key])