iscatt_core.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877
  1. """
  2. @package iscatt.iscatt_core
  3. @brief Non GUI functions.
  4. Classes:
  5. - iscatt_core::Core
  6. - iscatt_core::CatRastUpdater
  7. - iscatt_core::AnalyzedData
  8. - iscatt_core::ScattPlotsCondsData
  9. - iscatt_core::ScattPlotsData
  10. (C) 2013 by the GRASS Development Team
  11. This program is free software under the GNU General Public License
  12. (>=v2). Read the file COPYING that comes with GRASS for details.
  13. @author Stepan Turek <stepan.turek seznam.cz> (mentor: Martin Landa)
  14. """
  15. import os
  16. import six
  17. import numpy as np
  18. # used iclass perimeters algorithm instead of convolve2d
  19. # from matplotlib.path import Path
  20. # from scipy.signal import convolve2d
  21. from math import sqrt, ceil, floor
  22. from core.gcmd import GException, RunCommand
  23. import grass.script as grass
  24. from iscatt.core_c import CreateCatRast, ComputeScatts, UpdateCatRast, Rasterize
  25. MAX_SCATT_SIZE = 4100 * 4100
  26. WARN_SCATT_SIZE = 2000 * 2000
  27. MAX_NCELLS = 65536 * 65536
  28. WARN_NCELLS = 12000 * 12000
  29. class Core:
  30. """Represents scatter plot backend."""
  31. def __init__(self):
  32. self.an_data = AnalyzedData()
  33. self.scatts_dt = ScattPlotsData(self.an_data)
  34. self.scatt_conds_dt = ScattPlotsCondsData(self.an_data)
  35. self.cat_rast_updater = CatRastUpdater(self.scatts_dt, self.an_data, self)
  36. def SetData(self, bands):
  37. """Set bands for analysis."""
  38. ret = self.an_data.Create(bands)
  39. if not ret:
  40. return False
  41. n_bands = len(self.GetBands())
  42. self.scatts_dt.Create(n_bands)
  43. self.scatt_conds_dt.Create(n_bands)
  44. return True
  45. def AddCategory(self, cat_id):
  46. self.scatts_dt.AddCategory(cat_id)
  47. return self.scatt_conds_dt.AddCategory(cat_id)
  48. def DeleteCategory(self, cat_id):
  49. self.scatts_dt.DeleteCategory(cat_id)
  50. self.scatt_conds_dt.DeleteCategory(cat_id)
  51. def CleanUp(self):
  52. self.scatts_dt.CleanUp()
  53. self.scatt_conds_dt.CleanUp()
  54. def GetBands(self):
  55. return self.an_data.GetBands()
  56. def GetScattsData(self):
  57. return self.scatts_dt, self.scatt_conds_dt
  58. def GetRegion(self):
  59. return self.an_data.GetRegion()
  60. def GetCatRast(self, cat_id):
  61. return self.scatts_dt.GetCatRast(cat_id)
  62. def AddScattPlots(self, scatt_ids):
  63. for s_id in scatt_ids:
  64. self.scatts_dt.AddScattPlot(scatt_id=s_id)
  65. cats_ids = self.scatts_dt.GetCategories()
  66. self.ComputeCatsScatts(cats_ids)
  67. def SetEditCatData(self, cat_id, scatt_id, bbox, value):
  68. if cat_id not in self.scatts_dt.GetCategories():
  69. raise GException(_("Select category for editing."))
  70. if self.scatt_conds_dt.AddScattPlot(cat_id, scatt_id) < 0:
  71. return None
  72. arr = self.scatt_conds_dt.GetValuesArr(cat_id, scatt_id)
  73. for k, v in six.iteritems(bbox):
  74. bbox[k] = self._validExtend(v)
  75. arr[bbox["btm_y"] : bbox["up_y"], bbox["btm_x"] : bbox["up_x"]] = value
  76. self.ComputeCatsScatts([cat_id])
  77. return cat_id
  78. def ComputeCatsScatts(self, cats_ids):
  79. requested_dt = {}
  80. requested_dt_conds = {}
  81. for c in cats_ids:
  82. requested_dt_conds[c] = self.scatt_conds_dt.GetCatScatts(c)
  83. requested_dt[c] = self.scatts_dt.GetCatScatts(c)
  84. scatt_conds = self.scatt_conds_dt.GetData(requested_dt_conds)
  85. scatts = self.scatts_dt.GetData(requested_dt)
  86. bands = self.an_data.GetBands()
  87. cats_rasts = self.scatts_dt.GetCatsRasts()
  88. cats_rasts_conds = self.scatts_dt.GetCatsRastsConds()
  89. returncode, scatts = ComputeScatts(
  90. self.an_data.GetRegion(),
  91. scatt_conds,
  92. bands,
  93. len(self.GetBands()),
  94. scatts,
  95. cats_rasts_conds,
  96. cats_rasts,
  97. )
  98. if returncode < 0:
  99. GException(_("Computing of scatter plots failed."))
  100. def CatRastUpdater(self):
  101. return self.cat_rast_updater
  102. def UpdateCategoryWithPolygons(self, cat_id, scatts_pols, value):
  103. if cat_id not in self.scatts_dt.GetCategories():
  104. raise GException(_("Select category for editing."))
  105. for scatt_id, coords in six.iteritems(scatts_pols):
  106. if self.scatt_conds_dt.AddScattPlot(cat_id, scatt_id) < 0:
  107. return False
  108. b1, b2 = idScattToidBands(scatt_id, len(self.an_data.GetBands()))
  109. b = self.scatts_dt.GetBandsInfo(scatt_id)
  110. region = {}
  111. region["s"] = b["b2"]["min"] - 0.5
  112. region["n"] = b["b2"]["max"] + 0.5
  113. region["w"] = b["b1"]["min"] - 0.5
  114. region["e"] = b["b1"]["max"] + 0.5
  115. arr = self.scatt_conds_dt.GetValuesArr(cat_id, scatt_id)
  116. arr = Rasterize(polygon=coords, rast=arr, region=region, value=value)
  117. # previous way of rasterization / used scipy
  118. # raster_pol = RasterizePolygon(coords, b['b1']['range'], b['b1']['min'],
  119. # b['b2']['range'], b['b2']['min'])
  120. # raster_ind = np.where(raster_pol > 0)
  121. # arr = self.scatt_conds_dt.GetValuesArr(cat_id, scatt_id)
  122. # arr[raster_ind] = value
  123. # arr.flush()
  124. self.ComputeCatsScatts([cat_id])
  125. return cat_id
  126. def ExportCatRast(self, cat_id, rast_name):
  127. cat_rast = self.scatts_dt.GetCatRast(cat_id)
  128. if not cat_rast:
  129. return 1
  130. return RunCommand(
  131. "g.copy",
  132. raster=cat_rast + "," + rast_name,
  133. getErrorMsg=True,
  134. overwrite=True,
  135. )
  136. def _validExtend(self, val):
  137. # TODO do it general
  138. if val > 255:
  139. val = 255
  140. elif val < 0:
  141. val = 0
  142. return val
  143. class CatRastUpdater:
  144. """Update backend data structures according to selected areas in mapwindow."""
  145. def __init__(self, scatts_dt, an_data, core):
  146. self.scatts_dt = scatts_dt
  147. self.an_data = an_data # TODO may be confusing
  148. self.core = core
  149. self.vectMap = None
  150. def SetVectMap(self, vectMap):
  151. self.vectMap = vectMap
  152. def SyncWithMap(self):
  153. # TODO possible optimization - bbox only of vertex and its two
  154. # neighbours
  155. region = self.an_data.GetRegion()
  156. bbox = {}
  157. bbox["maxx"] = region["e"]
  158. bbox["minx"] = region["w"]
  159. bbox["maxy"] = region["n"]
  160. bbox["miny"] = region["s"]
  161. updated_cats = []
  162. for cat_id in self.scatts_dt.GetCategories():
  163. if cat_id == 0:
  164. continue
  165. cat = [{1: [cat_id]}]
  166. self._updateCatRast(bbox, cat, updated_cats)
  167. return updated_cats
  168. def EditedFeature(self, new_bboxs, new_areas_cats, old_bboxs, old_areas_cats):
  169. # TODO possible optimization - bbox only of vertex and its two
  170. # neighbours
  171. bboxs = old_bboxs + new_bboxs
  172. areas_cats = old_areas_cats + new_areas_cats
  173. updated_cats = []
  174. for i in range(len(areas_cats)):
  175. self._updateCatRast(bboxs[i], areas_cats[i], updated_cats)
  176. return updated_cats
  177. def _updateCatRast(self, bbox, areas_cats, updated_cats):
  178. rasterized_cats = []
  179. for c in range(len(areas_cats)):
  180. if not areas_cats[c]:
  181. continue
  182. layer = list(areas_cats[c])[0]
  183. cat = areas_cats[c][layer][0]
  184. if cat in rasterized_cats:
  185. continue
  186. rasterized_cats.append(cat)
  187. updated_cats.append(cat)
  188. grass_region = self._create_grass_region_env(bbox)
  189. # TODO hack check if raster exists?
  190. patch_rast = "temp_scatt_patch_%d" % (os.getpid())
  191. self._rasterize(grass_region, layer, cat, patch_rast)
  192. region = self.an_data.GetRegion()
  193. ret = UpdateCatRast(patch_rast, region, self.scatts_dt.GetCatRastCond(cat))
  194. if ret < 0:
  195. GException(_("Patching category raster conditions file failed."))
  196. RunCommand("g.remove", flags="f", type="raster", name=patch_rast)
  197. def _rasterize(self, grass_region, layer, cat, out_rast):
  198. # TODO different thread may be problem when user edits map
  199. environs = os.environ.copy()
  200. environs["GRASS_VECTOR_TEMPORARY"] = "1"
  201. ret, text, msg = RunCommand(
  202. "v.category",
  203. input=self.vectMap,
  204. getErrorMsg=True,
  205. option="report",
  206. read=True,
  207. env=environs,
  208. )
  209. ret, text, msg = RunCommand(
  210. "v.build", map=self.vectMap, getErrorMsg=True, read=True, env=environs
  211. )
  212. if ret != 0:
  213. GException(_("v.build failed:\n%s" % msg))
  214. environs = os.environ.copy()
  215. environs["GRASS_REGION"] = grass_region["GRASS_REGION"]
  216. environs["GRASS_VECTOR_TEMPORARY"] = "1"
  217. ret, text, msg = RunCommand(
  218. "v.to.rast",
  219. input=self.vectMap,
  220. use="cat",
  221. layer=str(layer),
  222. cat=str(cat),
  223. output=out_rast,
  224. getErrorMsg=True,
  225. read=True,
  226. overwrite=True,
  227. env=environs,
  228. )
  229. if ret != 0:
  230. GException(_("v.to.rast failed:\n%s" % msg))
  231. def _create_grass_region_env(self, bbox):
  232. r = self.an_data.GetRegion()
  233. new_r = {}
  234. if bbox["maxy"] <= r["s"]:
  235. return 0
  236. elif bbox["maxy"] >= r["n"]:
  237. new_r["n"] = bbox["maxy"]
  238. else:
  239. new_r["n"] = (
  240. ceil((bbox["maxy"] - r["s"]) / r["nsres"]) * r["nsres"] + r["s"]
  241. )
  242. if bbox["miny"] >= r["n"]:
  243. return 0
  244. elif bbox["miny"] <= r["s"]:
  245. new_r["s"] = bbox["miny"]
  246. else:
  247. new_r["s"] = (
  248. floor((bbox["miny"] - r["s"]) / r["nsres"]) * r["nsres"] + r["s"]
  249. )
  250. if bbox["maxx"] <= r["w"]:
  251. return 0
  252. elif bbox["maxx"] >= r["e"]:
  253. new_r["e"] = bbox["maxx"]
  254. else:
  255. new_r["e"] = (
  256. ceil((bbox["maxx"] - r["w"]) / r["ewres"]) * r["ewres"] + r["w"]
  257. )
  258. if bbox["minx"] >= r["e"]:
  259. return 0
  260. elif bbox["minx"] <= r["w"]:
  261. new_r["w"] = bbox["minx"]
  262. else:
  263. new_r["w"] = (
  264. floor((bbox["minx"] - r["w"]) / r["ewres"]) * r["ewres"] + r["w"]
  265. )
  266. # TODO check regions resolution
  267. new_r["nsres"] = r["nsres"]
  268. new_r["ewres"] = r["ewres"]
  269. return {"GRASS_REGION": grass.region_env(**new_r)}
  270. class AnalyzedData:
  271. """Represents analyzed data (bands, region)."""
  272. def __init__(self):
  273. self.bands = []
  274. self.bands_info = {}
  275. self.region = None
  276. def GetRegion(self):
  277. return self.region
  278. def Create(self, bands):
  279. self.bands = bands[:]
  280. self.region = None
  281. self.region = GetRegion()
  282. if self.region["rows"] * self.region["cols"] > MAX_NCELLS:
  283. GException("too big region")
  284. self.bands_info = {}
  285. for b in self.bands[:]:
  286. i = GetRasterInfo(b)
  287. if i is None:
  288. GException("raster %s is not CELL type" % (b))
  289. self.bands_info[b] = i
  290. # TODO check size of raster
  291. return True
  292. def GetBands(self):
  293. return self.bands
  294. def GetBandInfo(self, band_id):
  295. band = self.bands[band_id]
  296. return self.bands_info[band]
  297. class ScattPlotsCondsData:
  298. """Data structure for selected areas in scatter plot(conditions)."""
  299. def __init__(self, an_data):
  300. self.an_data = an_data
  301. # TODO
  302. self.max_n_cats = 10
  303. self.dtype = "uint8"
  304. self.type = 1
  305. self.CleanUp()
  306. def CleanUp(self):
  307. self.cats = {}
  308. self.n_scatts = -1
  309. self.n_bands = -1
  310. for cat_id in self.cats.keys():
  311. self.DeleteCategory(cat_id)
  312. def Create(self, n_bands):
  313. self.CleanUp()
  314. self.n_scatts = (n_bands - 1) * n_bands / 2
  315. self.n_bands = n_bands
  316. self.AddCategory(cat_id=0)
  317. def AddCategory(self, cat_id):
  318. if cat_id not in self.cats.keys():
  319. self.cats[cat_id] = {}
  320. return cat_id
  321. return -1
  322. def DeleteCategory(self, cat_id):
  323. if cat_id not in self.cats.keys():
  324. return False
  325. for scatt in six.itervalues(self.cats[cat_id]):
  326. grass.try_remove(scatt["np_vals"])
  327. del scatt["np_vals"]
  328. del self.cats[cat_id]
  329. return True
  330. def GetCategories(self):
  331. return self.cats.keys()
  332. def GetCatScatts(self, cat_id):
  333. if cat_id not in self.cats:
  334. return False
  335. return self.cats[cat_id].keys()
  336. def AddScattPlot(self, cat_id, scatt_id):
  337. if cat_id not in self.cats:
  338. return -1
  339. if scatt_id in self.cats[cat_id]:
  340. return 0
  341. b_i = self.GetBandsInfo(scatt_id)
  342. shape = (
  343. b_i["b2"]["max"] - b_i["b2"]["min"] + 1,
  344. b_i["b1"]["max"] - b_i["b1"]["min"] + 1,
  345. )
  346. np_vals = np.memmap(grass.tempfile(), dtype=self.dtype, mode="w+", shape=shape)
  347. self.cats[cat_id][scatt_id] = {"np_vals": np_vals}
  348. return 1
  349. def GetBandsInfo(self, scatt_id):
  350. b1, b2 = idScattToidBands(scatt_id, len(self.an_data.GetBands()))
  351. b1_info = self.an_data.GetBandInfo(b1)
  352. b2_info = self.an_data.GetBandInfo(b2)
  353. bands_info = {"b1": b1_info, "b2": b2_info}
  354. return bands_info
  355. def DeleScattPlot(self, cat_id, scatt_id):
  356. if cat_id not in self.cats:
  357. return False
  358. if scatt_id not in self.cats[cat_id]:
  359. return False
  360. del self.cats[cat_id][scatt_id]
  361. return True
  362. def GetValuesArr(self, cat_id, scatt_id):
  363. if cat_id not in self.cats:
  364. return None
  365. if scatt_id not in self.cats[cat_id]:
  366. return None
  367. return self.cats[cat_id][scatt_id]["np_vals"]
  368. def GetData(self, requested_dt):
  369. cats = {}
  370. for cat_id, scatt_ids in six.iteritems(requested_dt):
  371. if cat_id not in cats:
  372. cats[cat_id] = {}
  373. for scatt_id in scatt_ids:
  374. # if key is missing condition is always True (full scatter plor
  375. # is computed)
  376. if scatt_id in self.cats[cat_id]:
  377. cats[cat_id][scatt_id] = {
  378. "np_vals": self.cats[cat_id][scatt_id]["np_vals"],
  379. "bands_info": self.GetBandsInfo(scatt_id),
  380. }
  381. return cats
  382. def SetData(self, cats):
  383. for cat_id, scatt_ids in six.iteritems(cats):
  384. for scatt_id in scatt_ids:
  385. # if key is missing condition is always True (full scatter plor
  386. # is computed)
  387. if scatt_id in self.cats[cat_id]:
  388. self.cats[cat_id][scatt_id]["np_vals"] = cats[cat_id][scatt_id][
  389. "np_vals"
  390. ]
  391. def GetScatt(self, scatt_id, cats_ids=None):
  392. scatts = {}
  393. for cat_id in six.iterkeys(self.cats):
  394. if cats_ids and cat_id not in cats_ids:
  395. continue
  396. if scatt_id not in self.cats[cat_id]:
  397. continue
  398. scatts[cat_id] = {
  399. "np_vals": self.cats[cat_id][scatt_id]["np_vals"],
  400. "bands_info": self.GetBandsInfo(scatt_id),
  401. }
  402. return scatts
  403. class ScattPlotsData(ScattPlotsCondsData):
  404. """Data structure for computed points (classes) in scatter plots.\
  405. """
  406. def __init__(self, an_data):
  407. self.cats_rasts = {}
  408. self.cats_rasts_conds = {}
  409. self.scatts_ids = []
  410. ScattPlotsCondsData.__init__(self, an_data)
  411. self.dtype = "uint32"
  412. # TODO
  413. self.type = 0
  414. def AddCategory(self, cat_id):
  415. cat_id = ScattPlotsCondsData.AddCategory(self, cat_id)
  416. if cat_id < 0:
  417. return cat_id
  418. for scatt_id in self.scatts_ids:
  419. ScattPlotsCondsData.AddScattPlot(self, cat_id, scatt_id)
  420. if cat_id == 0:
  421. self.cats_rasts_conds[cat_id] = None
  422. self.cats_rasts[cat_id] = None
  423. else:
  424. self.cats_rasts_conds[cat_id] = grass.tempfile()
  425. self.cats_rasts[cat_id] = "temp_cat_rast_%d_%d" % (cat_id, os.getpid())
  426. region = self.an_data.GetRegion()
  427. CreateCatRast(region, self.cats_rasts_conds[cat_id])
  428. return cat_id
  429. def DeleteCategory(self, cat_id):
  430. ScattPlotsCondsData.DeleteCategory(self, cat_id)
  431. grass.try_remove(self.cats_rasts_conds[cat_id])
  432. del self.cats_rasts_conds[cat_id]
  433. RunCommand("g.remove", flags="f", type="raster", name=self.cats_rasts[cat_id])
  434. del self.cats_rasts[cat_id]
  435. return True
  436. def AddScattPlot(self, scatt_id):
  437. if scatt_id in self.scatts_ids:
  438. return False
  439. self.scatts_ids.append(scatt_id)
  440. for cat_id in six.iterkeys(self.cats):
  441. ScattPlotsCondsData.AddScattPlot(self, cat_id, scatt_id)
  442. self.cats[cat_id][scatt_id]["ellipse"] = None
  443. return True
  444. def DeleteScatterPlot(self, scatt_id):
  445. if scatt_id not in self.scatts_ids:
  446. return False
  447. self.scatts_ids.remove(scatt_id)
  448. for cat_id in six.iterkeys(self.cats):
  449. ScattPlotsCondsData.DeleteScattPlot(self, cat_id, scatt_id)
  450. return True
  451. def GetEllipses(self, scatt_id, styles):
  452. if scatt_id not in self.scatts_ids:
  453. return False
  454. scatts = {}
  455. for cat_id in six.iterkeys(self.cats):
  456. if cat_id == 0:
  457. continue
  458. nstd = styles[cat_id]["nstd"]
  459. scatts[cat_id] = self._getEllipse(cat_id, scatt_id, nstd)
  460. return scatts
  461. def _getEllipse(self, cat_id, scatt_id, nstd):
  462. # Joe Kington
  463. # http://stackoverflow.com/questions/12301071/multidimensional-confidence-intervals
  464. data = np.copy(self.cats[cat_id][scatt_id]["np_vals"])
  465. b = self.GetBandsInfo(scatt_id)
  466. sel_pts = np.where(data > 0)
  467. x = sel_pts[1]
  468. y = sel_pts[0]
  469. flatten_data = data.reshape([-1])
  470. flatten_sel_pts = np.nonzero(flatten_data)
  471. weights = flatten_data[flatten_sel_pts]
  472. if len(weights) == 0:
  473. return None
  474. x_avg = np.average(x, weights=weights)
  475. y_avg = np.average(y, weights=weights)
  476. pos = np.array([x_avg + b["b1"]["min"], y_avg + b["b2"]["min"]])
  477. x_diff = x - x_avg
  478. y_diff = y - y_avg
  479. x_diff = x - x_avg
  480. y_diff = y - y_avg
  481. diffs = x_diff * y_diff.T
  482. cov = np.dot(diffs, weights) / (np.sum(weights) - 1)
  483. diffs = x_diff * x_diff.T
  484. var_x = np.dot(diffs, weights) / (np.sum(weights) - 1)
  485. diffs = y_diff * y_diff.T
  486. var_y = np.dot(diffs, weights) / (np.sum(weights) - 1)
  487. cov = np.array([[var_x, cov], [cov, var_y]])
  488. def eigsorted(cov):
  489. vals, vecs = np.linalg.eigh(cov)
  490. order = vals.argsort()[::-1]
  491. return vals[order], vecs[:, order]
  492. vals, vecs = eigsorted(cov)
  493. theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
  494. # Width and height are "full" widths, not radius
  495. width, height = 2 * nstd * np.sqrt(vals)
  496. ellipse = {"pos": pos, "width": width, "height": height, "theta": theta}
  497. del data
  498. del flatten_data
  499. del flatten_sel_pts
  500. del weights
  501. del sel_pts
  502. return ellipse
  503. def CleanUp(self):
  504. ScattPlotsCondsData.CleanUp(self)
  505. for tmp in six.itervalues(self.cats_rasts_conds):
  506. grass.try_remove(tmp)
  507. for tmp in six.itervalues(self.cats_rasts):
  508. RunCommand("g.remove", flags="f", type="raster", name=tmp, getErrorMsg=True)
  509. self.cats_rasts = {}
  510. self.cats_rasts_conds = {}
  511. def GetCatRast(self, cat_id):
  512. if cat_id in self.cats_rasts:
  513. return self.cats_rasts[cat_id]
  514. return None
  515. def GetCatRastCond(self, cat_id):
  516. return self.cats_rasts_conds[cat_id]
  517. def GetCatsRastsConds(self):
  518. max_cat_id = max(self.cats_rasts_conds.keys())
  519. cats_rasts_conds = [""] * (max_cat_id + 1)
  520. for i_cat_id, i_rast in six.iteritems(self.cats_rasts_conds):
  521. cats_rasts_conds[i_cat_id] = i_rast
  522. return cats_rasts_conds
  523. def GetCatsRasts(self):
  524. max_cat_id = max(self.cats_rasts.keys())
  525. cats_rasts = [""] * (max_cat_id + 1)
  526. for i_cat_id, i_rast in six.iteritems(self.cats_rasts):
  527. cats_rasts[i_cat_id] = i_rast
  528. return cats_rasts
  529. # not used, using iclass_perimeter algorithm instead of scipy convolve2d
  530. """
  531. def RasterizePolygon(pol, height, min_h, width, min_w):
  532. # Joe Kington
  533. # http://stackoverflow.com/questions/3654289/scipy-create-2d-polygon-mask
  534. #poly_verts = [(1,1), (1,4), (4,4),(4,1), (1,1)]
  535. nx = width
  536. ny = height
  537. x, y = np.meshgrid(np.arange(-0.5 + min_w, nx + 0.5 + min_w, dtype=float),
  538. np.arange(-0.5 + min_h, ny + 0.5 + min_h, dtype=float))
  539. x, y = x.flatten(), y.flatten()
  540. points = np.vstack((x,y)).T
  541. p = Path(pol)
  542. grid = p.contains_points(points)
  543. grid = grid.reshape((ny + 1, nx + 1))
  544. raster = np.zeros((height, width), dtype=np.uint8)#TODO bool
  545. #TODO shift by 0.5
  546. B = np.ones((2,2))/4
  547. raster = convolve2d(grid, B, 'valid')
  548. return raster
  549. """
  550. def idScattToidBands(scatt_id, n_bands):
  551. """Get bands ids from scatter plot id."""
  552. n_b1 = n_bands - 1
  553. band_1 = (int)(
  554. (2 * n_b1 + 1 - sqrt(((2 * n_b1 + 1) * (2 * n_b1 + 1) - 8 * scatt_id))) / 2
  555. )
  556. band_2 = int(
  557. scatt_id - (band_1 * (2 * n_b1 + 1) - band_1 * band_1) / 2 + band_1 + 1
  558. )
  559. return band_1, band_2
  560. def idBandsToidScatt(band_1_id, band_2_id, n_bands):
  561. """Get scatter plot id from band ids."""
  562. if band_2_id < band_1_id:
  563. tmp = band_1_id
  564. band_1_id = band_2_id
  565. band_2_id = tmp
  566. n_b1 = n_bands - 1
  567. scatt_id = int(
  568. (band_1_id * (2 * n_b1 + 1) - band_1_id * band_1_id) / 2
  569. + band_2_id
  570. - band_1_id
  571. - 1
  572. )
  573. return scatt_id
  574. def GetRegion():
  575. ret, region, msg = RunCommand("g.region", flags="gp", getErrorMsg=True, read=True)
  576. if ret != 0:
  577. raise GException("g.region failed:\n%s" % msg)
  578. return _parseRegion(region)
  579. def _parseRegion(region_str):
  580. region = {}
  581. region_str = region_str.splitlines()
  582. for param in region_str:
  583. k, v = param.split("=")
  584. if k in ["rows", "cols", "cells"]:
  585. v = int(v)
  586. else:
  587. v = float(v)
  588. region[k] = v
  589. return region
  590. def GetRasterInfo(rast):
  591. ret, out, msg = RunCommand(
  592. "r.info", map=rast, flags="rg", getErrorMsg=True, read=True
  593. )
  594. if ret != 0:
  595. raise GException("r.info failed:\n%s" % msg)
  596. out = out.splitlines()
  597. raster_info = {}
  598. for b in out:
  599. if not b.strip():
  600. continue
  601. k, v = b.split("=")
  602. if k == "datatype":
  603. if v != "CELL":
  604. return None
  605. pass
  606. elif k in ["rows", "cols", "cells", "min", "max"]:
  607. v = int(v)
  608. else:
  609. v = float(v)
  610. raster_info[k] = v
  611. raster_info["range"] = raster_info["max"] - raster_info["min"] + 1
  612. return raster_info