grid.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Thu Mar 28 11:06:00 2013
  4. @author: pietro
  5. """
  6. import os
  7. import multiprocessing as mltp
  8. import subprocess as sub
  9. import shutil as sht
  10. from grass.script.setup import write_gisrc
  11. from grass.pygrass.gis import Mapset, Location, make_mapset
  12. from grass.pygrass.gis.region import Region
  13. from grass.pygrass.modules import Module
  14. from grass.pygrass.functions import get_mapset_raster
  15. from split import split_region_tiles
  16. from patch import rpatch_map
  17. def select(parms, ptype):
  18. """Select only a certain type of parameters. ::
  19. >>> slp = Module('r.slope.aspect',
  20. ... elevation='ele', slope='slp', aspect='asp',
  21. ... run_=False)
  22. >>> for rast in select(slp.outputs, 'raster'):
  23. ... print rast
  24. ...
  25. slp
  26. asp
  27. """
  28. for k in parms:
  29. par = parms[k]
  30. if par.type == ptype or par.typedesc == ptype and par.value:
  31. if par.multiple:
  32. for p in par.value:
  33. yield p
  34. else:
  35. yield par.value
  36. def copy_mapset(mapset, path):
  37. """Copy mapset to another place without copying raster and vector data.
  38. """
  39. per_old = os.path.join(mapset.gisdbase, mapset.location, 'PERMANENT')
  40. per_new = os.path.join(path, 'PERMANENT')
  41. map_old = mapset.path()
  42. map_new = os.path.join(path, mapset.name)
  43. if not os.path.isdir(per_new):
  44. os.makedirs(per_new)
  45. if not os.path.isdir(map_new):
  46. os.mkdir(map_new)
  47. for f in (fi for fi in os.listdir(per_old) if fi.isupper()):
  48. sht.copy(os.path.join(per_old, f), per_new)
  49. for f in (fi for fi in os.listdir(map_old) if fi.isupper()):
  50. sht.copy(os.path.join(map_old, f), map_new)
  51. gisdbase, location = os.path.split(path)
  52. return Mapset(mapset.name, location, gisdbase)
  53. def copy_groups(groups, src, dst, gisrc_dst=None, region=None):
  54. """Copy groupd from one mapset to another, crop the raster to the region.
  55. """
  56. env = os.environ.copy()
  57. # set region
  58. if region:
  59. region.set_current()
  60. # instantiate modules
  61. get_grp = Module('i.group', flags='lg', stdout_=sub.PIPE, run_=False)
  62. set_grp = Module('i.group')
  63. get_grp.run_ = True
  64. # get and set GISRC
  65. gisrc_src = env['GISRC']
  66. gisrc_dst = gisrc_dst if gisrc_dst else write_gisrc(dst.gisdbase,
  67. dst.location,
  68. dst.name)
  69. for grp in groups:
  70. # change gisdbase to src
  71. env['GISRC'] = gisrc_src
  72. get_grp(group=grp, env_=env)
  73. rasts = get_grp.outputs.stdout.split()
  74. # change gisdbase to dst
  75. env['GISRC'] = gisrc_dst
  76. set_grp(group=grp, input=rasts, env_=env)
  77. return gisrc_src, gisrc_dst
  78. def copy_rasters(rasters, src, dst, gisrc_dst=None, region=None):
  79. """Copy rasters from one mapset to another, crop the raster to the region.
  80. """
  81. env = os.environ.copy()
  82. # set region
  83. if region:
  84. region.set_current()
  85. nam = "copy%d__%s" % (id(dst), '%s')
  86. expr = "%s=%s"
  87. # instantiate modules
  88. mpclc = Module('r.mapcalc')
  89. rpck = Module('r.pack')
  90. rupck = Module('r.unpack')
  91. rm = Module('g.remove')
  92. # get and set GISRC
  93. gisrc_src = env['GISRC']
  94. gisrc_dst = gisrc_dst if gisrc_dst else write_gisrc(dst.gisdbase,
  95. dst.location,
  96. dst.name)
  97. pdst = dst.path()
  98. for rast in rasters:
  99. # change gisdbase to src
  100. env['GISRC'] = gisrc_src
  101. name = nam % rast
  102. mpclc(expression=expr % (name, rast), overwrite=True, env_=env)
  103. file_dst = "%s.pack" % os.path.join(pdst, name)
  104. rpck(input=name, output=file_dst, overwrite=True, env_=env)
  105. rm(rast=name, env_=env)
  106. # change gisdbase to dst
  107. env['GISRC'] = gisrc_dst
  108. rupck(input=file_dst, output=rast, overwrite=True, env_=env)
  109. os.remove(file_dst)
  110. return gisrc_src, gisrc_dst
  111. def copy_vectors(vectors, src, dst, gisrc_dst=None, region=None):
  112. """Copy vectors from one mapset to another, crop the raster to the region.
  113. """
  114. env = os.environ.copy()
  115. # set region
  116. if region:
  117. region.set_current()
  118. nam = "copy%d__%s" % (id(dst), '%s')
  119. # instantiate modules
  120. vpck = Module('v.pack')
  121. vupck = Module('v.unpack')
  122. rm = Module('g.remove')
  123. # get and set GISRC
  124. gisrc_src = env['GISRC']
  125. gisrc_dst = gisrc_dst if gisrc_dst else write_gisrc(dst.gisdbase,
  126. dst.location,
  127. dst.name)
  128. pdst = dst.path()
  129. for vect in vectors:
  130. # change gisdbase to src
  131. env['GISRC'] = gisrc_src
  132. name = nam % vect
  133. file_dst = "%s.pack" % os.path.join(pdst, name)
  134. vpck(input=name, output=file_dst, overwrite=True, env_=env)
  135. rm(vect=name, env_=env)
  136. # change gisdbase to dst
  137. env['GISRC'] = gisrc_dst
  138. vupck(input=file_dst, output=vect, overwrite=True, env_=env)
  139. os.remove(file_dst)
  140. return gisrc_src, gisrc_dst
  141. def get_cmd(cmdd):
  142. """Transforma a cmd dictionary to a list of parameters"""
  143. cmd = [cmdd['name'], ]
  144. cmd.extend(("%s=%s" % (k, v) for k, v in cmdd['inputs']
  145. if not isinstance(v, list)))
  146. cmd.extend(("%s=%s" % (k, ','.join(vals if isinstance(vals[0], str)
  147. else map(repr, vals)))
  148. for k, vals in cmdd['inputs']
  149. if isinstance(vals, list)))
  150. cmd.extend(("%s=%s" % (k, v) for k, v in cmdd['outputs']
  151. if not isinstance(v, list)))
  152. cmd.extend(("%s=%s" % (k, ','.join(map(repr, vals)))
  153. for k, vals in cmdd['outputs']
  154. if isinstance(vals, list)))
  155. cmd.extend(("%s" % (flg) for flg in cmdd['flags'] if len(flg) == 1))
  156. cmd.extend(("--%s" % (flg[0]) for flg in cmdd['flags'] if len(flg) > 1))
  157. return cmd
  158. def cmd_exe(args):
  159. """Create a mapset, and execute a cmd inside."""
  160. bbox, mapnames, msetname, cmd, groups = args
  161. mset = Mapset()
  162. try:
  163. make_mapset(msetname)
  164. except:
  165. pass
  166. ms = Mapset(msetname)
  167. ms.visible.extend(mset.visible)
  168. env = os.environ.copy()
  169. env['GISRC'] = write_gisrc(mset.gisdbase, mset.location, msetname)
  170. if mapnames:
  171. inputs = dict(cmd['inputs'])
  172. # reset the inputs to
  173. for key in mapnames:
  174. inputs[key] = mapnames[key]
  175. cmd['inputs'] = inputs.items()
  176. # set the region to the tile
  177. sub.Popen(['g,region', 'rast=%s' % key], env=env).wait()
  178. else:
  179. #reg = Region() nsres=reg.nsres, ewres=reg.ewres,
  180. # set the computational region
  181. lcmd = ['g.region', ]
  182. lcmd.extend(["%s=%s" % (k, v) for k, v in bbox.iteritems()])
  183. sub.Popen(lcmd, env=env).wait()
  184. if groups:
  185. src, dst = copy_groups(groups, mset, ms, env['GISRC'])
  186. # run the grass command
  187. sub.Popen(get_cmd(cmd), env=env).wait()
  188. # remove temp GISRC
  189. os.remove(env['GISRC'])
  190. class GridModule(object):
  191. """Run GRASS raster commands in a multiproccessing mode.
  192. Parameters
  193. -----------
  194. cmd: raster GRASS command
  195. Only command staring with r.* are valid.
  196. width: integer
  197. Width of the tile, in pixel.
  198. height: integer
  199. Height of the tile, in pixel.
  200. overlap: integer
  201. Overlap between tiles, in pixel.
  202. processes: number of threads
  203. Default value is equal to the number of processor available.
  204. split: boolean
  205. If True use r.tile to split all the inputs.
  206. run_: boolean
  207. If False only instantiate the object.
  208. args and kargs: cmd parameters
  209. Give all the parameters to the command.
  210. Examples
  211. --------
  212. ::
  213. >>> grd = GridModule('r.slope.aspect',
  214. ... width=500, height=500, overlap=2,
  215. ... processes=None, split=True,
  216. ... elevation='elevation',
  217. ... slope='slope', aspect='aspect', overwrite=True)
  218. >>> grd.run()
  219. """
  220. def __init__(self, cmd, width=None, height=None, overlap=0, processes=None,
  221. split=False, debug=False, region=None, move=None, log=False,
  222. start_row=0, start_col=0, out_prefix='',
  223. *args, **kargs):
  224. kargs['run_'] = False
  225. self.mset = Mapset()
  226. self.module = Module(cmd, *args, **kargs)
  227. self.width = width
  228. self.height = height
  229. self.overlap = overlap
  230. self.processes = processes
  231. self.region = region if region else Region()
  232. self.start_row = start_row
  233. self.start_col = start_col
  234. self.out_prefix = out_prefix
  235. self.n_mset = None
  236. self.gisrc_src = self.gisrc_dst = None
  237. self.log = log
  238. self.move = move
  239. if self.move:
  240. self.n_mset = copy_mapset(self.mset, self.move)
  241. rasters = select(self.module.inputs, 'raster')
  242. self.gisrc_src, self.gisrc_dst = copy_rasters(rasters,
  243. self.mset,
  244. self.n_mset,
  245. region=self.region)
  246. vectors = select(self.module.inputs, 'vector')
  247. copy_vectors(vectors, self.mset, self.n_mset,
  248. gisrc_dst=self.gisrc_dst, region=self.region)
  249. self.bboxes = split_region_tiles(region=region,
  250. width=width, height=height,
  251. overlap=overlap)
  252. self.msetstr = cmd.replace('.', '') + "_%03d_%03d"
  253. self.inlist = None
  254. if split:
  255. self.split()
  256. self.debug = debug
  257. def clean_location(self, location=None):
  258. """Remove all created mapsets."""
  259. location = location if location else Location()
  260. mapsets = location.mapsets(self.msetstr.split('_')[0] + '_*')
  261. for mset in mapsets:
  262. Mapset(mset).delete()
  263. def split(self):
  264. """Split all the raster inputs using r.tile"""
  265. rtile = Module('r.tile')
  266. inlist = {}
  267. for inm in select(self.module.inputs, 'raster'):
  268. rtile(input=inm.value, output=inm.value,
  269. width=self.width, height=self.height,
  270. overlap=self.overlap)
  271. patt = '%s-*' % inm.value
  272. inlist[inm.value] = sorted(self.mset.glist(type='rast',
  273. pattern=patt))
  274. self.inlist = inlist
  275. def get_works(self):
  276. """Return a list of tuble with the parameters for cmd_exe function"""
  277. works = []
  278. reg = Region()
  279. cmd = self.module.get_dict()
  280. groups = [g for g in select(self.module.inputs, 'group')]
  281. for row, box_row in enumerate(self.bboxes):
  282. for col, box in enumerate(box_row):
  283. inms = None
  284. if self.inlist:
  285. inms = {}
  286. cols = len(box_row)
  287. for key in self.inlist:
  288. indx = row * cols + col
  289. inms[key] = "%s@%s" % (self.inlist[key][indx],
  290. self.mset.name)
  291. # set the computational region, prepare the region parameters
  292. bbox = dict([(k[0], str(v)) for k, v in box.items()[:-2]])
  293. bbox['nsres'] = '%f' % reg.nsres
  294. bbox['ewres'] = '%f' % reg.ewres
  295. works.append((bbox, inms,
  296. self.msetstr % (self.start_row + row,
  297. self.start_col + col),
  298. cmd, groups))
  299. return works
  300. def define_mapset_inputs(self):
  301. for inmap in self.module.inputs:
  302. inm = self.module.inputs[inmap]
  303. # (inm.type == 'raster' or inm.typedesc == 'group') and inm.value:
  304. if inm.type == 'raster' and inm.value:
  305. if '@' not in inm.value:
  306. mset = get_mapset_raster(inm.value)
  307. inm.value = inm.value + '@%s' % mset
  308. def run(self, patch=True, clean=True):
  309. """Run the GRASS command."""
  310. self.module.flags.overwrite = True
  311. self.define_mapset_inputs()
  312. if self.debug:
  313. for wrk in self.get_works():
  314. cmd_exe(wrk)
  315. else:
  316. pool = mltp.Pool(processes=self.processes)
  317. result = pool.map_async(cmd_exe, self.get_works())
  318. result.wait()
  319. if not result.successful():
  320. raise RuntimeError
  321. self.mset.current()
  322. if patch:
  323. self.patch()
  324. if self.n_mset is not None:
  325. # move the raster outputs to the original mapset
  326. routputs = [self.out_prefix + o
  327. for o in select(self.module.outputs, 'raster')]
  328. copy_rasters(routputs, self.n_mset, self.mset,
  329. self.gisrc_src, self.region)
  330. if self.log:
  331. # record in the temp directory
  332. from grass.lib.gis import G_tempfile
  333. tmp, dummy = os.path.split(G_tempfile())
  334. tmpdir = os.path.join(tmp, self.module.name)
  335. for k in self.module.outputs:
  336. par = self.module.outputs[k]
  337. if par.typedesc == 'raster' and par.value:
  338. dirpath = os.path.join(tmpdir, par.name)
  339. if not os.path.isdir(dirpath):
  340. os.makedirs(dirpath)
  341. fp = open(os.path.join(dirpath,
  342. self.out_prefix + par.value), 'w+')
  343. fp.close()
  344. if clean:
  345. self.mset.current()
  346. self.clean_location()
  347. self.rm_tiles()
  348. if self.n_mset:
  349. gisdbase, location = os.path.split(self.move)
  350. self.clean_location(Location(location, gisdbase))
  351. # rm temporary gis_rc
  352. os.remove(self.gisrc_dst)
  353. self.gisrc_dst = None
  354. sht.rmtree(os.path.join(self.move, 'PERMANENT'))
  355. sht.rmtree(os.path.join(self.move, self.mset.name))
  356. def patch(self):
  357. """Patch the final results."""
  358. # patch all the outputs
  359. bboxes = split_region_tiles(width=self.width, height=self.height)
  360. for otmap in self.module.outputs:
  361. otm = self.module.outputs[otmap]
  362. if otm.typedesc == 'raster' and otm.value:
  363. rpatch_map(otm.value,
  364. self.mset.name, self.msetstr, bboxes,
  365. self.module.flags.overwrite,
  366. self.start_row, self.start_col, self.out_prefix)
  367. def rm_tiles(self):
  368. """Remove all the tiles."""
  369. # if split, remove tiles
  370. if self.inlist:
  371. grm = Module('g.remove')
  372. for key in self.inlist:
  373. grm(rast=self.inlist[key])