grid.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Thu Mar 28 11:06:00 2013
  4. @author: pietro
  5. """
  6. import os
  7. import multiprocessing as mltp
  8. import subprocess as sub
  9. import shutil as sht
  10. from grass.script.setup import write_gisrc
  11. from grass.pygrass.gis import Mapset, Location, make_mapset
  12. from grass.pygrass.gis.region import Region
  13. from grass.pygrass.modules import Module
  14. from grass.pygrass.functions import get_mapset_raster
  15. from split import split_region_tiles
  16. from patch import patch_map
  17. _GREG = Module('g.region')
  18. def select(parms, ptype):
  19. """Select only a certain type of parameters. ::
  20. >>> slp = Module('r.slope.aspect',
  21. ... elevation='ele', slope='slp', aspect='asp',
  22. ... run_=False)
  23. >>> for rast in select(slp.outputs, 'raster'):
  24. ... print rast
  25. ...
  26. slp
  27. asp
  28. """
  29. for k in parms:
  30. par = parms[k]
  31. if par.type == ptype or par.typedesc == ptype and par.value:
  32. if par.multiple:
  33. for p in par.value:
  34. yield p
  35. else:
  36. yield par.value
  37. def copy_mapset(mapset, path):
  38. """Copy mapset to another place without copying raster and vector data.
  39. """
  40. per_old = os.path.join(mapset.gisdbase, mapset.location, 'PERMANENT')
  41. per_new = os.path.join(path, 'PERMANENT')
  42. map_old = mapset.path()
  43. map_new = os.path.join(path, mapset.name)
  44. if not os.path.isdir(per_new):
  45. os.makedirs(per_new)
  46. if not os.path.isdir(map_new):
  47. os.mkdir(map_new)
  48. for f in (fi for fi in os.listdir(per_old) if fi.isupper()):
  49. sht.copy(os.path.join(per_old, f), per_new)
  50. for f in (fi for fi in os.listdir(map_old) if fi.isupper()):
  51. sht.copy(os.path.join(map_old, f), map_new)
  52. gisdbase, location = os.path.split(path)
  53. return Mapset(mapset.name, location, gisdbase)
  54. def copy_raster(rasters, src, dst, region=None):
  55. """Copy raster from one mapset to another, crop the raster to the region.
  56. """
  57. # set region
  58. if region:
  59. region.set_current()
  60. nam = "copy%d__%s" % (id(dst), '%s')
  61. expr = "%s=%s"
  62. # instantiate modules
  63. mpclc = Module('r.mapcalc')
  64. rpck = Module('r.pack')
  65. rupck = Module('r.unpack')
  66. rm = Module('g.remove')
  67. # get and set GISRC
  68. gisrc_src = os.environ['GISRC']
  69. gisrc_dst = write_gisrc(dst.gisdbase, dst.location, dst.name)
  70. pdst = dst.path()
  71. for rast in rasters:
  72. # change gisdbase to src
  73. os.environ['GISRC'] = gisrc_src
  74. src.current()
  75. name = nam % rast
  76. mpclc(expression=expr % (name, rast), overwrite=True)
  77. file_dst = "%s.pack" % os.path.join(pdst, name)
  78. rpck(input=name, output=file_dst, overwrite=True)
  79. rm(rast=name)
  80. # change gisdbase to dst
  81. os.environ['GISRC'] = gisrc_dst
  82. dst.current()
  83. rupck(input=file_dst, output=rast, overwrite=True)
  84. os.remove(file_dst)
  85. return gisrc_src, gisrc_dst
  86. def get_cmd(cmdd):
  87. """Transforma a cmd dictionary to a list of parameters"""
  88. cmd = [cmdd['name'], ]
  89. cmd.extend(("%s=%s" % (k, v) for k, v in cmdd['inputs']
  90. if not isinstance(v, list)))
  91. cmd.extend(("%s=%s" % (k, ','.join(vals if isinstance(vals[0], str)
  92. else map(repr, vals)))
  93. for k, vals in cmdd['inputs']
  94. if isinstance(vals, list)))
  95. cmd.extend(("%s=%s" % (k, v) for k, v in cmdd['outputs']
  96. if not isinstance(v, list)))
  97. cmd.extend(("%s=%s" % (k, ','.join(map(repr, vals)))
  98. for k, vals in cmdd['outputs']
  99. if isinstance(vals, list)))
  100. cmd.extend(("%s" % (flg) for flg in cmdd['flags'] if len(flg) == 1))
  101. cmd.extend(("--%s" % (flg[0]) for flg in cmdd['flags'] if len(flg) > 1))
  102. return cmd
  103. def cmd_exe(args):
  104. """Create a mapset, and execute a cmd inside."""
  105. bbox, mapnames, msetname, cmd = args
  106. mset = Mapset()
  107. try:
  108. make_mapset(msetname)
  109. except:
  110. pass
  111. ms = Mapset(msetname)
  112. ms.visible.extend(mset.visible)
  113. env = os.environ.copy()
  114. env['GISRC'] = write_gisrc(mset.gisdbase, mset.location, msetname)
  115. if mapnames:
  116. inputs = dict(cmd['inputs'])
  117. # reset the inputs to
  118. for key in mapnames:
  119. inputs[key] = mapnames[key]
  120. cmd['inputs'] = inputs.items()
  121. # set the region to the tile
  122. _GREG(env_=env, rast=key)
  123. else:
  124. #reg = Region() nsres=reg.nsres, ewres=reg.ewres,
  125. # set the computational region
  126. _GREG(env_=env, **bbox)
  127. # run the grass command
  128. #import ipdb; ipdb.set_trace()
  129. sub.Popen(get_cmd(cmd), env=env).wait()
  130. class GridModule(object):
  131. """Run GRASS raster commands in a multiproccessing mode.
  132. Parameters
  133. -----------
  134. cmd: raster GRASS command
  135. Only command staring with r.* are valid.
  136. width: integer
  137. Width of the tile, in pixel.
  138. height: integer
  139. Height of the tile, in pixel.
  140. overlap: integer
  141. Overlap between tiles, in pixel.
  142. nthreads: number of threads
  143. Default value is equal to the number of processor available.
  144. split: boolean
  145. If True use r.tile to split all the inputs.
  146. run_: boolean
  147. If False only instantiate the object.
  148. args and kargs: cmd parameters
  149. Give all the parameters to the command.
  150. Examples
  151. --------
  152. ::
  153. >>> grd = GridModule('r.slope.aspect',
  154. ... width=500, height=500, overlap=2,
  155. ... processes=None, split=True,
  156. ... elevation='elevation',
  157. ... slope='slope', aspect='aspect', overwrite=True)
  158. >>> grd.run()
  159. """
  160. def __init__(self, cmd, width=None, height=None, overlap=0, processes=None,
  161. split=False, debug=False, region=None, move=None,
  162. start_row=0, start_col=0, out_prefix='',
  163. *args, **kargs):
  164. kargs['run_'] = False
  165. self.mset = Mapset()
  166. self.module = Module(cmd, *args, **kargs)
  167. self.width = width
  168. self.height = height
  169. self.overlap = overlap
  170. self.processes = processes
  171. self.region = region if region else Region()
  172. self.start_row = start_row
  173. self.start_col = start_col
  174. self.out_prefix = out_prefix
  175. self.n_mset = None
  176. self.gisrc_src = self.gisrc_dst = None
  177. if move:
  178. self.n_mset = copy_mapset(self.mset, move)
  179. rasters = select(self.module.inputs, 'raster')
  180. self.gisrc_src, self.gisrc_dst = copy_raster(rasters,
  181. self.mset,
  182. self.n_mset,
  183. region=self.region)
  184. self.bboxes = split_region_tiles(region=region,
  185. width=width, height=height,
  186. overlap=overlap)
  187. self.msetstr = cmd.replace('.', '') + "_%03d_%03d"
  188. self.inlist = None
  189. if split:
  190. self.split()
  191. self.debug = debug
  192. def clean_location(self):
  193. """Remove all created mapsets."""
  194. mapsets = Location().mapsets(self.msetstr.split('_')[0] + '_*')
  195. for mset in mapsets:
  196. Mapset(mset).delete()
  197. def split(self):
  198. """Split all the raster inputs using r.tile"""
  199. rtile = Module('r.tile')
  200. inlist = {}
  201. for inm in select(self.module.inputs, 'raster'):
  202. rtile(input=inm.value, output=inm.value,
  203. width=self.width, height=self.height,
  204. overlap=self.overlap)
  205. patt = '%s-*' % inm.value
  206. inlist[inm.value] = sorted(self.mset.glist(type='rast',
  207. pattern=patt))
  208. self.inlist = inlist
  209. def get_works(self):
  210. """Return a list of tuble with the parameters for cmd_exe function"""
  211. works = []
  212. reg = Region()
  213. cmd = self.module.get_dict()
  214. for row, box_row in enumerate(self.bboxes):
  215. for col, box in enumerate(box_row):
  216. inms = None
  217. if self.inlist:
  218. inms = {}
  219. cols = len(box_row)
  220. for key in self.inlist:
  221. indx = row * cols + col
  222. inms[key] = "%s@%s" % (self.inlist[key][indx],
  223. self.mset.name)
  224. # set the computational region, prepare the region parameters
  225. bbox = dict([(k[0], str(v)) for k, v in box.items()[:-2]])
  226. bbox['nsres'] = '%f' % reg.nsres
  227. bbox['ewres'] = '%f' % reg.ewres
  228. works.append((bbox, inms,
  229. self.msetstr % (self.start_row + row,
  230. self.start_col + col),
  231. cmd))
  232. return works
  233. def define_mapset_inputs(self):
  234. for inmap in self.module.inputs:
  235. inm = self.module.inputs[inmap]
  236. # (inm.type == 'raster' or inm.typedesc == 'group') and inm.value:
  237. if inm.type == 'raster' and inm.value:
  238. if '@' not in inm.value:
  239. mset = get_mapset_raster(inm.value)
  240. inm.value = inm.value + '@%s' % mset
  241. def run(self, patch=True, clean=True):
  242. """Run the GRASS command."""
  243. self.module.flags.overwrite = True
  244. self.define_mapset_inputs()
  245. if self.debug:
  246. for wrk in self.get_works():
  247. cmd_exe(wrk)
  248. else:
  249. pool = mltp.Pool(processes=self.processes)
  250. result = pool.map_async(cmd_exe, self.get_works())
  251. result.wait()
  252. if not result.successful():
  253. raise RuntimeError
  254. if patch:
  255. self.patch()
  256. if clean:
  257. self.clean_location()
  258. self.rm_tiles()
  259. def patch(self):
  260. """Patch the final results."""
  261. # patch all the outputs
  262. for otmap in self.module.outputs:
  263. otm = self.module.outputs[otmap]
  264. if otm.typedesc == 'raster' and otm.value:
  265. patch_map(self.out_prefix + otm.value,
  266. self.mset.name, self.msetstr,
  267. split_region_tiles(width=self.width,
  268. height=self.height),
  269. self.module.flags.overwrite,
  270. self.start_row, self.start_col)
  271. def rm_tiles(self):
  272. """Remove all the tiles."""
  273. # if split, remove tiles
  274. if self.inlist:
  275. grm = Module('g.remove')
  276. for key in self.inlist:
  277. grm(rast=self.inlist[key])