Browse Source

Add the with statement also when opening a raster class in write mode

git-svn-id: https://svn.osgeo.org/grass/grass/trunk@57818 15284696-431f-4ddb-bdfa-cd5b030d7da7
Pietro Zambelli 11 years ago
parent
commit
e46fff0b42
2 changed files with 19 additions and 20 deletions
  1. 12 14
      lib/python/pygrass/raster/__init__.py
  2. 7 6
      lib/python/pygrass/raster/abstract.py

+ 12 - 14
lib/python/pygrass/raster/__init__.py

@@ -12,11 +12,9 @@ import numpy as np
 #
 from grass.script import fatal, warning
 from grass.script import core as grasscore
-#from grass.script import core
-#import grass.lib as grasslib
+
 import grass.lib.gis as libgis
 import grass.lib.raster as libraster
-import grass.lib.segment as libseg
 import grass.lib.rowio as librowio
 
 #
@@ -142,7 +140,7 @@ class RasterRow(RasterAbstractBase):
         """
         libraster.Rast_put_row(self._fd, row.p, self._gtype)
 
-    def open(self, mode='r', mtype='CELL', overwrite=False):
+    def open(self, mode=None, mtype=None, overwrite=None):
         """Open the raster if exist or created a new one.
 
         Parameters
@@ -165,9 +163,9 @@ class RasterRow(RasterAbstractBase):
             * self._gtype
             * self._rows and self._cols
         """
-        self.mode = mode
-        self.mtype = mtype
-        self.overwrite = overwrite
+        self.mode = mode if mode else self.mode
+        self.mtype = mtype if mtype else self.mtype
+        self.overwrite = overwrite if overwrite is not None else self.overwrite
 
         # check if exist and instantiate all the private attributes
         if self.exist():
@@ -206,7 +204,7 @@ class RasterRowIO(RasterRow):
         self.rowio = RowIO()
         super(RasterRowIO, self).__init__(name, *args, **kargs)
 
-    def open(self, mode='r', mtype='CELL', overwrite=False):
+    def open(self, mode=None, mtype=None, overwrite=False):
         super(RasterRowIO, self).open(mode, mtype, overwrite)
         self.rowio.open(self._fd, self._rows, self._cols, self.mtype)
 
@@ -259,7 +257,7 @@ class RasterSegment(RasterAbstractBase):
         return self._mode
 
     def _set_mode(self, mode):
-        if mode.lower() not in ('r', 'w', 'rw'):
+        if mode and mode.lower() not in ('r', 'w', 'rw'):
             str_err = _("Mode type: {0} not supported ('r', 'w','rw')")
             raise ValueError(str_err.format(mode))
         self._mode = mode
@@ -364,7 +362,7 @@ class RasterSegment(RasterAbstractBase):
         self.segment.val.value = val
         self.segment.put(row, col)
 
-    def open(self, mode='r', mtype='DCELL', overwrite=False):
+    def open(self, mode=None, mtype=None, overwrite=None):
         """Open the map, if the map already exist: determine the map type
         and copy the map to the segment files;
         else, open a new segment map.
@@ -384,14 +382,14 @@ class RasterSegment(RasterAbstractBase):
         self._rows = libraster.Rast_window_rows()
         self._cols = libraster.Rast_window_cols()
 
-        self.overwrite = overwrite
-        self.mode = mode
-        self.mtype = mtype
+        self.mode = mode if mode else self.mode
+        self.mtype = mtype if mtype else self.mtype
+        self.overwrite = overwrite if overwrite is not None else self.overwrite
 
         if self.exist():
             self.info = Info(self.name, self.mapset)
             if ((self.mode == "w" or self.mode == "rw") and
-                self.overwrite is False):
+                    self.overwrite is False):
                 str_err = _("Raster map <{0}> already exists. Use overwrite.")
                 fatal(str_err.format(self))
 

+ 7 - 6
lib/python/pygrass/raster/abstract.py

@@ -168,7 +168,8 @@ class RasterAbstractBase(object):
     * Implements color, history and category handling
     * Renaming, deletion, ...
     """
-    def __init__(self, name, mapset=""):
+    def __init__(self, name, mapset="",
+                 mode='r', mtype='FCELL', overwrite=False):
         """The constructor need at least the name of the map
         *optional* field is the `mapset`. ::
 
@@ -199,13 +200,13 @@ class RasterAbstractBase(object):
         self.hist = History()
         if self.exist():
             self.info = Info(self.name, self.mapset)
+        self.mode = mode
+        self.mtype = mtype
+        self.overwrite = overwrite
 
     def __enter__(self):
-        if self.exist():
-            self.open('r')
-            return self
-        else:
-            raise ValueError("Raster not found.")
+        self.open()
+        return self
 
     def __exit__(self, exc_type, exc_value, traceback):
         self.close()